Spaces:
Runtime error
Runtime error
初回コミットに基づくファイルの追加
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +43 -0
- LICENSE +201 -0
- README_ditto-talkinghead.md +232 -0
- core/atomic_components/audio2motion.py +196 -0
- core/atomic_components/avatar_registrar.py +102 -0
- core/atomic_components/cfg.py +111 -0
- core/atomic_components/condition_handler.py +168 -0
- core/atomic_components/decode_f3d.py +22 -0
- core/atomic_components/loader.py +133 -0
- core/atomic_components/motion_stitch.py +491 -0
- core/atomic_components/putback.py +60 -0
- core/atomic_components/source2info.py +155 -0
- core/atomic_components/warp_f3d.py +22 -0
- core/atomic_components/wav2feat.py +110 -0
- core/atomic_components/writer.py +36 -0
- core/aux_models/blaze_face.py +351 -0
- core/aux_models/face_mesh.py +101 -0
- core/aux_models/hubert_stream.py +29 -0
- core/aux_models/insightface_det.py +245 -0
- core/aux_models/insightface_landmark106.py +100 -0
- core/aux_models/landmark203.py +58 -0
- core/aux_models/mediapipe_landmark478.py +118 -0
- core/aux_models/modules/__init__.py +5 -0
- core/aux_models/modules/hubert_stream.py +21 -0
- core/aux_models/modules/landmark106.py +83 -0
- core/aux_models/modules/landmark203.py +42 -0
- core/aux_models/modules/landmark478.py +35 -0
- core/aux_models/modules/retinaface.py +215 -0
- core/models/appearance_extractor.py +29 -0
- core/models/decoder.py +30 -0
- core/models/lmdm.py +140 -0
- core/models/modules/LMDM.py +154 -0
- core/models/modules/__init__.py +6 -0
- core/models/modules/appearance_feature_extractor.py +74 -0
- core/models/modules/convnextv2.py +150 -0
- core/models/modules/dense_motion.py +104 -0
- core/models/modules/lmdm_modules/model.py +398 -0
- core/models/modules/lmdm_modules/rotary_embedding_torch.py +132 -0
- core/models/modules/lmdm_modules/utils.py +96 -0
- core/models/modules/motion_extractor.py +25 -0
- core/models/modules/spade_generator.py +87 -0
- core/models/modules/stitching_network.py +65 -0
- core/models/modules/util.py +452 -0
- core/models/modules/warping_network.py +87 -0
- core/models/motion_extractor.py +49 -0
- core/models/stitch_network.py +30 -0
- core/models/warp_network.py +35 -0
- core/utils/blend/__init__.py +4 -0
- core/utils/blend/blend.pyx +38 -0
- core/utils/blend/blend.pyxbld +11 -0
.gitignore
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*__pycache__
|
4 |
+
**/__pycache__/
|
5 |
+
*.py[cod]
|
6 |
+
**/*.py[cod]
|
7 |
+
*$py.class
|
8 |
+
|
9 |
+
# Model weights
|
10 |
+
checkpoints
|
11 |
+
**/*.pth
|
12 |
+
**/*.onnx
|
13 |
+
**/*.pt
|
14 |
+
**/*.pth.tar
|
15 |
+
|
16 |
+
.idea
|
17 |
+
.vscode
|
18 |
+
.DS_Store
|
19 |
+
*.DS_Store
|
20 |
+
|
21 |
+
*.swp
|
22 |
+
tmp*
|
23 |
+
|
24 |
+
*build
|
25 |
+
*.egg-info/
|
26 |
+
*.mp4
|
27 |
+
|
28 |
+
log/*
|
29 |
+
*.mp4
|
30 |
+
*.png
|
31 |
+
*.jpg
|
32 |
+
*.wav
|
33 |
+
*.pth
|
34 |
+
*.pyc
|
35 |
+
*.jpeg
|
36 |
+
|
37 |
+
# Folders to ignore
|
38 |
+
example/
|
39 |
+
ToDo/
|
40 |
+
|
41 |
+
!example/audio.wav
|
42 |
+
!example/image.png
|
43 |
+
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README_ditto-talkinghead.md
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<h2 align='center'>Ditto: Motion-Space Diffusion for Controllable Realtime Talking Head Synthesis</h2>
|
2 |
+
|
3 |
+
<div align='center'>
|
4 |
+
<a href=""><strong>Tianqi Li</strong></a>
|
5 |
+
·
|
6 |
+
<a href=""><strong>Ruobing Zheng</strong></a><sup>†</sup>
|
7 |
+
·
|
8 |
+
<a href=""><strong>Minghui Yang</strong></a>
|
9 |
+
·
|
10 |
+
<a href=""><strong>Jingdong Chen</strong></a>
|
11 |
+
·
|
12 |
+
<a href=""><strong>Ming Yang</strong></a>
|
13 |
+
</div>
|
14 |
+
<div align='center'>
|
15 |
+
Ant Group
|
16 |
+
</div>
|
17 |
+
<br>
|
18 |
+
<div align='center'>
|
19 |
+
<a href='https://arxiv.org/abs/2411.19509'><img src='https://img.shields.io/badge/Paper-arXiv-red'></a>
|
20 |
+
<a href='https://digital-avatar.github.io/ai/Ditto/'><img src='https://img.shields.io/badge/Project-Page-blue'></a>
|
21 |
+
<a href='https://huggingface.co/digital-avatar/ditto-talkinghead'><img src='https://img.shields.io/badge/Model-HuggingFace-yellow'></a>
|
22 |
+
<a href='https://github.com/antgroup/ditto-talkinghead'><img src='https://img.shields.io/badge/Code-GitHub-purple'></a>
|
23 |
+
<!-- <a href='https://github.com/antgroup/ditto-talkinghead'><img src='https://img.shields.io/github/stars/antgroup/ditto-talkinghead?style=social'></a> -->
|
24 |
+
<a href='https://colab.research.google.com/drive/19SUi1TiO32IS-Crmsu9wrkNspWE8tFbs?usp=sharing'><img src='https://img.shields.io/badge/Demo-Colab-orange'></a>
|
25 |
+
</div>
|
26 |
+
<br>
|
27 |
+
<div align="center">
|
28 |
+
<video style="width: 95%; object-fit: cover;" controls loop src="https://github.com/user-attachments/assets/ef1a0b08-bff3-4997-a6dd-62a7f51cdb40" muted="false"></video>
|
29 |
+
<p>
|
30 |
+
✨ For more results, visit our <a href="https://digital-avatar.github.io/ai/Ditto/"><strong>Project Page</strong></a> ✨
|
31 |
+
</p>
|
32 |
+
</div>
|
33 |
+
|
34 |
+
|
35 |
+
## 📌 Updates
|
36 |
+
* [2025.07.11] 🔥 The [PyTorch model](#-pytorch-model) is now available.
|
37 |
+
* [2025.07.07] 🔥 Ditto is accepted by ACM MM 2025.
|
38 |
+
* [2025.01.21] 🔥 We update the [Colab](https://colab.research.google.com/drive/19SUi1TiO32IS-Crmsu9wrkNspWE8tFbs?usp=sharing) demo, welcome to try it.
|
39 |
+
* [2025.01.10] 🔥 We release our inference [codes](https://github.com/antgroup/ditto-talkinghead) and [models](https://huggingface.co/digital-avatar/ditto-talkinghead).
|
40 |
+
* [2024.11.29] 🔥 Our [paper](https://arxiv.org/abs/2411.19509) is in public on arxiv.
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
## 🛠️ Installation
|
45 |
+
|
46 |
+
Tested Environment
|
47 |
+
- System: Centos 7.2
|
48 |
+
- GPU: A100
|
49 |
+
- Python: 3.10
|
50 |
+
- tensorRT: 8.6.1
|
51 |
+
|
52 |
+
|
53 |
+
Clone the codes from [GitHub](https://github.com/antgroup/ditto-talkinghead):
|
54 |
+
```bash
|
55 |
+
git clone https://github.com/antgroup/ditto-talkinghead
|
56 |
+
cd ditto-talkinghead
|
57 |
+
```
|
58 |
+
|
59 |
+
### Conda
|
60 |
+
Create `conda` environment:
|
61 |
+
```bash
|
62 |
+
conda env create -f environment.yaml
|
63 |
+
conda activate ditto
|
64 |
+
```
|
65 |
+
|
66 |
+
### Pip
|
67 |
+
If you have problems creating a conda environment, you can also refer to our [Colab](https://colab.research.google.com/drive/19SUi1TiO32IS-Crmsu9wrkNspWE8tFbs?usp=sharing).
|
68 |
+
After correctly installing `pytorch`, `cuda` and `cudnn`, you only need to install a few packages using pip:
|
69 |
+
```bash
|
70 |
+
pip install \
|
71 |
+
tensorrt==8.6.1 \
|
72 |
+
librosa \
|
73 |
+
tqdm \
|
74 |
+
filetype \
|
75 |
+
imageio \
|
76 |
+
opencv_python_headless \
|
77 |
+
scikit-image \
|
78 |
+
cython \
|
79 |
+
cuda-python \
|
80 |
+
imageio-ffmpeg \
|
81 |
+
colored \
|
82 |
+
polygraphy \
|
83 |
+
numpy==2.0.1
|
84 |
+
```
|
85 |
+
|
86 |
+
If you don't use `conda`, you may also need to install `ffmpeg` according to the [official website](https://www.ffmpeg.org/download.html).
|
87 |
+
|
88 |
+
|
89 |
+
## 📥 Download Checkpoints
|
90 |
+
|
91 |
+
Download checkpoints from [HuggingFace](https://huggingface.co/digital-avatar/ditto-talkinghead) and put them in `checkpoints` dir:
|
92 |
+
```bash
|
93 |
+
git lfs install
|
94 |
+
git clone https://huggingface.co/digital-avatar/ditto-talkinghead checkpoints
|
95 |
+
```
|
96 |
+
|
97 |
+
The `checkpoints` should be like:
|
98 |
+
```text
|
99 |
+
./checkpoints/
|
100 |
+
├── ditto_cfg
|
101 |
+
│ ├── v0.4_hubert_cfg_trt.pkl
|
102 |
+
│ └── v0.4_hubert_cfg_trt_online.pkl
|
103 |
+
├── ditto_onnx
|
104 |
+
│ ├── appearance_extractor.onnx
|
105 |
+
│ ├── blaze_face.onnx
|
106 |
+
│ ├── decoder.onnx
|
107 |
+
│ ├── face_mesh.onnx
|
108 |
+
│ ├── hubert.onnx
|
109 |
+
│ ├── insightface_det.onnx
|
110 |
+
│ ├── landmark106.onnx
|
111 |
+
│ ├── landmark203.onnx
|
112 |
+
│ ├── libgrid_sample_3d_plugin.so
|
113 |
+
│ ├── lmdm_v0.4_hubert.onnx
|
114 |
+
│ ├── motion_extractor.onnx
|
115 |
+
│ ├── stitch_network.onnx
|
116 |
+
│ └── warp_network.onnx
|
117 |
+
└── ditto_trt_Ampere_Plus
|
118 |
+
├── appearance_extractor_fp16.engine
|
119 |
+
├── blaze_face_fp16.engine
|
120 |
+
├── decoder_fp16.engine
|
121 |
+
├── face_mesh_fp16.engine
|
122 |
+
├── hubert_fp32.engine
|
123 |
+
├── insightface_det_fp16.engine
|
124 |
+
├── landmark106_fp16.engine
|
125 |
+
├── landmark203_fp16.engine
|
126 |
+
├── lmdm_v0.4_hubert_fp32.engine
|
127 |
+
├── motion_extractor_fp32.engine
|
128 |
+
├── stitch_network_fp16.engine
|
129 |
+
└── warp_network_fp16.engine
|
130 |
+
```
|
131 |
+
|
132 |
+
- The `ditto_cfg/v0.4_hubert_cfg_trt_online.pkl` is online config
|
133 |
+
- The `ditto_cfg/v0.4_hubert_cfg_trt.pkl` is offline config
|
134 |
+
|
135 |
+
|
136 |
+
## 🚀 Inference
|
137 |
+
|
138 |
+
Run `inference.py`:
|
139 |
+
|
140 |
+
```shell
|
141 |
+
python inference.py \
|
142 |
+
--data_root "<path-to-trt-model>" \
|
143 |
+
--cfg_pkl "<path-to-cfg-pkl>" \
|
144 |
+
--audio_path "<path-to-input-audio>" \
|
145 |
+
--source_path "<path-to-input-image>" \
|
146 |
+
--output_path "<path-to-output-mp4>"
|
147 |
+
```
|
148 |
+
|
149 |
+
For example:
|
150 |
+
|
151 |
+
```shell
|
152 |
+
python inference.py \
|
153 |
+
--data_root "./checkpoints/ditto_trt_Ampere_Plus" \
|
154 |
+
--cfg_pkl "./checkpoints/ditto_cfg/v0.4_hubert_cfg_trt.pkl" \
|
155 |
+
--audio_path "./example/audio.wav" \
|
156 |
+
--source_path "./example/image.png" \
|
157 |
+
--output_path "./tmp/result.mp4"
|
158 |
+
```
|
159 |
+
|
160 |
+
❗Note:
|
161 |
+
|
162 |
+
We have provided the tensorRT model with `hardware-compatibility-level=Ampere_Plus` (`checkpoints/ditto_trt_Ampere_Plus/`). If your GPU does not support it, please execute the `cvt_onnx_to_trt.py` script to convert from the general onnx model (`checkpoints/ditto_onnx/`) to the tensorRT model.
|
163 |
+
|
164 |
+
```bash
|
165 |
+
python scripts/cvt_onnx_to_trt.py --onnx_dir "./checkpoints/ditto_onnx" --trt_dir "./checkpoints/ditto_trt_custom"
|
166 |
+
```
|
167 |
+
|
168 |
+
Then run `inference.py` with `--data_root=./checkpoints/ditto_trt_custom`.
|
169 |
+
|
170 |
+
|
171 |
+
## ⚡ PyTorch Model
|
172 |
+
*Based on community interest and to better support further development, we are now open-sourcing the PyTorch version of the model.*
|
173 |
+
|
174 |
+
|
175 |
+
We have added the PyTorch model and corresponding configuration files to the [HuggingFace](https://huggingface.co/digital-avatar/ditto-talkinghead). Please refer to [Download Checkpoints](#-download-checkpoints) to prepare the model files.
|
176 |
+
|
177 |
+
The `checkpoints` should be like:
|
178 |
+
```text
|
179 |
+
./checkpoints/
|
180 |
+
├── ditto_cfg
|
181 |
+
│ ├── ...
|
182 |
+
│ └── v0.4_hubert_cfg_pytorch.pkl
|
183 |
+
├── ...
|
184 |
+
└── ditto_pytorch
|
185 |
+
├── aux_models
|
186 |
+
│ ├── 2d106det.onnx
|
187 |
+
│ ├── det_10g.onnx
|
188 |
+
│ ├── face_landmarker.task
|
189 |
+
│ ├── hubert_streaming_fix_kv.onnx
|
190 |
+
│ └── landmark203.onnx
|
191 |
+
└── models
|
192 |
+
├── appearance_extractor.pth
|
193 |
+
├── decoder.pth
|
194 |
+
├── lmdm_v0.4_hubert.pth
|
195 |
+
├── motion_extractor.pth
|
196 |
+
├── stitch_network.pth
|
197 |
+
└── warp_network.pth
|
198 |
+
```
|
199 |
+
|
200 |
+
To run inference, execute the following command:
|
201 |
+
|
202 |
+
```shell
|
203 |
+
python inference.py \
|
204 |
+
--data_root "./checkpoints/ditto_pytorch" \
|
205 |
+
--cfg_pkl "./checkpoints/ditto_cfg/v0.4_hubert_cfg_pytorch.pkl" \
|
206 |
+
--audio_path "./example/audio.wav" \
|
207 |
+
--source_path "./example/image.png" \
|
208 |
+
--output_path "./tmp/result.mp4"
|
209 |
+
```
|
210 |
+
|
211 |
+
|
212 |
+
## 📧 Acknowledgement
|
213 |
+
Our implementation is based on [S2G-MDDiffusion](https://github.com/thuhcsi/S2G-MDDiffusion) and [LivePortrait](https://github.com/KwaiVGI/LivePortrait). Thanks for their remarkable contribution and released code! If we missed any open-source projects or related articles, we would like to complement the acknowledgement of this specific work immediately.
|
214 |
+
|
215 |
+
## ⚖️ License
|
216 |
+
This repository is released under the Apache-2.0 license as found in the [LICENSE](LICENSE) file.
|
217 |
+
|
218 |
+
## 📚 Citation
|
219 |
+
If you find this codebase useful for your research, please use the following entry.
|
220 |
+
```BibTeX
|
221 |
+
@article{li2024ditto,
|
222 |
+
title={Ditto: Motion-Space Diffusion for Controllable Realtime Talking Head Synthesis},
|
223 |
+
author={Li, Tianqi and Zheng, Ruobing and Yang, Minghui and Chen, Jingdong and Yang, Ming},
|
224 |
+
journal={arXiv preprint arXiv:2411.19509},
|
225 |
+
year={2024}
|
226 |
+
}
|
227 |
+
```
|
228 |
+
|
229 |
+
|
230 |
+
## 🌟 Star History
|
231 |
+
|
232 |
+
[](https://www.star-history.com/#antgroup/ditto-talkinghead&Date)
|
core/atomic_components/audio2motion.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from ..models.lmdm import LMDM
|
3 |
+
|
4 |
+
|
5 |
+
"""
|
6 |
+
lmdm_cfg = {
|
7 |
+
"model_path": "",
|
8 |
+
"device": "cuda",
|
9 |
+
"motion_feat_dim": 265,
|
10 |
+
"audio_feat_dim": 1024+35,
|
11 |
+
"seq_frames": 80,
|
12 |
+
}
|
13 |
+
"""
|
14 |
+
|
15 |
+
|
16 |
+
def _cvt_LP_motion_info(inp, mode, ignore_keys=()):
|
17 |
+
ks_shape_map = [
|
18 |
+
['scale', (1, 1), 1],
|
19 |
+
['pitch', (1, 66), 66],
|
20 |
+
['yaw', (1, 66), 66],
|
21 |
+
['roll', (1, 66), 66],
|
22 |
+
['t', (1, 3), 3],
|
23 |
+
['exp', (1, 63), 63],
|
24 |
+
['kp', (1, 63), 63],
|
25 |
+
]
|
26 |
+
|
27 |
+
def _dic2arr(_dic):
|
28 |
+
arr = []
|
29 |
+
for k, _, ds in ks_shape_map:
|
30 |
+
if k not in _dic or k in ignore_keys:
|
31 |
+
continue
|
32 |
+
v = _dic[k].reshape(ds)
|
33 |
+
if k == 'scale':
|
34 |
+
v = v - 1
|
35 |
+
arr.append(v)
|
36 |
+
arr = np.concatenate(arr, -1) # (133)
|
37 |
+
return arr
|
38 |
+
|
39 |
+
def _arr2dic(_arr):
|
40 |
+
dic = {}
|
41 |
+
s = 0
|
42 |
+
for k, ds, ss in ks_shape_map:
|
43 |
+
if k in ignore_keys:
|
44 |
+
continue
|
45 |
+
v = _arr[s:s + ss].reshape(ds)
|
46 |
+
if k == 'scale':
|
47 |
+
v = v + 1
|
48 |
+
dic[k] = v
|
49 |
+
s += ss
|
50 |
+
if s >= len(_arr):
|
51 |
+
break
|
52 |
+
return dic
|
53 |
+
|
54 |
+
if mode == 'dic2arr':
|
55 |
+
assert isinstance(inp, dict)
|
56 |
+
return _dic2arr(inp) # (dim)
|
57 |
+
elif mode == 'arr2dic':
|
58 |
+
assert inp.shape[0] >= 265, f"{inp.shape}"
|
59 |
+
return _arr2dic(inp) # {k: (1, dim)}
|
60 |
+
else:
|
61 |
+
raise ValueError()
|
62 |
+
|
63 |
+
|
64 |
+
class Audio2Motion:
|
65 |
+
def __init__(
|
66 |
+
self,
|
67 |
+
lmdm_cfg,
|
68 |
+
):
|
69 |
+
self.lmdm = LMDM(**lmdm_cfg)
|
70 |
+
|
71 |
+
def setup(
|
72 |
+
self,
|
73 |
+
x_s_info,
|
74 |
+
overlap_v2=10,
|
75 |
+
fix_kp_cond=0,
|
76 |
+
fix_kp_cond_dim=None,
|
77 |
+
sampling_timesteps=50,
|
78 |
+
online_mode=False,
|
79 |
+
v_min_max_for_clip=None,
|
80 |
+
smo_k_d=3,
|
81 |
+
):
|
82 |
+
self.smo_k_d = smo_k_d
|
83 |
+
self.overlap_v2 = overlap_v2
|
84 |
+
self.seq_frames = self.lmdm.seq_frames
|
85 |
+
self.valid_clip_len = self.seq_frames - self.overlap_v2
|
86 |
+
|
87 |
+
# for fuse
|
88 |
+
self.online_mode = online_mode
|
89 |
+
if self.online_mode:
|
90 |
+
self.fuse_length = min(self.overlap_v2, self.valid_clip_len)
|
91 |
+
else:
|
92 |
+
self.fuse_length = self.overlap_v2
|
93 |
+
self.fuse_alpha = np.arange(self.fuse_length, dtype=np.float32).reshape(1, -1, 1) / self.fuse_length
|
94 |
+
|
95 |
+
self.fix_kp_cond = fix_kp_cond
|
96 |
+
self.fix_kp_cond_dim = fix_kp_cond_dim
|
97 |
+
self.sampling_timesteps = sampling_timesteps
|
98 |
+
|
99 |
+
self.v_min_max_for_clip = v_min_max_for_clip
|
100 |
+
if self.v_min_max_for_clip is not None:
|
101 |
+
self.v_min = self.v_min_max_for_clip[0][None] # [dim, 1]
|
102 |
+
self.v_max = self.v_min_max_for_clip[1][None]
|
103 |
+
|
104 |
+
kp_source = _cvt_LP_motion_info(x_s_info, mode='dic2arr', ignore_keys={'kp'})[None]
|
105 |
+
self.s_kp_cond = kp_source.copy().reshape(1, -1)
|
106 |
+
self.kp_cond = self.s_kp_cond.copy()
|
107 |
+
|
108 |
+
self.lmdm.setup(sampling_timesteps)
|
109 |
+
|
110 |
+
self.clip_idx = 0
|
111 |
+
|
112 |
+
def _fuse(self, res_kp_seq, pred_kp_seq):
|
113 |
+
## ========================
|
114 |
+
## offline fuse mode
|
115 |
+
## last clip: -------
|
116 |
+
## fuse part: *****
|
117 |
+
## curr clip: -------
|
118 |
+
## output: ^^
|
119 |
+
#
|
120 |
+
## online fuse mode
|
121 |
+
## last clip: -------
|
122 |
+
## fuse part: **
|
123 |
+
## curr clip: -------
|
124 |
+
## output: ^^
|
125 |
+
## ========================
|
126 |
+
|
127 |
+
fuse_r1_s = res_kp_seq.shape[1] - self.fuse_length
|
128 |
+
fuse_r1_e = res_kp_seq.shape[1]
|
129 |
+
fuse_r2_s = self.seq_frames - self.valid_clip_len - self.fuse_length
|
130 |
+
fuse_r2_e = self.seq_frames - self.valid_clip_len
|
131 |
+
|
132 |
+
r1 = res_kp_seq[:, fuse_r1_s:fuse_r1_e] # [1, fuse_len, dim]
|
133 |
+
r2 = pred_kp_seq[:, fuse_r2_s: fuse_r2_e] # [1, fuse_len, dim]
|
134 |
+
r_fuse = r1 * (1 - self.fuse_alpha) + r2 * self.fuse_alpha
|
135 |
+
|
136 |
+
res_kp_seq[:, fuse_r1_s:fuse_r1_e] = r_fuse # fuse last
|
137 |
+
res_kp_seq = np.concatenate([res_kp_seq, pred_kp_seq[:, fuse_r2_e:]], 1) # len(res_kp_seq) + valid_clip_len
|
138 |
+
|
139 |
+
return res_kp_seq
|
140 |
+
|
141 |
+
def _update_kp_cond(self, res_kp_seq, idx):
|
142 |
+
if self.fix_kp_cond == 0: # 不重置
|
143 |
+
self.kp_cond = res_kp_seq[:, idx-1]
|
144 |
+
elif self.fix_kp_cond > 0:
|
145 |
+
if self.clip_idx % self.fix_kp_cond == 0: # 重置
|
146 |
+
self.kp_cond = self.s_kp_cond.copy() # 重置所有
|
147 |
+
if self.fix_kp_cond_dim is not None:
|
148 |
+
ds, de = self.fix_kp_cond_dim
|
149 |
+
self.kp_cond[:, ds:de] = res_kp_seq[:, idx-1, ds:de]
|
150 |
+
else:
|
151 |
+
self.kp_cond = res_kp_seq[:, idx-1]
|
152 |
+
|
153 |
+
def _smo(self, res_kp_seq, s, e):
|
154 |
+
if self.smo_k_d <= 1:
|
155 |
+
return res_kp_seq
|
156 |
+
new_res_kp_seq = res_kp_seq.copy()
|
157 |
+
n = res_kp_seq.shape[1]
|
158 |
+
half_k = self.smo_k_d // 2
|
159 |
+
for i in range(s, e):
|
160 |
+
ss = max(0, i - half_k)
|
161 |
+
ee = min(n, i + half_k + 1)
|
162 |
+
res_kp_seq[:, i, :202] = np.mean(new_res_kp_seq[:, ss:ee, :202], axis=1)
|
163 |
+
return res_kp_seq
|
164 |
+
|
165 |
+
def __call__(self, aud_cond, res_kp_seq=None):
|
166 |
+
"""
|
167 |
+
aud_cond: (1, seq_frames, dim)
|
168 |
+
"""
|
169 |
+
|
170 |
+
pred_kp_seq = self.lmdm(self.kp_cond, aud_cond, self.sampling_timesteps)
|
171 |
+
if res_kp_seq is None:
|
172 |
+
res_kp_seq = pred_kp_seq # [1, seq_frames, dim]
|
173 |
+
res_kp_seq = self._smo(res_kp_seq, 0, res_kp_seq.shape[1])
|
174 |
+
else:
|
175 |
+
res_kp_seq = self._fuse(res_kp_seq, pred_kp_seq) # len(res_kp_seq) + valid_clip_len
|
176 |
+
res_kp_seq = self._smo(res_kp_seq, res_kp_seq.shape[1] - self.valid_clip_len - self.fuse_length, res_kp_seq.shape[1] - self.valid_clip_len + 1)
|
177 |
+
|
178 |
+
self.clip_idx += 1
|
179 |
+
|
180 |
+
idx = res_kp_seq.shape[1] - self.overlap_v2
|
181 |
+
self._update_kp_cond(res_kp_seq, idx)
|
182 |
+
|
183 |
+
return res_kp_seq
|
184 |
+
|
185 |
+
def cvt_fmt(self, res_kp_seq):
|
186 |
+
# res_kp_seq: [1, n, dim]
|
187 |
+
if self.v_min_max_for_clip is not None:
|
188 |
+
tmp_res_kp_seq = np.clip(res_kp_seq[0], self.v_min, self.v_max)
|
189 |
+
else:
|
190 |
+
tmp_res_kp_seq = res_kp_seq[0]
|
191 |
+
|
192 |
+
x_d_info_list = []
|
193 |
+
for i in range(tmp_res_kp_seq.shape[0]):
|
194 |
+
x_d_info = _cvt_LP_motion_info(tmp_res_kp_seq[i], 'arr2dic') # {k: (1, dim)}
|
195 |
+
x_d_info_list.append(x_d_info)
|
196 |
+
return x_d_info_list
|
core/atomic_components/avatar_registrar.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
from .loader import load_source_frames
|
4 |
+
from .source2info import Source2Info
|
5 |
+
|
6 |
+
|
7 |
+
def _mean_filter(arr, k):
|
8 |
+
n = arr.shape[0]
|
9 |
+
half_k = k // 2
|
10 |
+
res = []
|
11 |
+
for i in range(n):
|
12 |
+
s = max(0, i - half_k)
|
13 |
+
e = min(n, i + half_k + 1)
|
14 |
+
res.append(arr[s:e].mean(0))
|
15 |
+
res = np.stack(res, 0)
|
16 |
+
return res
|
17 |
+
|
18 |
+
|
19 |
+
def smooth_x_s_info_lst(x_s_info_list, ignore_keys=(), smo_k=13):
|
20 |
+
keys = x_s_info_list[0].keys()
|
21 |
+
N = len(x_s_info_list)
|
22 |
+
smo_dict = {}
|
23 |
+
for k in keys:
|
24 |
+
_lst = [x_s_info_list[i][k] for i in range(N)]
|
25 |
+
if k not in ignore_keys:
|
26 |
+
_lst = np.stack(_lst, 0)
|
27 |
+
_smo_lst = _mean_filter(_lst, smo_k)
|
28 |
+
else:
|
29 |
+
_smo_lst = _lst
|
30 |
+
smo_dict[k] = _smo_lst
|
31 |
+
|
32 |
+
smo_res = []
|
33 |
+
for i in range(N):
|
34 |
+
x_s_info = {k: smo_dict[k][i] for k in keys}
|
35 |
+
smo_res.append(x_s_info)
|
36 |
+
return smo_res
|
37 |
+
|
38 |
+
|
39 |
+
class AvatarRegistrar:
|
40 |
+
"""
|
41 |
+
source image|video -> rgb_list -> source_info
|
42 |
+
"""
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
insightface_det_cfg,
|
46 |
+
landmark106_cfg,
|
47 |
+
landmark203_cfg,
|
48 |
+
landmark478_cfg,
|
49 |
+
appearance_extractor_cfg,
|
50 |
+
motion_extractor_cfg,
|
51 |
+
):
|
52 |
+
self.source2info = Source2Info(
|
53 |
+
insightface_det_cfg,
|
54 |
+
landmark106_cfg,
|
55 |
+
landmark203_cfg,
|
56 |
+
landmark478_cfg,
|
57 |
+
appearance_extractor_cfg,
|
58 |
+
motion_extractor_cfg,
|
59 |
+
)
|
60 |
+
|
61 |
+
def register(
|
62 |
+
self,
|
63 |
+
source_path, # image | video
|
64 |
+
max_dim=1920,
|
65 |
+
n_frames=-1,
|
66 |
+
**kwargs,
|
67 |
+
):
|
68 |
+
"""
|
69 |
+
kwargs:
|
70 |
+
crop_scale: 2.3
|
71 |
+
crop_vx_ratio: 0
|
72 |
+
crop_vy_ratio: -0.125
|
73 |
+
crop_flag_do_rot: True
|
74 |
+
"""
|
75 |
+
rgb_list, is_image_flag = load_source_frames(source_path, max_dim=max_dim, n_frames=n_frames)
|
76 |
+
source_info = {
|
77 |
+
"x_s_info_lst": [],
|
78 |
+
"f_s_lst": [],
|
79 |
+
"M_c2o_lst": [],
|
80 |
+
"eye_open_lst": [],
|
81 |
+
"eye_ball_lst": [],
|
82 |
+
}
|
83 |
+
keys = ["x_s_info", "f_s", "M_c2o", "eye_open", "eye_ball"]
|
84 |
+
last_lmk = None
|
85 |
+
for rgb in rgb_list:
|
86 |
+
info = self.source2info(rgb, last_lmk, **kwargs)
|
87 |
+
for k in keys:
|
88 |
+
source_info[f"{k}_lst"].append(info[k])
|
89 |
+
|
90 |
+
last_lmk = info["lmk203"]
|
91 |
+
|
92 |
+
sc_f0 = source_info['x_s_info_lst'][0]['kp'].flatten()
|
93 |
+
|
94 |
+
source_info["sc"] = sc_f0
|
95 |
+
source_info["is_image_flag"] = is_image_flag
|
96 |
+
source_info["img_rgb_lst"] = rgb_list
|
97 |
+
|
98 |
+
return source_info
|
99 |
+
|
100 |
+
def __call__(self, *args, **kwargs):
|
101 |
+
return self.register(*args, **kwargs)
|
102 |
+
|
core/atomic_components/cfg.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pickle
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
def load_pkl(pkl):
|
7 |
+
with open(pkl, "rb") as f:
|
8 |
+
return pickle.load(f)
|
9 |
+
|
10 |
+
|
11 |
+
def parse_cfg(cfg_pkl, data_root, replace_cfg=None):
|
12 |
+
|
13 |
+
def _check_path(p):
|
14 |
+
if os.path.isfile(p):
|
15 |
+
return p
|
16 |
+
else:
|
17 |
+
return os.path.join(data_root, p)
|
18 |
+
|
19 |
+
cfg = load_pkl(cfg_pkl)
|
20 |
+
|
21 |
+
# ---
|
22 |
+
# replace cfg for debug
|
23 |
+
if isinstance(replace_cfg, dict):
|
24 |
+
for k, v in replace_cfg.items():
|
25 |
+
if not isinstance(v, dict):
|
26 |
+
continue
|
27 |
+
for kk, vv in v.items():
|
28 |
+
cfg[k][kk] = vv
|
29 |
+
# ---
|
30 |
+
|
31 |
+
base_cfg = cfg["base_cfg"]
|
32 |
+
audio2motion_cfg = cfg["audio2motion_cfg"]
|
33 |
+
default_kwargs = cfg["default_kwargs"]
|
34 |
+
|
35 |
+
for k in base_cfg:
|
36 |
+
if k == "landmark478_cfg":
|
37 |
+
for kk in ["task_path", "blaze_face_model_path", "face_mesh_model_path"]:
|
38 |
+
if kk in base_cfg[k] and base_cfg[k][kk]:
|
39 |
+
base_cfg[k][kk] = _check_path(base_cfg[k][kk])
|
40 |
+
else:
|
41 |
+
base_cfg[k]["model_path"] = _check_path(base_cfg[k]["model_path"])
|
42 |
+
|
43 |
+
audio2motion_cfg["model_path"] = _check_path(audio2motion_cfg["model_path"])
|
44 |
+
|
45 |
+
avatar_registrar_cfg = {
|
46 |
+
k: base_cfg[k]
|
47 |
+
for k in [
|
48 |
+
"insightface_det_cfg",
|
49 |
+
"landmark106_cfg",
|
50 |
+
"landmark203_cfg",
|
51 |
+
"landmark478_cfg",
|
52 |
+
"appearance_extractor_cfg",
|
53 |
+
"motion_extractor_cfg",
|
54 |
+
]
|
55 |
+
}
|
56 |
+
|
57 |
+
stitch_network_cfg = base_cfg["stitch_network_cfg"]
|
58 |
+
warp_network_cfg = base_cfg["warp_network_cfg"]
|
59 |
+
decoder_cfg = base_cfg["decoder_cfg"]
|
60 |
+
|
61 |
+
condition_handler_cfg = {
|
62 |
+
k: audio2motion_cfg[k]
|
63 |
+
for k in [
|
64 |
+
"use_emo",
|
65 |
+
"use_sc",
|
66 |
+
"use_eye_open",
|
67 |
+
"use_eye_ball",
|
68 |
+
"seq_frames",
|
69 |
+
]
|
70 |
+
}
|
71 |
+
|
72 |
+
lmdm_cfg = {
|
73 |
+
k: audio2motion_cfg[k]
|
74 |
+
for k in [
|
75 |
+
"model_path",
|
76 |
+
"device",
|
77 |
+
"motion_feat_dim",
|
78 |
+
"audio_feat_dim",
|
79 |
+
"seq_frames",
|
80 |
+
]
|
81 |
+
}
|
82 |
+
|
83 |
+
w2f_type = audio2motion_cfg["w2f_type"]
|
84 |
+
wav2feat_cfg = {
|
85 |
+
"w2f_cfg": base_cfg["hubert_cfg"] if w2f_type == "hubert" else base_cfg["wavlm_cfg"],
|
86 |
+
"w2f_type": w2f_type,
|
87 |
+
}
|
88 |
+
|
89 |
+
return [
|
90 |
+
avatar_registrar_cfg,
|
91 |
+
condition_handler_cfg,
|
92 |
+
lmdm_cfg,
|
93 |
+
stitch_network_cfg,
|
94 |
+
warp_network_cfg,
|
95 |
+
decoder_cfg,
|
96 |
+
wav2feat_cfg,
|
97 |
+
default_kwargs,
|
98 |
+
]
|
99 |
+
|
100 |
+
|
101 |
+
def print_cfg(**kwargs):
|
102 |
+
for k, v in kwargs.items():
|
103 |
+
if k == "ch_info":
|
104 |
+
print(k, type(v))
|
105 |
+
elif k == "ctrl_info":
|
106 |
+
print(k, type(v), len(v))
|
107 |
+
else:
|
108 |
+
if isinstance(v, np.ndarray):
|
109 |
+
print(k, type(v), v.shape)
|
110 |
+
else:
|
111 |
+
print(k, type(v), v)
|
core/atomic_components/condition_handler.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from scipy.special import softmax
|
3 |
+
import copy
|
4 |
+
|
5 |
+
|
6 |
+
def _get_emo_avg(idx=6):
|
7 |
+
emo_avg = np.zeros(8, dtype=np.float32)
|
8 |
+
if isinstance(idx, (list, tuple)):
|
9 |
+
for i in idx:
|
10 |
+
emo_avg[i] = 8
|
11 |
+
else:
|
12 |
+
emo_avg[idx] = 8
|
13 |
+
emo_avg = softmax(emo_avg)
|
14 |
+
#emo_avg = None
|
15 |
+
# 'Angry', 'Disgust', 'Fear', 'Happy', 'Neutral', 'Sad', 'Surprise', 'Contempt'
|
16 |
+
return emo_avg
|
17 |
+
|
18 |
+
|
19 |
+
def _mirror_index(index, size):
|
20 |
+
turn = index // size
|
21 |
+
res = index % size
|
22 |
+
if turn % 2 == 0:
|
23 |
+
return res
|
24 |
+
else:
|
25 |
+
return size - res - 1
|
26 |
+
|
27 |
+
|
28 |
+
class ConditionHandler:
|
29 |
+
"""
|
30 |
+
aud_feat, emo_seq, eye_seq, sc_seq -> cond_seq
|
31 |
+
"""
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
use_emo=True,
|
35 |
+
use_sc=True,
|
36 |
+
use_eye_open=True,
|
37 |
+
use_eye_ball=True,
|
38 |
+
seq_frames=80,
|
39 |
+
):
|
40 |
+
self.use_emo = use_emo
|
41 |
+
self.use_sc = use_sc
|
42 |
+
self.use_eye_open = use_eye_open
|
43 |
+
self.use_eye_ball = use_eye_ball
|
44 |
+
|
45 |
+
self.seq_frames = seq_frames
|
46 |
+
|
47 |
+
def setup(self, setup_info, emo, eye_f0_mode=False, ch_info=None):
|
48 |
+
"""
|
49 |
+
emo: int | [int] | [[int]] | numpy
|
50 |
+
"""
|
51 |
+
if ch_info is None:
|
52 |
+
source_info = copy.deepcopy(setup_info)
|
53 |
+
else:
|
54 |
+
source_info = ch_info
|
55 |
+
|
56 |
+
self.eye_f0_mode = eye_f0_mode
|
57 |
+
self.x_s_info_0 = source_info['x_s_info_lst'][0]
|
58 |
+
|
59 |
+
if self.use_sc:
|
60 |
+
self.sc = source_info["sc"] # 63
|
61 |
+
self.sc_seq = np.stack([self.sc] * self.seq_frames, 0)
|
62 |
+
|
63 |
+
if self.use_eye_open:
|
64 |
+
self.eye_open_lst = np.concatenate(source_info["eye_open_lst"], 0) # [n, 2]
|
65 |
+
self.num_eye_open = len(self.eye_open_lst)
|
66 |
+
if self.num_eye_open == 1 or self.eye_f0_mode:
|
67 |
+
self.eye_open_seq = np.stack([self.eye_open_lst[0]] * self.seq_frames, 0)
|
68 |
+
else:
|
69 |
+
self.eye_open_seq = None
|
70 |
+
|
71 |
+
if self.use_eye_ball:
|
72 |
+
self.eye_ball_lst = np.concatenate(source_info["eye_ball_lst"], 0) # [n, 6]
|
73 |
+
self.num_eye_ball = len(self.eye_ball_lst)
|
74 |
+
if self.num_eye_ball == 1 or self.eye_f0_mode:
|
75 |
+
self.eye_ball_seq = np.stack([self.eye_ball_lst[0]] * self.seq_frames, 0)
|
76 |
+
else:
|
77 |
+
self.eye_ball_seq = None
|
78 |
+
|
79 |
+
if self.use_emo:
|
80 |
+
self.emo_lst = self._parse_emo_seq(emo)
|
81 |
+
self.num_emo = len(self.emo_lst)
|
82 |
+
if self.num_emo == 1:
|
83 |
+
self.emo_seq = np.concatenate([self.emo_lst] * self.seq_frames, 0)
|
84 |
+
else:
|
85 |
+
self.emo_seq = None
|
86 |
+
|
87 |
+
@staticmethod
|
88 |
+
def _parse_emo_seq(emo, seq_len=-1):
|
89 |
+
if isinstance(emo, np.ndarray) and emo.ndim == 2 and emo.shape[1] == 8:
|
90 |
+
# emo arr, e.g. real
|
91 |
+
emo_seq = emo # [m, 8]
|
92 |
+
elif isinstance(emo, int) and 0 <= emo < 8:
|
93 |
+
# emo label, e.g. 4
|
94 |
+
emo_seq = _get_emo_avg(emo).reshape(1, 8) # [1, 8]
|
95 |
+
elif isinstance(emo, (list, tuple)) and 0 < len(emo) < 8 and isinstance(emo[0], int):
|
96 |
+
# emo labels, e.g. [3,4]
|
97 |
+
emo_seq = _get_emo_avg(emo).reshape(1, 8) # [1, 8]
|
98 |
+
elif isinstance(emo, list) and emo and isinstance(emo[0], (list, tuple)):
|
99 |
+
# emo label list, e.g. [[4], [3,4], [3],[3,4,5], ...]
|
100 |
+
emo_seq = np.stack([_get_emo_avg(i) for i in emo], 0) # [m, 8]
|
101 |
+
else:
|
102 |
+
raise ValueError(f"Unsupported emo type: {emo}")
|
103 |
+
|
104 |
+
if seq_len > 0:
|
105 |
+
if len(emo_seq) == seq_len:
|
106 |
+
return emo_seq
|
107 |
+
elif len(emo_seq) == 1:
|
108 |
+
return np.concatenate([emo_seq] * seq_len, 0)
|
109 |
+
elif len(emo_seq) > seq_len:
|
110 |
+
return emo_seq[:seq_len]
|
111 |
+
else:
|
112 |
+
raise ValueError(f"emo len {len(emo_seq)} can not match seq len ({seq_len})")
|
113 |
+
else:
|
114 |
+
return emo_seq
|
115 |
+
|
116 |
+
def __call__(self, aud_feat, idx, emo=None):
|
117 |
+
"""
|
118 |
+
aud_feat: [n, 1024]
|
119 |
+
idx: int, <0 means pad (first clip buffer)
|
120 |
+
"""
|
121 |
+
|
122 |
+
frame_num = len(aud_feat)
|
123 |
+
more_cond = [aud_feat]
|
124 |
+
if self.use_emo:
|
125 |
+
if emo is not None:
|
126 |
+
emo_seq = self._parse_emo_seq(emo, frame_num)
|
127 |
+
elif self.emo_seq is not None and len(self.emo_seq) == frame_num:
|
128 |
+
emo_seq = self.emo_seq
|
129 |
+
else:
|
130 |
+
emo_idx_list = [max(i, 0) % self.num_emo for i in range(idx, idx + frame_num)]
|
131 |
+
emo_seq = self.emo_lst[emo_idx_list]
|
132 |
+
more_cond.append(emo_seq)
|
133 |
+
|
134 |
+
if self.use_eye_open:
|
135 |
+
if self.eye_open_seq is not None and len(self.eye_open_seq) == frame_num:
|
136 |
+
eye_open_seq = self.eye_open_seq
|
137 |
+
else:
|
138 |
+
if self.eye_f0_mode:
|
139 |
+
eye_idx_list = [0] * frame_num
|
140 |
+
else:
|
141 |
+
eye_idx_list = [_mirror_index(max(i, 0), self.num_eye_open) for i in range(idx, idx + frame_num)]
|
142 |
+
eye_open_seq = self.eye_open_lst[eye_idx_list]
|
143 |
+
more_cond.append(eye_open_seq)
|
144 |
+
|
145 |
+
if self.use_eye_ball:
|
146 |
+
if self.eye_ball_seq is not None and len(self.eye_ball_seq) == frame_num:
|
147 |
+
eye_ball_seq = self.eye_ball_seq
|
148 |
+
else:
|
149 |
+
if self.eye_f0_mode:
|
150 |
+
eye_idx_list = [0] * frame_num
|
151 |
+
else:
|
152 |
+
eye_idx_list = [_mirror_index(max(i, 0), self.num_eye_ball) for i in range(idx, idx + frame_num)]
|
153 |
+
eye_ball_seq = self.eye_ball_lst[eye_idx_list]
|
154 |
+
more_cond.append(eye_ball_seq)
|
155 |
+
|
156 |
+
if self.use_sc:
|
157 |
+
if len(self.sc_seq) == frame_num:
|
158 |
+
sc_seq = self.sc_seq
|
159 |
+
else:
|
160 |
+
sc_seq = np.stack([self.sc] * frame_num, 0)
|
161 |
+
more_cond.append(sc_seq)
|
162 |
+
|
163 |
+
if len(more_cond) > 1:
|
164 |
+
cond_seq = np.concatenate(more_cond, -1) # [n, dim_cond]
|
165 |
+
else:
|
166 |
+
cond_seq = aud_feat
|
167 |
+
|
168 |
+
return cond_seq
|
core/atomic_components/decode_f3d.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..models.decoder import Decoder
|
2 |
+
|
3 |
+
|
4 |
+
"""
|
5 |
+
# __init__
|
6 |
+
decoder_cfg = {
|
7 |
+
"model_path": "",
|
8 |
+
"device": "cuda",
|
9 |
+
}
|
10 |
+
"""
|
11 |
+
|
12 |
+
class DecodeF3D:
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
decoder_cfg,
|
16 |
+
):
|
17 |
+
self.decoder = Decoder(**decoder_cfg)
|
18 |
+
|
19 |
+
def __call__(self, f_s):
|
20 |
+
out = self.decoder(f_s)
|
21 |
+
return out
|
22 |
+
|
core/atomic_components/loader.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import filetype
|
2 |
+
import imageio
|
3 |
+
import cv2
|
4 |
+
|
5 |
+
|
6 |
+
def is_image(file_path):
|
7 |
+
return filetype.is_image(file_path)
|
8 |
+
|
9 |
+
|
10 |
+
def is_video(file_path):
|
11 |
+
return filetype.is_video(file_path)
|
12 |
+
|
13 |
+
|
14 |
+
def check_resize(h, w, max_dim=1920, division=2):
|
15 |
+
rsz_flag = False
|
16 |
+
# ajust the size of the image according to the maximum dimension
|
17 |
+
if max_dim > 0 and max(h, w) > max_dim:
|
18 |
+
rsz_flag = True
|
19 |
+
if h > w:
|
20 |
+
new_h = max_dim
|
21 |
+
new_w = int(round(w * max_dim / h))
|
22 |
+
else:
|
23 |
+
new_w = max_dim
|
24 |
+
new_h = int(round(h * max_dim / w))
|
25 |
+
else:
|
26 |
+
new_h = h
|
27 |
+
new_w = w
|
28 |
+
|
29 |
+
# ensure that the image dimensions are multiples of n
|
30 |
+
if new_h % division != 0:
|
31 |
+
new_h = new_h - (new_h % division)
|
32 |
+
rsz_flag = True
|
33 |
+
if new_w % division != 0:
|
34 |
+
new_w = new_w - (new_w % division)
|
35 |
+
rsz_flag = True
|
36 |
+
|
37 |
+
return new_h, new_w, rsz_flag
|
38 |
+
|
39 |
+
|
40 |
+
def load_image(image_path, max_dim=-1):
|
41 |
+
img = cv2.imread(image_path, cv2.IMREAD_COLOR)
|
42 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
43 |
+
h, w = img.shape[:2]
|
44 |
+
new_h, new_w, rsz_flag = check_resize(h, w, max_dim)
|
45 |
+
if rsz_flag:
|
46 |
+
img = cv2.resize(img, (new_w, new_h))
|
47 |
+
return img
|
48 |
+
|
49 |
+
|
50 |
+
def load_video(video_path, n_frames=-1, max_dim=-1):
|
51 |
+
reader = imageio.get_reader(video_path, "ffmpeg")
|
52 |
+
|
53 |
+
new_h, new_w, rsz_flag = None, None, None
|
54 |
+
|
55 |
+
ret = []
|
56 |
+
for idx, frame_rgb in enumerate(reader):
|
57 |
+
if n_frames > 0 and idx >= n_frames:
|
58 |
+
break
|
59 |
+
|
60 |
+
if rsz_flag is None:
|
61 |
+
h, w = frame_rgb.shape[:2]
|
62 |
+
new_h, new_w, rsz_flag = check_resize(h, w, max_dim)
|
63 |
+
|
64 |
+
if rsz_flag:
|
65 |
+
frame_rgb = cv2.resize(frame_rgb, (new_w, new_h))
|
66 |
+
|
67 |
+
ret.append(frame_rgb)
|
68 |
+
|
69 |
+
reader.close()
|
70 |
+
return ret
|
71 |
+
|
72 |
+
|
73 |
+
def load_source_frames(source_path, max_dim=-1, n_frames=-1):
|
74 |
+
if is_image(source_path):
|
75 |
+
rgb = load_image(source_path, max_dim)
|
76 |
+
rgb_list = [rgb]
|
77 |
+
is_image_flag = True
|
78 |
+
elif is_video(source_path):
|
79 |
+
rgb_list = load_video(source_path, n_frames, max_dim)
|
80 |
+
is_image_flag = False
|
81 |
+
else:
|
82 |
+
raise ValueError(f"Unsupported source type: {source_path}")
|
83 |
+
return rgb_list, is_image_flag
|
84 |
+
|
85 |
+
|
86 |
+
def _mirror_index(index, size):
|
87 |
+
turn = index // size
|
88 |
+
res = index % size
|
89 |
+
if turn % 2 == 0:
|
90 |
+
return res
|
91 |
+
else:
|
92 |
+
return size - res - 1
|
93 |
+
|
94 |
+
|
95 |
+
class LoopLoader:
|
96 |
+
def __init__(self, item_list, max_iter_num=-1, mirror_loop=True):
|
97 |
+
self.item_list = item_list
|
98 |
+
self.idx = 0
|
99 |
+
self.item_num = len(self.item_list)
|
100 |
+
self.max_iter_num = max_iter_num if max_iter_num > 0 else self.item_num
|
101 |
+
self.mirror_loop = mirror_loop
|
102 |
+
|
103 |
+
def __len__(self):
|
104 |
+
return self.max_iter_num
|
105 |
+
|
106 |
+
def __iter__(self):
|
107 |
+
return self
|
108 |
+
|
109 |
+
def __next__(self):
|
110 |
+
if self.idx >= self.max_iter_num:
|
111 |
+
raise StopIteration
|
112 |
+
|
113 |
+
if self.mirror_loop:
|
114 |
+
idx = _mirror_index(self.idx, self.item_num)
|
115 |
+
else:
|
116 |
+
idx = self.idx % self.item_num
|
117 |
+
item = self.item_list[idx]
|
118 |
+
|
119 |
+
self.idx += 1
|
120 |
+
return item
|
121 |
+
|
122 |
+
def __call__(self):
|
123 |
+
return self.__iter__()
|
124 |
+
|
125 |
+
def reset(self, max_iter_num=-1):
|
126 |
+
self.frame_idx = 0
|
127 |
+
self.max_iter_num = max_iter_num if max_iter_num > 0 else self.item_num
|
128 |
+
|
129 |
+
|
130 |
+
|
131 |
+
|
132 |
+
|
133 |
+
|
core/atomic_components/motion_stitch.py
ADDED
@@ -0,0 +1,491 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import random
|
3 |
+
import numpy as np
|
4 |
+
from scipy.special import softmax
|
5 |
+
|
6 |
+
from ..models.stitch_network import StitchNetwork
|
7 |
+
|
8 |
+
|
9 |
+
"""
|
10 |
+
# __init__
|
11 |
+
stitch_network_cfg = {
|
12 |
+
"model_path": "",
|
13 |
+
"device": "cuda",
|
14 |
+
}
|
15 |
+
|
16 |
+
# __call__
|
17 |
+
kwargs:
|
18 |
+
fade_alpha
|
19 |
+
fade_out_keys
|
20 |
+
|
21 |
+
delta_pitch
|
22 |
+
delta_yaw
|
23 |
+
delta_roll
|
24 |
+
|
25 |
+
"""
|
26 |
+
|
27 |
+
|
28 |
+
def ctrl_motion(x_d_info, **kwargs):
|
29 |
+
# pose + offset
|
30 |
+
for kk in ["delta_pitch", "delta_yaw", "delta_roll"]:
|
31 |
+
if kk in kwargs:
|
32 |
+
k = kk[6:]
|
33 |
+
x_d_info[k] = bin66_to_degree(x_d_info[k]) + kwargs[kk]
|
34 |
+
|
35 |
+
# pose * alpha
|
36 |
+
for kk in ["alpha_pitch", "alpha_yaw", "alpha_roll"]:
|
37 |
+
if kk in kwargs:
|
38 |
+
k = kk[6:]
|
39 |
+
x_d_info[k] = x_d_info[k] * kwargs[kk]
|
40 |
+
|
41 |
+
# exp + offset
|
42 |
+
if "delta_exp" in kwargs:
|
43 |
+
k = "exp"
|
44 |
+
x_d_info[k] = x_d_info[k] + kwargs["delta_exp"]
|
45 |
+
|
46 |
+
return x_d_info
|
47 |
+
|
48 |
+
|
49 |
+
def fade(x_d_info, dst, alpha, keys=None):
|
50 |
+
if keys is None:
|
51 |
+
keys = x_d_info.keys()
|
52 |
+
for k in keys:
|
53 |
+
if k == 'kp':
|
54 |
+
continue
|
55 |
+
x_d_info[k] = x_d_info[k] * alpha + dst[k] * (1 - alpha)
|
56 |
+
return x_d_info
|
57 |
+
|
58 |
+
|
59 |
+
def ctrl_vad(x_d_info, dst, alpha):
|
60 |
+
exp = x_d_info["exp"]
|
61 |
+
exp_dst = dst["exp"]
|
62 |
+
|
63 |
+
_lip = [6, 12, 14, 17, 19, 20]
|
64 |
+
_a1 = np.zeros((21, 3), dtype=np.float32)
|
65 |
+
_a1[_lip] = alpha
|
66 |
+
_a1 = _a1.reshape(1, -1)
|
67 |
+
x_d_info["exp"] = exp * alpha + exp_dst * (1 - alpha)
|
68 |
+
|
69 |
+
return x_d_info
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
def _mix_s_d_info(
|
74 |
+
x_s_info,
|
75 |
+
x_d_info,
|
76 |
+
use_d_keys=("exp", "pitch", "yaw", "roll", "t"),
|
77 |
+
d0=None,
|
78 |
+
):
|
79 |
+
if d0 is not None:
|
80 |
+
if isinstance(use_d_keys, dict):
|
81 |
+
x_d_info = {
|
82 |
+
k: x_s_info[k] + (v - d0[k]) * use_d_keys.get(k, 1)
|
83 |
+
for k, v in x_d_info.items()
|
84 |
+
}
|
85 |
+
else:
|
86 |
+
x_d_info = {k: x_s_info[k] + (v - d0[k]) for k, v in x_d_info.items()}
|
87 |
+
|
88 |
+
for k, v in x_s_info.items():
|
89 |
+
if k not in x_d_info or k not in use_d_keys:
|
90 |
+
x_d_info[k] = v
|
91 |
+
|
92 |
+
if isinstance(use_d_keys, dict) and d0 is None:
|
93 |
+
for k, alpha in use_d_keys.items():
|
94 |
+
x_d_info[k] *= alpha
|
95 |
+
return x_d_info
|
96 |
+
|
97 |
+
|
98 |
+
def _set_eye_blink_idx(N, blink_n=15, open_n=-1):
|
99 |
+
"""
|
100 |
+
open_n:
|
101 |
+
-1: no blink
|
102 |
+
0: random open_n
|
103 |
+
>0: fix open_n
|
104 |
+
list: loop open_n
|
105 |
+
"""
|
106 |
+
OPEN_MIN = 60
|
107 |
+
OPEN_MAX = 100
|
108 |
+
|
109 |
+
idx = [0] * N
|
110 |
+
if isinstance(open_n, int):
|
111 |
+
if open_n < 0: # no blink
|
112 |
+
return idx
|
113 |
+
elif open_n > 0: # fix open_n
|
114 |
+
open_ns = [open_n]
|
115 |
+
else: # open_n == 0: # random open_n, 60-100
|
116 |
+
open_ns = []
|
117 |
+
elif isinstance(open_n, list):
|
118 |
+
open_ns = open_n # loop open_n
|
119 |
+
else:
|
120 |
+
raise ValueError()
|
121 |
+
|
122 |
+
blink_idx = list(range(blink_n))
|
123 |
+
|
124 |
+
start_n = open_ns[0] if open_ns else random.randint(OPEN_MIN, OPEN_MAX)
|
125 |
+
end_n = open_ns[-1] if open_ns else random.randint(OPEN_MIN, OPEN_MAX)
|
126 |
+
max_i = N - max(end_n, blink_n)
|
127 |
+
cur_i = start_n
|
128 |
+
cur_n_i = 1
|
129 |
+
while cur_i < max_i:
|
130 |
+
idx[cur_i : cur_i + blink_n] = blink_idx
|
131 |
+
|
132 |
+
if open_ns:
|
133 |
+
cur_n = open_ns[cur_n_i % len(open_ns)]
|
134 |
+
cur_n_i += 1
|
135 |
+
else:
|
136 |
+
cur_n = random.randint(OPEN_MIN, OPEN_MAX)
|
137 |
+
|
138 |
+
cur_i = cur_i + blink_n + cur_n
|
139 |
+
|
140 |
+
return idx
|
141 |
+
|
142 |
+
|
143 |
+
def _fix_exp_for_x_d_info(x_d_info, x_s_info, delta_eye=None, drive_eye=True):
|
144 |
+
_eye = [11, 13, 15, 16, 18]
|
145 |
+
_lip = [6, 12, 14, 17, 19, 20]
|
146 |
+
alpha = np.zeros((21, 3), dtype=x_d_info["exp"].dtype)
|
147 |
+
alpha[_lip] = 1
|
148 |
+
if delta_eye is None and drive_eye: # use d eye
|
149 |
+
alpha[_eye] = 1
|
150 |
+
alpha = alpha.reshape(1, -1)
|
151 |
+
x_d_info["exp"] = x_d_info["exp"] * alpha + x_s_info["exp"] * (1 - alpha)
|
152 |
+
|
153 |
+
if delta_eye is not None and drive_eye:
|
154 |
+
alpha = np.zeros((21, 3), dtype=x_d_info["exp"].dtype)
|
155 |
+
alpha[_eye] = 1
|
156 |
+
alpha = alpha.reshape(1, -1)
|
157 |
+
x_d_info["exp"] = (delta_eye + x_s_info["exp"]) * alpha + x_d_info["exp"] * (
|
158 |
+
1 - alpha
|
159 |
+
)
|
160 |
+
|
161 |
+
return x_d_info
|
162 |
+
|
163 |
+
|
164 |
+
def _fix_exp_for_x_d_info_v2(x_d_info, x_s_info, delta_eye, a1, a2, a3):
|
165 |
+
x_d_info["exp"] = x_d_info["exp"] * a1 + x_s_info["exp"] * a2 + delta_eye * a3
|
166 |
+
return x_d_info
|
167 |
+
|
168 |
+
|
169 |
+
def bin66_to_degree(pred):
|
170 |
+
if pred.ndim > 1 and pred.shape[1] == 66:
|
171 |
+
idx = np.arange(66).astype(np.float32)
|
172 |
+
pred = softmax(pred, axis=1)
|
173 |
+
degree = np.sum(pred * idx, axis=1) * 3 - 97.5
|
174 |
+
return degree
|
175 |
+
return pred
|
176 |
+
|
177 |
+
|
178 |
+
def _eye_delta(exp, dx=0, dy=0):
|
179 |
+
if dx > 0:
|
180 |
+
exp[0, 33] += dx * 0.0007
|
181 |
+
exp[0, 45] += dx * 0.001
|
182 |
+
else:
|
183 |
+
exp[0, 33] += dx * 0.001
|
184 |
+
exp[0, 45] += dx * 0.0007
|
185 |
+
|
186 |
+
exp[0, 34] += dy * -0.001
|
187 |
+
exp[0, 46] += dy * -0.001
|
188 |
+
return exp
|
189 |
+
|
190 |
+
def _fix_gaze(pose_s, x_d_info):
|
191 |
+
x_ratio = 0.26
|
192 |
+
y_ratio = 0.28
|
193 |
+
|
194 |
+
yaw_s, pitch_s = pose_s
|
195 |
+
yaw_d = bin66_to_degree(x_d_info['yaw']).item()
|
196 |
+
pitch_d = bin66_to_degree(x_d_info['pitch']).item()
|
197 |
+
|
198 |
+
delta_yaw = yaw_d - yaw_s
|
199 |
+
delta_pitch = pitch_d - pitch_s
|
200 |
+
|
201 |
+
dx = delta_yaw * x_ratio
|
202 |
+
dy = delta_pitch * y_ratio
|
203 |
+
|
204 |
+
x_d_info['exp'] = _eye_delta(x_d_info['exp'], dx, dy)
|
205 |
+
return x_d_info
|
206 |
+
|
207 |
+
|
208 |
+
def get_rotation_matrix(pitch_, yaw_, roll_):
|
209 |
+
""" the input is in degree
|
210 |
+
"""
|
211 |
+
# transform to radian
|
212 |
+
pitch = pitch_ / 180 * np.pi
|
213 |
+
yaw = yaw_ / 180 * np.pi
|
214 |
+
roll = roll_ / 180 * np.pi
|
215 |
+
|
216 |
+
if pitch.ndim == 1:
|
217 |
+
pitch = pitch[:, None]
|
218 |
+
if yaw.ndim == 1:
|
219 |
+
yaw = yaw[:, None]
|
220 |
+
if roll.ndim == 1:
|
221 |
+
roll = roll[:, None]
|
222 |
+
|
223 |
+
# calculate the euler matrix
|
224 |
+
bs = pitch.shape[0]
|
225 |
+
ones = np.ones((bs, 1), dtype=np.float32)
|
226 |
+
zeros = np.zeros((bs, 1), dtype=np.float32)
|
227 |
+
x, y, z = pitch, yaw, roll
|
228 |
+
|
229 |
+
rot_x = np.concatenate([
|
230 |
+
ones, zeros, zeros,
|
231 |
+
zeros, np.cos(x), -np.sin(x),
|
232 |
+
zeros, np.sin(x), np.cos(x)
|
233 |
+
], axis=1).reshape(bs, 3, 3)
|
234 |
+
|
235 |
+
rot_y = np.concatenate([
|
236 |
+
np.cos(y), zeros, np.sin(y),
|
237 |
+
zeros, ones, zeros,
|
238 |
+
-np.sin(y), zeros, np.cos(y)
|
239 |
+
], axis=1).reshape(bs, 3, 3)
|
240 |
+
|
241 |
+
rot_z = np.concatenate([
|
242 |
+
np.cos(z), -np.sin(z), zeros,
|
243 |
+
np.sin(z), np.cos(z), zeros,
|
244 |
+
zeros, zeros, ones
|
245 |
+
], axis=1).reshape(bs, 3, 3)
|
246 |
+
|
247 |
+
rot = np.matmul(np.matmul(rot_z, rot_y), rot_x)
|
248 |
+
return np.transpose(rot, (0, 2, 1))
|
249 |
+
|
250 |
+
|
251 |
+
def transform_keypoint(kp_info: dict):
|
252 |
+
"""
|
253 |
+
transform the implicit keypoints with the pose, shift, and expression deformation
|
254 |
+
kp: BxNx3
|
255 |
+
"""
|
256 |
+
kp = kp_info['kp'] # (bs, k, 3)
|
257 |
+
pitch, yaw, roll = kp_info['pitch'], kp_info['yaw'], kp_info['roll']
|
258 |
+
|
259 |
+
t, exp = kp_info['t'], kp_info['exp']
|
260 |
+
scale = kp_info['scale']
|
261 |
+
|
262 |
+
pitch = bin66_to_degree(pitch)
|
263 |
+
yaw = bin66_to_degree(yaw)
|
264 |
+
roll = bin66_to_degree(roll)
|
265 |
+
|
266 |
+
bs = kp.shape[0]
|
267 |
+
if kp.ndim == 2:
|
268 |
+
num_kp = kp.shape[1] // 3 # Bx(num_kpx3)
|
269 |
+
else:
|
270 |
+
num_kp = kp.shape[1] # Bxnum_kpx3
|
271 |
+
|
272 |
+
rot_mat = get_rotation_matrix(pitch, yaw, roll) # (bs, 3, 3)
|
273 |
+
|
274 |
+
# Eqn.2: s * (R * x_c,s + exp) + t
|
275 |
+
kp_transformed = np.matmul(kp.reshape(bs, num_kp, 3), rot_mat) + exp.reshape(bs, num_kp, 3)
|
276 |
+
kp_transformed *= scale[..., None] # (bs, k, 3) * (bs, 1, 1) = (bs, k, 3)
|
277 |
+
kp_transformed[:, :, 0:2] += t[:, None, 0:2] # remove z, only apply tx ty
|
278 |
+
|
279 |
+
return kp_transformed
|
280 |
+
|
281 |
+
|
282 |
+
class MotionStitch:
|
283 |
+
def __init__(
|
284 |
+
self,
|
285 |
+
stitch_network_cfg,
|
286 |
+
):
|
287 |
+
self.stitch_net = StitchNetwork(**stitch_network_cfg)
|
288 |
+
|
289 |
+
def set_Nd(self, N_d=-1):
|
290 |
+
# only for offline (make start|end eye open)
|
291 |
+
if N_d == self.N_d:
|
292 |
+
return
|
293 |
+
|
294 |
+
self.N_d = N_d
|
295 |
+
if self.drive_eye and self.delta_eye_arr is not None:
|
296 |
+
N = 3000 if self.N_d == -1 else self.N_d
|
297 |
+
self.delta_eye_idx_list = _set_eye_blink_idx(
|
298 |
+
N, len(self.delta_eye_arr), self.delta_eye_open_n
|
299 |
+
)
|
300 |
+
|
301 |
+
def setup(
|
302 |
+
self,
|
303 |
+
N_d=-1,
|
304 |
+
use_d_keys=None,
|
305 |
+
relative_d=True,
|
306 |
+
drive_eye=None, # use d eye or s eye
|
307 |
+
delta_eye_arr=None, # fix eye
|
308 |
+
delta_eye_open_n=-1, # int|list
|
309 |
+
fade_out_keys=("exp",),
|
310 |
+
fade_type="", # "" | "d0" | "s"
|
311 |
+
flag_stitching=True,
|
312 |
+
is_image_flag=True,
|
313 |
+
x_s_info=None,
|
314 |
+
d0=None,
|
315 |
+
ch_info=None,
|
316 |
+
overall_ctrl_info=None,
|
317 |
+
):
|
318 |
+
self.is_image_flag = is_image_flag
|
319 |
+
if use_d_keys is None:
|
320 |
+
if self.is_image_flag:
|
321 |
+
self.use_d_keys = ("exp", "pitch", "yaw", "roll", "t")
|
322 |
+
else:
|
323 |
+
self.use_d_keys = ("exp", )
|
324 |
+
else:
|
325 |
+
self.use_d_keys = use_d_keys
|
326 |
+
|
327 |
+
if drive_eye is None:
|
328 |
+
if self.is_image_flag:
|
329 |
+
self.drive_eye = True
|
330 |
+
else:
|
331 |
+
self.drive_eye = False
|
332 |
+
else:
|
333 |
+
self.drive_eye = drive_eye
|
334 |
+
|
335 |
+
self.N_d = N_d
|
336 |
+
self.relative_d = relative_d
|
337 |
+
self.delta_eye_arr = delta_eye_arr
|
338 |
+
self.delta_eye_open_n = delta_eye_open_n
|
339 |
+
self.fade_out_keys = fade_out_keys
|
340 |
+
self.fade_type = fade_type
|
341 |
+
self.flag_stitching = flag_stitching
|
342 |
+
|
343 |
+
_eye = [11, 13, 15, 16, 18]
|
344 |
+
_lip = [6, 12, 14, 17, 19, 20]
|
345 |
+
_a1 = np.zeros((21, 3), dtype=np.float32)
|
346 |
+
_a1[_lip] = 1
|
347 |
+
_a2 = 0
|
348 |
+
if self.drive_eye:
|
349 |
+
if self.delta_eye_arr is None:
|
350 |
+
_a1[_eye] = 1
|
351 |
+
else:
|
352 |
+
_a2 = np.zeros((21, 3), dtype=np.float32)
|
353 |
+
_a2[_eye] = 1
|
354 |
+
_a2 = _a2.reshape(1, -1)
|
355 |
+
_a1 = _a1.reshape(1, -1)
|
356 |
+
|
357 |
+
self.fix_exp_a1 = _a1 * (1 - _a2)
|
358 |
+
self.fix_exp_a2 = (1 - _a1) + _a1 * _a2
|
359 |
+
self.fix_exp_a3 = _a2
|
360 |
+
|
361 |
+
if self.drive_eye and self.delta_eye_arr is not None:
|
362 |
+
N = 3000 if self.N_d == -1 else self.N_d
|
363 |
+
self.delta_eye_idx_list = _set_eye_blink_idx(
|
364 |
+
N, len(self.delta_eye_arr), self.delta_eye_open_n
|
365 |
+
)
|
366 |
+
|
367 |
+
self.pose_s = None
|
368 |
+
self.x_s = None
|
369 |
+
self.fade_dst = None
|
370 |
+
if self.is_image_flag and x_s_info is not None:
|
371 |
+
yaw_s = bin66_to_degree(x_s_info['yaw']).item()
|
372 |
+
pitch_s = bin66_to_degree(x_s_info['pitch']).item()
|
373 |
+
self.pose_s = [yaw_s, pitch_s]
|
374 |
+
self.x_s = transform_keypoint(x_s_info)
|
375 |
+
|
376 |
+
if self.fade_type == "s":
|
377 |
+
self.fade_dst = copy.deepcopy(x_s_info)
|
378 |
+
|
379 |
+
if ch_info is not None:
|
380 |
+
self.scale_a = ch_info['x_s_info_lst'][0]['scale'].item()
|
381 |
+
if x_s_info is not None:
|
382 |
+
self.scale_b = x_s_info['scale'].item()
|
383 |
+
self.scale_ratio = self.scale_a / self.scale_b
|
384 |
+
self._set_scale_ratio(self.scale_ratio)
|
385 |
+
else:
|
386 |
+
self.scale_ratio = None
|
387 |
+
else:
|
388 |
+
self.scale_ratio = 1
|
389 |
+
|
390 |
+
self.overall_ctrl_info = overall_ctrl_info
|
391 |
+
|
392 |
+
self.d0 = d0
|
393 |
+
self.idx = 0
|
394 |
+
|
395 |
+
def _set_scale_ratio(self, scale_ratio=1):
|
396 |
+
if scale_ratio == 1:
|
397 |
+
return
|
398 |
+
if isinstance(self.use_d_keys, dict):
|
399 |
+
self.use_d_keys = {k: v * (scale_ratio if k in {"exp", "pitch", "yaw", "roll"} else 1) for k, v in self.use_d_keys.items()}
|
400 |
+
else:
|
401 |
+
self.use_d_keys = {k: scale_ratio if k in {"exp", "pitch", "yaw", "roll"} else 1 for k in self.use_d_keys}
|
402 |
+
|
403 |
+
@staticmethod
|
404 |
+
def _merge_kwargs(default_kwargs, run_kwargs):
|
405 |
+
if default_kwargs is None:
|
406 |
+
return run_kwargs
|
407 |
+
|
408 |
+
for k, v in default_kwargs.items():
|
409 |
+
if k not in run_kwargs:
|
410 |
+
run_kwargs[k] = v
|
411 |
+
return run_kwargs
|
412 |
+
|
413 |
+
def __call__(self, x_s_info, x_d_info, **kwargs):
|
414 |
+
# return x_s, x_d
|
415 |
+
|
416 |
+
kwargs = self._merge_kwargs(self.overall_ctrl_info, kwargs)
|
417 |
+
|
418 |
+
if self.scale_ratio is None:
|
419 |
+
self.scale_b = x_s_info['scale'].item()
|
420 |
+
self.scale_ratio = self.scale_a / self.scale_b
|
421 |
+
self._set_scale_ratio(self.scale_ratio)
|
422 |
+
|
423 |
+
if self.relative_d and self.d0 is None:
|
424 |
+
self.d0 = copy.deepcopy(x_d_info)
|
425 |
+
|
426 |
+
x_d_info = _mix_s_d_info(
|
427 |
+
x_s_info,
|
428 |
+
x_d_info,
|
429 |
+
self.use_d_keys,
|
430 |
+
self.d0,
|
431 |
+
)
|
432 |
+
|
433 |
+
delta_eye = 0
|
434 |
+
if self.drive_eye and self.delta_eye_arr is not None:
|
435 |
+
delta_eye = self.delta_eye_arr[
|
436 |
+
self.delta_eye_idx_list[self.idx % len(self.delta_eye_idx_list)]
|
437 |
+
][None]
|
438 |
+
x_d_info = _fix_exp_for_x_d_info_v2(
|
439 |
+
x_d_info,
|
440 |
+
x_s_info,
|
441 |
+
delta_eye,
|
442 |
+
self.fix_exp_a1,
|
443 |
+
self.fix_exp_a2,
|
444 |
+
self.fix_exp_a3,
|
445 |
+
)
|
446 |
+
|
447 |
+
if kwargs.get("vad_alpha", 1) < 1:
|
448 |
+
x_d_info = ctrl_vad(x_d_info, x_s_info, kwargs.get("vad_alpha", 1))
|
449 |
+
|
450 |
+
x_d_info = ctrl_motion(x_d_info, **kwargs)
|
451 |
+
|
452 |
+
if self.fade_type == "d0" and self.fade_dst is None:
|
453 |
+
self.fade_dst = copy.deepcopy(x_d_info)
|
454 |
+
|
455 |
+
# fade
|
456 |
+
if "fade_alpha" in kwargs and self.fade_type in ["d0", "s"]:
|
457 |
+
fade_alpha = kwargs["fade_alpha"]
|
458 |
+
fade_keys = kwargs.get("fade_out_keys", self.fade_out_keys)
|
459 |
+
if self.fade_type == "d0":
|
460 |
+
fade_dst = self.fade_dst
|
461 |
+
elif self.fade_type == "s":
|
462 |
+
if self.fade_dst is not None:
|
463 |
+
fade_dst = self.fade_dst
|
464 |
+
else:
|
465 |
+
fade_dst = copy.deepcopy(x_s_info)
|
466 |
+
if self.is_image_flag:
|
467 |
+
self.fade_dst = fade_dst
|
468 |
+
x_d_info = fade(x_d_info, fade_dst, fade_alpha, fade_keys)
|
469 |
+
|
470 |
+
if self.drive_eye:
|
471 |
+
if self.pose_s is None:
|
472 |
+
yaw_s = bin66_to_degree(x_s_info['yaw']).item()
|
473 |
+
pitch_s = bin66_to_degree(x_s_info['pitch']).item()
|
474 |
+
self.pose_s = [yaw_s, pitch_s]
|
475 |
+
x_d_info = _fix_gaze(self.pose_s, x_d_info)
|
476 |
+
|
477 |
+
if self.x_s is not None:
|
478 |
+
x_s = self.x_s
|
479 |
+
else:
|
480 |
+
x_s = transform_keypoint(x_s_info)
|
481 |
+
if self.is_image_flag:
|
482 |
+
self.x_s = x_s
|
483 |
+
|
484 |
+
x_d = transform_keypoint(x_d_info)
|
485 |
+
|
486 |
+
if self.flag_stitching:
|
487 |
+
x_d = self.stitch_net(x_s, x_d)
|
488 |
+
|
489 |
+
self.idx += 1
|
490 |
+
|
491 |
+
return x_s, x_d
|
core/atomic_components/putback.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
from ..utils.blend import blend_images_cy
|
4 |
+
from ..utils.get_mask import get_mask
|
5 |
+
|
6 |
+
|
7 |
+
class PutBackNumpy:
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
mask_template_path=None,
|
11 |
+
):
|
12 |
+
if mask_template_path is None:
|
13 |
+
mask = get_mask(512, 512, 0.9, 0.9)
|
14 |
+
self.mask_ori_float = np.concatenate([mask] * 3, 2)
|
15 |
+
else:
|
16 |
+
mask = cv2.imread(mask_template_path, cv2.IMREAD_COLOR)
|
17 |
+
self.mask_ori_float = mask.astype(np.float32) / 255.0
|
18 |
+
|
19 |
+
def __call__(self, frame_rgb, render_image, M_c2o):
|
20 |
+
h, w = frame_rgb.shape[:2]
|
21 |
+
mask_warped = cv2.warpAffine(
|
22 |
+
self.mask_ori_float, M_c2o[:2, :], dsize=(w, h), flags=cv2.INTER_LINEAR
|
23 |
+
).clip(0, 1)
|
24 |
+
frame_warped = cv2.warpAffine(
|
25 |
+
render_image, M_c2o[:2, :], dsize=(w, h), flags=cv2.INTER_LINEAR
|
26 |
+
)
|
27 |
+
result = mask_warped * frame_warped + (1 - mask_warped) * frame_rgb
|
28 |
+
result = np.clip(result, 0, 255)
|
29 |
+
result = result.astype(np.uint8)
|
30 |
+
return result
|
31 |
+
|
32 |
+
|
33 |
+
class PutBack:
|
34 |
+
def __init__(
|
35 |
+
self,
|
36 |
+
mask_template_path=None,
|
37 |
+
):
|
38 |
+
if mask_template_path is None:
|
39 |
+
mask = get_mask(512, 512, 0.9, 0.9)
|
40 |
+
mask = np.concatenate([mask] * 3, 2)
|
41 |
+
else:
|
42 |
+
mask = cv2.imread(mask_template_path, cv2.IMREAD_COLOR).astype(np.float32) / 255.0
|
43 |
+
|
44 |
+
self.mask_ori_float = np.ascontiguousarray(mask)[:,:,0]
|
45 |
+
self.result_buffer = None
|
46 |
+
|
47 |
+
def __call__(self, frame_rgb, render_image, M_c2o):
|
48 |
+
h, w = frame_rgb.shape[:2]
|
49 |
+
mask_warped = cv2.warpAffine(
|
50 |
+
self.mask_ori_float, M_c2o[:2, :], dsize=(w, h), flags=cv2.INTER_LINEAR
|
51 |
+
).clip(0, 1)
|
52 |
+
frame_warped = cv2.warpAffine(
|
53 |
+
render_image, M_c2o[:2, :], dsize=(w, h), flags=cv2.INTER_LINEAR
|
54 |
+
)
|
55 |
+
self.result_buffer = np.empty((h, w, 3), dtype=np.uint8)
|
56 |
+
|
57 |
+
# Use Cython implementation for blending
|
58 |
+
blend_images_cy(mask_warped, frame_warped, frame_rgb, self.result_buffer)
|
59 |
+
|
60 |
+
return self.result_buffer
|
core/atomic_components/source2info.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
|
4 |
+
from ..aux_models.insightface_det import InsightFaceDet
|
5 |
+
from ..aux_models.insightface_landmark106 import Landmark106
|
6 |
+
from ..aux_models.landmark203 import Landmark203
|
7 |
+
from ..aux_models.mediapipe_landmark478 import Landmark478
|
8 |
+
from ..models.appearance_extractor import AppearanceExtractor
|
9 |
+
from ..models.motion_extractor import MotionExtractor
|
10 |
+
|
11 |
+
from ..utils.crop import crop_image
|
12 |
+
from ..utils.eye_info import EyeAttrUtilsByMP
|
13 |
+
|
14 |
+
|
15 |
+
"""
|
16 |
+
insightface_det_cfg = {
|
17 |
+
"model_path": "",
|
18 |
+
"device": "cuda",
|
19 |
+
"force_ori_type": False,
|
20 |
+
}
|
21 |
+
landmark106_cfg = {
|
22 |
+
"model_path": "",
|
23 |
+
"device": "cuda",
|
24 |
+
"force_ori_type": False,
|
25 |
+
}
|
26 |
+
landmark203_cfg = {
|
27 |
+
"model_path": "",
|
28 |
+
"device": "cuda",
|
29 |
+
"force_ori_type": False,
|
30 |
+
}
|
31 |
+
landmark478_cfg = {
|
32 |
+
"blaze_face_model_path": "",
|
33 |
+
"face_mesh_model_path": "",
|
34 |
+
"device": "cuda",
|
35 |
+
"force_ori_type": False,
|
36 |
+
"task_path": "",
|
37 |
+
}
|
38 |
+
appearance_extractor_cfg = {
|
39 |
+
"model_path": "",
|
40 |
+
"device": "cuda",
|
41 |
+
}
|
42 |
+
motion_extractor_cfg = {
|
43 |
+
"model_path": "",
|
44 |
+
"device": "cuda",
|
45 |
+
}
|
46 |
+
"""
|
47 |
+
|
48 |
+
|
49 |
+
class Source2Info:
|
50 |
+
def __init__(
|
51 |
+
self,
|
52 |
+
insightface_det_cfg,
|
53 |
+
landmark106_cfg,
|
54 |
+
landmark203_cfg,
|
55 |
+
landmark478_cfg,
|
56 |
+
appearance_extractor_cfg,
|
57 |
+
motion_extractor_cfg,
|
58 |
+
):
|
59 |
+
self.insightface_det = InsightFaceDet(**insightface_det_cfg)
|
60 |
+
self.landmark106 = Landmark106(**landmark106_cfg)
|
61 |
+
self.landmark203 = Landmark203(**landmark203_cfg)
|
62 |
+
self.landmark478 = Landmark478(**landmark478_cfg)
|
63 |
+
|
64 |
+
self.appearance_extractor = AppearanceExtractor(**appearance_extractor_cfg)
|
65 |
+
self.motion_extractor = MotionExtractor(**motion_extractor_cfg)
|
66 |
+
|
67 |
+
def _crop(self, img, last_lmk=None, **kwargs):
|
68 |
+
# img_rgb -> det->landmark106->landmark203->crop
|
69 |
+
|
70 |
+
if last_lmk is None: # det for first frame or image
|
71 |
+
det, _ = self.insightface_det(img)
|
72 |
+
boxes = det[np.argsort(-(det[:, 2] - det[:, 0]) * (det[:, 3] - det[:, 1]))]
|
73 |
+
if len(boxes) == 0:
|
74 |
+
return None
|
75 |
+
lmk_for_track = self.landmark106(img, boxes[0]) # 106
|
76 |
+
else: # track for video frames
|
77 |
+
lmk_for_track = last_lmk # 203
|
78 |
+
|
79 |
+
crop_dct = crop_image(
|
80 |
+
img,
|
81 |
+
lmk_for_track,
|
82 |
+
dsize=self.landmark203.dsize,
|
83 |
+
scale=1.5,
|
84 |
+
vy_ratio=-0.1,
|
85 |
+
pt_crop_flag=False,
|
86 |
+
)
|
87 |
+
lmk203 = self.landmark203(crop_dct["img_crop"], crop_dct["M_c2o"])
|
88 |
+
|
89 |
+
ret_dct = crop_image(
|
90 |
+
img,
|
91 |
+
lmk203,
|
92 |
+
dsize=512,
|
93 |
+
scale=kwargs.get("crop_scale", 2.3),
|
94 |
+
vx_ratio=kwargs.get("crop_vx_ratio", 0),
|
95 |
+
vy_ratio=kwargs.get("crop_vy_ratio", -0.125),
|
96 |
+
flag_do_rot=kwargs.get("crop_flag_do_rot", True),
|
97 |
+
pt_crop_flag=False,
|
98 |
+
)
|
99 |
+
|
100 |
+
img_crop = ret_dct["img_crop"]
|
101 |
+
M_c2o = ret_dct["M_c2o"]
|
102 |
+
|
103 |
+
return img_crop, M_c2o, lmk203
|
104 |
+
|
105 |
+
@staticmethod
|
106 |
+
def _img_crop_to_bchw256(img_crop):
|
107 |
+
rgb_256 = cv2.resize(img_crop, (256, 256), interpolation=cv2.INTER_AREA)
|
108 |
+
rgb_256_bchw = (rgb_256.astype(np.float32) / 255.0)[None].transpose(0, 3, 1, 2)
|
109 |
+
return rgb_256_bchw
|
110 |
+
|
111 |
+
def _get_kp_info(self, img):
|
112 |
+
# rgb_256_bchw_norm01
|
113 |
+
kp_info = self.motion_extractor(img)
|
114 |
+
return kp_info
|
115 |
+
|
116 |
+
def _get_f3d(self, img):
|
117 |
+
# rgb_256_bchw_norm01
|
118 |
+
fs = self.appearance_extractor(img)
|
119 |
+
return fs
|
120 |
+
|
121 |
+
def _get_eye_info(self, img):
|
122 |
+
# rgb uint8
|
123 |
+
lmk478 = self.landmark478(img) # [1, 478, 3]
|
124 |
+
attr = EyeAttrUtilsByMP(lmk478)
|
125 |
+
lr_open = attr.LR_open().reshape(-1, 2) # [1, 2]
|
126 |
+
lr_ball = attr.LR_ball_move().reshape(-1, 6) # [1, 3, 2] -> [1, 6]
|
127 |
+
return [lr_open, lr_ball]
|
128 |
+
|
129 |
+
def __call__(self, img, last_lmk=None, **kwargs):
|
130 |
+
"""
|
131 |
+
img: rgb, uint8
|
132 |
+
last_lmk: last frame lmk203, for video tracking
|
133 |
+
kwargs: optional crop cfg
|
134 |
+
crop_scale: 2.3
|
135 |
+
crop_vx_ratio: 0
|
136 |
+
crop_vy_ratio: -0.125
|
137 |
+
crop_flag_do_rot: True
|
138 |
+
"""
|
139 |
+
img_crop, M_c2o, lmk203 = self._crop(img, last_lmk=last_lmk, **kwargs)
|
140 |
+
|
141 |
+
eye_open, eye_ball = self._get_eye_info(img_crop)
|
142 |
+
|
143 |
+
rgb_256_bchw = self._img_crop_to_bchw256(img_crop)
|
144 |
+
kp_info = self._get_kp_info(rgb_256_bchw)
|
145 |
+
fs = self._get_f3d(rgb_256_bchw)
|
146 |
+
|
147 |
+
source_info = {
|
148 |
+
"x_s_info": kp_info,
|
149 |
+
"f_s": fs,
|
150 |
+
"M_c2o": M_c2o,
|
151 |
+
"eye_open": eye_open, # [1, 2]
|
152 |
+
"eye_ball": eye_ball, # [1, 6]
|
153 |
+
"lmk203": lmk203, # for track
|
154 |
+
}
|
155 |
+
return source_info
|
core/atomic_components/warp_f3d.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..models.warp_network import WarpNetwork
|
2 |
+
|
3 |
+
|
4 |
+
"""
|
5 |
+
# __init__
|
6 |
+
warp_network_cfg = {
|
7 |
+
"model_path": "",
|
8 |
+
"device": "cuda",
|
9 |
+
}
|
10 |
+
"""
|
11 |
+
|
12 |
+
class WarpF3D:
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
warp_network_cfg,
|
16 |
+
):
|
17 |
+
self.warp_net = WarpNetwork(**warp_network_cfg)
|
18 |
+
|
19 |
+
def __call__(self, f_s, x_s, x_d):
|
20 |
+
out = self.warp_net(f_s, x_s, x_d)
|
21 |
+
return out
|
22 |
+
|
core/atomic_components/wav2feat.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import librosa
|
2 |
+
import numpy as np
|
3 |
+
import math
|
4 |
+
|
5 |
+
from ..aux_models.hubert_stream import HubertStreaming
|
6 |
+
|
7 |
+
"""
|
8 |
+
wavlm_cfg = {
|
9 |
+
"model_path": "",
|
10 |
+
"device": "cuda",
|
11 |
+
"force_ori_type": False,
|
12 |
+
}
|
13 |
+
hubert_cfg = {
|
14 |
+
"model_path": "",
|
15 |
+
"device": "cuda",
|
16 |
+
"force_ori_type": False,
|
17 |
+
}
|
18 |
+
"""
|
19 |
+
|
20 |
+
|
21 |
+
class Wav2Feat:
|
22 |
+
def __init__(self, w2f_cfg, w2f_type="hubert"):
|
23 |
+
self.w2f_type = w2f_type.lower()
|
24 |
+
if self.w2f_type == "hubert":
|
25 |
+
self.w2f = Wav2FeatHubert(hubert_cfg=w2f_cfg)
|
26 |
+
self.feat_dim = 1024
|
27 |
+
self.support_streaming = True
|
28 |
+
else:
|
29 |
+
raise ValueError(f"Unsupported w2f_type: {w2f_type}")
|
30 |
+
|
31 |
+
def __call__(
|
32 |
+
self,
|
33 |
+
audio,
|
34 |
+
sr=16000,
|
35 |
+
norm_mean_std=None, # for s2g
|
36 |
+
chunksize=(3, 5, 2), # for hubert
|
37 |
+
):
|
38 |
+
if self.w2f_type == "hubert":
|
39 |
+
feat = self.w2f(audio, chunksize=chunksize)
|
40 |
+
elif self.w2f_type == "s2g":
|
41 |
+
feat = self.w2f(audio, sr=sr, norm_mean_std=norm_mean_std)
|
42 |
+
else:
|
43 |
+
raise ValueError(f"Unsupported w2f_type: {self.w2f_type}")
|
44 |
+
return feat
|
45 |
+
|
46 |
+
def wav2feat(
|
47 |
+
self,
|
48 |
+
audio,
|
49 |
+
sr=16000,
|
50 |
+
norm_mean_std=None, # for s2g
|
51 |
+
chunksize=(3, 5, 2),
|
52 |
+
):
|
53 |
+
# for offline
|
54 |
+
if self.w2f_type == "hubert":
|
55 |
+
feat = self.w2f.wav2feat(audio, sr=sr, chunksize=chunksize)
|
56 |
+
elif self.w2f_type == "s2g":
|
57 |
+
feat = self.w2f(audio, sr=sr, norm_mean_std=norm_mean_std)
|
58 |
+
else:
|
59 |
+
raise ValueError(f"Unsupported w2f_type: {self.w2f_type}")
|
60 |
+
return feat
|
61 |
+
|
62 |
+
|
63 |
+
class Wav2FeatHubert:
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
hubert_cfg,
|
67 |
+
):
|
68 |
+
self.hubert = HubertStreaming(**hubert_cfg)
|
69 |
+
|
70 |
+
def __call__(self, audio_chunk, chunksize=(3, 5, 2)):
|
71 |
+
"""
|
72 |
+
audio_chunk: int(sum(chunksize) * 0.04 * 16000) + 80 # 6480
|
73 |
+
"""
|
74 |
+
valid_feat_s = - sum(chunksize[1:]) * 2 # -7
|
75 |
+
valid_feat_e = - chunksize[2] * 2 # -2
|
76 |
+
|
77 |
+
encoding_chunk = self.hubert(audio_chunk)
|
78 |
+
valid_encoding = encoding_chunk[valid_feat_s:valid_feat_e]
|
79 |
+
valid_feat = valid_encoding.reshape(chunksize[1], 2, 1024).mean(1) # [5, 1024]
|
80 |
+
return valid_feat
|
81 |
+
|
82 |
+
def wav2feat(self, audio, sr, chunksize=(3, 5, 2)):
|
83 |
+
# for offline
|
84 |
+
if sr != 16000:
|
85 |
+
audio_16k = librosa.resample(audio, orig_sr=sr, target_sr=16000)
|
86 |
+
else:
|
87 |
+
audio_16k = audio
|
88 |
+
|
89 |
+
num_f = math.ceil(len(audio_16k) / 16000 * 25)
|
90 |
+
split_len = int(sum(chunksize) * 0.04 * 16000) + 80 # 6480
|
91 |
+
|
92 |
+
speech_pad = np.concatenate([
|
93 |
+
np.zeros((split_len - int(sum(chunksize[1:]) * 0.04 * 16000),), dtype=audio_16k.dtype),
|
94 |
+
audio_16k,
|
95 |
+
np.zeros((split_len,), dtype=audio_16k.dtype),
|
96 |
+
], 0)
|
97 |
+
|
98 |
+
i = 0
|
99 |
+
res_lst = []
|
100 |
+
while i < num_f:
|
101 |
+
sss = int(i * 0.04 * 16000)
|
102 |
+
eee = sss + split_len
|
103 |
+
audio_chunk = speech_pad[sss:eee]
|
104 |
+
valid_feat = self.__call__(audio_chunk, chunksize)
|
105 |
+
res_lst.append(valid_feat)
|
106 |
+
i += chunksize[1]
|
107 |
+
|
108 |
+
ret = np.concatenate(res_lst, 0)
|
109 |
+
ret = ret[:num_f]
|
110 |
+
return ret
|
core/atomic_components/writer.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import imageio
|
2 |
+
import os
|
3 |
+
|
4 |
+
|
5 |
+
class VideoWriterByImageIO:
|
6 |
+
def __init__(self, video_path, fps=25, **kwargs):
|
7 |
+
video_format = kwargs.get("format", "mp4") # default is mp4 format
|
8 |
+
codec = kwargs.get("vcodec", "libx264") # default is libx264 encoding
|
9 |
+
quality = kwargs.get("quality") # video quality
|
10 |
+
pixelformat = kwargs.get("pixelformat", "yuv420p") # video pixel format
|
11 |
+
macro_block_size = kwargs.get("macro_block_size", 2)
|
12 |
+
ffmpeg_params = ["-crf", str(kwargs.get("crf", 18))]
|
13 |
+
|
14 |
+
os.makedirs(os.path.dirname(video_path), exist_ok=True)
|
15 |
+
|
16 |
+
writer = imageio.get_writer(
|
17 |
+
video_path,
|
18 |
+
fps=fps,
|
19 |
+
format=video_format,
|
20 |
+
codec=codec,
|
21 |
+
quality=quality,
|
22 |
+
ffmpeg_params=ffmpeg_params,
|
23 |
+
pixelformat=pixelformat,
|
24 |
+
macro_block_size=macro_block_size,
|
25 |
+
)
|
26 |
+
self.writer = writer
|
27 |
+
|
28 |
+
def __call__(self, img, fmt="bgr"):
|
29 |
+
if fmt == "bgr":
|
30 |
+
frame = img[..., ::-1]
|
31 |
+
else:
|
32 |
+
frame = img
|
33 |
+
self.writer.append_data(frame)
|
34 |
+
|
35 |
+
def close(self):
|
36 |
+
self.writer.close()
|
core/aux_models/blaze_face.py
ADDED
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
from ..utils.load_model import load_model
|
4 |
+
|
5 |
+
|
6 |
+
def intersect(box_a, box_b):
|
7 |
+
"""We resize both arrays to [A,B,2] without new malloc:
|
8 |
+
[A,2] -> [A,1,2] -> [A,B,2]
|
9 |
+
[B,2] -> [1,B,2] -> [A,B,2]
|
10 |
+
Then we compute the area of intersect between box_a and box_b.
|
11 |
+
Args:
|
12 |
+
box_a: (array) bounding boxes, Shape: [A,4].
|
13 |
+
box_b: (array) bounding boxes, Shape: [B,4].
|
14 |
+
Return:
|
15 |
+
(array) intersection area, Shape: [A,B].
|
16 |
+
"""
|
17 |
+
A = box_a.shape[0]
|
18 |
+
B = box_b.shape[0]
|
19 |
+
max_xy = np.minimum(
|
20 |
+
np.expand_dims(box_a[:, 2:], axis=1).repeat(B, axis=1),
|
21 |
+
np.expand_dims(box_b[:, 2:], axis=0).repeat(A, axis=0),
|
22 |
+
)
|
23 |
+
min_xy = np.maximum(
|
24 |
+
np.expand_dims(box_a[:, :2], axis=1).repeat(B, axis=1),
|
25 |
+
np.expand_dims(box_b[:, :2], axis=0).repeat(A, axis=0),
|
26 |
+
)
|
27 |
+
inter = np.clip((max_xy - min_xy), a_min=0, a_max=None)
|
28 |
+
return inter[:, :, 0] * inter[:, :, 1]
|
29 |
+
|
30 |
+
|
31 |
+
def jaccard(box_a, box_b):
|
32 |
+
"""Compute the jaccard overlap of two sets of boxes. The jaccard overlap
|
33 |
+
is simply the intersection over union of two boxes. Here we operate on
|
34 |
+
ground truth boxes and default boxes.
|
35 |
+
E.g.:
|
36 |
+
A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B)
|
37 |
+
Args:
|
38 |
+
box_a: (array) Ground truth bounding boxes, Shape: [num_objects,4]
|
39 |
+
box_b: (array) Prior boxes from priorbox layers, Shape: [num_priors,4]
|
40 |
+
Return:
|
41 |
+
jaccard overlap: (array) Shape: [box_a.size(0), box_b.size(0)]
|
42 |
+
"""
|
43 |
+
inter = intersect(box_a, box_b)
|
44 |
+
area_a = (
|
45 |
+
((box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1]))
|
46 |
+
.reshape(-1, 1)
|
47 |
+
.repeat(box_b.shape[0], axis=1)
|
48 |
+
) # [A,B]
|
49 |
+
area_b = (
|
50 |
+
((box_b[:, 2] - box_b[:, 0]) * (box_b[:, 3] - box_b[:, 1]))
|
51 |
+
.reshape(1, -1)
|
52 |
+
.repeat(box_a.shape[0], axis=0)
|
53 |
+
) # [A,B]
|
54 |
+
union = area_a + area_b - inter
|
55 |
+
return inter / union # [A,B]
|
56 |
+
|
57 |
+
|
58 |
+
def overlap_similarity(box, other_boxes):
|
59 |
+
"""Computes the IOU between a bounding box and set of other boxes."""
|
60 |
+
box = np.expand_dims(box, axis=0) # Equivalent to unsqueeze(0) in PyTorch
|
61 |
+
iou = jaccard(box, other_boxes)
|
62 |
+
return np.squeeze(iou, axis=0) # Equivalent to squeeze(0) in PyTorch
|
63 |
+
|
64 |
+
|
65 |
+
class BlazeFace:
|
66 |
+
def __init__(self, model_path, device="cuda"):
|
67 |
+
self.anchor_options = {
|
68 |
+
"num_layers": 4,
|
69 |
+
"min_scale": 0.1484375,
|
70 |
+
"max_scale": 0.75,
|
71 |
+
"input_size_height": 128,
|
72 |
+
"input_size_width": 128,
|
73 |
+
"anchor_offset_x": 0.5,
|
74 |
+
"anchor_offset_y": 0.5,
|
75 |
+
"strides": [8, 16, 16, 16],
|
76 |
+
"aspect_ratios": [1.0],
|
77 |
+
"reduce_boxes_in_lowest_layer": False,
|
78 |
+
"interpolated_scale_aspect_ratio": 1.0,
|
79 |
+
"fixed_anchor_size": True,
|
80 |
+
}
|
81 |
+
self.num_classes = 1
|
82 |
+
self.num_anchors = 896
|
83 |
+
self.num_coords = 16
|
84 |
+
self.x_scale = 128.0
|
85 |
+
self.y_scale = 128.0
|
86 |
+
self.h_scale = 128.0
|
87 |
+
self.w_scale = 128.0
|
88 |
+
self.min_score_thresh = 0.5
|
89 |
+
self.min_suppression_threshold = 0.3
|
90 |
+
self.anchors = self.generate_anchors(self.anchor_options)
|
91 |
+
self.anchors = np.array(self.anchors)
|
92 |
+
assert len(self.anchors) == 896
|
93 |
+
self.model, self.model_type = load_model(model_path, device=device)
|
94 |
+
self.output_names = ["regressors", "classificators"]
|
95 |
+
|
96 |
+
def __call__(self, image: np.ndarray):
|
97 |
+
"""
|
98 |
+
image: RGB image
|
99 |
+
"""
|
100 |
+
image = cv2.resize(image, (128, 128))
|
101 |
+
image = image[np.newaxis, :, :, :].astype(np.float32)
|
102 |
+
image = image / 127.5 - 1.0
|
103 |
+
outputs = {}
|
104 |
+
if self.model_type == "onnx":
|
105 |
+
out_list = self.model.run(None, {"input": image})
|
106 |
+
for i, name in enumerate(self.output_names):
|
107 |
+
outputs[name] = out_list[i]
|
108 |
+
elif self.model_type == "tensorrt":
|
109 |
+
self.model.setup({"input": image})
|
110 |
+
self.model.infer()
|
111 |
+
for name in self.output_names:
|
112 |
+
outputs[name] = self.model.buffer[name][0]
|
113 |
+
else:
|
114 |
+
raise ValueError(f"Unsupported model type: {self.model_type}")
|
115 |
+
boxes = self.postprocess(outputs["regressors"], outputs["classificators"])
|
116 |
+
return boxes
|
117 |
+
|
118 |
+
def calculate_scale(self, min_scale, max_scale, stride_index, num_strides):
|
119 |
+
return min_scale + (max_scale - min_scale) * stride_index / (num_strides - 1.0)
|
120 |
+
|
121 |
+
def generate_anchors(self, options):
|
122 |
+
strides_size = len(options["strides"])
|
123 |
+
assert options["num_layers"] == strides_size
|
124 |
+
|
125 |
+
anchors = []
|
126 |
+
layer_id = 0
|
127 |
+
while layer_id < strides_size:
|
128 |
+
anchor_height = []
|
129 |
+
anchor_width = []
|
130 |
+
aspect_ratios = []
|
131 |
+
scales = []
|
132 |
+
|
133 |
+
# For same strides, we merge the anchors in the same order.
|
134 |
+
last_same_stride_layer = layer_id
|
135 |
+
while (last_same_stride_layer < strides_size) and (
|
136 |
+
options["strides"][last_same_stride_layer]
|
137 |
+
== options["strides"][layer_id]
|
138 |
+
):
|
139 |
+
scale = self.calculate_scale(
|
140 |
+
options["min_scale"],
|
141 |
+
options["max_scale"],
|
142 |
+
last_same_stride_layer,
|
143 |
+
strides_size,
|
144 |
+
)
|
145 |
+
|
146 |
+
if (
|
147 |
+
last_same_stride_layer == 0
|
148 |
+
and options["reduce_boxes_in_lowest_layer"]
|
149 |
+
):
|
150 |
+
# For first layer, it can be specified to use predefined anchors.
|
151 |
+
aspect_ratios.append(1.0)
|
152 |
+
aspect_ratios.append(2.0)
|
153 |
+
aspect_ratios.append(0.5)
|
154 |
+
scales.append(0.1)
|
155 |
+
scales.append(scale)
|
156 |
+
scales.append(scale)
|
157 |
+
else:
|
158 |
+
for aspect_ratio in options["aspect_ratios"]:
|
159 |
+
aspect_ratios.append(aspect_ratio)
|
160 |
+
scales.append(scale)
|
161 |
+
|
162 |
+
if options["interpolated_scale_aspect_ratio"] > 0.0:
|
163 |
+
scale_next = (
|
164 |
+
1.0
|
165 |
+
if last_same_stride_layer == strides_size - 1
|
166 |
+
else self.calculate_scale(
|
167 |
+
options["min_scale"],
|
168 |
+
options["max_scale"],
|
169 |
+
last_same_stride_layer + 1,
|
170 |
+
strides_size,
|
171 |
+
)
|
172 |
+
)
|
173 |
+
scales.append(np.sqrt(scale * scale_next))
|
174 |
+
aspect_ratios.append(options["interpolated_scale_aspect_ratio"])
|
175 |
+
|
176 |
+
last_same_stride_layer += 1
|
177 |
+
|
178 |
+
for i in range(len(aspect_ratios)):
|
179 |
+
ratio_sqrts = np.sqrt(aspect_ratios[i])
|
180 |
+
anchor_height.append(scales[i] / ratio_sqrts)
|
181 |
+
anchor_width.append(scales[i] * ratio_sqrts)
|
182 |
+
|
183 |
+
stride = options["strides"][layer_id]
|
184 |
+
feature_map_height = int(np.ceil(options["input_size_height"] / stride))
|
185 |
+
feature_map_width = int(np.ceil(options["input_size_width"] / stride))
|
186 |
+
|
187 |
+
for y in range(feature_map_height):
|
188 |
+
for x in range(feature_map_width):
|
189 |
+
for anchor_id in range(len(anchor_height)):
|
190 |
+
x_center = (x + options["anchor_offset_x"]) / feature_map_width
|
191 |
+
y_center = (y + options["anchor_offset_y"]) / feature_map_height
|
192 |
+
|
193 |
+
new_anchor = [x_center, y_center, 0, 0]
|
194 |
+
if options["fixed_anchor_size"]:
|
195 |
+
new_anchor[2] = 1.0
|
196 |
+
new_anchor[3] = 1.0
|
197 |
+
else:
|
198 |
+
new_anchor[2] = anchor_width[anchor_id]
|
199 |
+
new_anchor[3] = anchor_height[anchor_id]
|
200 |
+
anchors.append(new_anchor)
|
201 |
+
|
202 |
+
layer_id = last_same_stride_layer
|
203 |
+
|
204 |
+
return anchors
|
205 |
+
|
206 |
+
def _tensors_to_detections(self, raw_box_tensor, raw_score_tensor, anchors):
|
207 |
+
"""The output of the neural network is a tensor of shape (b, 896, 16)
|
208 |
+
containing the bounding box regressor predictions, as well as a tensor
|
209 |
+
of shape (b, 896, 1) with the classification confidences.
|
210 |
+
|
211 |
+
This function converts these two "raw" tensors into proper detections.
|
212 |
+
Returns a list of (num_detections, 17) tensors, one for each image in
|
213 |
+
the batch.
|
214 |
+
|
215 |
+
This is based on the source code from:
|
216 |
+
mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.cc
|
217 |
+
mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.proto
|
218 |
+
"""
|
219 |
+
assert raw_box_tensor.ndim == 3
|
220 |
+
assert raw_box_tensor.shape[1] == self.num_anchors
|
221 |
+
assert raw_box_tensor.shape[2] == self.num_coords
|
222 |
+
|
223 |
+
assert raw_score_tensor.ndim == 3
|
224 |
+
assert raw_score_tensor.shape[1] == self.num_anchors
|
225 |
+
assert raw_score_tensor.shape[2] == self.num_classes
|
226 |
+
|
227 |
+
assert raw_box_tensor.shape[0] == raw_score_tensor.shape[0]
|
228 |
+
|
229 |
+
detection_boxes = self._decode_boxes(raw_box_tensor, anchors)
|
230 |
+
|
231 |
+
raw_score_tensor = np.clip(raw_score_tensor, -50, 100)
|
232 |
+
detection_scores = 1 / (1 + np.exp(-raw_score_tensor))
|
233 |
+
mask = detection_scores >= self.min_score_thresh
|
234 |
+
mask = mask[0, :, 0]
|
235 |
+
boxes = detection_boxes[0, mask, :]
|
236 |
+
scores = detection_scores[0, mask, :]
|
237 |
+
return np.concatenate((boxes, scores), axis=-1)
|
238 |
+
|
239 |
+
def _decode_boxes(self, raw_boxes, anchors):
|
240 |
+
"""Converts the predictions into actual coordinates using
|
241 |
+
the anchor boxes. Processes the entire batch at once.
|
242 |
+
"""
|
243 |
+
boxes = np.zeros_like(raw_boxes)
|
244 |
+
|
245 |
+
x_center = raw_boxes[..., 0] / self.x_scale * anchors[:, 2] + anchors[:, 0]
|
246 |
+
y_center = raw_boxes[..., 1] / self.y_scale * anchors[:, 3] + anchors[:, 1]
|
247 |
+
|
248 |
+
w = raw_boxes[..., 2] / self.w_scale * anchors[:, 2]
|
249 |
+
h = raw_boxes[..., 3] / self.h_scale * anchors[:, 3]
|
250 |
+
|
251 |
+
boxes[..., 0] = self.x_scale * (x_center - w / 2.0) # xmin
|
252 |
+
boxes[..., 1] = self.y_scale * (y_center - h / 2.0) # ymin
|
253 |
+
boxes[..., 2] = self.w_scale * (x_center + w / 2.0) # xmax
|
254 |
+
boxes[..., 3] = self.h_scale * (y_center + h / 2.0) # ymax
|
255 |
+
|
256 |
+
for k in range(6):
|
257 |
+
offset = 4 + k * 2
|
258 |
+
keypoint_x = (
|
259 |
+
raw_boxes[..., offset] / self.x_scale * anchors[:, 2] + anchors[:, 0]
|
260 |
+
)
|
261 |
+
keypoint_y = (
|
262 |
+
raw_boxes[..., offset + 1] / self.y_scale * anchors[:, 3]
|
263 |
+
+ anchors[:, 1]
|
264 |
+
)
|
265 |
+
boxes[..., offset] = keypoint_x
|
266 |
+
boxes[..., offset + 1] = keypoint_y
|
267 |
+
|
268 |
+
return boxes
|
269 |
+
|
270 |
+
def _weighted_non_max_suppression(self, detections):
|
271 |
+
"""The alternative NMS method as mentioned in the BlazeFace paper:
|
272 |
+
|
273 |
+
"We replace the suppression algorithm with a blending strategy that
|
274 |
+
estimates the regression parameters of a bounding box as a weighted
|
275 |
+
mean between the overlapping predictions."
|
276 |
+
|
277 |
+
The original MediaPipe code assigns the score of the most confident
|
278 |
+
detection to the weighted detection, but we take the average score
|
279 |
+
of the overlapping detections.
|
280 |
+
|
281 |
+
The input detections should be a NumPy array of shape (count, 17).
|
282 |
+
|
283 |
+
Returns a list of NumPy arrays, one for each detected face.
|
284 |
+
|
285 |
+
This is based on the source code from:
|
286 |
+
mediapipe/calculators/util/non_max_suppression_calculator.cc
|
287 |
+
mediapipe/calculators/util/non_max_suppression_calculator.proto
|
288 |
+
"""
|
289 |
+
if len(detections) == 0:
|
290 |
+
return []
|
291 |
+
|
292 |
+
output_detections = []
|
293 |
+
|
294 |
+
# Sort the detections from highest to lowest score.
|
295 |
+
remaining = np.argsort(detections[:, 16])[::-1]
|
296 |
+
|
297 |
+
while len(remaining) > 0:
|
298 |
+
detection = detections[remaining[0]]
|
299 |
+
|
300 |
+
# Compute the overlap between the first box and the other
|
301 |
+
# remaining boxes. (Note that the other_boxes also include
|
302 |
+
# the first_box.)
|
303 |
+
first_box = detection[:4]
|
304 |
+
other_boxes = detections[remaining, :4]
|
305 |
+
ious = overlap_similarity(first_box, other_boxes)
|
306 |
+
|
307 |
+
# If two detections don't overlap enough, they are considered
|
308 |
+
# to be from different faces.
|
309 |
+
mask = ious > self.min_suppression_threshold
|
310 |
+
overlapping = remaining[mask]
|
311 |
+
remaining = remaining[~mask]
|
312 |
+
|
313 |
+
# Take an average of the coordinates from the overlapping
|
314 |
+
# detections, weighted by their confidence scores.
|
315 |
+
weighted_detection = detection.copy()
|
316 |
+
if len(overlapping) > 1:
|
317 |
+
coordinates = detections[overlapping, :16]
|
318 |
+
scores = detections[overlapping, 16:17]
|
319 |
+
total_score = scores.sum()
|
320 |
+
weighted = (coordinates * scores).sum(axis=0) / total_score
|
321 |
+
weighted_detection[:16] = weighted
|
322 |
+
weighted_detection[16] = total_score / len(overlapping)
|
323 |
+
|
324 |
+
output_detections.append(weighted_detection)
|
325 |
+
|
326 |
+
return output_detections
|
327 |
+
|
328 |
+
def postprocess(self, raw_boxes, scores):
|
329 |
+
detections = self._tensors_to_detections(raw_boxes, scores, self.anchors)
|
330 |
+
|
331 |
+
detections = self._weighted_non_max_suppression(detections)
|
332 |
+
detections = np.array(detections)
|
333 |
+
return detections
|
334 |
+
|
335 |
+
|
336 |
+
if __name__ == "__main__":
|
337 |
+
import argparse
|
338 |
+
|
339 |
+
parser = argparse.ArgumentParser()
|
340 |
+
parser.add_argument("--model", type=str, default="")
|
341 |
+
parser.add_argument("--image", type=str, default=None)
|
342 |
+
args = parser.parse_args()
|
343 |
+
|
344 |
+
blaze_face = BlazeFace(args.model)
|
345 |
+
image = cv2.imread(args.image)
|
346 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
347 |
+
image = cv2.resize(image, (128, 128))
|
348 |
+
image = image[np.newaxis, :, :, :].astype(np.float32)
|
349 |
+
image = image / 127.5 - 1.0
|
350 |
+
boxes = blaze_face(image)
|
351 |
+
print(boxes)
|
core/aux_models/face_mesh.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
from ..utils.load_model import load_model
|
5 |
+
|
6 |
+
|
7 |
+
class FaceMesh:
|
8 |
+
def __init__(self, model_path, device="cuda"):
|
9 |
+
self.model, self.model_type = load_model(model_path, device=device)
|
10 |
+
self.input_size = (256, 256) # (w, h)
|
11 |
+
self.output_names = [
|
12 |
+
"Identity",
|
13 |
+
"Identity_1",
|
14 |
+
"Identity_2",
|
15 |
+
] # Identity is the mesh
|
16 |
+
|
17 |
+
def project_landmarks(self, points, roi):
|
18 |
+
width, height = self.input_size
|
19 |
+
points /= (width, height, width)
|
20 |
+
sin, cos = np.sin(roi[4]), np.cos(roi[4])
|
21 |
+
matrix = np.array([[cos, sin, 0.0], [-sin, cos, 0.0], [1.0, 1.0, 1.0]])
|
22 |
+
points -= (0.5, 0.5, 0.0)
|
23 |
+
rotated = np.matmul(points * (1, 1, 0), matrix)
|
24 |
+
points *= (0, 0, 1)
|
25 |
+
points += rotated
|
26 |
+
points *= (roi[2], roi[3], roi[2])
|
27 |
+
points += (roi[0], roi[1], 0.0)
|
28 |
+
return points
|
29 |
+
|
30 |
+
def __call__(self, image, roi):
|
31 |
+
"""
|
32 |
+
image: np.ndarray, RGB, (H, W, C), [0, 255]
|
33 |
+
roi: np.ndarray, (cx, cy, w, h, rotation), rotation in radian
|
34 |
+
"""
|
35 |
+
cx, cy, w, h = roi[:4]
|
36 |
+
w_half, h_half = w / 2, h / 2
|
37 |
+
pts = [
|
38 |
+
(cx - w_half, cy - h_half),
|
39 |
+
(cx + w_half, cy - h_half),
|
40 |
+
(cx + w_half, cy + h_half),
|
41 |
+
(cx - w_half, cy + h_half),
|
42 |
+
]
|
43 |
+
rotation = roi[4]
|
44 |
+
s, c = np.sin(rotation), np.cos(rotation)
|
45 |
+
t = np.array(pts) - (cx, cy)
|
46 |
+
r = np.array([[c, s], [-s, c]])
|
47 |
+
src_pts = np.matmul(t, r) + (cx, cy)
|
48 |
+
src_pts = src_pts.astype(np.float32)
|
49 |
+
|
50 |
+
dst_pts = np.array(
|
51 |
+
[
|
52 |
+
[0.0, 0.0],
|
53 |
+
[self.input_size[0], 0.0],
|
54 |
+
[self.input_size[0], self.input_size[1]],
|
55 |
+
[0.0, self.input_size[1]],
|
56 |
+
]
|
57 |
+
).astype(np.float32)
|
58 |
+
M = cv2.getPerspectiveTransform(src_pts, dst_pts)
|
59 |
+
roi_image = cv2.warpPerspective(
|
60 |
+
image, M, self.input_size, flags=cv2.INTER_LINEAR
|
61 |
+
)
|
62 |
+
# cv2.imwrite('test.jpg', cv2.cvtColor(roi_image, cv2.COLOR_RGB2BGR))
|
63 |
+
roi_image = roi_image / 255.0
|
64 |
+
roi_image = roi_image.astype(np.float32)
|
65 |
+
roi_image = roi_image[np.newaxis, :, :, :]
|
66 |
+
|
67 |
+
outputs = {}
|
68 |
+
if self.model_type == "onnx":
|
69 |
+
out_list = self.model.run(None, {"input": roi_image})
|
70 |
+
for i, name in enumerate(self.output_names):
|
71 |
+
outputs[name] = out_list[i]
|
72 |
+
elif self.model_type == "tensorrt":
|
73 |
+
self.model.setup({"input": roi_image})
|
74 |
+
self.model.infer()
|
75 |
+
for name in self.output_names:
|
76 |
+
outputs[name] = self.model.buffer[name][0]
|
77 |
+
else:
|
78 |
+
raise ValueError(f"Unsupported model type: {self.model_type}")
|
79 |
+
points = outputs["Identity"].reshape(1434 // 3, 3)
|
80 |
+
points = self.project_landmarks(points, roi)
|
81 |
+
return points
|
82 |
+
|
83 |
+
|
84 |
+
if __name__ == "__main__":
|
85 |
+
import argparse
|
86 |
+
|
87 |
+
parser = argparse.ArgumentParser()
|
88 |
+
parser.add_argument("--model", type=str, help="model path")
|
89 |
+
parser.add_argument("--image", type=str, help="image path")
|
90 |
+
parser.add_argument("--device", type=str, default="cuda", help="device")
|
91 |
+
args = parser.parse_args()
|
92 |
+
|
93 |
+
face_mesh = FaceMesh(args.model, args.device)
|
94 |
+
image = cv2.imread(args.image, cv2.IMREAD_COLOR)
|
95 |
+
image = cv2.resize(image, (256, 256))
|
96 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
97 |
+
|
98 |
+
face_mesh = FaceMesh(args.model, args.device)
|
99 |
+
roi = np.array([128, 128, 256, 256, np.pi / 2])
|
100 |
+
mesh = face_mesh(image, roi)
|
101 |
+
print(mesh.shape)
|
core/aux_models/hubert_stream.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..utils.load_model import load_model
|
2 |
+
|
3 |
+
|
4 |
+
class HubertStreaming:
|
5 |
+
def __init__(self, model_path, device="cuda", **kwargs):
|
6 |
+
kwargs["model_file"] = model_path
|
7 |
+
kwargs["module_name"] = "HubertStreamingONNX"
|
8 |
+
kwargs["package_name"] = "..aux_models.modules"
|
9 |
+
|
10 |
+
self.model, self.model_type = load_model(model_path, device=device, **kwargs)
|
11 |
+
self.device = device
|
12 |
+
|
13 |
+
def forward_chunk(self, audio_chunk):
|
14 |
+
if self.model_type == "onnx":
|
15 |
+
output = self.model.run(None, {"input_values": audio_chunk.reshape(1, -1)})[0]
|
16 |
+
elif self.model_type == "tensorrt":
|
17 |
+
self.model.setup({"input_values": audio_chunk.reshape(1, -1)})
|
18 |
+
self.model.infer()
|
19 |
+
output = self.model.buffer["encoding_out"][0]
|
20 |
+
else:
|
21 |
+
raise ValueError(f"Unsupported model type: {self.model_type}")
|
22 |
+
return output
|
23 |
+
|
24 |
+
def __call__(self, audio_chunk):
|
25 |
+
if self.model_type == "ori":
|
26 |
+
output = self.model.forward_chunk(audio_chunk)
|
27 |
+
else:
|
28 |
+
output = self.forward_chunk(audio_chunk)
|
29 |
+
return output
|
core/aux_models/insightface_det.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import division
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
|
5 |
+
from ..utils.load_model import load_model
|
6 |
+
|
7 |
+
|
8 |
+
def distance2bbox(points, distance, max_shape=None):
|
9 |
+
"""Decode distance prediction to bounding box.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
points (Tensor): Shape (n, 2), [x, y].
|
13 |
+
distance (Tensor): Distance from the given point to 4
|
14 |
+
boundaries (left, top, right, bottom).
|
15 |
+
max_shape (tuple): Shape of the image.
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
Tensor: Decoded bboxes.
|
19 |
+
"""
|
20 |
+
x1 = points[:, 0] - distance[:, 0]
|
21 |
+
y1 = points[:, 1] - distance[:, 1]
|
22 |
+
x2 = points[:, 0] + distance[:, 2]
|
23 |
+
y2 = points[:, 1] + distance[:, 3]
|
24 |
+
if max_shape is not None:
|
25 |
+
x1 = x1.clamp(min=0, max=max_shape[1])
|
26 |
+
y1 = y1.clamp(min=0, max=max_shape[0])
|
27 |
+
x2 = x2.clamp(min=0, max=max_shape[1])
|
28 |
+
y2 = y2.clamp(min=0, max=max_shape[0])
|
29 |
+
return np.stack([x1, y1, x2, y2], axis=-1)
|
30 |
+
|
31 |
+
|
32 |
+
def distance2kps(points, distance, max_shape=None):
|
33 |
+
"""Decode distance prediction to bounding box.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
points (Tensor): Shape (n, 2), [x, y].
|
37 |
+
distance (Tensor): Distance from the given point to 4
|
38 |
+
boundaries (left, top, right, bottom).
|
39 |
+
max_shape (tuple): Shape of the image.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
Tensor: Decoded bboxes.
|
43 |
+
"""
|
44 |
+
preds = []
|
45 |
+
for i in range(0, distance.shape[1], 2):
|
46 |
+
px = points[:, i%2] + distance[:, i]
|
47 |
+
py = points[:, i%2+1] + distance[:, i+1]
|
48 |
+
if max_shape is not None:
|
49 |
+
px = px.clamp(min=0, max=max_shape[1])
|
50 |
+
py = py.clamp(min=0, max=max_shape[0])
|
51 |
+
preds.append(px)
|
52 |
+
preds.append(py)
|
53 |
+
return np.stack(preds, axis=-1)
|
54 |
+
|
55 |
+
|
56 |
+
class InsightFaceDet:
|
57 |
+
def __init__(self, model_path, device="cuda", **kwargs):
|
58 |
+
kwargs["model_file"] = model_path
|
59 |
+
kwargs["module_name"] = "RetinaFace"
|
60 |
+
kwargs["package_name"] = "..aux_models.modules"
|
61 |
+
|
62 |
+
self.model, self.model_type = load_model(model_path, device=device, **kwargs)
|
63 |
+
self.device = device
|
64 |
+
|
65 |
+
if self.model_type != "ori":
|
66 |
+
self._init_vars()
|
67 |
+
|
68 |
+
def _init_vars(self):
|
69 |
+
self.center_cache = {}
|
70 |
+
|
71 |
+
self.nms_thresh = 0.4
|
72 |
+
self.det_thresh = 0.5
|
73 |
+
|
74 |
+
self.input_size = (512, 512)
|
75 |
+
self.input_mean = 127.5
|
76 |
+
self.input_std = 128.0
|
77 |
+
self._anchor_ratio = 1.0
|
78 |
+
self.fmc = 3
|
79 |
+
self._feat_stride_fpn = [8, 16, 32]
|
80 |
+
self._num_anchors = 2
|
81 |
+
self.use_kps = True
|
82 |
+
|
83 |
+
self.output_names = [
|
84 |
+
"scores1",
|
85 |
+
"scores2",
|
86 |
+
"scores3",
|
87 |
+
"boxes1",
|
88 |
+
"boxes2",
|
89 |
+
"boxes3",
|
90 |
+
"kps1",
|
91 |
+
"kps2",
|
92 |
+
"kps3",
|
93 |
+
]
|
94 |
+
|
95 |
+
def _run_model(self, blob):
|
96 |
+
if self.model_type == "onnx":
|
97 |
+
net_outs = self.model.run(None, {"image": blob})
|
98 |
+
elif self.model_type == "tensorrt":
|
99 |
+
self.model.setup({"image": blob})
|
100 |
+
self.model.infer()
|
101 |
+
net_outs = [self.model.buffer[name][0] for name in self.output_names]
|
102 |
+
else:
|
103 |
+
raise ValueError(f"Unsupported model type: {self.model_type}")
|
104 |
+
return net_outs
|
105 |
+
|
106 |
+
def _forward(self, img, threshold):
|
107 |
+
"""
|
108 |
+
img: np.ndarray, shape (h, w, 3)
|
109 |
+
"""
|
110 |
+
scores_list = []
|
111 |
+
bboxes_list = []
|
112 |
+
kpss_list = []
|
113 |
+
input_size = tuple(img.shape[0:2][::-1])
|
114 |
+
blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
|
115 |
+
# (1, 3, 512, 512)
|
116 |
+
net_outs = self._run_model(blob)
|
117 |
+
|
118 |
+
input_height = blob.shape[2]
|
119 |
+
input_width = blob.shape[3]
|
120 |
+
fmc = self.fmc
|
121 |
+
for idx, stride in enumerate(self._feat_stride_fpn):
|
122 |
+
scores = net_outs[idx]
|
123 |
+
bbox_preds = net_outs[idx+fmc]
|
124 |
+
bbox_preds = bbox_preds * stride
|
125 |
+
if self.use_kps:
|
126 |
+
kps_preds = net_outs[idx+fmc*2] * stride
|
127 |
+
height = input_height // stride
|
128 |
+
width = input_width // stride
|
129 |
+
# K = height * width
|
130 |
+
key = (height, width, stride)
|
131 |
+
if key in self.center_cache:
|
132 |
+
anchor_centers = self.center_cache[key]
|
133 |
+
else:
|
134 |
+
#solution-3:
|
135 |
+
anchor_centers = np.stack(np.mgrid[:height, :width][::-1], axis=-1).astype(np.float32)
|
136 |
+
anchor_centers = (anchor_centers * stride).reshape( (-1, 2) )
|
137 |
+
if self._num_anchors>1:
|
138 |
+
anchor_centers = np.stack([anchor_centers]*self._num_anchors, axis=1).reshape( (-1,2) )
|
139 |
+
if len(self.center_cache)<100:
|
140 |
+
self.center_cache[key] = anchor_centers
|
141 |
+
|
142 |
+
pos_inds = np.where(scores>=threshold)[0]
|
143 |
+
bboxes = distance2bbox(anchor_centers, bbox_preds)
|
144 |
+
pos_scores = scores[pos_inds]
|
145 |
+
pos_bboxes = bboxes[pos_inds]
|
146 |
+
scores_list.append(pos_scores)
|
147 |
+
bboxes_list.append(pos_bboxes)
|
148 |
+
if self.use_kps:
|
149 |
+
kpss = distance2kps(anchor_centers, kps_preds)
|
150 |
+
kpss = kpss.reshape( (kpss.shape[0], -1, 2) )
|
151 |
+
pos_kpss = kpss[pos_inds]
|
152 |
+
kpss_list.append(pos_kpss)
|
153 |
+
return scores_list, bboxes_list, kpss_list
|
154 |
+
|
155 |
+
def detect(self, img, input_size=None, max_num=0, metric='default', det_thresh=None):
|
156 |
+
input_size = self.input_size if input_size is None else input_size
|
157 |
+
det_thresh = self.det_thresh if det_thresh is None else det_thresh
|
158 |
+
|
159 |
+
im_ratio = float(img.shape[0]) / img.shape[1]
|
160 |
+
model_ratio = float(input_size[1]) / input_size[0]
|
161 |
+
if im_ratio>model_ratio:
|
162 |
+
new_height = input_size[1]
|
163 |
+
new_width = int(new_height / im_ratio)
|
164 |
+
else:
|
165 |
+
new_width = input_size[0]
|
166 |
+
new_height = int(new_width * im_ratio)
|
167 |
+
det_scale = float(new_height) / img.shape[0]
|
168 |
+
resized_img = cv2.resize(img, (new_width, new_height))
|
169 |
+
det_img = np.zeros( (input_size[1], input_size[0], 3), dtype=np.uint8 )
|
170 |
+
det_img[:new_height, :new_width, :] = resized_img
|
171 |
+
|
172 |
+
scores_list, bboxes_list, kpss_list = self._forward(det_img, det_thresh)
|
173 |
+
|
174 |
+
scores = np.vstack(scores_list)
|
175 |
+
scores_ravel = scores.ravel()
|
176 |
+
order = scores_ravel.argsort()[::-1]
|
177 |
+
bboxes = np.vstack(bboxes_list) / det_scale
|
178 |
+
if self.use_kps:
|
179 |
+
kpss = np.vstack(kpss_list) / det_scale
|
180 |
+
pre_det = np.hstack((bboxes, scores)).astype(np.float32, copy=False)
|
181 |
+
pre_det = pre_det[order, :]
|
182 |
+
keep = self.nms(pre_det)
|
183 |
+
det = pre_det[keep, :]
|
184 |
+
if self.use_kps:
|
185 |
+
kpss = kpss[order,:,:]
|
186 |
+
kpss = kpss[keep,:,:]
|
187 |
+
else:
|
188 |
+
kpss = None
|
189 |
+
if max_num > 0 and det.shape[0] > max_num:
|
190 |
+
area = (det[:, 2] - det[:, 0]) * (det[:, 3] - det[:, 1])
|
191 |
+
img_center = img.shape[0] // 2, img.shape[1] // 2
|
192 |
+
offsets = np.vstack([
|
193 |
+
(det[:, 0] + det[:, 2]) / 2 - img_center[1],
|
194 |
+
(det[:, 1] + det[:, 3]) / 2 - img_center[0]
|
195 |
+
])
|
196 |
+
offset_dist_squared = np.sum(np.power(offsets, 2.0), 0)
|
197 |
+
if metric=='max':
|
198 |
+
values = area
|
199 |
+
else:
|
200 |
+
values = area - offset_dist_squared * 2.0 # some extra weight on the centering
|
201 |
+
bindex = np.argsort(values)[::-1] # some extra weight on the centering
|
202 |
+
bindex = bindex[0:max_num]
|
203 |
+
det = det[bindex, :]
|
204 |
+
if kpss is not None:
|
205 |
+
kpss = kpss[bindex, :]
|
206 |
+
return det, kpss
|
207 |
+
|
208 |
+
def nms(self, dets):
|
209 |
+
thresh = self.nms_thresh
|
210 |
+
x1 = dets[:, 0]
|
211 |
+
y1 = dets[:, 1]
|
212 |
+
x2 = dets[:, 2]
|
213 |
+
y2 = dets[:, 3]
|
214 |
+
scores = dets[:, 4]
|
215 |
+
|
216 |
+
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
217 |
+
order = scores.argsort()[::-1]
|
218 |
+
|
219 |
+
keep = []
|
220 |
+
while order.size > 0:
|
221 |
+
i = order[0]
|
222 |
+
keep.append(i)
|
223 |
+
xx1 = np.maximum(x1[i], x1[order[1:]])
|
224 |
+
yy1 = np.maximum(y1[i], y1[order[1:]])
|
225 |
+
xx2 = np.minimum(x2[i], x2[order[1:]])
|
226 |
+
yy2 = np.minimum(y2[i], y2[order[1:]])
|
227 |
+
|
228 |
+
w = np.maximum(0.0, xx2 - xx1 + 1)
|
229 |
+
h = np.maximum(0.0, yy2 - yy1 + 1)
|
230 |
+
inter = w * h
|
231 |
+
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
232 |
+
|
233 |
+
inds = np.where(ovr <= thresh)[0]
|
234 |
+
order = order[inds + 1]
|
235 |
+
|
236 |
+
return keep
|
237 |
+
|
238 |
+
def __call__(self, img, **kwargs):
|
239 |
+
if self.model_type == "ori":
|
240 |
+
det, kpss = self.model.detect(img, **kwargs)
|
241 |
+
else:
|
242 |
+
det, kpss = self.detect(img, **kwargs)
|
243 |
+
|
244 |
+
return det, kpss
|
245 |
+
|
core/aux_models/insightface_landmark106.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import division
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import cv2
|
5 |
+
from skimage import transform as trans
|
6 |
+
|
7 |
+
from ..utils.load_model import load_model
|
8 |
+
|
9 |
+
|
10 |
+
def transform(data, center, output_size, scale, rotation):
|
11 |
+
scale_ratio = scale
|
12 |
+
rot = float(rotation) * np.pi / 180.0
|
13 |
+
|
14 |
+
t1 = trans.SimilarityTransform(scale=scale_ratio)
|
15 |
+
cx = center[0] * scale_ratio
|
16 |
+
cy = center[1] * scale_ratio
|
17 |
+
t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy))
|
18 |
+
t3 = trans.SimilarityTransform(rotation=rot)
|
19 |
+
t4 = trans.SimilarityTransform(translation=(output_size / 2,
|
20 |
+
output_size / 2))
|
21 |
+
t = t1 + t2 + t3 + t4
|
22 |
+
M = t.params[0:2]
|
23 |
+
cropped = cv2.warpAffine(data,
|
24 |
+
M, (output_size, output_size),
|
25 |
+
borderValue=0.0)
|
26 |
+
return cropped, M
|
27 |
+
|
28 |
+
|
29 |
+
def trans_points2d(pts, M):
|
30 |
+
new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
|
31 |
+
for i in range(pts.shape[0]):
|
32 |
+
pt = pts[i]
|
33 |
+
new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32)
|
34 |
+
new_pt = np.dot(M, new_pt)
|
35 |
+
new_pts[i] = new_pt[0:2]
|
36 |
+
|
37 |
+
return new_pts
|
38 |
+
|
39 |
+
|
40 |
+
class Landmark106:
|
41 |
+
def __init__(self, model_path, device="cuda", **kwargs):
|
42 |
+
kwargs["model_file"] = model_path
|
43 |
+
kwargs["module_name"] = "Landmark106"
|
44 |
+
kwargs["package_name"] = "..aux_models.modules"
|
45 |
+
|
46 |
+
self.model, self.model_type = load_model(model_path, device=device, **kwargs)
|
47 |
+
self.device = device
|
48 |
+
|
49 |
+
if self.model_type != "ori":
|
50 |
+
self._init_vars()
|
51 |
+
|
52 |
+
def _init_vars(self):
|
53 |
+
self.input_mean = 0.0
|
54 |
+
self.input_std = 1.0
|
55 |
+
self.input_size = (192, 192)
|
56 |
+
self.lmk_num = 106
|
57 |
+
|
58 |
+
self.output_names = ["fc1"]
|
59 |
+
|
60 |
+
def _run_model(self, blob):
|
61 |
+
if self.model_type == "onnx":
|
62 |
+
pred = self.model.run(None, {"data": blob})[0]
|
63 |
+
elif self.model_type == "tensorrt":
|
64 |
+
self.model.setup({"data": blob})
|
65 |
+
self.model.infer()
|
66 |
+
pred = self.model.buffer[self.output_names[0]][0]
|
67 |
+
else:
|
68 |
+
raise ValueError(f"Unsupported model type: {self.model_type}")
|
69 |
+
return pred
|
70 |
+
|
71 |
+
def get(self, img, bbox):
|
72 |
+
w, h = (bbox[2] - bbox[0]), (bbox[3] - bbox[1])
|
73 |
+
center = (bbox[2] + bbox[0]) / 2, (bbox[3] + bbox[1]) / 2
|
74 |
+
rotate = 0
|
75 |
+
_scale = self.input_size[0] / (max(w, h)*1.5)
|
76 |
+
|
77 |
+
aimg, M = transform(img, center, self.input_size[0], _scale, rotate)
|
78 |
+
input_size = tuple(aimg.shape[0:2][::-1])
|
79 |
+
|
80 |
+
blob = cv2.dnn.blobFromImage(aimg, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
|
81 |
+
|
82 |
+
pred = self._run_model(blob)
|
83 |
+
|
84 |
+
pred = pred.reshape((-1, 2))
|
85 |
+
if self.lmk_num < pred.shape[0]:
|
86 |
+
pred = pred[self.lmk_num*-1:,:]
|
87 |
+
pred[:, 0:2] += 1
|
88 |
+
pred[:, 0:2] *= (self.input_size[0] // 2)
|
89 |
+
|
90 |
+
IM = cv2.invertAffineTransform(M)
|
91 |
+
pred = trans_points2d(pred, IM)
|
92 |
+
return pred
|
93 |
+
|
94 |
+
def __call__(self, img, bbox):
|
95 |
+
if self.model_type == "ori":
|
96 |
+
pred = self.model.get(img, bbox)
|
97 |
+
else:
|
98 |
+
pred = self.get(img, bbox)
|
99 |
+
|
100 |
+
return pred
|
core/aux_models/landmark203.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from ..utils.load_model import load_model
|
3 |
+
|
4 |
+
|
5 |
+
def _transform_pts(pts, M):
|
6 |
+
""" conduct similarity or affine transformation to the pts
|
7 |
+
pts: Nx2 ndarray
|
8 |
+
M: 2x3 matrix or 3x3 matrix
|
9 |
+
return: Nx2
|
10 |
+
"""
|
11 |
+
return pts @ M[:2, :2].T + M[:2, 2]
|
12 |
+
|
13 |
+
|
14 |
+
class Landmark203:
|
15 |
+
def __init__(self, model_path, device="cuda", **kwargs):
|
16 |
+
kwargs["model_file"] = model_path
|
17 |
+
kwargs["module_name"] = "Landmark203"
|
18 |
+
kwargs["package_name"] = "..aux_models.modules"
|
19 |
+
|
20 |
+
self.model, self.model_type = load_model(model_path, device=device, **kwargs)
|
21 |
+
self.device = device
|
22 |
+
|
23 |
+
self.output_names = ["landmarks"]
|
24 |
+
self.dsize = 224
|
25 |
+
|
26 |
+
def _run_model(self, inp):
|
27 |
+
if self.model_type == "onnx":
|
28 |
+
out_pts = self.model.run(None, {"input": inp})[0]
|
29 |
+
elif self.model_type == "tensorrt":
|
30 |
+
self.model.setup({"input": inp})
|
31 |
+
self.model.infer()
|
32 |
+
out_pts = self.model.buffer[self.output_names[0]][0]
|
33 |
+
else:
|
34 |
+
raise ValueError(f"Unsupported model type: {self.model_type}")
|
35 |
+
return out_pts
|
36 |
+
|
37 |
+
def run(self, img_crop_rgb, M_c2o=None):
|
38 |
+
# img_crop_rgb: 224x224
|
39 |
+
|
40 |
+
inp = (img_crop_rgb.astype(np.float32) / 255.).transpose(2, 0, 1)[None, ...] # HxWx3 (BGR) -> 1x3xHxW (RGB!)
|
41 |
+
|
42 |
+
out_pts = self._run_model(inp)
|
43 |
+
|
44 |
+
# 2d landmarks 203 points
|
45 |
+
lmk = out_pts[0].reshape(-1, 2) * self.dsize # scale to 0-224
|
46 |
+
if M_c2o is not None:
|
47 |
+
lmk = _transform_pts(lmk, M=M_c2o)
|
48 |
+
|
49 |
+
return lmk
|
50 |
+
|
51 |
+
def __call__(self, img_crop_rgb, M_c2o=None):
|
52 |
+
if self.model_type == "ori":
|
53 |
+
lmk = self.model.run(img_crop_rgb, M_c2o)
|
54 |
+
else:
|
55 |
+
lmk = self.run(img_crop_rgb, M_c2o)
|
56 |
+
|
57 |
+
return lmk
|
58 |
+
|
core/aux_models/mediapipe_landmark478.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
from ..utils.load_model import load_model
|
5 |
+
from .blaze_face import BlazeFace
|
6 |
+
from .face_mesh import FaceMesh
|
7 |
+
|
8 |
+
|
9 |
+
class SizeMode(Enum):
|
10 |
+
DEFAULT = 0
|
11 |
+
SQUARE_LONG = 1
|
12 |
+
SQUARE_SHORT = 2
|
13 |
+
|
14 |
+
|
15 |
+
def _select_roi_size(
|
16 |
+
bbox: np.ndarray, image_size, size_mode: SizeMode # x1, y1, x2, y2 # w,h
|
17 |
+
):
|
18 |
+
"""Return the size of an ROI based on bounding box, image size and mode"""
|
19 |
+
width, height = bbox[2] - bbox[0], bbox[3] - bbox[1]
|
20 |
+
image_width, image_height = image_size
|
21 |
+
if size_mode == SizeMode.SQUARE_LONG:
|
22 |
+
long_size = max(width, height)
|
23 |
+
width, height = long_size, long_size
|
24 |
+
elif size_mode == SizeMode.SQUARE_SHORT:
|
25 |
+
short_side = min(width, height)
|
26 |
+
width, height = short_side, short_side
|
27 |
+
return width, height
|
28 |
+
|
29 |
+
|
30 |
+
def bbox_to_roi(
|
31 |
+
bbox: np.ndarray,
|
32 |
+
image_size, # w,h
|
33 |
+
rotation_keypoints=None,
|
34 |
+
scale=(1.0, 1.0), # w, h
|
35 |
+
size_mode: SizeMode = SizeMode.SQUARE_LONG,
|
36 |
+
):
|
37 |
+
PI = np.pi
|
38 |
+
TWO_PI = 2 * np.pi
|
39 |
+
# select ROI dimensions
|
40 |
+
width, height = _select_roi_size(bbox, image_size, size_mode)
|
41 |
+
scale_x, scale_y = scale
|
42 |
+
# calculate ROI size and -centre
|
43 |
+
width, height = width * scale_x, height * scale_y
|
44 |
+
cx = (bbox[0] + bbox[2]) / 2
|
45 |
+
cy = (bbox[1] + bbox[3]) / 2
|
46 |
+
# calculate rotation of required
|
47 |
+
if rotation_keypoints is None or len(rotation_keypoints) < 2:
|
48 |
+
return np.array([cx, cy, width, height, 0])
|
49 |
+
x0, y0 = rotation_keypoints[0]
|
50 |
+
x1, y1 = rotation_keypoints[1]
|
51 |
+
angle = -np.atan2(y0 - y1, x1 - x0)
|
52 |
+
# normalise to [0, 2*PI]
|
53 |
+
rotation = angle - TWO_PI * np.floor((angle + PI) / TWO_PI)
|
54 |
+
return np.array([cx, cy, width, height, rotation])
|
55 |
+
|
56 |
+
|
57 |
+
class Landmark478:
|
58 |
+
def __init__(self, blaze_face_model_path="", face_mesh_model_path="", device="cuda", **kwargs):
|
59 |
+
if kwargs.get("force_ori_type", False):
|
60 |
+
assert "task_path" in kwargs
|
61 |
+
kwargs["module_name"] = "Landmark478"
|
62 |
+
kwargs["package_name"] = "..aux_models.modules"
|
63 |
+
self.model, self.model_type = load_model("", device=device, **kwargs)
|
64 |
+
else:
|
65 |
+
self.blaze_face = BlazeFace(blaze_face_model_path, device)
|
66 |
+
self.face_mesh = FaceMesh(face_mesh_model_path, device)
|
67 |
+
self.model_type = ""
|
68 |
+
|
69 |
+
def get(self, image):
|
70 |
+
bboxes = self.blaze_face(image)
|
71 |
+
if len(bboxes) == 0:
|
72 |
+
return None
|
73 |
+
bbox = bboxes[0]
|
74 |
+
scale = (image.shape[1] / 128.0, image.shape[0] / 128.0)
|
75 |
+
|
76 |
+
# The first 4 numbers describe the bounding box corners:
|
77 |
+
#
|
78 |
+
# ymin, xmin, ymax, xmax
|
79 |
+
# These are normalized coordinates (between 0 and 1).
|
80 |
+
# The next 12 numbers are the x,y-coordinates of the 6 facial landmark keypoints:
|
81 |
+
#
|
82 |
+
# right_eye_x, right_eye_y
|
83 |
+
# left_eye_x, left_eye_y
|
84 |
+
# nose_x, nose_y
|
85 |
+
# mouth_x, mouth_y
|
86 |
+
# right_ear_x, right_ear_y
|
87 |
+
# left_ear_x, left_ear_y
|
88 |
+
# Tip: these labeled as seen from the perspective of the person, so their right is your left.
|
89 |
+
# The final number is the confidence score that this detection really is a face.
|
90 |
+
|
91 |
+
bbox[0] = bbox[0] * scale[1]
|
92 |
+
bbox[1] = bbox[1] * scale[0]
|
93 |
+
bbox[2] = bbox[2] * scale[1]
|
94 |
+
bbox[3] = bbox[3] * scale[0]
|
95 |
+
left_eye = (bbox[4], bbox[5])
|
96 |
+
right_eye = (bbox[6], bbox[7])
|
97 |
+
|
98 |
+
roi = bbox_to_roi(
|
99 |
+
bbox,
|
100 |
+
(image.shape[1], image.shape[0]),
|
101 |
+
rotation_keypoints=[left_eye, right_eye],
|
102 |
+
scale=(1.5, 1.5),
|
103 |
+
size_mode=SizeMode.SQUARE_LONG,
|
104 |
+
)
|
105 |
+
|
106 |
+
mesh = self.face_mesh(image, roi)
|
107 |
+
mesh = mesh / (image.shape[1], image.shape[0], image.shape[1])
|
108 |
+
return mesh
|
109 |
+
|
110 |
+
def __call__(self, image):
|
111 |
+
if self.model_type == "ori":
|
112 |
+
det = self.model.detect_from_npimage(image.copy())
|
113 |
+
lmk = self.model.mplmk_to_nplmk(det)
|
114 |
+
return lmk
|
115 |
+
else:
|
116 |
+
lmk = self.get(image)
|
117 |
+
lmk = lmk.reshape(1, -1, 3).astype(np.float32)
|
118 |
+
return lmk
|
core/aux_models/modules/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .retinaface import RetinaFace
|
2 |
+
from .landmark106 import Landmark106
|
3 |
+
from .landmark203 import Landmark203
|
4 |
+
from .landmark478 import Landmark478
|
5 |
+
from .hubert_stream import HubertStreamingONNX
|
core/aux_models/modules/hubert_stream.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import onnxruntime
|
3 |
+
|
4 |
+
|
5 |
+
class HubertStreamingONNX:
|
6 |
+
def __init__(self, model_file, device="cuda"):
|
7 |
+
if device == "cuda":
|
8 |
+
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
9 |
+
else:
|
10 |
+
providers = ["CPUExecutionProvider"]
|
11 |
+
|
12 |
+
self.session = onnxruntime.InferenceSession(model_file, providers=providers)
|
13 |
+
|
14 |
+
def forward_chunk(self, input_values):
|
15 |
+
encoding_out = self.session.run(
|
16 |
+
None,
|
17 |
+
{"input_values": input_values.reshape(1, -1)}
|
18 |
+
)[0]
|
19 |
+
return encoding_out
|
20 |
+
|
21 |
+
|
core/aux_models/modules/landmark106.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# insightface
|
2 |
+
from __future__ import division
|
3 |
+
import onnxruntime
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
from skimage import transform as trans
|
7 |
+
|
8 |
+
|
9 |
+
def transform(data, center, output_size, scale, rotation):
|
10 |
+
scale_ratio = scale
|
11 |
+
rot = float(rotation) * np.pi / 180.0
|
12 |
+
|
13 |
+
t1 = trans.SimilarityTransform(scale=scale_ratio)
|
14 |
+
cx = center[0] * scale_ratio
|
15 |
+
cy = center[1] * scale_ratio
|
16 |
+
t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy))
|
17 |
+
t3 = trans.SimilarityTransform(rotation=rot)
|
18 |
+
t4 = trans.SimilarityTransform(translation=(output_size / 2,
|
19 |
+
output_size / 2))
|
20 |
+
t = t1 + t2 + t3 + t4
|
21 |
+
M = t.params[0:2]
|
22 |
+
cropped = cv2.warpAffine(data,
|
23 |
+
M, (output_size, output_size),
|
24 |
+
borderValue=0.0)
|
25 |
+
return cropped, M
|
26 |
+
|
27 |
+
|
28 |
+
def trans_points2d(pts, M):
|
29 |
+
new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
|
30 |
+
for i in range(pts.shape[0]):
|
31 |
+
pt = pts[i]
|
32 |
+
new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32)
|
33 |
+
new_pt = np.dot(M, new_pt)
|
34 |
+
new_pts[i] = new_pt[0:2]
|
35 |
+
|
36 |
+
return new_pts
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
class Landmark106:
|
41 |
+
def __init__(self, model_file, device="cuda"):
|
42 |
+
if device == "cuda":
|
43 |
+
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
44 |
+
else:
|
45 |
+
providers = ["CPUExecutionProvider"]
|
46 |
+
self.session = onnxruntime.InferenceSession(model_file, providers=providers)
|
47 |
+
|
48 |
+
self.input_mean = 0.0
|
49 |
+
self.input_std = 1.0
|
50 |
+
self.input_size = (192, 192)
|
51 |
+
input_cfg = self.session.get_inputs()[0]
|
52 |
+
input_name = input_cfg.name
|
53 |
+
outputs = self.session.get_outputs()
|
54 |
+
output_names = []
|
55 |
+
for out in outputs:
|
56 |
+
output_names.append(out.name)
|
57 |
+
self.input_name = input_name
|
58 |
+
self.output_names = output_names
|
59 |
+
self.lmk_num = 106
|
60 |
+
|
61 |
+
def get(self, img, bbox):
|
62 |
+
w, h = (bbox[2] - bbox[0]), (bbox[3] - bbox[1])
|
63 |
+
center = (bbox[2] + bbox[0]) / 2, (bbox[3] + bbox[1]) / 2
|
64 |
+
rotate = 0
|
65 |
+
_scale = self.input_size[0] / (max(w, h)*1.5)
|
66 |
+
|
67 |
+
aimg, M = transform(img, center, self.input_size[0], _scale, rotate)
|
68 |
+
input_size = tuple(aimg.shape[0:2][::-1])
|
69 |
+
|
70 |
+
blob = cv2.dnn.blobFromImage(aimg, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
|
71 |
+
|
72 |
+
pred = self.session.run(self.output_names, {self.input_name : blob})[0][0]
|
73 |
+
|
74 |
+
pred = pred.reshape((-1, 2))
|
75 |
+
if self.lmk_num < pred.shape[0]:
|
76 |
+
pred = pred[self.lmk_num*-1:,:]
|
77 |
+
pred[:, 0:2] += 1
|
78 |
+
pred[:, 0:2] *= (self.input_size[0] // 2)
|
79 |
+
|
80 |
+
IM = cv2.invertAffineTransform(M)
|
81 |
+
pred = trans_points2d(pred, IM)
|
82 |
+
return pred
|
83 |
+
|
core/aux_models/modules/landmark203.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import onnxruntime
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
def _transform_pts(pts, M):
|
6 |
+
""" conduct similarity or affine transformation to the pts
|
7 |
+
pts: Nx2 ndarray
|
8 |
+
M: 2x3 matrix or 3x3 matrix
|
9 |
+
return: Nx2
|
10 |
+
"""
|
11 |
+
return pts @ M[:2, :2].T + M[:2, 2]
|
12 |
+
|
13 |
+
|
14 |
+
class Landmark203:
|
15 |
+
def __init__(self, model_file, device="cuda"):
|
16 |
+
if device == "cuda":
|
17 |
+
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
18 |
+
else:
|
19 |
+
providers = ["CPUExecutionProvider"]
|
20 |
+
self.session = onnxruntime.InferenceSession(model_file, providers=providers)
|
21 |
+
|
22 |
+
self.dsize = 224
|
23 |
+
|
24 |
+
def _run(self, inp):
|
25 |
+
out = self.session.run(None, {'input': inp})
|
26 |
+
return out
|
27 |
+
|
28 |
+
def run(self, img_crop_rgb, M_c2o=None):
|
29 |
+
# img_crop_rgb: 224x224
|
30 |
+
|
31 |
+
inp = (img_crop_rgb.astype(np.float32) / 255.).transpose(2, 0, 1)[None, ...] # HxWx3 (BGR) -> 1x3xHxW (RGB!)
|
32 |
+
|
33 |
+
out_lst = self._run(inp)
|
34 |
+
out_pts = out_lst[2]
|
35 |
+
|
36 |
+
# 2d landmarks 203 points
|
37 |
+
lmk = out_pts[0].reshape(-1, 2) * self.dsize # scale to 0-224
|
38 |
+
if M_c2o is not None:
|
39 |
+
lmk = _transform_pts(lmk, M=M_c2o)
|
40 |
+
|
41 |
+
return lmk
|
42 |
+
|
core/aux_models/modules/landmark478.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import mediapipe as mp
|
3 |
+
from mediapipe.tasks.python import vision, BaseOptions
|
4 |
+
|
5 |
+
|
6 |
+
class Landmark478:
|
7 |
+
def __init__(self, task_path):
|
8 |
+
base_options = BaseOptions(model_asset_path=task_path)
|
9 |
+
options = vision.FaceLandmarkerOptions(
|
10 |
+
base_options=base_options,
|
11 |
+
output_face_blendshapes=True,
|
12 |
+
output_facial_transformation_matrixes=True,
|
13 |
+
num_faces=1,
|
14 |
+
)
|
15 |
+
detector = vision.FaceLandmarker.create_from_options(options)
|
16 |
+
self.detector = detector
|
17 |
+
|
18 |
+
def detect_from_imp(self, imp):
|
19 |
+
image = mp.Image.create_from_file(imp)
|
20 |
+
detection_result = self.detector.detect(image)
|
21 |
+
return detection_result
|
22 |
+
|
23 |
+
def detect_from_npimage(self, img):
|
24 |
+
image = mp.Image(image_format=mp.ImageFormat.SRGB, data=img)
|
25 |
+
detection_result = self.detector.detect(image)
|
26 |
+
return detection_result
|
27 |
+
|
28 |
+
@staticmethod
|
29 |
+
def mplmk_to_nplmk(results):
|
30 |
+
face_landmarks_list = results.face_landmarks
|
31 |
+
np_lms = []
|
32 |
+
for face_lms in face_landmarks_list:
|
33 |
+
lms = [[lm.x, lm.y, lm.z] for lm in face_lms]
|
34 |
+
np_lms.append(lms)
|
35 |
+
return np.array(np_lms).astype(np.float32)
|
core/aux_models/modules/retinaface.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# insightface
|
2 |
+
from __future__ import division
|
3 |
+
import onnxruntime
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
def distance2bbox(points, distance, max_shape=None):
|
9 |
+
"""Decode distance prediction to bounding box.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
points (Tensor): Shape (n, 2), [x, y].
|
13 |
+
distance (Tensor): Distance from the given point to 4
|
14 |
+
boundaries (left, top, right, bottom).
|
15 |
+
max_shape (tuple): Shape of the image.
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
Tensor: Decoded bboxes.
|
19 |
+
"""
|
20 |
+
x1 = points[:, 0] - distance[:, 0]
|
21 |
+
y1 = points[:, 1] - distance[:, 1]
|
22 |
+
x2 = points[:, 0] + distance[:, 2]
|
23 |
+
y2 = points[:, 1] + distance[:, 3]
|
24 |
+
if max_shape is not None:
|
25 |
+
x1 = x1.clamp(min=0, max=max_shape[1])
|
26 |
+
y1 = y1.clamp(min=0, max=max_shape[0])
|
27 |
+
x2 = x2.clamp(min=0, max=max_shape[1])
|
28 |
+
y2 = y2.clamp(min=0, max=max_shape[0])
|
29 |
+
return np.stack([x1, y1, x2, y2], axis=-1)
|
30 |
+
|
31 |
+
|
32 |
+
def distance2kps(points, distance, max_shape=None):
|
33 |
+
"""Decode distance prediction to bounding box.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
points (Tensor): Shape (n, 2), [x, y].
|
37 |
+
distance (Tensor): Distance from the given point to 4
|
38 |
+
boundaries (left, top, right, bottom).
|
39 |
+
max_shape (tuple): Shape of the image.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
Tensor: Decoded bboxes.
|
43 |
+
"""
|
44 |
+
preds = []
|
45 |
+
for i in range(0, distance.shape[1], 2):
|
46 |
+
px = points[:, i%2] + distance[:, i]
|
47 |
+
py = points[:, i%2+1] + distance[:, i+1]
|
48 |
+
if max_shape is not None:
|
49 |
+
px = px.clamp(min=0, max=max_shape[1])
|
50 |
+
py = py.clamp(min=0, max=max_shape[0])
|
51 |
+
preds.append(px)
|
52 |
+
preds.append(py)
|
53 |
+
return np.stack(preds, axis=-1)
|
54 |
+
|
55 |
+
|
56 |
+
class RetinaFace:
|
57 |
+
def __init__(self, model_file, device="cuda"):
|
58 |
+
if device == "cuda":
|
59 |
+
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
60 |
+
else:
|
61 |
+
providers = ["CPUExecutionProvider"]
|
62 |
+
self.session = onnxruntime.InferenceSession(model_file, providers=providers)
|
63 |
+
|
64 |
+
self.center_cache = {}
|
65 |
+
self.nms_thresh = 0.4
|
66 |
+
self.det_thresh = 0.5
|
67 |
+
self._init_vars()
|
68 |
+
|
69 |
+
def _init_vars(self):
|
70 |
+
self.input_size = (512, 512)
|
71 |
+
input_cfg = self.session.get_inputs()[0]
|
72 |
+
input_name = input_cfg.name
|
73 |
+
outputs = self.session.get_outputs()
|
74 |
+
output_names = []
|
75 |
+
for o in outputs:
|
76 |
+
output_names.append(o.name)
|
77 |
+
self.input_name = input_name
|
78 |
+
self.output_names = output_names
|
79 |
+
self.input_mean = 127.5
|
80 |
+
self.input_std = 128.0
|
81 |
+
self._anchor_ratio = 1.0
|
82 |
+
self.fmc = 3
|
83 |
+
self._feat_stride_fpn = [8, 16, 32]
|
84 |
+
self._num_anchors = 2
|
85 |
+
self.use_kps = True
|
86 |
+
|
87 |
+
def forward(self, img, threshold):
|
88 |
+
scores_list = []
|
89 |
+
bboxes_list = []
|
90 |
+
kpss_list = []
|
91 |
+
input_size = tuple(img.shape[0:2][::-1])
|
92 |
+
blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
|
93 |
+
net_outs = self.session.run(self.output_names, {self.input_name : blob})
|
94 |
+
|
95 |
+
input_height = blob.shape[2]
|
96 |
+
input_width = blob.shape[3]
|
97 |
+
fmc = self.fmc
|
98 |
+
for idx, stride in enumerate(self._feat_stride_fpn):
|
99 |
+
scores = net_outs[idx]
|
100 |
+
bbox_preds = net_outs[idx+fmc]
|
101 |
+
bbox_preds = bbox_preds * stride
|
102 |
+
if self.use_kps:
|
103 |
+
kps_preds = net_outs[idx+fmc*2] * stride
|
104 |
+
height = input_height // stride
|
105 |
+
width = input_width // stride
|
106 |
+
# K = height * width
|
107 |
+
key = (height, width, stride)
|
108 |
+
if key in self.center_cache:
|
109 |
+
anchor_centers = self.center_cache[key]
|
110 |
+
else:
|
111 |
+
#solution-3:
|
112 |
+
anchor_centers = np.stack(np.mgrid[:height, :width][::-1], axis=-1).astype(np.float32)
|
113 |
+
anchor_centers = (anchor_centers * stride).reshape( (-1, 2) )
|
114 |
+
if self._num_anchors>1:
|
115 |
+
anchor_centers = np.stack([anchor_centers]*self._num_anchors, axis=1).reshape( (-1,2) )
|
116 |
+
if len(self.center_cache)<100:
|
117 |
+
self.center_cache[key] = anchor_centers
|
118 |
+
|
119 |
+
pos_inds = np.where(scores>=threshold)[0]
|
120 |
+
bboxes = distance2bbox(anchor_centers, bbox_preds)
|
121 |
+
pos_scores = scores[pos_inds]
|
122 |
+
pos_bboxes = bboxes[pos_inds]
|
123 |
+
scores_list.append(pos_scores)
|
124 |
+
bboxes_list.append(pos_bboxes)
|
125 |
+
if self.use_kps:
|
126 |
+
kpss = distance2kps(anchor_centers, kps_preds)
|
127 |
+
kpss = kpss.reshape( (kpss.shape[0], -1, 2) )
|
128 |
+
pos_kpss = kpss[pos_inds]
|
129 |
+
kpss_list.append(pos_kpss)
|
130 |
+
return scores_list, bboxes_list, kpss_list
|
131 |
+
|
132 |
+
|
133 |
+
def detect(self, img, input_size=None, max_num=0, metric='default', det_thresh=None):
|
134 |
+
input_size = self.input_size if input_size is None else input_size
|
135 |
+
det_thresh = self.det_thresh if det_thresh is None else det_thresh
|
136 |
+
|
137 |
+
im_ratio = float(img.shape[0]) / img.shape[1]
|
138 |
+
model_ratio = float(input_size[1]) / input_size[0]
|
139 |
+
if im_ratio>model_ratio:
|
140 |
+
new_height = input_size[1]
|
141 |
+
new_width = int(new_height / im_ratio)
|
142 |
+
else:
|
143 |
+
new_width = input_size[0]
|
144 |
+
new_height = int(new_width * im_ratio)
|
145 |
+
det_scale = float(new_height) / img.shape[0]
|
146 |
+
resized_img = cv2.resize(img, (new_width, new_height))
|
147 |
+
det_img = np.zeros( (input_size[1], input_size[0], 3), dtype=np.uint8 )
|
148 |
+
det_img[:new_height, :new_width, :] = resized_img
|
149 |
+
|
150 |
+
scores_list, bboxes_list, kpss_list = self.forward(det_img, det_thresh)
|
151 |
+
|
152 |
+
scores = np.vstack(scores_list)
|
153 |
+
scores_ravel = scores.ravel()
|
154 |
+
order = scores_ravel.argsort()[::-1]
|
155 |
+
bboxes = np.vstack(bboxes_list) / det_scale
|
156 |
+
if self.use_kps:
|
157 |
+
kpss = np.vstack(kpss_list) / det_scale
|
158 |
+
pre_det = np.hstack((bboxes, scores)).astype(np.float32, copy=False)
|
159 |
+
pre_det = pre_det[order, :]
|
160 |
+
keep = self.nms(pre_det)
|
161 |
+
det = pre_det[keep, :]
|
162 |
+
if self.use_kps:
|
163 |
+
kpss = kpss[order,:,:]
|
164 |
+
kpss = kpss[keep,:,:]
|
165 |
+
else:
|
166 |
+
kpss = None
|
167 |
+
if max_num > 0 and det.shape[0] > max_num:
|
168 |
+
area = (det[:, 2] - det[:, 0]) * (det[:, 3] - det[:, 1])
|
169 |
+
img_center = img.shape[0] // 2, img.shape[1] // 2
|
170 |
+
offsets = np.vstack([
|
171 |
+
(det[:, 0] + det[:, 2]) / 2 - img_center[1],
|
172 |
+
(det[:, 1] + det[:, 3]) / 2 - img_center[0]
|
173 |
+
])
|
174 |
+
offset_dist_squared = np.sum(np.power(offsets, 2.0), 0)
|
175 |
+
if metric=='max':
|
176 |
+
values = area
|
177 |
+
else:
|
178 |
+
values = area - offset_dist_squared * 2.0 # some extra weight on the centering
|
179 |
+
bindex = np.argsort(values)[::-1] # some extra weight on the centering
|
180 |
+
bindex = bindex[0:max_num]
|
181 |
+
det = det[bindex, :]
|
182 |
+
if kpss is not None:
|
183 |
+
kpss = kpss[bindex, :]
|
184 |
+
return det, kpss
|
185 |
+
|
186 |
+
def nms(self, dets):
|
187 |
+
thresh = self.nms_thresh
|
188 |
+
x1 = dets[:, 0]
|
189 |
+
y1 = dets[:, 1]
|
190 |
+
x2 = dets[:, 2]
|
191 |
+
y2 = dets[:, 3]
|
192 |
+
scores = dets[:, 4]
|
193 |
+
|
194 |
+
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
195 |
+
order = scores.argsort()[::-1]
|
196 |
+
|
197 |
+
keep = []
|
198 |
+
while order.size > 0:
|
199 |
+
i = order[0]
|
200 |
+
keep.append(i)
|
201 |
+
xx1 = np.maximum(x1[i], x1[order[1:]])
|
202 |
+
yy1 = np.maximum(y1[i], y1[order[1:]])
|
203 |
+
xx2 = np.minimum(x2[i], x2[order[1:]])
|
204 |
+
yy2 = np.minimum(y2[i], y2[order[1:]])
|
205 |
+
|
206 |
+
w = np.maximum(0.0, xx2 - xx1 + 1)
|
207 |
+
h = np.maximum(0.0, yy2 - yy1 + 1)
|
208 |
+
inter = w * h
|
209 |
+
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
210 |
+
|
211 |
+
inds = np.where(ovr <= thresh)[0]
|
212 |
+
order = order[inds + 1]
|
213 |
+
|
214 |
+
return keep
|
215 |
+
|
core/models/appearance_extractor.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from ..utils.load_model import load_model
|
4 |
+
|
5 |
+
|
6 |
+
class AppearanceExtractor:
|
7 |
+
def __init__(self, model_path, device="cuda"):
|
8 |
+
kwargs = {
|
9 |
+
"module_name": "AppearanceFeatureExtractor",
|
10 |
+
}
|
11 |
+
self.model, self.model_type = load_model(model_path, device=device, **kwargs)
|
12 |
+
self.device = device
|
13 |
+
|
14 |
+
def __call__(self, image):
|
15 |
+
"""
|
16 |
+
image: np.ndarray, shape (1, 3, 256, 256), float32, range [0, 1]
|
17 |
+
"""
|
18 |
+
if self.model_type == "onnx":
|
19 |
+
pred = self.model.run(None, {"image": image})[0]
|
20 |
+
elif self.model_type == "tensorrt":
|
21 |
+
self.model.setup({"image": image})
|
22 |
+
self.model.infer()
|
23 |
+
pred = self.model.buffer["pred"][0].copy()
|
24 |
+
elif self.model_type == 'pytorch':
|
25 |
+
with torch.no_grad(), torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=True):
|
26 |
+
pred = self.model(torch.from_numpy(image).to(self.device)).float().cpu().numpy()
|
27 |
+
else:
|
28 |
+
raise ValueError(f"Unsupported model type: {self.model_type}")
|
29 |
+
return pred
|
core/models/decoder.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from ..utils.load_model import load_model
|
4 |
+
|
5 |
+
|
6 |
+
class Decoder:
|
7 |
+
def __init__(self, model_path, device="cuda"):
|
8 |
+
kwargs = {
|
9 |
+
"module_name": "SPADEDecoder",
|
10 |
+
}
|
11 |
+
self.model, self.model_type = load_model(model_path, device=device, **kwargs)
|
12 |
+
self.device = device
|
13 |
+
|
14 |
+
def __call__(self, feature):
|
15 |
+
|
16 |
+
if self.model_type == "onnx":
|
17 |
+
pred = self.model.run(None, {"feature": feature})[0]
|
18 |
+
elif self.model_type == "tensorrt":
|
19 |
+
self.model.setup({"feature": feature})
|
20 |
+
self.model.infer()
|
21 |
+
pred = self.model.buffer["output"][0].copy()
|
22 |
+
elif self.model_type == 'pytorch':
|
23 |
+
with torch.no_grad(), torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=True):
|
24 |
+
pred = self.model(torch.from_numpy(feature).to(self.device)).float().cpu().numpy()
|
25 |
+
else:
|
26 |
+
raise ValueError(f"Unsupported model type: {self.model_type}")
|
27 |
+
|
28 |
+
pred = np.transpose(pred[0], [1, 2, 0]).clip(0, 1) * 255 # [h, w, c]
|
29 |
+
|
30 |
+
return pred
|
core/models/lmdm.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from ..utils.load_model import load_model
|
4 |
+
|
5 |
+
|
6 |
+
def make_beta(n_timestep, cosine_s=8e-3):
|
7 |
+
timesteps = (
|
8 |
+
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
|
9 |
+
)
|
10 |
+
alphas = timesteps / (1 + cosine_s) * np.pi / 2
|
11 |
+
alphas = torch.cos(alphas).pow(2)
|
12 |
+
alphas = alphas / alphas[0]
|
13 |
+
betas = 1 - alphas[1:] / alphas[:-1]
|
14 |
+
betas = np.clip(betas, a_min=0, a_max=0.999)
|
15 |
+
return betas.numpy()
|
16 |
+
|
17 |
+
|
18 |
+
class LMDM:
|
19 |
+
def __init__(self, model_path, device="cuda", **kwargs):
|
20 |
+
kwargs["module_name"] = "LMDM"
|
21 |
+
|
22 |
+
self.model, self.model_type = load_model(model_path, device=device, **kwargs)
|
23 |
+
self.device = device
|
24 |
+
|
25 |
+
self.motion_feat_dim = kwargs.get("motion_feat_dim", 265)
|
26 |
+
self.audio_feat_dim = kwargs.get("audio_feat_dim", 1024+35)
|
27 |
+
self.seq_frames = kwargs.get("seq_frames", 80)
|
28 |
+
|
29 |
+
if self.model_type == "pytorch":
|
30 |
+
pass
|
31 |
+
else:
|
32 |
+
self._init_np()
|
33 |
+
|
34 |
+
def setup(self, sampling_timesteps):
|
35 |
+
if self.model_type == "pytorch":
|
36 |
+
self.model.setup(sampling_timesteps)
|
37 |
+
else:
|
38 |
+
self._setup_np(sampling_timesteps)
|
39 |
+
|
40 |
+
def _init_np(self):
|
41 |
+
self.sampling_timesteps = None
|
42 |
+
self.n_timestep = 1000
|
43 |
+
|
44 |
+
betas = torch.Tensor(make_beta(n_timestep=self.n_timestep))
|
45 |
+
alphas = 1.0 - betas
|
46 |
+
self.alphas_cumprod = torch.cumprod(alphas, axis=0).cpu().numpy()
|
47 |
+
|
48 |
+
def _setup_np(self, sampling_timesteps=50):
|
49 |
+
if self.sampling_timesteps == sampling_timesteps:
|
50 |
+
return
|
51 |
+
|
52 |
+
self.sampling_timesteps = sampling_timesteps
|
53 |
+
|
54 |
+
total_timesteps = self.n_timestep
|
55 |
+
eta = 1
|
56 |
+
shape = (1, self.seq_frames, self.motion_feat_dim)
|
57 |
+
|
58 |
+
times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
|
59 |
+
times = list(reversed(times.int().tolist()))
|
60 |
+
self.time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
|
61 |
+
|
62 |
+
self.time_cond_list = []
|
63 |
+
self.alpha_next_sqrt_list = []
|
64 |
+
self.sigma_list = []
|
65 |
+
self.c_list = []
|
66 |
+
self.noise_list = []
|
67 |
+
|
68 |
+
for time, time_next in self.time_pairs:
|
69 |
+
time_cond = np.full((1,), time, dtype=np.int64)
|
70 |
+
self.time_cond_list.append(time_cond)
|
71 |
+
if time_next < 0:
|
72 |
+
continue
|
73 |
+
|
74 |
+
alpha = self.alphas_cumprod[time]
|
75 |
+
alpha_next = self.alphas_cumprod[time_next]
|
76 |
+
|
77 |
+
sigma = eta * np.sqrt((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha))
|
78 |
+
c = np.sqrt(1 - alpha_next - sigma ** 2)
|
79 |
+
noise = np.random.randn(*shape).astype(np.float32)
|
80 |
+
|
81 |
+
self.alpha_next_sqrt_list.append(np.sqrt(alpha_next))
|
82 |
+
self.sigma_list.append(sigma)
|
83 |
+
self.c_list.append(c)
|
84 |
+
self.noise_list.append(noise)
|
85 |
+
|
86 |
+
def _one_step(self, x, cond_frame, cond, time_cond):
|
87 |
+
if self.model_type == "onnx":
|
88 |
+
pred = self.model.run(None, {"x": x, "cond_frame": cond_frame, "cond": cond, "time_cond": time_cond})
|
89 |
+
pred_noise, x_start = pred[0], pred[1]
|
90 |
+
elif self.model_type == "tensorrt":
|
91 |
+
self.model.setup({"x": x, "cond_frame": cond_frame, "cond": cond, "time_cond": time_cond})
|
92 |
+
self.model.infer()
|
93 |
+
pred_noise, x_start = self.model.buffer["pred_noise"][0], self.model.buffer["x_start"][0]
|
94 |
+
elif self.model_type == "pytorch":
|
95 |
+
with torch.no_grad():
|
96 |
+
pred_noise, x_start = self.model(x, cond_frame, cond, time_cond)
|
97 |
+
else:
|
98 |
+
raise ValueError(f"Unsupported model type: {self.model_type}")
|
99 |
+
|
100 |
+
return pred_noise, x_start
|
101 |
+
|
102 |
+
def _call_np(self, kp_cond, aud_cond, sampling_timesteps):
|
103 |
+
self._setup_np(sampling_timesteps)
|
104 |
+
|
105 |
+
cond_frame = kp_cond
|
106 |
+
cond = aud_cond
|
107 |
+
|
108 |
+
x = np.random.randn(1, self.seq_frames, self.motion_feat_dim).astype(np.float32)
|
109 |
+
|
110 |
+
x_start = None
|
111 |
+
i = 0
|
112 |
+
for _, time_next in self.time_pairs:
|
113 |
+
time_cond = self.time_cond_list[i]
|
114 |
+
pred_noise, x_start = self._one_step(x, cond_frame, cond, time_cond)
|
115 |
+
if time_next < 0:
|
116 |
+
x = x_start
|
117 |
+
continue
|
118 |
+
|
119 |
+
alpha_next_sqrt = self.alpha_next_sqrt_list[i]
|
120 |
+
c = self.c_list[i]
|
121 |
+
sigma = self.sigma_list[i]
|
122 |
+
noise = self.noise_list[i]
|
123 |
+
x = x_start * alpha_next_sqrt + c * pred_noise + sigma * noise
|
124 |
+
|
125 |
+
i += 1
|
126 |
+
|
127 |
+
return x
|
128 |
+
|
129 |
+
def __call__(self, kp_cond, aud_cond, sampling_timesteps):
|
130 |
+
if self.model_type == "pytorch":
|
131 |
+
pred_kp_seq = self.model.ddim_sample(
|
132 |
+
torch.from_numpy(kp_cond).to(self.device),
|
133 |
+
torch.from_numpy(aud_cond).to(self.device),
|
134 |
+
sampling_timesteps,
|
135 |
+
).cpu().numpy()
|
136 |
+
else:
|
137 |
+
pred_kp_seq = self._call_np(kp_cond, aud_cond, sampling_timesteps)
|
138 |
+
return pred_kp_seq
|
139 |
+
|
140 |
+
|
core/models/modules/LMDM.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Latent Motion Diffusion Model
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from .lmdm_modules.model import MotionDecoder
|
5 |
+
from .lmdm_modules.utils import extract, make_beta_schedule
|
6 |
+
|
7 |
+
|
8 |
+
class LMDM(nn.Module):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
motion_feat_dim=265,
|
12 |
+
audio_feat_dim=1024+35,
|
13 |
+
seq_frames=80,
|
14 |
+
checkpoint='',
|
15 |
+
device='cuda',
|
16 |
+
clip_denoised=False, # clip denoised (-1,1)
|
17 |
+
multi_cond_frame=False,
|
18 |
+
):
|
19 |
+
super().__init__()
|
20 |
+
|
21 |
+
self.motion_feat_dim = motion_feat_dim
|
22 |
+
self.audio_feat_dim = audio_feat_dim
|
23 |
+
self.seq_frames = seq_frames
|
24 |
+
self.device = device
|
25 |
+
|
26 |
+
self.n_timestep = 1000
|
27 |
+
self.clip_denoised = clip_denoised
|
28 |
+
self.guidance_weight = 2
|
29 |
+
|
30 |
+
self.model = MotionDecoder(
|
31 |
+
nfeats=motion_feat_dim,
|
32 |
+
seq_len=seq_frames,
|
33 |
+
latent_dim=512,
|
34 |
+
ff_size=1024,
|
35 |
+
num_layers=8,
|
36 |
+
num_heads=8,
|
37 |
+
dropout=0.1,
|
38 |
+
cond_feature_dim=audio_feat_dim,
|
39 |
+
multi_cond_frame=multi_cond_frame,
|
40 |
+
)
|
41 |
+
|
42 |
+
self.init_diff()
|
43 |
+
|
44 |
+
self.sampling_timesteps = None
|
45 |
+
|
46 |
+
def init_diff(self):
|
47 |
+
n_timestep = self.n_timestep
|
48 |
+
betas = torch.Tensor(
|
49 |
+
make_beta_schedule(schedule="cosine", n_timestep=n_timestep)
|
50 |
+
)
|
51 |
+
alphas = 1.0 - betas
|
52 |
+
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
53 |
+
|
54 |
+
self.register_buffer("alphas_cumprod", alphas_cumprod)
|
55 |
+
self.register_buffer(
|
56 |
+
"sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1)
|
57 |
+
)
|
58 |
+
self.register_buffer("sqrt_recip1m_alphas_cumprod", torch.sqrt(1.0 / (1.0 - alphas_cumprod)))
|
59 |
+
|
60 |
+
def predict_noise_from_start(self, x_t, t, x0):
|
61 |
+
a = extract(self.sqrt_recip1m_alphas_cumprod, t, x_t.shape)
|
62 |
+
b = extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
63 |
+
return (a * x_t - x0 / b)
|
64 |
+
|
65 |
+
def maybe_clip(self, x):
|
66 |
+
if self.clip_denoised:
|
67 |
+
return torch.clamp(x, min=-1., max=1.)
|
68 |
+
else:
|
69 |
+
return x
|
70 |
+
|
71 |
+
def model_predictions(self, x, cond_frame, cond, t):
|
72 |
+
weight = self.guidance_weight
|
73 |
+
x_start = self.model.guided_forward(x, cond_frame, cond, t, weight)
|
74 |
+
x_start = self.maybe_clip(x_start)
|
75 |
+
pred_noise = self.predict_noise_from_start(x, t, x_start)
|
76 |
+
return pred_noise, x_start
|
77 |
+
|
78 |
+
@torch.no_grad()
|
79 |
+
def forward(self, x, cond_frame, cond, time_cond):
|
80 |
+
pred_noise, x_start = self.model_predictions(x, cond_frame, cond, time_cond)
|
81 |
+
return pred_noise, x_start
|
82 |
+
|
83 |
+
def load_model(self, ckpt_path):
|
84 |
+
checkpoint = torch.load(ckpt_path, map_location='cpu')
|
85 |
+
self.model.load_state_dict(checkpoint["model_state_dict"])
|
86 |
+
self.eval()
|
87 |
+
return self
|
88 |
+
|
89 |
+
def setup(self, sampling_timesteps=50):
|
90 |
+
if self.sampling_timesteps == sampling_timesteps:
|
91 |
+
return
|
92 |
+
|
93 |
+
self.sampling_timesteps = sampling_timesteps
|
94 |
+
|
95 |
+
total_timesteps = self.n_timestep
|
96 |
+
device = self.device
|
97 |
+
eta = 1
|
98 |
+
shape = (1, self.seq_frames, self.motion_feat_dim)
|
99 |
+
|
100 |
+
times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
|
101 |
+
times = list(reversed(times.int().tolist()))
|
102 |
+
self.time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
|
103 |
+
|
104 |
+
self.time_cond_list = []
|
105 |
+
self.alpha_next_sqrt_list = []
|
106 |
+
self.sigma_list = []
|
107 |
+
self.c_list = []
|
108 |
+
self.noise_list = []
|
109 |
+
|
110 |
+
for time, time_next in self.time_pairs:
|
111 |
+
time_cond = torch.full((1,), time, device=device, dtype=torch.long)
|
112 |
+
self.time_cond_list.append(time_cond)
|
113 |
+
if time_next < 0:
|
114 |
+
continue
|
115 |
+
alpha = self.alphas_cumprod[time]
|
116 |
+
alpha_next = self.alphas_cumprod[time_next]
|
117 |
+
|
118 |
+
sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
119 |
+
c = (1 - alpha_next - sigma ** 2).sqrt()
|
120 |
+
noise = torch.randn(shape, device=device)
|
121 |
+
|
122 |
+
self.alpha_next_sqrt_list.append(alpha_next.sqrt())
|
123 |
+
self.sigma_list.append(sigma)
|
124 |
+
self.c_list.append(c)
|
125 |
+
self.noise_list.append(noise)
|
126 |
+
|
127 |
+
@torch.no_grad()
|
128 |
+
def ddim_sample(self, kp_cond, aud_cond, sampling_timesteps):
|
129 |
+
self.setup(sampling_timesteps)
|
130 |
+
|
131 |
+
cond_frame = kp_cond
|
132 |
+
cond = aud_cond
|
133 |
+
|
134 |
+
shape = (1, self.seq_frames, self.motion_feat_dim)
|
135 |
+
x = torch.randn(shape, device=self.device)
|
136 |
+
|
137 |
+
x_start = None
|
138 |
+
i = 0
|
139 |
+
for _, time_next in self.time_pairs:
|
140 |
+
time_cond = self.time_cond_list[i]
|
141 |
+
pred_noise, x_start = self.model_predictions(x, cond_frame, cond, time_cond)
|
142 |
+
if time_next < 0:
|
143 |
+
x = x_start
|
144 |
+
continue
|
145 |
+
|
146 |
+
alpha_next_sqrt = self.alpha_next_sqrt_list[i]
|
147 |
+
c = self.c_list[i]
|
148 |
+
sigma = self.sigma_list[i]
|
149 |
+
noise = self.noise_list[i]
|
150 |
+
x = x_start * alpha_next_sqrt + c * pred_noise + sigma * noise
|
151 |
+
|
152 |
+
i += 1
|
153 |
+
return x # pred_kp_seq
|
154 |
+
|
core/models/modules/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .appearance_feature_extractor import AppearanceFeatureExtractor
|
2 |
+
from .motion_extractor import MotionExtractor
|
3 |
+
from .warping_network import WarpingNetwork
|
4 |
+
from .spade_generator import SPADEDecoder
|
5 |
+
from .stitching_network import StitchingNetwork
|
6 |
+
from .LMDM import LMDM
|
core/models/modules/appearance_feature_extractor.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
|
3 |
+
"""
|
4 |
+
Appearance extractor(F) defined in paper, which maps the source image s to a 3D appearance feature volume.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
from .util import SameBlock2d, DownBlock2d, ResBlock3d
|
10 |
+
|
11 |
+
|
12 |
+
class AppearanceFeatureExtractor(nn.Module):
|
13 |
+
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
image_channel=3,
|
17 |
+
block_expansion=64,
|
18 |
+
num_down_blocks=2,
|
19 |
+
max_features=512,
|
20 |
+
reshape_channel=32,
|
21 |
+
reshape_depth=16,
|
22 |
+
num_resblocks=6,
|
23 |
+
):
|
24 |
+
super(AppearanceFeatureExtractor, self).__init__()
|
25 |
+
self.image_channel = image_channel
|
26 |
+
self.block_expansion = block_expansion
|
27 |
+
self.num_down_blocks = num_down_blocks
|
28 |
+
self.max_features = max_features
|
29 |
+
self.reshape_channel = reshape_channel
|
30 |
+
self.reshape_depth = reshape_depth
|
31 |
+
|
32 |
+
self.first = SameBlock2d(
|
33 |
+
image_channel, block_expansion, kernel_size=(3, 3), padding=(1, 1)
|
34 |
+
)
|
35 |
+
|
36 |
+
down_blocks = []
|
37 |
+
for i in range(num_down_blocks):
|
38 |
+
in_features = min(max_features, block_expansion * (2**i))
|
39 |
+
out_features = min(max_features, block_expansion * (2 ** (i + 1)))
|
40 |
+
down_blocks.append(
|
41 |
+
DownBlock2d(
|
42 |
+
in_features, out_features, kernel_size=(3, 3), padding=(1, 1)
|
43 |
+
)
|
44 |
+
)
|
45 |
+
self.down_blocks = nn.ModuleList(down_blocks)
|
46 |
+
|
47 |
+
self.second = nn.Conv2d(
|
48 |
+
in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1
|
49 |
+
)
|
50 |
+
|
51 |
+
self.resblocks_3d = torch.nn.Sequential()
|
52 |
+
for i in range(num_resblocks):
|
53 |
+
self.resblocks_3d.add_module(
|
54 |
+
"3dr" + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1)
|
55 |
+
)
|
56 |
+
|
57 |
+
def forward(self, source_image):
|
58 |
+
out = self.first(source_image) # Bx3x256x256 -> Bx64x256x256
|
59 |
+
|
60 |
+
for i in range(len(self.down_blocks)):
|
61 |
+
out = self.down_blocks[i](out)
|
62 |
+
out = self.second(out)
|
63 |
+
bs, c, h, w = out.shape # ->Bx512x64x64
|
64 |
+
|
65 |
+
f_s = out.view(
|
66 |
+
bs, self.reshape_channel, self.reshape_depth, h, w
|
67 |
+
) # ->Bx32x16x64x64
|
68 |
+
f_s = self.resblocks_3d(f_s) # ->Bx32x16x64x64
|
69 |
+
return f_s
|
70 |
+
|
71 |
+
def load_model(self, ckpt_path):
|
72 |
+
self.load_state_dict(torch.load(ckpt_path, map_location=lambda storage, loc: storage))
|
73 |
+
self.eval()
|
74 |
+
return self
|
core/models/modules/convnextv2.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
|
3 |
+
"""
|
4 |
+
This moudle is adapted to the ConvNeXtV2 version for the extraction of implicit keypoints, poses, and expression deformation.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
# from timm.models.layers import trunc_normal_, DropPath
|
10 |
+
from .util import LayerNorm, DropPath, trunc_normal_, GRN
|
11 |
+
|
12 |
+
__all__ = ['convnextv2_tiny']
|
13 |
+
|
14 |
+
|
15 |
+
class Block(nn.Module):
|
16 |
+
""" ConvNeXtV2 Block.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
dim (int): Number of input channels.
|
20 |
+
drop_path (float): Stochastic depth rate. Default: 0.0
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, dim, drop_path=0.):
|
24 |
+
super().__init__()
|
25 |
+
self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
|
26 |
+
self.norm = LayerNorm(dim, eps=1e-6)
|
27 |
+
self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
|
28 |
+
self.act = nn.GELU()
|
29 |
+
self.grn = GRN(4 * dim)
|
30 |
+
self.pwconv2 = nn.Linear(4 * dim, dim)
|
31 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
input = x
|
35 |
+
x = self.dwconv(x)
|
36 |
+
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
37 |
+
x = self.norm(x)
|
38 |
+
x = self.pwconv1(x)
|
39 |
+
x = self.act(x)
|
40 |
+
x = self.grn(x)
|
41 |
+
x = self.pwconv2(x)
|
42 |
+
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
43 |
+
|
44 |
+
x = input + self.drop_path(x)
|
45 |
+
return x
|
46 |
+
|
47 |
+
|
48 |
+
class ConvNeXtV2(nn.Module):
|
49 |
+
""" ConvNeXt V2
|
50 |
+
|
51 |
+
Args:
|
52 |
+
in_chans (int): Number of input image channels. Default: 3
|
53 |
+
num_classes (int): Number of classes for classification head. Default: 1000
|
54 |
+
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
|
55 |
+
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
|
56 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.
|
57 |
+
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
|
58 |
+
"""
|
59 |
+
|
60 |
+
def __init__(
|
61 |
+
self,
|
62 |
+
in_chans=3,
|
63 |
+
depths=[3, 3, 9, 3],
|
64 |
+
dims=[96, 192, 384, 768],
|
65 |
+
drop_path_rate=0.,
|
66 |
+
**kwargs
|
67 |
+
):
|
68 |
+
super().__init__()
|
69 |
+
self.depths = depths
|
70 |
+
self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
|
71 |
+
stem = nn.Sequential(
|
72 |
+
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
|
73 |
+
LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
|
74 |
+
)
|
75 |
+
self.downsample_layers.append(stem)
|
76 |
+
for i in range(3):
|
77 |
+
downsample_layer = nn.Sequential(
|
78 |
+
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
|
79 |
+
nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
|
80 |
+
)
|
81 |
+
self.downsample_layers.append(downsample_layer)
|
82 |
+
|
83 |
+
self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
|
84 |
+
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
85 |
+
cur = 0
|
86 |
+
for i in range(4):
|
87 |
+
stage = nn.Sequential(
|
88 |
+
*[Block(dim=dims[i], drop_path=dp_rates[cur + j]) for j in range(depths[i])]
|
89 |
+
)
|
90 |
+
self.stages.append(stage)
|
91 |
+
cur += depths[i]
|
92 |
+
|
93 |
+
self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
|
94 |
+
|
95 |
+
# NOTE: the output semantic items
|
96 |
+
num_bins = kwargs.get('num_bins', 66)
|
97 |
+
num_kp = kwargs.get('num_kp', 24) # the number of implicit keypoints
|
98 |
+
self.fc_kp = nn.Linear(dims[-1], 3 * num_kp) # implicit keypoints
|
99 |
+
|
100 |
+
# print('dims[-1]: ', dims[-1])
|
101 |
+
self.fc_scale = nn.Linear(dims[-1], 1) # scale
|
102 |
+
self.fc_pitch = nn.Linear(dims[-1], num_bins) # pitch bins
|
103 |
+
self.fc_yaw = nn.Linear(dims[-1], num_bins) # yaw bins
|
104 |
+
self.fc_roll = nn.Linear(dims[-1], num_bins) # roll bins
|
105 |
+
self.fc_t = nn.Linear(dims[-1], 3) # translation
|
106 |
+
self.fc_exp = nn.Linear(dims[-1], 3 * num_kp) # expression / delta
|
107 |
+
|
108 |
+
def _init_weights(self, m):
|
109 |
+
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
110 |
+
trunc_normal_(m.weight, std=.02)
|
111 |
+
nn.init.constant_(m.bias, 0)
|
112 |
+
|
113 |
+
def forward_features(self, x):
|
114 |
+
for i in range(4):
|
115 |
+
x = self.downsample_layers[i](x)
|
116 |
+
x = self.stages[i](x)
|
117 |
+
return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)
|
118 |
+
|
119 |
+
def forward(self, x):
|
120 |
+
x = self.forward_features(x)
|
121 |
+
|
122 |
+
# implicit keypoints
|
123 |
+
kp = self.fc_kp(x)
|
124 |
+
|
125 |
+
# pose and expression deformation
|
126 |
+
pitch = self.fc_pitch(x)
|
127 |
+
yaw = self.fc_yaw(x)
|
128 |
+
roll = self.fc_roll(x)
|
129 |
+
t = self.fc_t(x)
|
130 |
+
exp = self.fc_exp(x)
|
131 |
+
scale = self.fc_scale(x)
|
132 |
+
|
133 |
+
# ret_dct = {
|
134 |
+
# 'pitch': pitch,
|
135 |
+
# 'yaw': yaw,
|
136 |
+
# 'roll': roll,
|
137 |
+
# 't': t,
|
138 |
+
# 'exp': exp,
|
139 |
+
# 'scale': scale,
|
140 |
+
|
141 |
+
# 'kp': kp, # canonical keypoint
|
142 |
+
# }
|
143 |
+
|
144 |
+
# return ret_dct
|
145 |
+
return pitch, yaw, roll, t, exp, scale, kp
|
146 |
+
|
147 |
+
|
148 |
+
def convnextv2_tiny(**kwargs):
|
149 |
+
model = ConvNeXtV2(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
|
150 |
+
return model
|
core/models/modules/dense_motion.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
|
3 |
+
"""
|
4 |
+
The module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving
|
5 |
+
"""
|
6 |
+
|
7 |
+
from torch import nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torch
|
10 |
+
from .util import Hourglass, make_coordinate_grid, kp2gaussian
|
11 |
+
|
12 |
+
|
13 |
+
class DenseMotionNetwork(nn.Module):
|
14 |
+
def __init__(self, block_expansion, num_blocks, max_features, num_kp, feature_channel, reshape_depth, compress, estimate_occlusion_map=True):
|
15 |
+
super(DenseMotionNetwork, self).__init__()
|
16 |
+
self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(compress+1), max_features=max_features, num_blocks=num_blocks) # ~60+G
|
17 |
+
|
18 |
+
self.mask = nn.Conv3d(self.hourglass.out_filters, num_kp + 1, kernel_size=7, padding=3) # 65G! NOTE: computation cost is large
|
19 |
+
self.compress = nn.Conv3d(feature_channel, compress, kernel_size=1) # 0.8G
|
20 |
+
self.norm = nn.BatchNorm3d(compress, affine=True)
|
21 |
+
self.num_kp = num_kp
|
22 |
+
self.flag_estimate_occlusion_map = estimate_occlusion_map
|
23 |
+
|
24 |
+
if self.flag_estimate_occlusion_map:
|
25 |
+
self.occlusion = nn.Conv2d(self.hourglass.out_filters*reshape_depth, 1, kernel_size=7, padding=3)
|
26 |
+
else:
|
27 |
+
self.occlusion = None
|
28 |
+
|
29 |
+
def create_sparse_motions(self, feature, kp_driving, kp_source):
|
30 |
+
bs, _, d, h, w = feature.shape # (bs, 4, 16, 64, 64)
|
31 |
+
identity_grid = make_coordinate_grid((d, h, w), ref=kp_source) # (16, 64, 64, 3)
|
32 |
+
identity_grid = identity_grid.view(1, 1, d, h, w, 3) # (1, 1, d=16, h=64, w=64, 3)
|
33 |
+
coordinate_grid = identity_grid - kp_driving.view(bs, self.num_kp, 1, 1, 1, 3)
|
34 |
+
|
35 |
+
k = coordinate_grid.shape[1]
|
36 |
+
|
37 |
+
# NOTE: there lacks an one-order flow
|
38 |
+
driving_to_source = coordinate_grid + kp_source.view(bs, self.num_kp, 1, 1, 1, 3) # (bs, num_kp, d, h, w, 3)
|
39 |
+
|
40 |
+
# adding background feature
|
41 |
+
identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1, 1)
|
42 |
+
sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1) # (bs, 1+num_kp, d, h, w, 3)
|
43 |
+
return sparse_motions
|
44 |
+
|
45 |
+
def create_deformed_feature(self, feature, sparse_motions):
|
46 |
+
bs, _, d, h, w = feature.shape
|
47 |
+
feature_repeat = feature.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp+1, 1, 1, 1, 1, 1) # (bs, num_kp+1, 1, c, d, h, w)
|
48 |
+
feature_repeat = feature_repeat.view(bs * (self.num_kp+1), -1, d, h, w) # (bs*(num_kp+1), c, d, h, w)
|
49 |
+
sparse_motions = sparse_motions.view((bs * (self.num_kp+1), d, h, w, -1)) # (bs*(num_kp+1), d, h, w, 3)
|
50 |
+
sparse_deformed = F.grid_sample(feature_repeat, sparse_motions, align_corners=False)
|
51 |
+
sparse_deformed = sparse_deformed.view((bs, self.num_kp+1, -1, d, h, w)) # (bs, num_kp+1, c, d, h, w)
|
52 |
+
|
53 |
+
return sparse_deformed
|
54 |
+
|
55 |
+
def create_heatmap_representations(self, feature, kp_driving, kp_source):
|
56 |
+
spatial_size = feature.shape[3:] # (d=16, h=64, w=64)
|
57 |
+
gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=0.01) # (bs, num_kp, d, h, w)
|
58 |
+
gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=0.01) # (bs, num_kp, d, h, w)
|
59 |
+
heatmap = gaussian_driving - gaussian_source # (bs, num_kp, d, h, w)
|
60 |
+
|
61 |
+
# adding background feature
|
62 |
+
zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1], spatial_size[2]).type(heatmap.dtype).to(heatmap.device)
|
63 |
+
heatmap = torch.cat([zeros, heatmap], dim=1)
|
64 |
+
heatmap = heatmap.unsqueeze(2) # (bs, 1+num_kp, 1, d, h, w)
|
65 |
+
return heatmap
|
66 |
+
|
67 |
+
def forward(self, feature, kp_driving, kp_source):
|
68 |
+
bs, _, d, h, w = feature.shape # (bs, 32, 16, 64, 64)
|
69 |
+
|
70 |
+
feature = self.compress(feature) # (bs, 4, 16, 64, 64)
|
71 |
+
feature = self.norm(feature) # (bs, 4, 16, 64, 64)
|
72 |
+
feature = F.relu(feature) # (bs, 4, 16, 64, 64)
|
73 |
+
|
74 |
+
out_dict = dict()
|
75 |
+
|
76 |
+
# 1. deform 3d feature
|
77 |
+
sparse_motion = self.create_sparse_motions(feature, kp_driving, kp_source) # (bs, 1+num_kp, d, h, w, 3)
|
78 |
+
deformed_feature = self.create_deformed_feature(feature, sparse_motion) # (bs, 1+num_kp, c=4, d=16, h=64, w=64)
|
79 |
+
|
80 |
+
# 2. (bs, 1+num_kp, d, h, w)
|
81 |
+
heatmap = self.create_heatmap_representations(deformed_feature, kp_driving, kp_source) # (bs, 1+num_kp, 1, d, h, w)
|
82 |
+
|
83 |
+
input = torch.cat([heatmap, deformed_feature], dim=2) # (bs, 1+num_kp, c=5, d=16, h=64, w=64)
|
84 |
+
input = input.view(bs, -1, d, h, w) # (bs, (1+num_kp)*c=105, d=16, h=64, w=64)
|
85 |
+
|
86 |
+
prediction = self.hourglass(input)
|
87 |
+
|
88 |
+
mask = self.mask(prediction)
|
89 |
+
mask = F.softmax(mask, dim=1) # (bs, 1+num_kp, d=16, h=64, w=64)
|
90 |
+
out_dict['mask'] = mask
|
91 |
+
mask = mask.unsqueeze(2) # (bs, num_kp+1, 1, d, h, w)
|
92 |
+
sparse_motion = sparse_motion.permute(0, 1, 5, 2, 3, 4) # (bs, num_kp+1, 3, d, h, w)
|
93 |
+
deformation = (sparse_motion * mask).sum(dim=1) # (bs, 3, d, h, w) mask take effect in this place
|
94 |
+
deformation = deformation.permute(0, 2, 3, 4, 1) # (bs, d, h, w, 3)
|
95 |
+
|
96 |
+
out_dict['deformation'] = deformation
|
97 |
+
|
98 |
+
if self.flag_estimate_occlusion_map:
|
99 |
+
bs, _, d, h, w = prediction.shape
|
100 |
+
prediction_reshape = prediction.view(bs, -1, h, w)
|
101 |
+
occlusion_map = torch.sigmoid(self.occlusion(prediction_reshape)) # Bx1x64x64
|
102 |
+
out_dict['occlusion_map'] = occlusion_map
|
103 |
+
|
104 |
+
return out_dict
|
core/models/modules/lmdm_modules/model.py
ADDED
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, Optional, Union
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from einops import rearrange
|
5 |
+
from einops.layers.torch import Rearrange
|
6 |
+
from torch import Tensor
|
7 |
+
from torch.nn import functional as F
|
8 |
+
|
9 |
+
from .rotary_embedding_torch import RotaryEmbedding
|
10 |
+
from .utils import PositionalEncoding, SinusoidalPosEmb, prob_mask_like
|
11 |
+
|
12 |
+
|
13 |
+
class DenseFiLM(nn.Module):
|
14 |
+
"""Feature-wise linear modulation (FiLM) generator."""
|
15 |
+
|
16 |
+
def __init__(self, embed_channels):
|
17 |
+
super().__init__()
|
18 |
+
self.embed_channels = embed_channels
|
19 |
+
self.block = nn.Sequential(
|
20 |
+
nn.Mish(), nn.Linear(embed_channels, embed_channels * 2)
|
21 |
+
)
|
22 |
+
|
23 |
+
def forward(self, position):
|
24 |
+
pos_encoding = self.block(position)
|
25 |
+
pos_encoding = rearrange(pos_encoding, "b c -> b 1 c")
|
26 |
+
scale_shift = pos_encoding.chunk(2, dim=-1)
|
27 |
+
return scale_shift
|
28 |
+
|
29 |
+
|
30 |
+
def featurewise_affine(x, scale_shift):
|
31 |
+
scale, shift = scale_shift
|
32 |
+
return (scale + 1) * x + shift
|
33 |
+
|
34 |
+
|
35 |
+
class TransformerEncoderLayer(nn.Module):
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
d_model: int,
|
39 |
+
nhead: int,
|
40 |
+
dim_feedforward: int = 2048,
|
41 |
+
dropout: float = 0.1,
|
42 |
+
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
|
43 |
+
layer_norm_eps: float = 1e-5,
|
44 |
+
batch_first: bool = False,
|
45 |
+
norm_first: bool = True,
|
46 |
+
device=None,
|
47 |
+
dtype=None,
|
48 |
+
rotary=None,
|
49 |
+
) -> None:
|
50 |
+
super().__init__()
|
51 |
+
self.self_attn = nn.MultiheadAttention(
|
52 |
+
d_model, nhead, dropout=dropout, batch_first=batch_first
|
53 |
+
)
|
54 |
+
# Implementation of Feedforward model
|
55 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
56 |
+
self.dropout = nn.Dropout(dropout)
|
57 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
58 |
+
|
59 |
+
self.norm_first = norm_first
|
60 |
+
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
61 |
+
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
62 |
+
self.dropout1 = nn.Dropout(dropout)
|
63 |
+
self.dropout2 = nn.Dropout(dropout)
|
64 |
+
self.activation = activation
|
65 |
+
|
66 |
+
self.rotary = rotary
|
67 |
+
self.use_rotary = rotary is not None
|
68 |
+
|
69 |
+
def forward(
|
70 |
+
self,
|
71 |
+
src: Tensor,
|
72 |
+
src_mask: Optional[Tensor] = None,
|
73 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
74 |
+
) -> Tensor:
|
75 |
+
x = src
|
76 |
+
if self.norm_first:
|
77 |
+
x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
|
78 |
+
x = x + self._ff_block(self.norm2(x))
|
79 |
+
else:
|
80 |
+
x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
|
81 |
+
x = self.norm2(x + self._ff_block(x))
|
82 |
+
|
83 |
+
return x
|
84 |
+
|
85 |
+
# self-attention block
|
86 |
+
def _sa_block(
|
87 |
+
self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]
|
88 |
+
) -> Tensor:
|
89 |
+
qk = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x
|
90 |
+
x = self.self_attn(
|
91 |
+
qk,
|
92 |
+
qk,
|
93 |
+
x,
|
94 |
+
attn_mask=attn_mask,
|
95 |
+
key_padding_mask=key_padding_mask,
|
96 |
+
need_weights=False,
|
97 |
+
)[0]
|
98 |
+
return self.dropout1(x)
|
99 |
+
|
100 |
+
# feed forward block
|
101 |
+
def _ff_block(self, x: Tensor) -> Tensor:
|
102 |
+
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
103 |
+
return self.dropout2(x)
|
104 |
+
|
105 |
+
|
106 |
+
class FiLMTransformerDecoderLayer(nn.Module):
|
107 |
+
def __init__(
|
108 |
+
self,
|
109 |
+
d_model: int,
|
110 |
+
nhead: int,
|
111 |
+
dim_feedforward=2048,
|
112 |
+
dropout=0.1,
|
113 |
+
activation=F.relu,
|
114 |
+
layer_norm_eps=1e-5,
|
115 |
+
batch_first=False,
|
116 |
+
norm_first=True,
|
117 |
+
device=None,
|
118 |
+
dtype=None,
|
119 |
+
rotary=None,
|
120 |
+
):
|
121 |
+
super().__init__()
|
122 |
+
self.self_attn = nn.MultiheadAttention(
|
123 |
+
d_model, nhead, dropout=dropout, batch_first=batch_first
|
124 |
+
)
|
125 |
+
self.multihead_attn = nn.MultiheadAttention(
|
126 |
+
d_model, nhead, dropout=dropout, batch_first=batch_first
|
127 |
+
)
|
128 |
+
# Feedforward
|
129 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
130 |
+
self.dropout = nn.Dropout(dropout)
|
131 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
132 |
+
|
133 |
+
self.norm_first = norm_first
|
134 |
+
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
135 |
+
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
136 |
+
self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
137 |
+
self.dropout1 = nn.Dropout(dropout)
|
138 |
+
self.dropout2 = nn.Dropout(dropout)
|
139 |
+
self.dropout3 = nn.Dropout(dropout)
|
140 |
+
self.activation = activation
|
141 |
+
|
142 |
+
self.film1 = DenseFiLM(d_model)
|
143 |
+
self.film2 = DenseFiLM(d_model)
|
144 |
+
self.film3 = DenseFiLM(d_model)
|
145 |
+
|
146 |
+
self.rotary = rotary
|
147 |
+
self.use_rotary = rotary is not None
|
148 |
+
|
149 |
+
# x, cond, t
|
150 |
+
def forward(
|
151 |
+
self,
|
152 |
+
tgt,
|
153 |
+
memory,
|
154 |
+
t,
|
155 |
+
tgt_mask=None,
|
156 |
+
memory_mask=None,
|
157 |
+
tgt_key_padding_mask=None,
|
158 |
+
memory_key_padding_mask=None,
|
159 |
+
):
|
160 |
+
x = tgt
|
161 |
+
if self.norm_first:
|
162 |
+
# self-attention -> film -> residual
|
163 |
+
x_1 = self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask)
|
164 |
+
x = x + featurewise_affine(x_1, self.film1(t))
|
165 |
+
# cross-attention -> film -> residual
|
166 |
+
x_2 = self._mha_block(
|
167 |
+
self.norm2(x), memory, memory_mask, memory_key_padding_mask
|
168 |
+
)
|
169 |
+
x = x + featurewise_affine(x_2, self.film2(t))
|
170 |
+
# feedforward -> film -> residual
|
171 |
+
x_3 = self._ff_block(self.norm3(x))
|
172 |
+
x = x + featurewise_affine(x_3, self.film3(t))
|
173 |
+
else:
|
174 |
+
x = self.norm1(
|
175 |
+
x
|
176 |
+
+ featurewise_affine(
|
177 |
+
self._sa_block(x, tgt_mask, tgt_key_padding_mask), self.film1(t)
|
178 |
+
)
|
179 |
+
)
|
180 |
+
x = self.norm2(
|
181 |
+
x
|
182 |
+
+ featurewise_affine(
|
183 |
+
self._mha_block(x, memory, memory_mask, memory_key_padding_mask),
|
184 |
+
self.film2(t),
|
185 |
+
)
|
186 |
+
)
|
187 |
+
x = self.norm3(x + featurewise_affine(self._ff_block(x), self.film3(t)))
|
188 |
+
return x
|
189 |
+
|
190 |
+
# self-attention block
|
191 |
+
# qkv
|
192 |
+
def _sa_block(self, x, attn_mask, key_padding_mask):
|
193 |
+
qk = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x
|
194 |
+
x = self.self_attn(
|
195 |
+
qk,
|
196 |
+
qk,
|
197 |
+
x,
|
198 |
+
attn_mask=attn_mask,
|
199 |
+
key_padding_mask=key_padding_mask,
|
200 |
+
need_weights=False,
|
201 |
+
)[0]
|
202 |
+
return self.dropout1(x)
|
203 |
+
|
204 |
+
# multihead attention block
|
205 |
+
# qkv
|
206 |
+
def _mha_block(self, x, mem, attn_mask, key_padding_mask):
|
207 |
+
q = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x
|
208 |
+
k = self.rotary.rotate_queries_or_keys(mem) if self.use_rotary else mem
|
209 |
+
x = self.multihead_attn(
|
210 |
+
q,
|
211 |
+
k,
|
212 |
+
mem,
|
213 |
+
attn_mask=attn_mask,
|
214 |
+
key_padding_mask=key_padding_mask,
|
215 |
+
need_weights=False,
|
216 |
+
)[0]
|
217 |
+
return self.dropout2(x)
|
218 |
+
|
219 |
+
# feed forward block
|
220 |
+
def _ff_block(self, x):
|
221 |
+
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
222 |
+
return self.dropout3(x)
|
223 |
+
|
224 |
+
|
225 |
+
class DecoderLayerStack(nn.Module):
|
226 |
+
def __init__(self, stack):
|
227 |
+
super().__init__()
|
228 |
+
self.stack = stack
|
229 |
+
|
230 |
+
def forward(self, x, cond, t):
|
231 |
+
for layer in self.stack:
|
232 |
+
x = layer(x, cond, t)
|
233 |
+
return x
|
234 |
+
|
235 |
+
|
236 |
+
class MotionDecoder(nn.Module):
|
237 |
+
def __init__(
|
238 |
+
self,
|
239 |
+
nfeats: int,
|
240 |
+
seq_len: int = 100, # 4 seconds, 25 fps
|
241 |
+
latent_dim: int = 256,
|
242 |
+
ff_size: int = 1024,
|
243 |
+
num_layers: int = 4,
|
244 |
+
num_heads: int = 4,
|
245 |
+
dropout: float = 0.1,
|
246 |
+
cond_feature_dim: int = 4800,
|
247 |
+
activation: Callable[[Tensor], Tensor] = F.gelu,
|
248 |
+
use_rotary=True,
|
249 |
+
multi_cond_frame=False,
|
250 |
+
**kwargs
|
251 |
+
) -> None:
|
252 |
+
|
253 |
+
super().__init__()
|
254 |
+
|
255 |
+
self.multi_cond_frame = multi_cond_frame
|
256 |
+
|
257 |
+
output_feats = nfeats
|
258 |
+
|
259 |
+
# positional embeddings
|
260 |
+
self.rotary = None
|
261 |
+
self.abs_pos_encoding = nn.Identity()
|
262 |
+
# if rotary, replace absolute embedding with a rotary embedding instance (absolute becomes an identity)
|
263 |
+
if use_rotary:
|
264 |
+
self.rotary = RotaryEmbedding(dim=latent_dim)
|
265 |
+
else:
|
266 |
+
self.abs_pos_encoding = PositionalEncoding(
|
267 |
+
latent_dim, dropout, batch_first=True
|
268 |
+
)
|
269 |
+
|
270 |
+
# time embedding processing
|
271 |
+
self.time_mlp = nn.Sequential(
|
272 |
+
SinusoidalPosEmb(latent_dim), # learned?
|
273 |
+
nn.Linear(latent_dim, latent_dim * 4),
|
274 |
+
nn.Mish(),
|
275 |
+
)
|
276 |
+
|
277 |
+
self.to_time_cond = nn.Sequential(nn.Linear(latent_dim * 4, latent_dim),)
|
278 |
+
|
279 |
+
self.to_time_tokens = nn.Sequential(
|
280 |
+
nn.Linear(latent_dim * 4, latent_dim * 2), # 2 time tokens
|
281 |
+
Rearrange("b (r d) -> b r d", r=2),
|
282 |
+
)
|
283 |
+
|
284 |
+
# null embeddings for guidance dropout
|
285 |
+
self.null_cond_embed = nn.Parameter(torch.randn(1, seq_len, latent_dim))
|
286 |
+
self.null_cond_hidden = nn.Parameter(torch.randn(1, latent_dim))
|
287 |
+
|
288 |
+
self.norm_cond = nn.LayerNorm(latent_dim)
|
289 |
+
|
290 |
+
# input projection
|
291 |
+
if self.multi_cond_frame:
|
292 |
+
self.input_projection = nn.Linear(nfeats * 2 + 1, latent_dim)
|
293 |
+
else:
|
294 |
+
self.input_projection = nn.Linear(nfeats * 2, latent_dim)
|
295 |
+
self.cond_encoder = nn.Sequential()
|
296 |
+
for _ in range(2):
|
297 |
+
self.cond_encoder.append(
|
298 |
+
TransformerEncoderLayer(
|
299 |
+
d_model=latent_dim,
|
300 |
+
nhead=num_heads,
|
301 |
+
dim_feedforward=ff_size,
|
302 |
+
dropout=dropout,
|
303 |
+
activation=activation,
|
304 |
+
batch_first=True,
|
305 |
+
rotary=self.rotary,
|
306 |
+
)
|
307 |
+
)
|
308 |
+
# conditional projection
|
309 |
+
self.cond_projection = nn.Linear(cond_feature_dim, latent_dim)
|
310 |
+
self.non_attn_cond_projection = nn.Sequential(
|
311 |
+
nn.LayerNorm(latent_dim),
|
312 |
+
nn.Linear(latent_dim, latent_dim),
|
313 |
+
nn.SiLU(),
|
314 |
+
nn.Linear(latent_dim, latent_dim),
|
315 |
+
)
|
316 |
+
# decoder
|
317 |
+
decoderstack = nn.ModuleList([])
|
318 |
+
for _ in range(num_layers):
|
319 |
+
decoderstack.append(
|
320 |
+
FiLMTransformerDecoderLayer(
|
321 |
+
latent_dim,
|
322 |
+
num_heads,
|
323 |
+
dim_feedforward=ff_size,
|
324 |
+
dropout=dropout,
|
325 |
+
activation=activation,
|
326 |
+
batch_first=True,
|
327 |
+
rotary=self.rotary,
|
328 |
+
)
|
329 |
+
)
|
330 |
+
|
331 |
+
self.seqTransDecoder = DecoderLayerStack(decoderstack)
|
332 |
+
|
333 |
+
self.final_layer = nn.Linear(latent_dim, output_feats)
|
334 |
+
|
335 |
+
self.epsilon = 0.00001
|
336 |
+
|
337 |
+
def guided_forward(self, x, cond_frame, cond_embed, times, guidance_weight):
|
338 |
+
unc = self.forward(x, cond_frame, cond_embed, times, cond_drop_prob=1)
|
339 |
+
conditioned = self.forward(x, cond_frame, cond_embed, times, cond_drop_prob=0)
|
340 |
+
|
341 |
+
return unc + (conditioned - unc) * guidance_weight
|
342 |
+
|
343 |
+
def forward(
|
344 |
+
self, x: Tensor, cond_frame: Tensor, cond_embed: Tensor, times: Tensor, cond_drop_prob: float = 0.0
|
345 |
+
):
|
346 |
+
batch_size, device = x.shape[0], x.device
|
347 |
+
|
348 |
+
# concat last frame, project to latent space
|
349 |
+
# cond_frame: [b, dim] | [b, n, dim+1]
|
350 |
+
if self.multi_cond_frame:
|
351 |
+
# [b, n, dim+1] (+1 mask)
|
352 |
+
x = torch.cat([x, cond_frame], dim=-1)
|
353 |
+
else:
|
354 |
+
# [b, dim]
|
355 |
+
x = torch.cat([x, cond_frame.unsqueeze(1).repeat(1, x.shape[1], 1)], dim=-1)
|
356 |
+
x = self.input_projection(x)
|
357 |
+
# add the positional embeddings of the input sequence to provide temporal information
|
358 |
+
x = self.abs_pos_encoding(x)
|
359 |
+
|
360 |
+
# create audio conditional embedding with conditional dropout
|
361 |
+
keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device=device)
|
362 |
+
keep_mask_embed = rearrange(keep_mask, "b -> b 1 1")
|
363 |
+
keep_mask_hidden = rearrange(keep_mask, "b -> b 1")
|
364 |
+
|
365 |
+
cond_tokens = self.cond_projection(cond_embed)
|
366 |
+
# encode tokens
|
367 |
+
cond_tokens = self.abs_pos_encoding(cond_tokens)
|
368 |
+
cond_tokens = self.cond_encoder(cond_tokens)
|
369 |
+
|
370 |
+
null_cond_embed = self.null_cond_embed.to(cond_tokens.dtype)
|
371 |
+
cond_tokens = torch.where(keep_mask_embed, cond_tokens, null_cond_embed)
|
372 |
+
|
373 |
+
mean_pooled_cond_tokens = cond_tokens.mean(dim=-2)
|
374 |
+
cond_hidden = self.non_attn_cond_projection(mean_pooled_cond_tokens)
|
375 |
+
|
376 |
+
# create the diffusion timestep embedding, add the extra audio projection
|
377 |
+
t_hidden = self.time_mlp(times)
|
378 |
+
|
379 |
+
# project to attention and FiLM conditioning
|
380 |
+
t = self.to_time_cond(t_hidden)
|
381 |
+
t_tokens = self.to_time_tokens(t_hidden)
|
382 |
+
|
383 |
+
# FiLM conditioning
|
384 |
+
null_cond_hidden = self.null_cond_hidden.to(t.dtype)
|
385 |
+
cond_hidden = torch.where(keep_mask_hidden, cond_hidden, null_cond_hidden)
|
386 |
+
t += cond_hidden
|
387 |
+
|
388 |
+
# cross-attention conditioning
|
389 |
+
c = torch.cat((cond_tokens, t_tokens), dim=-2)
|
390 |
+
cond_tokens = self.norm_cond(c)
|
391 |
+
|
392 |
+
# Pass through the transformer decoder
|
393 |
+
# attending to the conditional embedding
|
394 |
+
output = self.seqTransDecoder(x, cond_tokens, t)
|
395 |
+
|
396 |
+
output = self.final_layer(output)
|
397 |
+
|
398 |
+
return output
|
core/models/modules/lmdm_modules/rotary_embedding_torch.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from inspect import isfunction
|
2 |
+
from math import log, pi
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from einops import rearrange, repeat
|
6 |
+
from torch import einsum, nn
|
7 |
+
|
8 |
+
# helper functions
|
9 |
+
|
10 |
+
|
11 |
+
def exists(val):
|
12 |
+
return val is not None
|
13 |
+
|
14 |
+
|
15 |
+
def broadcat(tensors, dim=-1):
|
16 |
+
num_tensors = len(tensors)
|
17 |
+
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
|
18 |
+
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
|
19 |
+
shape_len = list(shape_lens)[0]
|
20 |
+
|
21 |
+
dim = (dim + shape_len) if dim < 0 else dim
|
22 |
+
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
|
23 |
+
|
24 |
+
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
25 |
+
assert all(
|
26 |
+
[*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
|
27 |
+
), "invalid dimensions for broadcastable concatentation"
|
28 |
+
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
|
29 |
+
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
|
30 |
+
expanded_dims.insert(dim, (dim, dims[dim]))
|
31 |
+
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
|
32 |
+
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
|
33 |
+
return torch.cat(tensors, dim=dim)
|
34 |
+
|
35 |
+
|
36 |
+
# rotary embedding helper functions
|
37 |
+
|
38 |
+
|
39 |
+
def rotate_half(x):
|
40 |
+
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
41 |
+
x1, x2 = x.unbind(dim=-1)
|
42 |
+
x = torch.stack((-x2, x1), dim=-1)
|
43 |
+
return rearrange(x, "... d r -> ... (d r)")
|
44 |
+
|
45 |
+
|
46 |
+
def apply_rotary_emb(freqs, t, start_index=0):
|
47 |
+
freqs = freqs.to(t)
|
48 |
+
rot_dim = freqs.shape[-1]
|
49 |
+
end_index = start_index + rot_dim
|
50 |
+
assert (
|
51 |
+
rot_dim <= t.shape[-1]
|
52 |
+
), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
|
53 |
+
t_left, t, t_right = (
|
54 |
+
t[..., :start_index],
|
55 |
+
t[..., start_index:end_index],
|
56 |
+
t[..., end_index:],
|
57 |
+
)
|
58 |
+
t = (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
|
59 |
+
return torch.cat((t_left, t, t_right), dim=-1)
|
60 |
+
|
61 |
+
|
62 |
+
# learned rotation helpers
|
63 |
+
|
64 |
+
|
65 |
+
def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None):
|
66 |
+
if exists(freq_ranges):
|
67 |
+
rotations = einsum("..., f -> ... f", rotations, freq_ranges)
|
68 |
+
rotations = rearrange(rotations, "... r f -> ... (r f)")
|
69 |
+
|
70 |
+
rotations = repeat(rotations, "... n -> ... (n r)", r=2)
|
71 |
+
return apply_rotary_emb(rotations, t, start_index=start_index)
|
72 |
+
|
73 |
+
|
74 |
+
# classes
|
75 |
+
|
76 |
+
|
77 |
+
class RotaryEmbedding(nn.Module):
|
78 |
+
def __init__(
|
79 |
+
self,
|
80 |
+
dim,
|
81 |
+
custom_freqs=None,
|
82 |
+
freqs_for="lang",
|
83 |
+
theta=10000,
|
84 |
+
max_freq=10,
|
85 |
+
num_freqs=1,
|
86 |
+
learned_freq=False,
|
87 |
+
):
|
88 |
+
super().__init__()
|
89 |
+
if exists(custom_freqs):
|
90 |
+
freqs = custom_freqs
|
91 |
+
elif freqs_for == "lang":
|
92 |
+
freqs = 1.0 / (
|
93 |
+
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
|
94 |
+
)
|
95 |
+
elif freqs_for == "pixel":
|
96 |
+
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
|
97 |
+
elif freqs_for == "constant":
|
98 |
+
freqs = torch.ones(num_freqs).float()
|
99 |
+
else:
|
100 |
+
raise ValueError(f"unknown modality {freqs_for}")
|
101 |
+
|
102 |
+
self.cache = dict()
|
103 |
+
|
104 |
+
if learned_freq:
|
105 |
+
self.freqs = nn.Parameter(freqs)
|
106 |
+
else:
|
107 |
+
self.register_buffer("freqs", freqs)
|
108 |
+
|
109 |
+
def rotate_queries_or_keys(self, t, seq_dim=-2):
|
110 |
+
device = t.device
|
111 |
+
seq_len = t.shape[seq_dim]
|
112 |
+
freqs = self.forward(
|
113 |
+
lambda: torch.arange(seq_len, device=device), cache_key=seq_len
|
114 |
+
)
|
115 |
+
return apply_rotary_emb(freqs, t)
|
116 |
+
|
117 |
+
def forward(self, t, cache_key=None):
|
118 |
+
if exists(cache_key) and cache_key in self.cache:
|
119 |
+
return self.cache[cache_key]
|
120 |
+
|
121 |
+
if isfunction(t):
|
122 |
+
t = t()
|
123 |
+
|
124 |
+
freqs = self.freqs
|
125 |
+
|
126 |
+
freqs = torch.einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
|
127 |
+
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
|
128 |
+
|
129 |
+
if exists(cache_key):
|
130 |
+
self.cache[cache_key] = freqs
|
131 |
+
|
132 |
+
return freqs
|
core/models/modules/lmdm_modules/utils.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
|
7 |
+
# absolute positional embedding used for vanilla transformer sequential data
|
8 |
+
class PositionalEncoding(nn.Module):
|
9 |
+
def __init__(self, d_model, dropout=0.1, max_len=500, batch_first=False):
|
10 |
+
super().__init__()
|
11 |
+
self.batch_first = batch_first
|
12 |
+
|
13 |
+
self.dropout = nn.Dropout(p=dropout)
|
14 |
+
|
15 |
+
pe = torch.zeros(max_len, d_model)
|
16 |
+
position = torch.arange(0, max_len).unsqueeze(1)
|
17 |
+
div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
|
18 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
19 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
20 |
+
pe = pe.unsqueeze(0).transpose(0, 1)
|
21 |
+
|
22 |
+
self.register_buffer("pe", pe)
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
if self.batch_first:
|
26 |
+
x = x + self.pe.permute(1, 0, 2)[:, : x.shape[1], :]
|
27 |
+
else:
|
28 |
+
x = x + self.pe[: x.shape[0], :]
|
29 |
+
return self.dropout(x)
|
30 |
+
|
31 |
+
|
32 |
+
# very similar positional embedding used for diffusion timesteps
|
33 |
+
class SinusoidalPosEmb(nn.Module):
|
34 |
+
def __init__(self, dim):
|
35 |
+
super().__init__()
|
36 |
+
self.dim = dim
|
37 |
+
|
38 |
+
def forward(self, x):
|
39 |
+
device = x.device
|
40 |
+
half_dim = self.dim // 2
|
41 |
+
emb = math.log(10000) / (half_dim - 1)
|
42 |
+
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
43 |
+
emb = x[:, None] * emb[None, :]
|
44 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
45 |
+
return emb
|
46 |
+
|
47 |
+
|
48 |
+
# dropout mask
|
49 |
+
def prob_mask_like(shape, prob, device):
|
50 |
+
if prob == 1:
|
51 |
+
return torch.ones(shape, device=device, dtype=torch.bool)
|
52 |
+
elif prob == 0:
|
53 |
+
return torch.zeros(shape, device=device, dtype=torch.bool)
|
54 |
+
else:
|
55 |
+
return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob
|
56 |
+
|
57 |
+
|
58 |
+
def extract(a, t, x_shape):
|
59 |
+
b, *_ = t.shape
|
60 |
+
out = a.gather(-1, t)
|
61 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
62 |
+
|
63 |
+
|
64 |
+
def make_beta_schedule(
|
65 |
+
schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
|
66 |
+
):
|
67 |
+
if schedule == "linear":
|
68 |
+
betas = (
|
69 |
+
torch.linspace(
|
70 |
+
linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64
|
71 |
+
)
|
72 |
+
** 2
|
73 |
+
)
|
74 |
+
|
75 |
+
elif schedule == "cosine":
|
76 |
+
timesteps = (
|
77 |
+
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
|
78 |
+
)
|
79 |
+
alphas = timesteps / (1 + cosine_s) * np.pi / 2
|
80 |
+
alphas = torch.cos(alphas).pow(2)
|
81 |
+
alphas = alphas / alphas[0]
|
82 |
+
betas = 1 - alphas[1:] / alphas[:-1]
|
83 |
+
betas = np.clip(betas, a_min=0, a_max=0.999)
|
84 |
+
|
85 |
+
elif schedule == "sqrt_linear":
|
86 |
+
betas = torch.linspace(
|
87 |
+
linear_start, linear_end, n_timestep, dtype=torch.float64
|
88 |
+
)
|
89 |
+
elif schedule == "sqrt":
|
90 |
+
betas = (
|
91 |
+
torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
|
92 |
+
** 0.5
|
93 |
+
)
|
94 |
+
else:
|
95 |
+
raise ValueError(f"schedule '{schedule}' unknown.")
|
96 |
+
return betas.numpy()
|
core/models/modules/motion_extractor.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
|
3 |
+
"""
|
4 |
+
Motion extractor(M), which directly predicts the canonical keypoints, head pose and expression deformation of the input image
|
5 |
+
"""
|
6 |
+
|
7 |
+
from torch import nn
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from .convnextv2 import convnextv2_tiny
|
11 |
+
|
12 |
+
|
13 |
+
class MotionExtractor(nn.Module):
|
14 |
+
def __init__(self, num_kp=21, backbone="convnextv2_tiny"):
|
15 |
+
super(MotionExtractor, self).__init__()
|
16 |
+
self.detector = convnextv2_tiny(num_kp=num_kp, backbone=backbone)
|
17 |
+
|
18 |
+
def forward(self, x):
|
19 |
+
out = self.detector(x)
|
20 |
+
return out # pitch, yaw, roll, t, exp, scale, kp
|
21 |
+
|
22 |
+
def load_model(self, ckpt_path):
|
23 |
+
self.load_state_dict(torch.load(ckpt_path, map_location=lambda storage, loc: storage))
|
24 |
+
self.eval()
|
25 |
+
return self
|
core/models/modules/spade_generator.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
|
3 |
+
"""
|
4 |
+
Spade decoder(G) defined in the paper, which input the warped feature to generate the animated image.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from .util import SPADEResnetBlock
|
11 |
+
|
12 |
+
|
13 |
+
class SPADEDecoder(nn.Module):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
upscale=2,
|
17 |
+
max_features=512,
|
18 |
+
block_expansion=64,
|
19 |
+
out_channels=64,
|
20 |
+
num_down_blocks=2,
|
21 |
+
):
|
22 |
+
for i in range(num_down_blocks):
|
23 |
+
input_channels = min(max_features, block_expansion * (2 ** (i + 1)))
|
24 |
+
self.upscale = upscale
|
25 |
+
super().__init__()
|
26 |
+
norm_G = "spadespectralinstance"
|
27 |
+
label_num_channels = input_channels # 256
|
28 |
+
|
29 |
+
self.fc = nn.Conv2d(input_channels, 2 * input_channels, 3, padding=1)
|
30 |
+
self.G_middle_0 = SPADEResnetBlock(
|
31 |
+
2 * input_channels, 2 * input_channels, norm_G, label_num_channels
|
32 |
+
)
|
33 |
+
self.G_middle_1 = SPADEResnetBlock(
|
34 |
+
2 * input_channels, 2 * input_channels, norm_G, label_num_channels
|
35 |
+
)
|
36 |
+
self.G_middle_2 = SPADEResnetBlock(
|
37 |
+
2 * input_channels, 2 * input_channels, norm_G, label_num_channels
|
38 |
+
)
|
39 |
+
self.G_middle_3 = SPADEResnetBlock(
|
40 |
+
2 * input_channels, 2 * input_channels, norm_G, label_num_channels
|
41 |
+
)
|
42 |
+
self.G_middle_4 = SPADEResnetBlock(
|
43 |
+
2 * input_channels, 2 * input_channels, norm_G, label_num_channels
|
44 |
+
)
|
45 |
+
self.G_middle_5 = SPADEResnetBlock(
|
46 |
+
2 * input_channels, 2 * input_channels, norm_G, label_num_channels
|
47 |
+
)
|
48 |
+
self.up_0 = SPADEResnetBlock(
|
49 |
+
2 * input_channels, input_channels, norm_G, label_num_channels
|
50 |
+
)
|
51 |
+
self.up_1 = SPADEResnetBlock(
|
52 |
+
input_channels, out_channels, norm_G, label_num_channels
|
53 |
+
)
|
54 |
+
self.up = nn.Upsample(scale_factor=2)
|
55 |
+
|
56 |
+
if self.upscale is None or self.upscale <= 1:
|
57 |
+
self.conv_img = nn.Conv2d(out_channels, 3, 3, padding=1)
|
58 |
+
else:
|
59 |
+
self.conv_img = nn.Sequential(
|
60 |
+
nn.Conv2d(out_channels, 3 * (2 * 2), kernel_size=3, padding=1),
|
61 |
+
nn.PixelShuffle(upscale_factor=2),
|
62 |
+
)
|
63 |
+
|
64 |
+
def forward(self, feature):
|
65 |
+
seg = feature # Bx256x64x64
|
66 |
+
x = self.fc(feature) # Bx512x64x64
|
67 |
+
x = self.G_middle_0(x, seg)
|
68 |
+
x = self.G_middle_1(x, seg)
|
69 |
+
x = self.G_middle_2(x, seg)
|
70 |
+
x = self.G_middle_3(x, seg)
|
71 |
+
x = self.G_middle_4(x, seg)
|
72 |
+
x = self.G_middle_5(x, seg)
|
73 |
+
|
74 |
+
x = self.up(x) # Bx512x64x64 -> Bx512x128x128
|
75 |
+
x = self.up_0(x, seg) # Bx512x128x128 -> Bx256x128x128
|
76 |
+
x = self.up(x) # Bx256x128x128 -> Bx256x256x256
|
77 |
+
x = self.up_1(x, seg) # Bx256x256x256 -> Bx64x256x256
|
78 |
+
|
79 |
+
x = self.conv_img(F.leaky_relu(x, 2e-1)) # Bx64x256x256 -> Bx3xHxW
|
80 |
+
x = torch.sigmoid(x) # Bx3xHxW
|
81 |
+
|
82 |
+
return x
|
83 |
+
|
84 |
+
def load_model(self, ckpt_path):
|
85 |
+
self.load_state_dict(torch.load(ckpt_path, map_location=lambda storage, loc: storage))
|
86 |
+
self.eval()
|
87 |
+
return self
|
core/models/modules/stitching_network.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
|
3 |
+
"""
|
4 |
+
Stitching module(S) and two retargeting modules(R) defined in the paper.
|
5 |
+
|
6 |
+
- The stitching module pastes the animated portrait back into the original image space without pixel misalignment, such as in
|
7 |
+
the stitching region.
|
8 |
+
|
9 |
+
- The eyes retargeting module is designed to address the issue of incomplete eye closure during cross-id reenactment, especially
|
10 |
+
when a person with small eyes drives a person with larger eyes.
|
11 |
+
|
12 |
+
- The lip retargeting module is designed similarly to the eye retargeting module, and can also normalize the input by ensuring that
|
13 |
+
the lips are in a closed state, which facilitates better animation driving.
|
14 |
+
"""
|
15 |
+
import torch
|
16 |
+
from torch import nn
|
17 |
+
|
18 |
+
|
19 |
+
def remove_ddp_dumplicate_key(state_dict):
|
20 |
+
from collections import OrderedDict
|
21 |
+
state_dict_new = OrderedDict()
|
22 |
+
for key in state_dict.keys():
|
23 |
+
state_dict_new[key.replace('module.', '')] = state_dict[key]
|
24 |
+
return state_dict_new
|
25 |
+
|
26 |
+
|
27 |
+
class StitchingNetwork(nn.Module):
|
28 |
+
def __init__(self, input_size=126, hidden_sizes=[128, 128, 64], output_size=65):
|
29 |
+
super(StitchingNetwork, self).__init__()
|
30 |
+
layers = []
|
31 |
+
for i in range(len(hidden_sizes)):
|
32 |
+
if i == 0:
|
33 |
+
layers.append(nn.Linear(input_size, hidden_sizes[i]))
|
34 |
+
else:
|
35 |
+
layers.append(nn.Linear(hidden_sizes[i - 1], hidden_sizes[i]))
|
36 |
+
layers.append(nn.ReLU(inplace=True))
|
37 |
+
layers.append(nn.Linear(hidden_sizes[-1], output_size))
|
38 |
+
self.mlp = nn.Sequential(*layers)
|
39 |
+
|
40 |
+
def _forward(self, x):
|
41 |
+
return self.mlp(x)
|
42 |
+
|
43 |
+
def load_model(self, ckpt_path):
|
44 |
+
checkpoint = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
|
45 |
+
self.load_state_dict(remove_ddp_dumplicate_key(checkpoint['retarget_shoulder']))
|
46 |
+
self.eval()
|
47 |
+
return self
|
48 |
+
|
49 |
+
def stitching(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
|
50 |
+
""" conduct the stitching
|
51 |
+
kp_source: Bxnum_kpx3
|
52 |
+
kp_driving: Bxnum_kpx3
|
53 |
+
"""
|
54 |
+
bs, num_kp = kp_source.shape[:2]
|
55 |
+
kp_driving_new = kp_driving.clone()
|
56 |
+
delta = self._forward(torch.cat([kp_source.view(bs, -1), kp_driving_new.view(bs, -1)], dim=1))
|
57 |
+
delta_exp = delta[..., :3*num_kp].reshape(bs, num_kp, 3) # 1x20x3
|
58 |
+
delta_tx_ty = delta[..., 3*num_kp:3*num_kp+2].reshape(bs, 1, 2) # 1x1x2
|
59 |
+
kp_driving_new += delta_exp
|
60 |
+
kp_driving_new[..., :2] += delta_tx_ty
|
61 |
+
return kp_driving_new
|
62 |
+
|
63 |
+
def forward(self, kp_source, kp_driving):
|
64 |
+
out = self.stitching(kp_source, kp_driving)
|
65 |
+
return out
|
core/models/modules/util.py
ADDED
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
|
3 |
+
"""
|
4 |
+
This file defines various neural network modules and utility functions, including convolutional and residual blocks,
|
5 |
+
normalizations, and functions for spatial transformation and tensor manipulation.
|
6 |
+
"""
|
7 |
+
|
8 |
+
from torch import nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torch
|
11 |
+
import torch.nn.utils.spectral_norm as spectral_norm
|
12 |
+
import math
|
13 |
+
import warnings
|
14 |
+
import collections.abc
|
15 |
+
from itertools import repeat
|
16 |
+
|
17 |
+
def kp2gaussian(kp, spatial_size, kp_variance):
|
18 |
+
"""
|
19 |
+
Transform a keypoint into gaussian like representation
|
20 |
+
"""
|
21 |
+
mean = kp
|
22 |
+
|
23 |
+
coordinate_grid = make_coordinate_grid(spatial_size, mean)
|
24 |
+
number_of_leading_dimensions = len(mean.shape) - 1
|
25 |
+
shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape
|
26 |
+
coordinate_grid = coordinate_grid.view(*shape)
|
27 |
+
repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 1)
|
28 |
+
coordinate_grid = coordinate_grid.repeat(*repeats)
|
29 |
+
|
30 |
+
# Preprocess kp shape
|
31 |
+
shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 3)
|
32 |
+
mean = mean.view(*shape)
|
33 |
+
|
34 |
+
mean_sub = (coordinate_grid - mean)
|
35 |
+
|
36 |
+
out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance)
|
37 |
+
|
38 |
+
return out
|
39 |
+
|
40 |
+
|
41 |
+
def make_coordinate_grid(spatial_size, ref, **kwargs):
|
42 |
+
d, h, w = spatial_size
|
43 |
+
x = torch.arange(w).type(ref.dtype).to(ref.device)
|
44 |
+
y = torch.arange(h).type(ref.dtype).to(ref.device)
|
45 |
+
z = torch.arange(d).type(ref.dtype).to(ref.device)
|
46 |
+
|
47 |
+
# NOTE: must be right-down-in
|
48 |
+
x = (2 * (x / (w - 1)) - 1) # the x axis faces to the right
|
49 |
+
y = (2 * (y / (h - 1)) - 1) # the y axis faces to the bottom
|
50 |
+
z = (2 * (z / (d - 1)) - 1) # the z axis faces to the inner
|
51 |
+
|
52 |
+
yy = y.view(1, -1, 1).repeat(d, 1, w)
|
53 |
+
xx = x.view(1, 1, -1).repeat(d, h, 1)
|
54 |
+
zz = z.view(-1, 1, 1).repeat(1, h, w)
|
55 |
+
|
56 |
+
meshed = torch.cat([xx.unsqueeze_(3), yy.unsqueeze_(3), zz.unsqueeze_(3)], 3)
|
57 |
+
|
58 |
+
return meshed
|
59 |
+
|
60 |
+
|
61 |
+
class ConvT2d(nn.Module):
|
62 |
+
"""
|
63 |
+
Upsampling block for use in decoder.
|
64 |
+
"""
|
65 |
+
|
66 |
+
def __init__(self, in_features, out_features, kernel_size=3, stride=2, padding=1, output_padding=1):
|
67 |
+
super(ConvT2d, self).__init__()
|
68 |
+
|
69 |
+
self.convT = nn.ConvTranspose2d(in_features, out_features, kernel_size=kernel_size, stride=stride,
|
70 |
+
padding=padding, output_padding=output_padding)
|
71 |
+
self.norm = nn.InstanceNorm2d(out_features)
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
out = self.convT(x)
|
75 |
+
out = self.norm(out)
|
76 |
+
out = F.leaky_relu(out)
|
77 |
+
return out
|
78 |
+
|
79 |
+
|
80 |
+
class ResBlock3d(nn.Module):
|
81 |
+
"""
|
82 |
+
Res block, preserve spatial resolution.
|
83 |
+
"""
|
84 |
+
|
85 |
+
def __init__(self, in_features, kernel_size, padding):
|
86 |
+
super(ResBlock3d, self).__init__()
|
87 |
+
self.conv1 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, padding=padding)
|
88 |
+
self.conv2 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, padding=padding)
|
89 |
+
self.norm1 = nn.BatchNorm3d(in_features, affine=True)
|
90 |
+
self.norm2 = nn.BatchNorm3d(in_features, affine=True)
|
91 |
+
|
92 |
+
def forward(self, x):
|
93 |
+
out = self.norm1(x)
|
94 |
+
out = F.relu(out)
|
95 |
+
out = self.conv1(out)
|
96 |
+
out = self.norm2(out)
|
97 |
+
out = F.relu(out)
|
98 |
+
out = self.conv2(out)
|
99 |
+
out += x
|
100 |
+
return out
|
101 |
+
|
102 |
+
|
103 |
+
class UpBlock3d(nn.Module):
|
104 |
+
"""
|
105 |
+
Upsampling block for use in decoder.
|
106 |
+
"""
|
107 |
+
|
108 |
+
def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
|
109 |
+
super(UpBlock3d, self).__init__()
|
110 |
+
|
111 |
+
self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
|
112 |
+
padding=padding, groups=groups)
|
113 |
+
self.norm = nn.BatchNorm3d(out_features, affine=True)
|
114 |
+
|
115 |
+
def forward(self, x):
|
116 |
+
out = F.interpolate(x, scale_factor=(1, 2, 2))
|
117 |
+
out = self.conv(out)
|
118 |
+
out = self.norm(out)
|
119 |
+
out = F.relu(out)
|
120 |
+
return out
|
121 |
+
|
122 |
+
|
123 |
+
class DownBlock2d(nn.Module):
|
124 |
+
"""
|
125 |
+
Downsampling block for use in encoder.
|
126 |
+
"""
|
127 |
+
|
128 |
+
def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
|
129 |
+
super(DownBlock2d, self).__init__()
|
130 |
+
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups)
|
131 |
+
self.norm = nn.BatchNorm2d(out_features, affine=True)
|
132 |
+
self.pool = nn.AvgPool2d(kernel_size=(2, 2))
|
133 |
+
|
134 |
+
def forward(self, x):
|
135 |
+
out = self.conv(x)
|
136 |
+
out = self.norm(out)
|
137 |
+
out = F.relu(out)
|
138 |
+
out = self.pool(out)
|
139 |
+
return out
|
140 |
+
|
141 |
+
|
142 |
+
class DownBlock3d(nn.Module):
|
143 |
+
"""
|
144 |
+
Downsampling block for use in encoder.
|
145 |
+
"""
|
146 |
+
|
147 |
+
def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
|
148 |
+
super(DownBlock3d, self).__init__()
|
149 |
+
'''
|
150 |
+
self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
|
151 |
+
padding=padding, groups=groups, stride=(1, 2, 2))
|
152 |
+
'''
|
153 |
+
self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
|
154 |
+
padding=padding, groups=groups)
|
155 |
+
self.norm = nn.BatchNorm3d(out_features, affine=True)
|
156 |
+
self.pool = nn.AvgPool3d(kernel_size=(1, 2, 2))
|
157 |
+
|
158 |
+
def forward(self, x):
|
159 |
+
out = self.conv(x)
|
160 |
+
out = self.norm(out)
|
161 |
+
out = F.relu(out)
|
162 |
+
out = self.pool(out)
|
163 |
+
return out
|
164 |
+
|
165 |
+
|
166 |
+
class SameBlock2d(nn.Module):
|
167 |
+
"""
|
168 |
+
Simple block, preserve spatial resolution.
|
169 |
+
"""
|
170 |
+
|
171 |
+
def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1, lrelu=False):
|
172 |
+
super(SameBlock2d, self).__init__()
|
173 |
+
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups)
|
174 |
+
self.norm = nn.BatchNorm2d(out_features, affine=True)
|
175 |
+
if lrelu:
|
176 |
+
self.ac = nn.LeakyReLU()
|
177 |
+
else:
|
178 |
+
self.ac = nn.ReLU()
|
179 |
+
|
180 |
+
def forward(self, x):
|
181 |
+
out = self.conv(x)
|
182 |
+
out = self.norm(out)
|
183 |
+
out = self.ac(out)
|
184 |
+
return out
|
185 |
+
|
186 |
+
|
187 |
+
class Encoder(nn.Module):
|
188 |
+
"""
|
189 |
+
Hourglass Encoder
|
190 |
+
"""
|
191 |
+
|
192 |
+
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
|
193 |
+
super(Encoder, self).__init__()
|
194 |
+
|
195 |
+
down_blocks = []
|
196 |
+
for i in range(num_blocks):
|
197 |
+
down_blocks.append(DownBlock3d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)), min(max_features, block_expansion * (2 ** (i + 1))), kernel_size=3, padding=1))
|
198 |
+
self.down_blocks = nn.ModuleList(down_blocks)
|
199 |
+
|
200 |
+
def forward(self, x):
|
201 |
+
outs = [x]
|
202 |
+
for down_block in self.down_blocks:
|
203 |
+
outs.append(down_block(outs[-1]))
|
204 |
+
return outs
|
205 |
+
|
206 |
+
|
207 |
+
class Decoder(nn.Module):
|
208 |
+
"""
|
209 |
+
Hourglass Decoder
|
210 |
+
"""
|
211 |
+
|
212 |
+
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
|
213 |
+
super(Decoder, self).__init__()
|
214 |
+
|
215 |
+
up_blocks = []
|
216 |
+
|
217 |
+
for i in range(num_blocks)[::-1]:
|
218 |
+
in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1)))
|
219 |
+
out_filters = min(max_features, block_expansion * (2 ** i))
|
220 |
+
up_blocks.append(UpBlock3d(in_filters, out_filters, kernel_size=3, padding=1))
|
221 |
+
|
222 |
+
self.up_blocks = nn.ModuleList(up_blocks)
|
223 |
+
self.out_filters = block_expansion + in_features
|
224 |
+
|
225 |
+
self.conv = nn.Conv3d(in_channels=self.out_filters, out_channels=self.out_filters, kernel_size=3, padding=1)
|
226 |
+
self.norm = nn.BatchNorm3d(self.out_filters, affine=True)
|
227 |
+
|
228 |
+
def forward(self, x):
|
229 |
+
out = x.pop()
|
230 |
+
for up_block in self.up_blocks:
|
231 |
+
out = up_block(out)
|
232 |
+
skip = x.pop()
|
233 |
+
out = torch.cat([out, skip], dim=1)
|
234 |
+
out = self.conv(out)
|
235 |
+
out = self.norm(out)
|
236 |
+
out = F.relu(out)
|
237 |
+
return out
|
238 |
+
|
239 |
+
|
240 |
+
class Hourglass(nn.Module):
|
241 |
+
"""
|
242 |
+
Hourglass architecture.
|
243 |
+
"""
|
244 |
+
|
245 |
+
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
|
246 |
+
super(Hourglass, self).__init__()
|
247 |
+
self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features)
|
248 |
+
self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features)
|
249 |
+
self.out_filters = self.decoder.out_filters
|
250 |
+
|
251 |
+
def forward(self, x):
|
252 |
+
return self.decoder(self.encoder(x))
|
253 |
+
|
254 |
+
|
255 |
+
class SPADE(nn.Module):
|
256 |
+
def __init__(self, norm_nc, label_nc):
|
257 |
+
super().__init__()
|
258 |
+
|
259 |
+
self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
|
260 |
+
nhidden = 128
|
261 |
+
|
262 |
+
self.mlp_shared = nn.Sequential(
|
263 |
+
nn.Conv2d(label_nc, nhidden, kernel_size=3, padding=1),
|
264 |
+
nn.ReLU())
|
265 |
+
self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
|
266 |
+
self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
|
267 |
+
|
268 |
+
def forward(self, x, segmap):
|
269 |
+
normalized = self.param_free_norm(x)
|
270 |
+
segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
|
271 |
+
actv = self.mlp_shared(segmap)
|
272 |
+
gamma = self.mlp_gamma(actv)
|
273 |
+
beta = self.mlp_beta(actv)
|
274 |
+
out = normalized * (1 + gamma) + beta
|
275 |
+
return out
|
276 |
+
|
277 |
+
|
278 |
+
class SPADEResnetBlock(nn.Module):
|
279 |
+
def __init__(self, fin, fout, norm_G, label_nc, use_se=False, dilation=1):
|
280 |
+
super().__init__()
|
281 |
+
# Attributes
|
282 |
+
self.learned_shortcut = (fin != fout)
|
283 |
+
fmiddle = min(fin, fout)
|
284 |
+
self.use_se = use_se
|
285 |
+
# create conv layers
|
286 |
+
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation)
|
287 |
+
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation)
|
288 |
+
if self.learned_shortcut:
|
289 |
+
self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
|
290 |
+
# apply spectral norm if specified
|
291 |
+
if 'spectral' in norm_G:
|
292 |
+
self.conv_0 = spectral_norm(self.conv_0)
|
293 |
+
self.conv_1 = spectral_norm(self.conv_1)
|
294 |
+
if self.learned_shortcut:
|
295 |
+
self.conv_s = spectral_norm(self.conv_s)
|
296 |
+
# define normalization layers
|
297 |
+
self.norm_0 = SPADE(fin, label_nc)
|
298 |
+
self.norm_1 = SPADE(fmiddle, label_nc)
|
299 |
+
if self.learned_shortcut:
|
300 |
+
self.norm_s = SPADE(fin, label_nc)
|
301 |
+
|
302 |
+
def forward(self, x, seg1):
|
303 |
+
x_s = self.shortcut(x, seg1)
|
304 |
+
dx = self.conv_0(self.actvn(self.norm_0(x, seg1)))
|
305 |
+
dx = self.conv_1(self.actvn(self.norm_1(dx, seg1)))
|
306 |
+
out = x_s + dx
|
307 |
+
return out
|
308 |
+
|
309 |
+
def shortcut(self, x, seg1):
|
310 |
+
if self.learned_shortcut:
|
311 |
+
x_s = self.conv_s(self.norm_s(x, seg1))
|
312 |
+
else:
|
313 |
+
x_s = x
|
314 |
+
return x_s
|
315 |
+
|
316 |
+
def actvn(self, x):
|
317 |
+
return F.leaky_relu(x, 2e-1)
|
318 |
+
|
319 |
+
|
320 |
+
def filter_state_dict(state_dict, remove_name='fc'):
|
321 |
+
new_state_dict = {}
|
322 |
+
for key in state_dict:
|
323 |
+
if remove_name in key:
|
324 |
+
continue
|
325 |
+
new_state_dict[key] = state_dict[key]
|
326 |
+
return new_state_dict
|
327 |
+
|
328 |
+
|
329 |
+
class GRN(nn.Module):
|
330 |
+
""" GRN (Global Response Normalization) layer
|
331 |
+
"""
|
332 |
+
|
333 |
+
def __init__(self, dim):
|
334 |
+
super().__init__()
|
335 |
+
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
|
336 |
+
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
|
337 |
+
|
338 |
+
def forward(self, x):
|
339 |
+
Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
|
340 |
+
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
|
341 |
+
return self.gamma * (x * Nx) + self.beta + x
|
342 |
+
|
343 |
+
|
344 |
+
class LayerNorm(nn.Module):
|
345 |
+
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
346 |
+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
347 |
+
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
348 |
+
with shape (batch_size, channels, height, width).
|
349 |
+
"""
|
350 |
+
|
351 |
+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
|
352 |
+
super().__init__()
|
353 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
354 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
355 |
+
self.eps = eps
|
356 |
+
self.data_format = data_format
|
357 |
+
if self.data_format not in ["channels_last", "channels_first"]:
|
358 |
+
raise NotImplementedError
|
359 |
+
self.normalized_shape = (normalized_shape, )
|
360 |
+
|
361 |
+
def forward(self, x):
|
362 |
+
if self.data_format == "channels_last":
|
363 |
+
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
364 |
+
elif self.data_format == "channels_first":
|
365 |
+
u = x.mean(1, keepdim=True)
|
366 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
367 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
368 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
369 |
+
return x
|
370 |
+
|
371 |
+
|
372 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
373 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
374 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
375 |
+
def norm_cdf(x):
|
376 |
+
# Computes standard normal cumulative distribution function
|
377 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
378 |
+
|
379 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
380 |
+
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
381 |
+
"The distribution of values may be incorrect.",
|
382 |
+
stacklevel=2)
|
383 |
+
|
384 |
+
with torch.no_grad():
|
385 |
+
# Values are generated by using a truncated uniform distribution and
|
386 |
+
# then using the inverse CDF for the normal distribution.
|
387 |
+
# Get upper and lower cdf values
|
388 |
+
l = norm_cdf((a - mean) / std)
|
389 |
+
u = norm_cdf((b - mean) / std)
|
390 |
+
|
391 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
392 |
+
# [2l-1, 2u-1].
|
393 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
394 |
+
|
395 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
396 |
+
# standard normal
|
397 |
+
tensor.erfinv_()
|
398 |
+
|
399 |
+
# Transform to proper mean, std
|
400 |
+
tensor.mul_(std * math.sqrt(2.))
|
401 |
+
tensor.add_(mean)
|
402 |
+
|
403 |
+
# Clamp to ensure it's in the proper range
|
404 |
+
tensor.clamp_(min=a, max=b)
|
405 |
+
return tensor
|
406 |
+
|
407 |
+
|
408 |
+
def drop_path(x, drop_prob=0., training=False, scale_by_keep=True):
|
409 |
+
""" Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
410 |
+
|
411 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
412 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
413 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
414 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
415 |
+
'survival rate' as the argument.
|
416 |
+
|
417 |
+
"""
|
418 |
+
if drop_prob == 0. or not training:
|
419 |
+
return x
|
420 |
+
keep_prob = 1 - drop_prob
|
421 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
422 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
423 |
+
if keep_prob > 0.0 and scale_by_keep:
|
424 |
+
random_tensor.div_(keep_prob)
|
425 |
+
return x * random_tensor
|
426 |
+
|
427 |
+
|
428 |
+
class DropPath(nn.Module):
|
429 |
+
""" Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
430 |
+
"""
|
431 |
+
|
432 |
+
def __init__(self, drop_prob=None, scale_by_keep=True):
|
433 |
+
super(DropPath, self).__init__()
|
434 |
+
self.drop_prob = drop_prob
|
435 |
+
self.scale_by_keep = scale_by_keep
|
436 |
+
|
437 |
+
def forward(self, x):
|
438 |
+
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
439 |
+
|
440 |
+
|
441 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
442 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
443 |
+
|
444 |
+
# From PyTorch internals
|
445 |
+
def _ntuple(n):
|
446 |
+
def parse(x):
|
447 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
448 |
+
return tuple(x)
|
449 |
+
return tuple(repeat(x, n))
|
450 |
+
return parse
|
451 |
+
|
452 |
+
to_2tuple = _ntuple(2)
|
core/models/modules/warping_network.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
|
3 |
+
"""
|
4 |
+
Warping field estimator(W) defined in the paper, which generates a warping field using the implicit
|
5 |
+
keypoint representations x_s and x_d, and employs this flow field to warp the source feature volume f_s.
|
6 |
+
"""
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from .util import SameBlock2d
|
11 |
+
from .dense_motion import DenseMotionNetwork
|
12 |
+
|
13 |
+
|
14 |
+
class WarpingNetwork(nn.Module):
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
num_kp=21,
|
18 |
+
block_expansion=64,
|
19 |
+
max_features=512,
|
20 |
+
num_down_blocks=2,
|
21 |
+
reshape_channel=32,
|
22 |
+
estimate_occlusion_map=True,
|
23 |
+
**kwargs
|
24 |
+
):
|
25 |
+
super(WarpingNetwork, self).__init__()
|
26 |
+
|
27 |
+
self.upscale = kwargs.get('upscale', 1)
|
28 |
+
self.flag_use_occlusion_map = kwargs.get('flag_use_occlusion_map', True)
|
29 |
+
|
30 |
+
dense_motion_params = {
|
31 |
+
"block_expansion": 32,
|
32 |
+
"max_features": 1024,
|
33 |
+
"num_blocks": 5,
|
34 |
+
"reshape_depth": 16,
|
35 |
+
"compress": 4,
|
36 |
+
}
|
37 |
+
|
38 |
+
self.dense_motion_network = DenseMotionNetwork(
|
39 |
+
num_kp=num_kp,
|
40 |
+
feature_channel=reshape_channel,
|
41 |
+
estimate_occlusion_map=estimate_occlusion_map,
|
42 |
+
**dense_motion_params
|
43 |
+
)
|
44 |
+
|
45 |
+
self.third = SameBlock2d(max_features, block_expansion * (2 ** num_down_blocks), kernel_size=(3, 3), padding=(1, 1), lrelu=True)
|
46 |
+
self.fourth = nn.Conv2d(in_channels=block_expansion * (2 ** num_down_blocks), out_channels=block_expansion * (2 ** num_down_blocks), kernel_size=1, stride=1)
|
47 |
+
|
48 |
+
self.estimate_occlusion_map = estimate_occlusion_map
|
49 |
+
|
50 |
+
def deform_input(self, inp, deformation):
|
51 |
+
return F.grid_sample(inp, deformation, align_corners=False)
|
52 |
+
|
53 |
+
def forward(self, feature_3d, kp_source, kp_driving):
|
54 |
+
# Feature warper, Transforming feature representation according to deformation and occlusion
|
55 |
+
dense_motion = self.dense_motion_network(
|
56 |
+
feature=feature_3d, kp_driving=kp_driving, kp_source=kp_source
|
57 |
+
)
|
58 |
+
if 'occlusion_map' in dense_motion:
|
59 |
+
occlusion_map = dense_motion['occlusion_map'] # Bx1x64x64
|
60 |
+
else:
|
61 |
+
occlusion_map = None
|
62 |
+
|
63 |
+
deformation = dense_motion['deformation'] # Bx16x64x64x3
|
64 |
+
out = self.deform_input(feature_3d, deformation) # Bx32x16x64x64
|
65 |
+
|
66 |
+
bs, c, d, h, w = out.shape # Bx32x16x64x64
|
67 |
+
out = out.view(bs, c * d, h, w) # -> Bx512x64x64
|
68 |
+
out = self.third(out) # -> Bx256x64x64
|
69 |
+
out = self.fourth(out) # -> Bx256x64x64
|
70 |
+
|
71 |
+
if self.flag_use_occlusion_map and (occlusion_map is not None):
|
72 |
+
out = out * occlusion_map
|
73 |
+
|
74 |
+
# ret_dct = {
|
75 |
+
# 'occlusion_map': occlusion_map,
|
76 |
+
# 'deformation': deformation,
|
77 |
+
# 'out': out,
|
78 |
+
# }
|
79 |
+
|
80 |
+
# return ret_dct
|
81 |
+
|
82 |
+
return out
|
83 |
+
|
84 |
+
def load_model(self, ckpt_path):
|
85 |
+
self.load_state_dict(torch.load(ckpt_path, map_location=lambda storage, loc: storage))
|
86 |
+
self.eval()
|
87 |
+
return self
|
core/models/motion_extractor.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from ..utils.load_model import load_model
|
4 |
+
|
5 |
+
|
6 |
+
class MotionExtractor:
|
7 |
+
def __init__(self, model_path, device="cuda"):
|
8 |
+
kwargs = {
|
9 |
+
"module_name": "MotionExtractor",
|
10 |
+
}
|
11 |
+
self.model, self.model_type = load_model(model_path, device=device, **kwargs)
|
12 |
+
self.device = device
|
13 |
+
|
14 |
+
self.output_names = [
|
15 |
+
"pitch",
|
16 |
+
"yaw",
|
17 |
+
"roll",
|
18 |
+
"t",
|
19 |
+
"exp",
|
20 |
+
"scale",
|
21 |
+
"kp",
|
22 |
+
]
|
23 |
+
|
24 |
+
def __call__(self, image):
|
25 |
+
"""
|
26 |
+
image: np.ndarray, shape (1, 3, 256, 256), RGB, 0-1
|
27 |
+
"""
|
28 |
+
outputs = {}
|
29 |
+
if self.model_type == "onnx":
|
30 |
+
out_list = self.model.run(None, {"image": image})
|
31 |
+
for i, name in enumerate(self.output_names):
|
32 |
+
outputs[name] = out_list[i]
|
33 |
+
elif self.model_type == "tensorrt":
|
34 |
+
self.model.setup({"image": image})
|
35 |
+
self.model.infer()
|
36 |
+
for name in self.output_names:
|
37 |
+
outputs[name] = self.model.buffer[name][0].copy()
|
38 |
+
elif self.model_type == "pytorch":
|
39 |
+
with torch.no_grad(), torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=True):
|
40 |
+
pred = self.model(torch.from_numpy(image).to(self.device))
|
41 |
+
for i, name in enumerate(self.output_names):
|
42 |
+
outputs[name] = pred[i].float().cpu().numpy()
|
43 |
+
else:
|
44 |
+
raise ValueError(f"Unsupported model type: {self.model_type}")
|
45 |
+
outputs["exp"] = outputs["exp"].reshape(1, -1)
|
46 |
+
outputs["kp"] = outputs["kp"].reshape(1, -1)
|
47 |
+
return outputs
|
48 |
+
|
49 |
+
|
core/models/stitch_network.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from ..utils.load_model import load_model
|
4 |
+
|
5 |
+
|
6 |
+
class StitchNetwork:
|
7 |
+
def __init__(self, model_path, device="cuda"):
|
8 |
+
kwargs = {
|
9 |
+
"module_name": "StitchingNetwork",
|
10 |
+
}
|
11 |
+
self.model, self.model_type = load_model(model_path, device=device, **kwargs)
|
12 |
+
self.device = device
|
13 |
+
|
14 |
+
def __call__(self, kp_source, kp_driving):
|
15 |
+
if self.model_type == "onnx":
|
16 |
+
pred = self.model.run(None, {"kp_source": kp_source, "kp_driving": kp_driving})[0]
|
17 |
+
elif self.model_type == "tensorrt":
|
18 |
+
self.model.setup({"kp_source": kp_source, "kp_driving": kp_driving})
|
19 |
+
self.model.infer()
|
20 |
+
pred = self.model.buffer["out"][0].copy()
|
21 |
+
elif self.model_type == 'pytorch':
|
22 |
+
with torch.no_grad():
|
23 |
+
pred = self.model(
|
24 |
+
torch.from_numpy(kp_source).to(self.device),
|
25 |
+
torch.from_numpy(kp_driving).to(self.device)
|
26 |
+
).cpu().numpy()
|
27 |
+
else:
|
28 |
+
raise ValueError(f"Unsupported model type: {self.model_type}")
|
29 |
+
|
30 |
+
return pred
|
core/models/warp_network.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from ..utils.load_model import load_model
|
4 |
+
|
5 |
+
|
6 |
+
class WarpNetwork:
|
7 |
+
def __init__(self, model_path, device="cuda"):
|
8 |
+
kwargs = {
|
9 |
+
"module_name": "WarpingNetwork",
|
10 |
+
}
|
11 |
+
self.model, self.model_type = load_model(model_path, device=device, **kwargs)
|
12 |
+
self.device = device
|
13 |
+
|
14 |
+
def __call__(self, feature_3d, kp_source, kp_driving):
|
15 |
+
"""
|
16 |
+
feature_3d: np.ndarray, shape (1, 32, 16, 64, 64)
|
17 |
+
kp_source | kp_driving: np.ndarray, shape (1, 21, 3)
|
18 |
+
"""
|
19 |
+
if self.model_type == "onnx":
|
20 |
+
pred = self.model.run(None, {"feature_3d": feature_3d, "kp_source": kp_source, "kp_driving": kp_driving})[0]
|
21 |
+
elif self.model_type == "tensorrt":
|
22 |
+
self.model.setup({"feature_3d": feature_3d, "kp_source": kp_source, "kp_driving": kp_driving})
|
23 |
+
self.model.infer()
|
24 |
+
pred = self.model.buffer["out"][0].copy()
|
25 |
+
elif self.model_type == 'pytorch':
|
26 |
+
with torch.no_grad(), torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=True):
|
27 |
+
pred = self.model(
|
28 |
+
torch.from_numpy(feature_3d).to(self.device),
|
29 |
+
torch.from_numpy(kp_source).to(self.device),
|
30 |
+
torch.from_numpy(kp_driving).to(self.device)
|
31 |
+
).float().cpu().numpy()
|
32 |
+
else:
|
33 |
+
raise ValueError(f"Unsupported model type: {self.model_type}")
|
34 |
+
|
35 |
+
return pred
|
core/utils/blend/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pyximport
|
2 |
+
pyximport.install()
|
3 |
+
|
4 |
+
from .blend import blend_images_cy
|
core/utils/blend/blend.pyx
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#cython: language_level=3
|
2 |
+
import numpy as np
|
3 |
+
cimport numpy as np
|
4 |
+
|
5 |
+
cdef extern from "blend_impl.h":
|
6 |
+
void _blend_images_cy_impl(
|
7 |
+
const float* mask_warped,
|
8 |
+
const float* frame_warped,
|
9 |
+
const unsigned char* frame_rgb,
|
10 |
+
const int height,
|
11 |
+
const int width,
|
12 |
+
unsigned char* result
|
13 |
+
) noexcept nogil
|
14 |
+
|
15 |
+
def blend_images_cy(
|
16 |
+
np.ndarray[np.float32_t, ndim=2] mask_warped,
|
17 |
+
np.ndarray[np.float32_t, ndim=3] frame_warped,
|
18 |
+
np.ndarray[np.uint8_t, ndim=3] frame_rgb,
|
19 |
+
np.ndarray[np.uint8_t, ndim=3] result
|
20 |
+
):
|
21 |
+
cdef int h = mask_warped.shape[0]
|
22 |
+
cdef int w = mask_warped.shape[1]
|
23 |
+
|
24 |
+
if not mask_warped.flags['C_CONTIGUOUS']:
|
25 |
+
mask_warped = np.ascontiguousarray(mask_warped)
|
26 |
+
if not frame_warped.flags['C_CONTIGUOUS']:
|
27 |
+
frame_warped = np.ascontiguousarray(frame_warped)
|
28 |
+
if not frame_rgb.flags['C_CONTIGUOUS']:
|
29 |
+
frame_rgb = np.ascontiguousarray(frame_rgb)
|
30 |
+
|
31 |
+
with nogil:
|
32 |
+
_blend_images_cy_impl(
|
33 |
+
<const float*>mask_warped.data,
|
34 |
+
<const float*>frame_warped.data,
|
35 |
+
<const unsigned char*>frame_rgb.data,
|
36 |
+
h, w,
|
37 |
+
<unsigned char*>result.data
|
38 |
+
)
|
core/utils/blend/blend.pyxbld
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import os
|
3 |
+
|
4 |
+
def make_ext(modname, pyxfilename):
|
5 |
+
from distutils.extension import Extension
|
6 |
+
|
7 |
+
return Extension(name=modname,
|
8 |
+
sources=[pyxfilename, os.path.join(os.path.dirname(pyxfilename), "blend_impl.c")],
|
9 |
+
include_dirs=[np.get_include(), os.path.dirname(pyxfilename)],
|
10 |
+
extra_compile_args=["-O3", "-std=c99", "-march=native", "-ffast-math"],
|
11 |
+
)
|