oKen38461 commited on
Commit
ac7cda5
·
1 Parent(s): d291c0c

初回コミットに基づくファイルの追加

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +43 -0
  2. LICENSE +201 -0
  3. README_ditto-talkinghead.md +232 -0
  4. core/atomic_components/audio2motion.py +196 -0
  5. core/atomic_components/avatar_registrar.py +102 -0
  6. core/atomic_components/cfg.py +111 -0
  7. core/atomic_components/condition_handler.py +168 -0
  8. core/atomic_components/decode_f3d.py +22 -0
  9. core/atomic_components/loader.py +133 -0
  10. core/atomic_components/motion_stitch.py +491 -0
  11. core/atomic_components/putback.py +60 -0
  12. core/atomic_components/source2info.py +155 -0
  13. core/atomic_components/warp_f3d.py +22 -0
  14. core/atomic_components/wav2feat.py +110 -0
  15. core/atomic_components/writer.py +36 -0
  16. core/aux_models/blaze_face.py +351 -0
  17. core/aux_models/face_mesh.py +101 -0
  18. core/aux_models/hubert_stream.py +29 -0
  19. core/aux_models/insightface_det.py +245 -0
  20. core/aux_models/insightface_landmark106.py +100 -0
  21. core/aux_models/landmark203.py +58 -0
  22. core/aux_models/mediapipe_landmark478.py +118 -0
  23. core/aux_models/modules/__init__.py +5 -0
  24. core/aux_models/modules/hubert_stream.py +21 -0
  25. core/aux_models/modules/landmark106.py +83 -0
  26. core/aux_models/modules/landmark203.py +42 -0
  27. core/aux_models/modules/landmark478.py +35 -0
  28. core/aux_models/modules/retinaface.py +215 -0
  29. core/models/appearance_extractor.py +29 -0
  30. core/models/decoder.py +30 -0
  31. core/models/lmdm.py +140 -0
  32. core/models/modules/LMDM.py +154 -0
  33. core/models/modules/__init__.py +6 -0
  34. core/models/modules/appearance_feature_extractor.py +74 -0
  35. core/models/modules/convnextv2.py +150 -0
  36. core/models/modules/dense_motion.py +104 -0
  37. core/models/modules/lmdm_modules/model.py +398 -0
  38. core/models/modules/lmdm_modules/rotary_embedding_torch.py +132 -0
  39. core/models/modules/lmdm_modules/utils.py +96 -0
  40. core/models/modules/motion_extractor.py +25 -0
  41. core/models/modules/spade_generator.py +87 -0
  42. core/models/modules/stitching_network.py +65 -0
  43. core/models/modules/util.py +452 -0
  44. core/models/modules/warping_network.py +87 -0
  45. core/models/motion_extractor.py +49 -0
  46. core/models/stitch_network.py +30 -0
  47. core/models/warp_network.py +35 -0
  48. core/utils/blend/__init__.py +4 -0
  49. core/utils/blend/blend.pyx +38 -0
  50. core/utils/blend/blend.pyxbld +11 -0
.gitignore ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *__pycache__
4
+ **/__pycache__/
5
+ *.py[cod]
6
+ **/*.py[cod]
7
+ *$py.class
8
+
9
+ # Model weights
10
+ checkpoints
11
+ **/*.pth
12
+ **/*.onnx
13
+ **/*.pt
14
+ **/*.pth.tar
15
+
16
+ .idea
17
+ .vscode
18
+ .DS_Store
19
+ *.DS_Store
20
+
21
+ *.swp
22
+ tmp*
23
+
24
+ *build
25
+ *.egg-info/
26
+ *.mp4
27
+
28
+ log/*
29
+ *.mp4
30
+ *.png
31
+ *.jpg
32
+ *.wav
33
+ *.pth
34
+ *.pyc
35
+ *.jpeg
36
+
37
+ # Folders to ignore
38
+ example/
39
+ ToDo/
40
+
41
+ !example/audio.wav
42
+ !example/image.png
43
+
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README_ditto-talkinghead.md ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <h2 align='center'>Ditto: Motion-Space Diffusion for Controllable Realtime Talking Head Synthesis</h2>
2
+
3
+ <div align='center'>
4
+ <a href=""><strong>Tianqi Li</strong></a>
5
+ ·
6
+ <a href=""><strong>Ruobing Zheng</strong></a><sup>†</sup>
7
+ ·
8
+ <a href=""><strong>Minghui Yang</strong></a>
9
+ ·
10
+ <a href=""><strong>Jingdong Chen</strong></a>
11
+ ·
12
+ <a href=""><strong>Ming Yang</strong></a>
13
+ </div>
14
+ <div align='center'>
15
+ Ant Group
16
+ </div>
17
+ <br>
18
+ <div align='center'>
19
+ <a href='https://arxiv.org/abs/2411.19509'><img src='https://img.shields.io/badge/Paper-arXiv-red'></a>
20
+ <a href='https://digital-avatar.github.io/ai/Ditto/'><img src='https://img.shields.io/badge/Project-Page-blue'></a>
21
+ <a href='https://huggingface.co/digital-avatar/ditto-talkinghead'><img src='https://img.shields.io/badge/Model-HuggingFace-yellow'></a>
22
+ <a href='https://github.com/antgroup/ditto-talkinghead'><img src='https://img.shields.io/badge/Code-GitHub-purple'></a>
23
+ <!-- <a href='https://github.com/antgroup/ditto-talkinghead'><img src='https://img.shields.io/github/stars/antgroup/ditto-talkinghead?style=social'></a> -->
24
+ <a href='https://colab.research.google.com/drive/19SUi1TiO32IS-Crmsu9wrkNspWE8tFbs?usp=sharing'><img src='https://img.shields.io/badge/Demo-Colab-orange'></a>
25
+ </div>
26
+ <br>
27
+ <div align="center">
28
+ <video style="width: 95%; object-fit: cover;" controls loop src="https://github.com/user-attachments/assets/ef1a0b08-bff3-4997-a6dd-62a7f51cdb40" muted="false"></video>
29
+ <p>
30
+ ✨ For more results, visit our <a href="https://digital-avatar.github.io/ai/Ditto/"><strong>Project Page</strong></a> ✨
31
+ </p>
32
+ </div>
33
+
34
+
35
+ ## 📌 Updates
36
+ * [2025.07.11] 🔥 The [PyTorch model](#-pytorch-model) is now available.
37
+ * [2025.07.07] 🔥 Ditto is accepted by ACM MM 2025.
38
+ * [2025.01.21] 🔥 We update the [Colab](https://colab.research.google.com/drive/19SUi1TiO32IS-Crmsu9wrkNspWE8tFbs?usp=sharing) demo, welcome to try it.
39
+ * [2025.01.10] 🔥 We release our inference [codes](https://github.com/antgroup/ditto-talkinghead) and [models](https://huggingface.co/digital-avatar/ditto-talkinghead).
40
+ * [2024.11.29] 🔥 Our [paper](https://arxiv.org/abs/2411.19509) is in public on arxiv.
41
+
42
+
43
+
44
+ ## 🛠️ Installation
45
+
46
+ Tested Environment
47
+ - System: Centos 7.2
48
+ - GPU: A100
49
+ - Python: 3.10
50
+ - tensorRT: 8.6.1
51
+
52
+
53
+ Clone the codes from [GitHub](https://github.com/antgroup/ditto-talkinghead):
54
+ ```bash
55
+ git clone https://github.com/antgroup/ditto-talkinghead
56
+ cd ditto-talkinghead
57
+ ```
58
+
59
+ ### Conda
60
+ Create `conda` environment:
61
+ ```bash
62
+ conda env create -f environment.yaml
63
+ conda activate ditto
64
+ ```
65
+
66
+ ### Pip
67
+ If you have problems creating a conda environment, you can also refer to our [Colab](https://colab.research.google.com/drive/19SUi1TiO32IS-Crmsu9wrkNspWE8tFbs?usp=sharing).
68
+ After correctly installing `pytorch`, `cuda` and `cudnn`, you only need to install a few packages using pip:
69
+ ```bash
70
+ pip install \
71
+ tensorrt==8.6.1 \
72
+ librosa \
73
+ tqdm \
74
+ filetype \
75
+ imageio \
76
+ opencv_python_headless \
77
+ scikit-image \
78
+ cython \
79
+ cuda-python \
80
+ imageio-ffmpeg \
81
+ colored \
82
+ polygraphy \
83
+ numpy==2.0.1
84
+ ```
85
+
86
+ If you don't use `conda`, you may also need to install `ffmpeg` according to the [official website](https://www.ffmpeg.org/download.html).
87
+
88
+
89
+ ## 📥 Download Checkpoints
90
+
91
+ Download checkpoints from [HuggingFace](https://huggingface.co/digital-avatar/ditto-talkinghead) and put them in `checkpoints` dir:
92
+ ```bash
93
+ git lfs install
94
+ git clone https://huggingface.co/digital-avatar/ditto-talkinghead checkpoints
95
+ ```
96
+
97
+ The `checkpoints` should be like:
98
+ ```text
99
+ ./checkpoints/
100
+ ├── ditto_cfg
101
+ │   ├── v0.4_hubert_cfg_trt.pkl
102
+ │   └── v0.4_hubert_cfg_trt_online.pkl
103
+ ├── ditto_onnx
104
+ │   ├── appearance_extractor.onnx
105
+ │   ├── blaze_face.onnx
106
+ │   ├── decoder.onnx
107
+ │   ├── face_mesh.onnx
108
+ │   ├── hubert.onnx
109
+ │   ├── insightface_det.onnx
110
+ │   ├── landmark106.onnx
111
+ │   ├── landmark203.onnx
112
+ │   ├── libgrid_sample_3d_plugin.so
113
+ │   ├── lmdm_v0.4_hubert.onnx
114
+ │   ├── motion_extractor.onnx
115
+ │   ├── stitch_network.onnx
116
+ │   └── warp_network.onnx
117
+ └── ditto_trt_Ampere_Plus
118
+ ├── appearance_extractor_fp16.engine
119
+ ├── blaze_face_fp16.engine
120
+ ├── decoder_fp16.engine
121
+ ├── face_mesh_fp16.engine
122
+ ├── hubert_fp32.engine
123
+ ├── insightface_det_fp16.engine
124
+ ├── landmark106_fp16.engine
125
+ ├── landmark203_fp16.engine
126
+ ├── lmdm_v0.4_hubert_fp32.engine
127
+ ├── motion_extractor_fp32.engine
128
+ ├── stitch_network_fp16.engine
129
+ └── warp_network_fp16.engine
130
+ ```
131
+
132
+ - The `ditto_cfg/v0.4_hubert_cfg_trt_online.pkl` is online config
133
+ - The `ditto_cfg/v0.4_hubert_cfg_trt.pkl` is offline config
134
+
135
+
136
+ ## 🚀 Inference
137
+
138
+ Run `inference.py`:
139
+
140
+ ```shell
141
+ python inference.py \
142
+ --data_root "<path-to-trt-model>" \
143
+ --cfg_pkl "<path-to-cfg-pkl>" \
144
+ --audio_path "<path-to-input-audio>" \
145
+ --source_path "<path-to-input-image>" \
146
+ --output_path "<path-to-output-mp4>"
147
+ ```
148
+
149
+ For example:
150
+
151
+ ```shell
152
+ python inference.py \
153
+ --data_root "./checkpoints/ditto_trt_Ampere_Plus" \
154
+ --cfg_pkl "./checkpoints/ditto_cfg/v0.4_hubert_cfg_trt.pkl" \
155
+ --audio_path "./example/audio.wav" \
156
+ --source_path "./example/image.png" \
157
+ --output_path "./tmp/result.mp4"
158
+ ```
159
+
160
+ ❗Note:
161
+
162
+ We have provided the tensorRT model with `hardware-compatibility-level=Ampere_Plus` (`checkpoints/ditto_trt_Ampere_Plus/`). If your GPU does not support it, please execute the `cvt_onnx_to_trt.py` script to convert from the general onnx model (`checkpoints/ditto_onnx/`) to the tensorRT model.
163
+
164
+ ```bash
165
+ python scripts/cvt_onnx_to_trt.py --onnx_dir "./checkpoints/ditto_onnx" --trt_dir "./checkpoints/ditto_trt_custom"
166
+ ```
167
+
168
+ Then run `inference.py` with `--data_root=./checkpoints/ditto_trt_custom`.
169
+
170
+
171
+ ## ⚡ PyTorch Model
172
+ *Based on community interest and to better support further development, we are now open-sourcing the PyTorch version of the model.*
173
+
174
+
175
+ We have added the PyTorch model and corresponding configuration files to the [HuggingFace](https://huggingface.co/digital-avatar/ditto-talkinghead). Please refer to [Download Checkpoints](#-download-checkpoints) to prepare the model files.
176
+
177
+ The `checkpoints` should be like:
178
+ ```text
179
+ ./checkpoints/
180
+ ├── ditto_cfg
181
+ │   ├── ...
182
+ │   └── v0.4_hubert_cfg_pytorch.pkl
183
+ ├── ...
184
+ └── ditto_pytorch
185
+ ├── aux_models
186
+ │ ├── 2d106det.onnx
187
+ │ ├── det_10g.onnx
188
+ │ ├── face_landmarker.task
189
+ │ ├── hubert_streaming_fix_kv.onnx
190
+ │ └── landmark203.onnx
191
+ └── models
192
+ ├── appearance_extractor.pth
193
+ ├── decoder.pth
194
+ ├── lmdm_v0.4_hubert.pth
195
+ ├── motion_extractor.pth
196
+ ├── stitch_network.pth
197
+ └── warp_network.pth
198
+ ```
199
+
200
+ To run inference, execute the following command:
201
+
202
+ ```shell
203
+ python inference.py \
204
+ --data_root "./checkpoints/ditto_pytorch" \
205
+ --cfg_pkl "./checkpoints/ditto_cfg/v0.4_hubert_cfg_pytorch.pkl" \
206
+ --audio_path "./example/audio.wav" \
207
+ --source_path "./example/image.png" \
208
+ --output_path "./tmp/result.mp4"
209
+ ```
210
+
211
+
212
+ ## 📧 Acknowledgement
213
+ Our implementation is based on [S2G-MDDiffusion](https://github.com/thuhcsi/S2G-MDDiffusion) and [LivePortrait](https://github.com/KwaiVGI/LivePortrait). Thanks for their remarkable contribution and released code! If we missed any open-source projects or related articles, we would like to complement the acknowledgement of this specific work immediately.
214
+
215
+ ## ⚖️ License
216
+ This repository is released under the Apache-2.0 license as found in the [LICENSE](LICENSE) file.
217
+
218
+ ## 📚 Citation
219
+ If you find this codebase useful for your research, please use the following entry.
220
+ ```BibTeX
221
+ @article{li2024ditto,
222
+ title={Ditto: Motion-Space Diffusion for Controllable Realtime Talking Head Synthesis},
223
+ author={Li, Tianqi and Zheng, Ruobing and Yang, Minghui and Chen, Jingdong and Yang, Ming},
224
+ journal={arXiv preprint arXiv:2411.19509},
225
+ year={2024}
226
+ }
227
+ ```
228
+
229
+
230
+ ## 🌟 Star History
231
+
232
+ [![Star History Chart](https://api.star-history.com/svg?repos=antgroup/ditto-talkinghead&type=Date)](https://www.star-history.com/#antgroup/ditto-talkinghead&Date)
core/atomic_components/audio2motion.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from ..models.lmdm import LMDM
3
+
4
+
5
+ """
6
+ lmdm_cfg = {
7
+ "model_path": "",
8
+ "device": "cuda",
9
+ "motion_feat_dim": 265,
10
+ "audio_feat_dim": 1024+35,
11
+ "seq_frames": 80,
12
+ }
13
+ """
14
+
15
+
16
+ def _cvt_LP_motion_info(inp, mode, ignore_keys=()):
17
+ ks_shape_map = [
18
+ ['scale', (1, 1), 1],
19
+ ['pitch', (1, 66), 66],
20
+ ['yaw', (1, 66), 66],
21
+ ['roll', (1, 66), 66],
22
+ ['t', (1, 3), 3],
23
+ ['exp', (1, 63), 63],
24
+ ['kp', (1, 63), 63],
25
+ ]
26
+
27
+ def _dic2arr(_dic):
28
+ arr = []
29
+ for k, _, ds in ks_shape_map:
30
+ if k not in _dic or k in ignore_keys:
31
+ continue
32
+ v = _dic[k].reshape(ds)
33
+ if k == 'scale':
34
+ v = v - 1
35
+ arr.append(v)
36
+ arr = np.concatenate(arr, -1) # (133)
37
+ return arr
38
+
39
+ def _arr2dic(_arr):
40
+ dic = {}
41
+ s = 0
42
+ for k, ds, ss in ks_shape_map:
43
+ if k in ignore_keys:
44
+ continue
45
+ v = _arr[s:s + ss].reshape(ds)
46
+ if k == 'scale':
47
+ v = v + 1
48
+ dic[k] = v
49
+ s += ss
50
+ if s >= len(_arr):
51
+ break
52
+ return dic
53
+
54
+ if mode == 'dic2arr':
55
+ assert isinstance(inp, dict)
56
+ return _dic2arr(inp) # (dim)
57
+ elif mode == 'arr2dic':
58
+ assert inp.shape[0] >= 265, f"{inp.shape}"
59
+ return _arr2dic(inp) # {k: (1, dim)}
60
+ else:
61
+ raise ValueError()
62
+
63
+
64
+ class Audio2Motion:
65
+ def __init__(
66
+ self,
67
+ lmdm_cfg,
68
+ ):
69
+ self.lmdm = LMDM(**lmdm_cfg)
70
+
71
+ def setup(
72
+ self,
73
+ x_s_info,
74
+ overlap_v2=10,
75
+ fix_kp_cond=0,
76
+ fix_kp_cond_dim=None,
77
+ sampling_timesteps=50,
78
+ online_mode=False,
79
+ v_min_max_for_clip=None,
80
+ smo_k_d=3,
81
+ ):
82
+ self.smo_k_d = smo_k_d
83
+ self.overlap_v2 = overlap_v2
84
+ self.seq_frames = self.lmdm.seq_frames
85
+ self.valid_clip_len = self.seq_frames - self.overlap_v2
86
+
87
+ # for fuse
88
+ self.online_mode = online_mode
89
+ if self.online_mode:
90
+ self.fuse_length = min(self.overlap_v2, self.valid_clip_len)
91
+ else:
92
+ self.fuse_length = self.overlap_v2
93
+ self.fuse_alpha = np.arange(self.fuse_length, dtype=np.float32).reshape(1, -1, 1) / self.fuse_length
94
+
95
+ self.fix_kp_cond = fix_kp_cond
96
+ self.fix_kp_cond_dim = fix_kp_cond_dim
97
+ self.sampling_timesteps = sampling_timesteps
98
+
99
+ self.v_min_max_for_clip = v_min_max_for_clip
100
+ if self.v_min_max_for_clip is not None:
101
+ self.v_min = self.v_min_max_for_clip[0][None] # [dim, 1]
102
+ self.v_max = self.v_min_max_for_clip[1][None]
103
+
104
+ kp_source = _cvt_LP_motion_info(x_s_info, mode='dic2arr', ignore_keys={'kp'})[None]
105
+ self.s_kp_cond = kp_source.copy().reshape(1, -1)
106
+ self.kp_cond = self.s_kp_cond.copy()
107
+
108
+ self.lmdm.setup(sampling_timesteps)
109
+
110
+ self.clip_idx = 0
111
+
112
+ def _fuse(self, res_kp_seq, pred_kp_seq):
113
+ ## ========================
114
+ ## offline fuse mode
115
+ ## last clip: -------
116
+ ## fuse part: *****
117
+ ## curr clip: -------
118
+ ## output: ^^
119
+ #
120
+ ## online fuse mode
121
+ ## last clip: -------
122
+ ## fuse part: **
123
+ ## curr clip: -------
124
+ ## output: ^^
125
+ ## ========================
126
+
127
+ fuse_r1_s = res_kp_seq.shape[1] - self.fuse_length
128
+ fuse_r1_e = res_kp_seq.shape[1]
129
+ fuse_r2_s = self.seq_frames - self.valid_clip_len - self.fuse_length
130
+ fuse_r2_e = self.seq_frames - self.valid_clip_len
131
+
132
+ r1 = res_kp_seq[:, fuse_r1_s:fuse_r1_e] # [1, fuse_len, dim]
133
+ r2 = pred_kp_seq[:, fuse_r2_s: fuse_r2_e] # [1, fuse_len, dim]
134
+ r_fuse = r1 * (1 - self.fuse_alpha) + r2 * self.fuse_alpha
135
+
136
+ res_kp_seq[:, fuse_r1_s:fuse_r1_e] = r_fuse # fuse last
137
+ res_kp_seq = np.concatenate([res_kp_seq, pred_kp_seq[:, fuse_r2_e:]], 1) # len(res_kp_seq) + valid_clip_len
138
+
139
+ return res_kp_seq
140
+
141
+ def _update_kp_cond(self, res_kp_seq, idx):
142
+ if self.fix_kp_cond == 0: # 不重置
143
+ self.kp_cond = res_kp_seq[:, idx-1]
144
+ elif self.fix_kp_cond > 0:
145
+ if self.clip_idx % self.fix_kp_cond == 0: # 重置
146
+ self.kp_cond = self.s_kp_cond.copy() # 重置所有
147
+ if self.fix_kp_cond_dim is not None:
148
+ ds, de = self.fix_kp_cond_dim
149
+ self.kp_cond[:, ds:de] = res_kp_seq[:, idx-1, ds:de]
150
+ else:
151
+ self.kp_cond = res_kp_seq[:, idx-1]
152
+
153
+ def _smo(self, res_kp_seq, s, e):
154
+ if self.smo_k_d <= 1:
155
+ return res_kp_seq
156
+ new_res_kp_seq = res_kp_seq.copy()
157
+ n = res_kp_seq.shape[1]
158
+ half_k = self.smo_k_d // 2
159
+ for i in range(s, e):
160
+ ss = max(0, i - half_k)
161
+ ee = min(n, i + half_k + 1)
162
+ res_kp_seq[:, i, :202] = np.mean(new_res_kp_seq[:, ss:ee, :202], axis=1)
163
+ return res_kp_seq
164
+
165
+ def __call__(self, aud_cond, res_kp_seq=None):
166
+ """
167
+ aud_cond: (1, seq_frames, dim)
168
+ """
169
+
170
+ pred_kp_seq = self.lmdm(self.kp_cond, aud_cond, self.sampling_timesteps)
171
+ if res_kp_seq is None:
172
+ res_kp_seq = pred_kp_seq # [1, seq_frames, dim]
173
+ res_kp_seq = self._smo(res_kp_seq, 0, res_kp_seq.shape[1])
174
+ else:
175
+ res_kp_seq = self._fuse(res_kp_seq, pred_kp_seq) # len(res_kp_seq) + valid_clip_len
176
+ res_kp_seq = self._smo(res_kp_seq, res_kp_seq.shape[1] - self.valid_clip_len - self.fuse_length, res_kp_seq.shape[1] - self.valid_clip_len + 1)
177
+
178
+ self.clip_idx += 1
179
+
180
+ idx = res_kp_seq.shape[1] - self.overlap_v2
181
+ self._update_kp_cond(res_kp_seq, idx)
182
+
183
+ return res_kp_seq
184
+
185
+ def cvt_fmt(self, res_kp_seq):
186
+ # res_kp_seq: [1, n, dim]
187
+ if self.v_min_max_for_clip is not None:
188
+ tmp_res_kp_seq = np.clip(res_kp_seq[0], self.v_min, self.v_max)
189
+ else:
190
+ tmp_res_kp_seq = res_kp_seq[0]
191
+
192
+ x_d_info_list = []
193
+ for i in range(tmp_res_kp_seq.shape[0]):
194
+ x_d_info = _cvt_LP_motion_info(tmp_res_kp_seq[i], 'arr2dic') # {k: (1, dim)}
195
+ x_d_info_list.append(x_d_info)
196
+ return x_d_info_list
core/atomic_components/avatar_registrar.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from .loader import load_source_frames
4
+ from .source2info import Source2Info
5
+
6
+
7
+ def _mean_filter(arr, k):
8
+ n = arr.shape[0]
9
+ half_k = k // 2
10
+ res = []
11
+ for i in range(n):
12
+ s = max(0, i - half_k)
13
+ e = min(n, i + half_k + 1)
14
+ res.append(arr[s:e].mean(0))
15
+ res = np.stack(res, 0)
16
+ return res
17
+
18
+
19
+ def smooth_x_s_info_lst(x_s_info_list, ignore_keys=(), smo_k=13):
20
+ keys = x_s_info_list[0].keys()
21
+ N = len(x_s_info_list)
22
+ smo_dict = {}
23
+ for k in keys:
24
+ _lst = [x_s_info_list[i][k] for i in range(N)]
25
+ if k not in ignore_keys:
26
+ _lst = np.stack(_lst, 0)
27
+ _smo_lst = _mean_filter(_lst, smo_k)
28
+ else:
29
+ _smo_lst = _lst
30
+ smo_dict[k] = _smo_lst
31
+
32
+ smo_res = []
33
+ for i in range(N):
34
+ x_s_info = {k: smo_dict[k][i] for k in keys}
35
+ smo_res.append(x_s_info)
36
+ return smo_res
37
+
38
+
39
+ class AvatarRegistrar:
40
+ """
41
+ source image|video -> rgb_list -> source_info
42
+ """
43
+ def __init__(
44
+ self,
45
+ insightface_det_cfg,
46
+ landmark106_cfg,
47
+ landmark203_cfg,
48
+ landmark478_cfg,
49
+ appearance_extractor_cfg,
50
+ motion_extractor_cfg,
51
+ ):
52
+ self.source2info = Source2Info(
53
+ insightface_det_cfg,
54
+ landmark106_cfg,
55
+ landmark203_cfg,
56
+ landmark478_cfg,
57
+ appearance_extractor_cfg,
58
+ motion_extractor_cfg,
59
+ )
60
+
61
+ def register(
62
+ self,
63
+ source_path, # image | video
64
+ max_dim=1920,
65
+ n_frames=-1,
66
+ **kwargs,
67
+ ):
68
+ """
69
+ kwargs:
70
+ crop_scale: 2.3
71
+ crop_vx_ratio: 0
72
+ crop_vy_ratio: -0.125
73
+ crop_flag_do_rot: True
74
+ """
75
+ rgb_list, is_image_flag = load_source_frames(source_path, max_dim=max_dim, n_frames=n_frames)
76
+ source_info = {
77
+ "x_s_info_lst": [],
78
+ "f_s_lst": [],
79
+ "M_c2o_lst": [],
80
+ "eye_open_lst": [],
81
+ "eye_ball_lst": [],
82
+ }
83
+ keys = ["x_s_info", "f_s", "M_c2o", "eye_open", "eye_ball"]
84
+ last_lmk = None
85
+ for rgb in rgb_list:
86
+ info = self.source2info(rgb, last_lmk, **kwargs)
87
+ for k in keys:
88
+ source_info[f"{k}_lst"].append(info[k])
89
+
90
+ last_lmk = info["lmk203"]
91
+
92
+ sc_f0 = source_info['x_s_info_lst'][0]['kp'].flatten()
93
+
94
+ source_info["sc"] = sc_f0
95
+ source_info["is_image_flag"] = is_image_flag
96
+ source_info["img_rgb_lst"] = rgb_list
97
+
98
+ return source_info
99
+
100
+ def __call__(self, *args, **kwargs):
101
+ return self.register(*args, **kwargs)
102
+
core/atomic_components/cfg.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import numpy as np
4
+
5
+
6
+ def load_pkl(pkl):
7
+ with open(pkl, "rb") as f:
8
+ return pickle.load(f)
9
+
10
+
11
+ def parse_cfg(cfg_pkl, data_root, replace_cfg=None):
12
+
13
+ def _check_path(p):
14
+ if os.path.isfile(p):
15
+ return p
16
+ else:
17
+ return os.path.join(data_root, p)
18
+
19
+ cfg = load_pkl(cfg_pkl)
20
+
21
+ # ---
22
+ # replace cfg for debug
23
+ if isinstance(replace_cfg, dict):
24
+ for k, v in replace_cfg.items():
25
+ if not isinstance(v, dict):
26
+ continue
27
+ for kk, vv in v.items():
28
+ cfg[k][kk] = vv
29
+ # ---
30
+
31
+ base_cfg = cfg["base_cfg"]
32
+ audio2motion_cfg = cfg["audio2motion_cfg"]
33
+ default_kwargs = cfg["default_kwargs"]
34
+
35
+ for k in base_cfg:
36
+ if k == "landmark478_cfg":
37
+ for kk in ["task_path", "blaze_face_model_path", "face_mesh_model_path"]:
38
+ if kk in base_cfg[k] and base_cfg[k][kk]:
39
+ base_cfg[k][kk] = _check_path(base_cfg[k][kk])
40
+ else:
41
+ base_cfg[k]["model_path"] = _check_path(base_cfg[k]["model_path"])
42
+
43
+ audio2motion_cfg["model_path"] = _check_path(audio2motion_cfg["model_path"])
44
+
45
+ avatar_registrar_cfg = {
46
+ k: base_cfg[k]
47
+ for k in [
48
+ "insightface_det_cfg",
49
+ "landmark106_cfg",
50
+ "landmark203_cfg",
51
+ "landmark478_cfg",
52
+ "appearance_extractor_cfg",
53
+ "motion_extractor_cfg",
54
+ ]
55
+ }
56
+
57
+ stitch_network_cfg = base_cfg["stitch_network_cfg"]
58
+ warp_network_cfg = base_cfg["warp_network_cfg"]
59
+ decoder_cfg = base_cfg["decoder_cfg"]
60
+
61
+ condition_handler_cfg = {
62
+ k: audio2motion_cfg[k]
63
+ for k in [
64
+ "use_emo",
65
+ "use_sc",
66
+ "use_eye_open",
67
+ "use_eye_ball",
68
+ "seq_frames",
69
+ ]
70
+ }
71
+
72
+ lmdm_cfg = {
73
+ k: audio2motion_cfg[k]
74
+ for k in [
75
+ "model_path",
76
+ "device",
77
+ "motion_feat_dim",
78
+ "audio_feat_dim",
79
+ "seq_frames",
80
+ ]
81
+ }
82
+
83
+ w2f_type = audio2motion_cfg["w2f_type"]
84
+ wav2feat_cfg = {
85
+ "w2f_cfg": base_cfg["hubert_cfg"] if w2f_type == "hubert" else base_cfg["wavlm_cfg"],
86
+ "w2f_type": w2f_type,
87
+ }
88
+
89
+ return [
90
+ avatar_registrar_cfg,
91
+ condition_handler_cfg,
92
+ lmdm_cfg,
93
+ stitch_network_cfg,
94
+ warp_network_cfg,
95
+ decoder_cfg,
96
+ wav2feat_cfg,
97
+ default_kwargs,
98
+ ]
99
+
100
+
101
+ def print_cfg(**kwargs):
102
+ for k, v in kwargs.items():
103
+ if k == "ch_info":
104
+ print(k, type(v))
105
+ elif k == "ctrl_info":
106
+ print(k, type(v), len(v))
107
+ else:
108
+ if isinstance(v, np.ndarray):
109
+ print(k, type(v), v.shape)
110
+ else:
111
+ print(k, type(v), v)
core/atomic_components/condition_handler.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from scipy.special import softmax
3
+ import copy
4
+
5
+
6
+ def _get_emo_avg(idx=6):
7
+ emo_avg = np.zeros(8, dtype=np.float32)
8
+ if isinstance(idx, (list, tuple)):
9
+ for i in idx:
10
+ emo_avg[i] = 8
11
+ else:
12
+ emo_avg[idx] = 8
13
+ emo_avg = softmax(emo_avg)
14
+ #emo_avg = None
15
+ # 'Angry', 'Disgust', 'Fear', 'Happy', 'Neutral', 'Sad', 'Surprise', 'Contempt'
16
+ return emo_avg
17
+
18
+
19
+ def _mirror_index(index, size):
20
+ turn = index // size
21
+ res = index % size
22
+ if turn % 2 == 0:
23
+ return res
24
+ else:
25
+ return size - res - 1
26
+
27
+
28
+ class ConditionHandler:
29
+ """
30
+ aud_feat, emo_seq, eye_seq, sc_seq -> cond_seq
31
+ """
32
+ def __init__(
33
+ self,
34
+ use_emo=True,
35
+ use_sc=True,
36
+ use_eye_open=True,
37
+ use_eye_ball=True,
38
+ seq_frames=80,
39
+ ):
40
+ self.use_emo = use_emo
41
+ self.use_sc = use_sc
42
+ self.use_eye_open = use_eye_open
43
+ self.use_eye_ball = use_eye_ball
44
+
45
+ self.seq_frames = seq_frames
46
+
47
+ def setup(self, setup_info, emo, eye_f0_mode=False, ch_info=None):
48
+ """
49
+ emo: int | [int] | [[int]] | numpy
50
+ """
51
+ if ch_info is None:
52
+ source_info = copy.deepcopy(setup_info)
53
+ else:
54
+ source_info = ch_info
55
+
56
+ self.eye_f0_mode = eye_f0_mode
57
+ self.x_s_info_0 = source_info['x_s_info_lst'][0]
58
+
59
+ if self.use_sc:
60
+ self.sc = source_info["sc"] # 63
61
+ self.sc_seq = np.stack([self.sc] * self.seq_frames, 0)
62
+
63
+ if self.use_eye_open:
64
+ self.eye_open_lst = np.concatenate(source_info["eye_open_lst"], 0) # [n, 2]
65
+ self.num_eye_open = len(self.eye_open_lst)
66
+ if self.num_eye_open == 1 or self.eye_f0_mode:
67
+ self.eye_open_seq = np.stack([self.eye_open_lst[0]] * self.seq_frames, 0)
68
+ else:
69
+ self.eye_open_seq = None
70
+
71
+ if self.use_eye_ball:
72
+ self.eye_ball_lst = np.concatenate(source_info["eye_ball_lst"], 0) # [n, 6]
73
+ self.num_eye_ball = len(self.eye_ball_lst)
74
+ if self.num_eye_ball == 1 or self.eye_f0_mode:
75
+ self.eye_ball_seq = np.stack([self.eye_ball_lst[0]] * self.seq_frames, 0)
76
+ else:
77
+ self.eye_ball_seq = None
78
+
79
+ if self.use_emo:
80
+ self.emo_lst = self._parse_emo_seq(emo)
81
+ self.num_emo = len(self.emo_lst)
82
+ if self.num_emo == 1:
83
+ self.emo_seq = np.concatenate([self.emo_lst] * self.seq_frames, 0)
84
+ else:
85
+ self.emo_seq = None
86
+
87
+ @staticmethod
88
+ def _parse_emo_seq(emo, seq_len=-1):
89
+ if isinstance(emo, np.ndarray) and emo.ndim == 2 and emo.shape[1] == 8:
90
+ # emo arr, e.g. real
91
+ emo_seq = emo # [m, 8]
92
+ elif isinstance(emo, int) and 0 <= emo < 8:
93
+ # emo label, e.g. 4
94
+ emo_seq = _get_emo_avg(emo).reshape(1, 8) # [1, 8]
95
+ elif isinstance(emo, (list, tuple)) and 0 < len(emo) < 8 and isinstance(emo[0], int):
96
+ # emo labels, e.g. [3,4]
97
+ emo_seq = _get_emo_avg(emo).reshape(1, 8) # [1, 8]
98
+ elif isinstance(emo, list) and emo and isinstance(emo[0], (list, tuple)):
99
+ # emo label list, e.g. [[4], [3,4], [3],[3,4,5], ...]
100
+ emo_seq = np.stack([_get_emo_avg(i) for i in emo], 0) # [m, 8]
101
+ else:
102
+ raise ValueError(f"Unsupported emo type: {emo}")
103
+
104
+ if seq_len > 0:
105
+ if len(emo_seq) == seq_len:
106
+ return emo_seq
107
+ elif len(emo_seq) == 1:
108
+ return np.concatenate([emo_seq] * seq_len, 0)
109
+ elif len(emo_seq) > seq_len:
110
+ return emo_seq[:seq_len]
111
+ else:
112
+ raise ValueError(f"emo len {len(emo_seq)} can not match seq len ({seq_len})")
113
+ else:
114
+ return emo_seq
115
+
116
+ def __call__(self, aud_feat, idx, emo=None):
117
+ """
118
+ aud_feat: [n, 1024]
119
+ idx: int, <0 means pad (first clip buffer)
120
+ """
121
+
122
+ frame_num = len(aud_feat)
123
+ more_cond = [aud_feat]
124
+ if self.use_emo:
125
+ if emo is not None:
126
+ emo_seq = self._parse_emo_seq(emo, frame_num)
127
+ elif self.emo_seq is not None and len(self.emo_seq) == frame_num:
128
+ emo_seq = self.emo_seq
129
+ else:
130
+ emo_idx_list = [max(i, 0) % self.num_emo for i in range(idx, idx + frame_num)]
131
+ emo_seq = self.emo_lst[emo_idx_list]
132
+ more_cond.append(emo_seq)
133
+
134
+ if self.use_eye_open:
135
+ if self.eye_open_seq is not None and len(self.eye_open_seq) == frame_num:
136
+ eye_open_seq = self.eye_open_seq
137
+ else:
138
+ if self.eye_f0_mode:
139
+ eye_idx_list = [0] * frame_num
140
+ else:
141
+ eye_idx_list = [_mirror_index(max(i, 0), self.num_eye_open) for i in range(idx, idx + frame_num)]
142
+ eye_open_seq = self.eye_open_lst[eye_idx_list]
143
+ more_cond.append(eye_open_seq)
144
+
145
+ if self.use_eye_ball:
146
+ if self.eye_ball_seq is not None and len(self.eye_ball_seq) == frame_num:
147
+ eye_ball_seq = self.eye_ball_seq
148
+ else:
149
+ if self.eye_f0_mode:
150
+ eye_idx_list = [0] * frame_num
151
+ else:
152
+ eye_idx_list = [_mirror_index(max(i, 0), self.num_eye_ball) for i in range(idx, idx + frame_num)]
153
+ eye_ball_seq = self.eye_ball_lst[eye_idx_list]
154
+ more_cond.append(eye_ball_seq)
155
+
156
+ if self.use_sc:
157
+ if len(self.sc_seq) == frame_num:
158
+ sc_seq = self.sc_seq
159
+ else:
160
+ sc_seq = np.stack([self.sc] * frame_num, 0)
161
+ more_cond.append(sc_seq)
162
+
163
+ if len(more_cond) > 1:
164
+ cond_seq = np.concatenate(more_cond, -1) # [n, dim_cond]
165
+ else:
166
+ cond_seq = aud_feat
167
+
168
+ return cond_seq
core/atomic_components/decode_f3d.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..models.decoder import Decoder
2
+
3
+
4
+ """
5
+ # __init__
6
+ decoder_cfg = {
7
+ "model_path": "",
8
+ "device": "cuda",
9
+ }
10
+ """
11
+
12
+ class DecodeF3D:
13
+ def __init__(
14
+ self,
15
+ decoder_cfg,
16
+ ):
17
+ self.decoder = Decoder(**decoder_cfg)
18
+
19
+ def __call__(self, f_s):
20
+ out = self.decoder(f_s)
21
+ return out
22
+
core/atomic_components/loader.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import filetype
2
+ import imageio
3
+ import cv2
4
+
5
+
6
+ def is_image(file_path):
7
+ return filetype.is_image(file_path)
8
+
9
+
10
+ def is_video(file_path):
11
+ return filetype.is_video(file_path)
12
+
13
+
14
+ def check_resize(h, w, max_dim=1920, division=2):
15
+ rsz_flag = False
16
+ # ajust the size of the image according to the maximum dimension
17
+ if max_dim > 0 and max(h, w) > max_dim:
18
+ rsz_flag = True
19
+ if h > w:
20
+ new_h = max_dim
21
+ new_w = int(round(w * max_dim / h))
22
+ else:
23
+ new_w = max_dim
24
+ new_h = int(round(h * max_dim / w))
25
+ else:
26
+ new_h = h
27
+ new_w = w
28
+
29
+ # ensure that the image dimensions are multiples of n
30
+ if new_h % division != 0:
31
+ new_h = new_h - (new_h % division)
32
+ rsz_flag = True
33
+ if new_w % division != 0:
34
+ new_w = new_w - (new_w % division)
35
+ rsz_flag = True
36
+
37
+ return new_h, new_w, rsz_flag
38
+
39
+
40
+ def load_image(image_path, max_dim=-1):
41
+ img = cv2.imread(image_path, cv2.IMREAD_COLOR)
42
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
43
+ h, w = img.shape[:2]
44
+ new_h, new_w, rsz_flag = check_resize(h, w, max_dim)
45
+ if rsz_flag:
46
+ img = cv2.resize(img, (new_w, new_h))
47
+ return img
48
+
49
+
50
+ def load_video(video_path, n_frames=-1, max_dim=-1):
51
+ reader = imageio.get_reader(video_path, "ffmpeg")
52
+
53
+ new_h, new_w, rsz_flag = None, None, None
54
+
55
+ ret = []
56
+ for idx, frame_rgb in enumerate(reader):
57
+ if n_frames > 0 and idx >= n_frames:
58
+ break
59
+
60
+ if rsz_flag is None:
61
+ h, w = frame_rgb.shape[:2]
62
+ new_h, new_w, rsz_flag = check_resize(h, w, max_dim)
63
+
64
+ if rsz_flag:
65
+ frame_rgb = cv2.resize(frame_rgb, (new_w, new_h))
66
+
67
+ ret.append(frame_rgb)
68
+
69
+ reader.close()
70
+ return ret
71
+
72
+
73
+ def load_source_frames(source_path, max_dim=-1, n_frames=-1):
74
+ if is_image(source_path):
75
+ rgb = load_image(source_path, max_dim)
76
+ rgb_list = [rgb]
77
+ is_image_flag = True
78
+ elif is_video(source_path):
79
+ rgb_list = load_video(source_path, n_frames, max_dim)
80
+ is_image_flag = False
81
+ else:
82
+ raise ValueError(f"Unsupported source type: {source_path}")
83
+ return rgb_list, is_image_flag
84
+
85
+
86
+ def _mirror_index(index, size):
87
+ turn = index // size
88
+ res = index % size
89
+ if turn % 2 == 0:
90
+ return res
91
+ else:
92
+ return size - res - 1
93
+
94
+
95
+ class LoopLoader:
96
+ def __init__(self, item_list, max_iter_num=-1, mirror_loop=True):
97
+ self.item_list = item_list
98
+ self.idx = 0
99
+ self.item_num = len(self.item_list)
100
+ self.max_iter_num = max_iter_num if max_iter_num > 0 else self.item_num
101
+ self.mirror_loop = mirror_loop
102
+
103
+ def __len__(self):
104
+ return self.max_iter_num
105
+
106
+ def __iter__(self):
107
+ return self
108
+
109
+ def __next__(self):
110
+ if self.idx >= self.max_iter_num:
111
+ raise StopIteration
112
+
113
+ if self.mirror_loop:
114
+ idx = _mirror_index(self.idx, self.item_num)
115
+ else:
116
+ idx = self.idx % self.item_num
117
+ item = self.item_list[idx]
118
+
119
+ self.idx += 1
120
+ return item
121
+
122
+ def __call__(self):
123
+ return self.__iter__()
124
+
125
+ def reset(self, max_iter_num=-1):
126
+ self.frame_idx = 0
127
+ self.max_iter_num = max_iter_num if max_iter_num > 0 else self.item_num
128
+
129
+
130
+
131
+
132
+
133
+
core/atomic_components/motion_stitch.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import random
3
+ import numpy as np
4
+ from scipy.special import softmax
5
+
6
+ from ..models.stitch_network import StitchNetwork
7
+
8
+
9
+ """
10
+ # __init__
11
+ stitch_network_cfg = {
12
+ "model_path": "",
13
+ "device": "cuda",
14
+ }
15
+
16
+ # __call__
17
+ kwargs:
18
+ fade_alpha
19
+ fade_out_keys
20
+
21
+ delta_pitch
22
+ delta_yaw
23
+ delta_roll
24
+
25
+ """
26
+
27
+
28
+ def ctrl_motion(x_d_info, **kwargs):
29
+ # pose + offset
30
+ for kk in ["delta_pitch", "delta_yaw", "delta_roll"]:
31
+ if kk in kwargs:
32
+ k = kk[6:]
33
+ x_d_info[k] = bin66_to_degree(x_d_info[k]) + kwargs[kk]
34
+
35
+ # pose * alpha
36
+ for kk in ["alpha_pitch", "alpha_yaw", "alpha_roll"]:
37
+ if kk in kwargs:
38
+ k = kk[6:]
39
+ x_d_info[k] = x_d_info[k] * kwargs[kk]
40
+
41
+ # exp + offset
42
+ if "delta_exp" in kwargs:
43
+ k = "exp"
44
+ x_d_info[k] = x_d_info[k] + kwargs["delta_exp"]
45
+
46
+ return x_d_info
47
+
48
+
49
+ def fade(x_d_info, dst, alpha, keys=None):
50
+ if keys is None:
51
+ keys = x_d_info.keys()
52
+ for k in keys:
53
+ if k == 'kp':
54
+ continue
55
+ x_d_info[k] = x_d_info[k] * alpha + dst[k] * (1 - alpha)
56
+ return x_d_info
57
+
58
+
59
+ def ctrl_vad(x_d_info, dst, alpha):
60
+ exp = x_d_info["exp"]
61
+ exp_dst = dst["exp"]
62
+
63
+ _lip = [6, 12, 14, 17, 19, 20]
64
+ _a1 = np.zeros((21, 3), dtype=np.float32)
65
+ _a1[_lip] = alpha
66
+ _a1 = _a1.reshape(1, -1)
67
+ x_d_info["exp"] = exp * alpha + exp_dst * (1 - alpha)
68
+
69
+ return x_d_info
70
+
71
+
72
+
73
+ def _mix_s_d_info(
74
+ x_s_info,
75
+ x_d_info,
76
+ use_d_keys=("exp", "pitch", "yaw", "roll", "t"),
77
+ d0=None,
78
+ ):
79
+ if d0 is not None:
80
+ if isinstance(use_d_keys, dict):
81
+ x_d_info = {
82
+ k: x_s_info[k] + (v - d0[k]) * use_d_keys.get(k, 1)
83
+ for k, v in x_d_info.items()
84
+ }
85
+ else:
86
+ x_d_info = {k: x_s_info[k] + (v - d0[k]) for k, v in x_d_info.items()}
87
+
88
+ for k, v in x_s_info.items():
89
+ if k not in x_d_info or k not in use_d_keys:
90
+ x_d_info[k] = v
91
+
92
+ if isinstance(use_d_keys, dict) and d0 is None:
93
+ for k, alpha in use_d_keys.items():
94
+ x_d_info[k] *= alpha
95
+ return x_d_info
96
+
97
+
98
+ def _set_eye_blink_idx(N, blink_n=15, open_n=-1):
99
+ """
100
+ open_n:
101
+ -1: no blink
102
+ 0: random open_n
103
+ >0: fix open_n
104
+ list: loop open_n
105
+ """
106
+ OPEN_MIN = 60
107
+ OPEN_MAX = 100
108
+
109
+ idx = [0] * N
110
+ if isinstance(open_n, int):
111
+ if open_n < 0: # no blink
112
+ return idx
113
+ elif open_n > 0: # fix open_n
114
+ open_ns = [open_n]
115
+ else: # open_n == 0: # random open_n, 60-100
116
+ open_ns = []
117
+ elif isinstance(open_n, list):
118
+ open_ns = open_n # loop open_n
119
+ else:
120
+ raise ValueError()
121
+
122
+ blink_idx = list(range(blink_n))
123
+
124
+ start_n = open_ns[0] if open_ns else random.randint(OPEN_MIN, OPEN_MAX)
125
+ end_n = open_ns[-1] if open_ns else random.randint(OPEN_MIN, OPEN_MAX)
126
+ max_i = N - max(end_n, blink_n)
127
+ cur_i = start_n
128
+ cur_n_i = 1
129
+ while cur_i < max_i:
130
+ idx[cur_i : cur_i + blink_n] = blink_idx
131
+
132
+ if open_ns:
133
+ cur_n = open_ns[cur_n_i % len(open_ns)]
134
+ cur_n_i += 1
135
+ else:
136
+ cur_n = random.randint(OPEN_MIN, OPEN_MAX)
137
+
138
+ cur_i = cur_i + blink_n + cur_n
139
+
140
+ return idx
141
+
142
+
143
+ def _fix_exp_for_x_d_info(x_d_info, x_s_info, delta_eye=None, drive_eye=True):
144
+ _eye = [11, 13, 15, 16, 18]
145
+ _lip = [6, 12, 14, 17, 19, 20]
146
+ alpha = np.zeros((21, 3), dtype=x_d_info["exp"].dtype)
147
+ alpha[_lip] = 1
148
+ if delta_eye is None and drive_eye: # use d eye
149
+ alpha[_eye] = 1
150
+ alpha = alpha.reshape(1, -1)
151
+ x_d_info["exp"] = x_d_info["exp"] * alpha + x_s_info["exp"] * (1 - alpha)
152
+
153
+ if delta_eye is not None and drive_eye:
154
+ alpha = np.zeros((21, 3), dtype=x_d_info["exp"].dtype)
155
+ alpha[_eye] = 1
156
+ alpha = alpha.reshape(1, -1)
157
+ x_d_info["exp"] = (delta_eye + x_s_info["exp"]) * alpha + x_d_info["exp"] * (
158
+ 1 - alpha
159
+ )
160
+
161
+ return x_d_info
162
+
163
+
164
+ def _fix_exp_for_x_d_info_v2(x_d_info, x_s_info, delta_eye, a1, a2, a3):
165
+ x_d_info["exp"] = x_d_info["exp"] * a1 + x_s_info["exp"] * a2 + delta_eye * a3
166
+ return x_d_info
167
+
168
+
169
+ def bin66_to_degree(pred):
170
+ if pred.ndim > 1 and pred.shape[1] == 66:
171
+ idx = np.arange(66).astype(np.float32)
172
+ pred = softmax(pred, axis=1)
173
+ degree = np.sum(pred * idx, axis=1) * 3 - 97.5
174
+ return degree
175
+ return pred
176
+
177
+
178
+ def _eye_delta(exp, dx=0, dy=0):
179
+ if dx > 0:
180
+ exp[0, 33] += dx * 0.0007
181
+ exp[0, 45] += dx * 0.001
182
+ else:
183
+ exp[0, 33] += dx * 0.001
184
+ exp[0, 45] += dx * 0.0007
185
+
186
+ exp[0, 34] += dy * -0.001
187
+ exp[0, 46] += dy * -0.001
188
+ return exp
189
+
190
+ def _fix_gaze(pose_s, x_d_info):
191
+ x_ratio = 0.26
192
+ y_ratio = 0.28
193
+
194
+ yaw_s, pitch_s = pose_s
195
+ yaw_d = bin66_to_degree(x_d_info['yaw']).item()
196
+ pitch_d = bin66_to_degree(x_d_info['pitch']).item()
197
+
198
+ delta_yaw = yaw_d - yaw_s
199
+ delta_pitch = pitch_d - pitch_s
200
+
201
+ dx = delta_yaw * x_ratio
202
+ dy = delta_pitch * y_ratio
203
+
204
+ x_d_info['exp'] = _eye_delta(x_d_info['exp'], dx, dy)
205
+ return x_d_info
206
+
207
+
208
+ def get_rotation_matrix(pitch_, yaw_, roll_):
209
+ """ the input is in degree
210
+ """
211
+ # transform to radian
212
+ pitch = pitch_ / 180 * np.pi
213
+ yaw = yaw_ / 180 * np.pi
214
+ roll = roll_ / 180 * np.pi
215
+
216
+ if pitch.ndim == 1:
217
+ pitch = pitch[:, None]
218
+ if yaw.ndim == 1:
219
+ yaw = yaw[:, None]
220
+ if roll.ndim == 1:
221
+ roll = roll[:, None]
222
+
223
+ # calculate the euler matrix
224
+ bs = pitch.shape[0]
225
+ ones = np.ones((bs, 1), dtype=np.float32)
226
+ zeros = np.zeros((bs, 1), dtype=np.float32)
227
+ x, y, z = pitch, yaw, roll
228
+
229
+ rot_x = np.concatenate([
230
+ ones, zeros, zeros,
231
+ zeros, np.cos(x), -np.sin(x),
232
+ zeros, np.sin(x), np.cos(x)
233
+ ], axis=1).reshape(bs, 3, 3)
234
+
235
+ rot_y = np.concatenate([
236
+ np.cos(y), zeros, np.sin(y),
237
+ zeros, ones, zeros,
238
+ -np.sin(y), zeros, np.cos(y)
239
+ ], axis=1).reshape(bs, 3, 3)
240
+
241
+ rot_z = np.concatenate([
242
+ np.cos(z), -np.sin(z), zeros,
243
+ np.sin(z), np.cos(z), zeros,
244
+ zeros, zeros, ones
245
+ ], axis=1).reshape(bs, 3, 3)
246
+
247
+ rot = np.matmul(np.matmul(rot_z, rot_y), rot_x)
248
+ return np.transpose(rot, (0, 2, 1))
249
+
250
+
251
+ def transform_keypoint(kp_info: dict):
252
+ """
253
+ transform the implicit keypoints with the pose, shift, and expression deformation
254
+ kp: BxNx3
255
+ """
256
+ kp = kp_info['kp'] # (bs, k, 3)
257
+ pitch, yaw, roll = kp_info['pitch'], kp_info['yaw'], kp_info['roll']
258
+
259
+ t, exp = kp_info['t'], kp_info['exp']
260
+ scale = kp_info['scale']
261
+
262
+ pitch = bin66_to_degree(pitch)
263
+ yaw = bin66_to_degree(yaw)
264
+ roll = bin66_to_degree(roll)
265
+
266
+ bs = kp.shape[0]
267
+ if kp.ndim == 2:
268
+ num_kp = kp.shape[1] // 3 # Bx(num_kpx3)
269
+ else:
270
+ num_kp = kp.shape[1] # Bxnum_kpx3
271
+
272
+ rot_mat = get_rotation_matrix(pitch, yaw, roll) # (bs, 3, 3)
273
+
274
+ # Eqn.2: s * (R * x_c,s + exp) + t
275
+ kp_transformed = np.matmul(kp.reshape(bs, num_kp, 3), rot_mat) + exp.reshape(bs, num_kp, 3)
276
+ kp_transformed *= scale[..., None] # (bs, k, 3) * (bs, 1, 1) = (bs, k, 3)
277
+ kp_transformed[:, :, 0:2] += t[:, None, 0:2] # remove z, only apply tx ty
278
+
279
+ return kp_transformed
280
+
281
+
282
+ class MotionStitch:
283
+ def __init__(
284
+ self,
285
+ stitch_network_cfg,
286
+ ):
287
+ self.stitch_net = StitchNetwork(**stitch_network_cfg)
288
+
289
+ def set_Nd(self, N_d=-1):
290
+ # only for offline (make start|end eye open)
291
+ if N_d == self.N_d:
292
+ return
293
+
294
+ self.N_d = N_d
295
+ if self.drive_eye and self.delta_eye_arr is not None:
296
+ N = 3000 if self.N_d == -1 else self.N_d
297
+ self.delta_eye_idx_list = _set_eye_blink_idx(
298
+ N, len(self.delta_eye_arr), self.delta_eye_open_n
299
+ )
300
+
301
+ def setup(
302
+ self,
303
+ N_d=-1,
304
+ use_d_keys=None,
305
+ relative_d=True,
306
+ drive_eye=None, # use d eye or s eye
307
+ delta_eye_arr=None, # fix eye
308
+ delta_eye_open_n=-1, # int|list
309
+ fade_out_keys=("exp",),
310
+ fade_type="", # "" | "d0" | "s"
311
+ flag_stitching=True,
312
+ is_image_flag=True,
313
+ x_s_info=None,
314
+ d0=None,
315
+ ch_info=None,
316
+ overall_ctrl_info=None,
317
+ ):
318
+ self.is_image_flag = is_image_flag
319
+ if use_d_keys is None:
320
+ if self.is_image_flag:
321
+ self.use_d_keys = ("exp", "pitch", "yaw", "roll", "t")
322
+ else:
323
+ self.use_d_keys = ("exp", )
324
+ else:
325
+ self.use_d_keys = use_d_keys
326
+
327
+ if drive_eye is None:
328
+ if self.is_image_flag:
329
+ self.drive_eye = True
330
+ else:
331
+ self.drive_eye = False
332
+ else:
333
+ self.drive_eye = drive_eye
334
+
335
+ self.N_d = N_d
336
+ self.relative_d = relative_d
337
+ self.delta_eye_arr = delta_eye_arr
338
+ self.delta_eye_open_n = delta_eye_open_n
339
+ self.fade_out_keys = fade_out_keys
340
+ self.fade_type = fade_type
341
+ self.flag_stitching = flag_stitching
342
+
343
+ _eye = [11, 13, 15, 16, 18]
344
+ _lip = [6, 12, 14, 17, 19, 20]
345
+ _a1 = np.zeros((21, 3), dtype=np.float32)
346
+ _a1[_lip] = 1
347
+ _a2 = 0
348
+ if self.drive_eye:
349
+ if self.delta_eye_arr is None:
350
+ _a1[_eye] = 1
351
+ else:
352
+ _a2 = np.zeros((21, 3), dtype=np.float32)
353
+ _a2[_eye] = 1
354
+ _a2 = _a2.reshape(1, -1)
355
+ _a1 = _a1.reshape(1, -1)
356
+
357
+ self.fix_exp_a1 = _a1 * (1 - _a2)
358
+ self.fix_exp_a2 = (1 - _a1) + _a1 * _a2
359
+ self.fix_exp_a3 = _a2
360
+
361
+ if self.drive_eye and self.delta_eye_arr is not None:
362
+ N = 3000 if self.N_d == -1 else self.N_d
363
+ self.delta_eye_idx_list = _set_eye_blink_idx(
364
+ N, len(self.delta_eye_arr), self.delta_eye_open_n
365
+ )
366
+
367
+ self.pose_s = None
368
+ self.x_s = None
369
+ self.fade_dst = None
370
+ if self.is_image_flag and x_s_info is not None:
371
+ yaw_s = bin66_to_degree(x_s_info['yaw']).item()
372
+ pitch_s = bin66_to_degree(x_s_info['pitch']).item()
373
+ self.pose_s = [yaw_s, pitch_s]
374
+ self.x_s = transform_keypoint(x_s_info)
375
+
376
+ if self.fade_type == "s":
377
+ self.fade_dst = copy.deepcopy(x_s_info)
378
+
379
+ if ch_info is not None:
380
+ self.scale_a = ch_info['x_s_info_lst'][0]['scale'].item()
381
+ if x_s_info is not None:
382
+ self.scale_b = x_s_info['scale'].item()
383
+ self.scale_ratio = self.scale_a / self.scale_b
384
+ self._set_scale_ratio(self.scale_ratio)
385
+ else:
386
+ self.scale_ratio = None
387
+ else:
388
+ self.scale_ratio = 1
389
+
390
+ self.overall_ctrl_info = overall_ctrl_info
391
+
392
+ self.d0 = d0
393
+ self.idx = 0
394
+
395
+ def _set_scale_ratio(self, scale_ratio=1):
396
+ if scale_ratio == 1:
397
+ return
398
+ if isinstance(self.use_d_keys, dict):
399
+ self.use_d_keys = {k: v * (scale_ratio if k in {"exp", "pitch", "yaw", "roll"} else 1) for k, v in self.use_d_keys.items()}
400
+ else:
401
+ self.use_d_keys = {k: scale_ratio if k in {"exp", "pitch", "yaw", "roll"} else 1 for k in self.use_d_keys}
402
+
403
+ @staticmethod
404
+ def _merge_kwargs(default_kwargs, run_kwargs):
405
+ if default_kwargs is None:
406
+ return run_kwargs
407
+
408
+ for k, v in default_kwargs.items():
409
+ if k not in run_kwargs:
410
+ run_kwargs[k] = v
411
+ return run_kwargs
412
+
413
+ def __call__(self, x_s_info, x_d_info, **kwargs):
414
+ # return x_s, x_d
415
+
416
+ kwargs = self._merge_kwargs(self.overall_ctrl_info, kwargs)
417
+
418
+ if self.scale_ratio is None:
419
+ self.scale_b = x_s_info['scale'].item()
420
+ self.scale_ratio = self.scale_a / self.scale_b
421
+ self._set_scale_ratio(self.scale_ratio)
422
+
423
+ if self.relative_d and self.d0 is None:
424
+ self.d0 = copy.deepcopy(x_d_info)
425
+
426
+ x_d_info = _mix_s_d_info(
427
+ x_s_info,
428
+ x_d_info,
429
+ self.use_d_keys,
430
+ self.d0,
431
+ )
432
+
433
+ delta_eye = 0
434
+ if self.drive_eye and self.delta_eye_arr is not None:
435
+ delta_eye = self.delta_eye_arr[
436
+ self.delta_eye_idx_list[self.idx % len(self.delta_eye_idx_list)]
437
+ ][None]
438
+ x_d_info = _fix_exp_for_x_d_info_v2(
439
+ x_d_info,
440
+ x_s_info,
441
+ delta_eye,
442
+ self.fix_exp_a1,
443
+ self.fix_exp_a2,
444
+ self.fix_exp_a3,
445
+ )
446
+
447
+ if kwargs.get("vad_alpha", 1) < 1:
448
+ x_d_info = ctrl_vad(x_d_info, x_s_info, kwargs.get("vad_alpha", 1))
449
+
450
+ x_d_info = ctrl_motion(x_d_info, **kwargs)
451
+
452
+ if self.fade_type == "d0" and self.fade_dst is None:
453
+ self.fade_dst = copy.deepcopy(x_d_info)
454
+
455
+ # fade
456
+ if "fade_alpha" in kwargs and self.fade_type in ["d0", "s"]:
457
+ fade_alpha = kwargs["fade_alpha"]
458
+ fade_keys = kwargs.get("fade_out_keys", self.fade_out_keys)
459
+ if self.fade_type == "d0":
460
+ fade_dst = self.fade_dst
461
+ elif self.fade_type == "s":
462
+ if self.fade_dst is not None:
463
+ fade_dst = self.fade_dst
464
+ else:
465
+ fade_dst = copy.deepcopy(x_s_info)
466
+ if self.is_image_flag:
467
+ self.fade_dst = fade_dst
468
+ x_d_info = fade(x_d_info, fade_dst, fade_alpha, fade_keys)
469
+
470
+ if self.drive_eye:
471
+ if self.pose_s is None:
472
+ yaw_s = bin66_to_degree(x_s_info['yaw']).item()
473
+ pitch_s = bin66_to_degree(x_s_info['pitch']).item()
474
+ self.pose_s = [yaw_s, pitch_s]
475
+ x_d_info = _fix_gaze(self.pose_s, x_d_info)
476
+
477
+ if self.x_s is not None:
478
+ x_s = self.x_s
479
+ else:
480
+ x_s = transform_keypoint(x_s_info)
481
+ if self.is_image_flag:
482
+ self.x_s = x_s
483
+
484
+ x_d = transform_keypoint(x_d_info)
485
+
486
+ if self.flag_stitching:
487
+ x_d = self.stitch_net(x_s, x_d)
488
+
489
+ self.idx += 1
490
+
491
+ return x_s, x_d
core/atomic_components/putback.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from ..utils.blend import blend_images_cy
4
+ from ..utils.get_mask import get_mask
5
+
6
+
7
+ class PutBackNumpy:
8
+ def __init__(
9
+ self,
10
+ mask_template_path=None,
11
+ ):
12
+ if mask_template_path is None:
13
+ mask = get_mask(512, 512, 0.9, 0.9)
14
+ self.mask_ori_float = np.concatenate([mask] * 3, 2)
15
+ else:
16
+ mask = cv2.imread(mask_template_path, cv2.IMREAD_COLOR)
17
+ self.mask_ori_float = mask.astype(np.float32) / 255.0
18
+
19
+ def __call__(self, frame_rgb, render_image, M_c2o):
20
+ h, w = frame_rgb.shape[:2]
21
+ mask_warped = cv2.warpAffine(
22
+ self.mask_ori_float, M_c2o[:2, :], dsize=(w, h), flags=cv2.INTER_LINEAR
23
+ ).clip(0, 1)
24
+ frame_warped = cv2.warpAffine(
25
+ render_image, M_c2o[:2, :], dsize=(w, h), flags=cv2.INTER_LINEAR
26
+ )
27
+ result = mask_warped * frame_warped + (1 - mask_warped) * frame_rgb
28
+ result = np.clip(result, 0, 255)
29
+ result = result.astype(np.uint8)
30
+ return result
31
+
32
+
33
+ class PutBack:
34
+ def __init__(
35
+ self,
36
+ mask_template_path=None,
37
+ ):
38
+ if mask_template_path is None:
39
+ mask = get_mask(512, 512, 0.9, 0.9)
40
+ mask = np.concatenate([mask] * 3, 2)
41
+ else:
42
+ mask = cv2.imread(mask_template_path, cv2.IMREAD_COLOR).astype(np.float32) / 255.0
43
+
44
+ self.mask_ori_float = np.ascontiguousarray(mask)[:,:,0]
45
+ self.result_buffer = None
46
+
47
+ def __call__(self, frame_rgb, render_image, M_c2o):
48
+ h, w = frame_rgb.shape[:2]
49
+ mask_warped = cv2.warpAffine(
50
+ self.mask_ori_float, M_c2o[:2, :], dsize=(w, h), flags=cv2.INTER_LINEAR
51
+ ).clip(0, 1)
52
+ frame_warped = cv2.warpAffine(
53
+ render_image, M_c2o[:2, :], dsize=(w, h), flags=cv2.INTER_LINEAR
54
+ )
55
+ self.result_buffer = np.empty((h, w, 3), dtype=np.uint8)
56
+
57
+ # Use Cython implementation for blending
58
+ blend_images_cy(mask_warped, frame_warped, frame_rgb, self.result_buffer)
59
+
60
+ return self.result_buffer
core/atomic_components/source2info.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+
4
+ from ..aux_models.insightface_det import InsightFaceDet
5
+ from ..aux_models.insightface_landmark106 import Landmark106
6
+ from ..aux_models.landmark203 import Landmark203
7
+ from ..aux_models.mediapipe_landmark478 import Landmark478
8
+ from ..models.appearance_extractor import AppearanceExtractor
9
+ from ..models.motion_extractor import MotionExtractor
10
+
11
+ from ..utils.crop import crop_image
12
+ from ..utils.eye_info import EyeAttrUtilsByMP
13
+
14
+
15
+ """
16
+ insightface_det_cfg = {
17
+ "model_path": "",
18
+ "device": "cuda",
19
+ "force_ori_type": False,
20
+ }
21
+ landmark106_cfg = {
22
+ "model_path": "",
23
+ "device": "cuda",
24
+ "force_ori_type": False,
25
+ }
26
+ landmark203_cfg = {
27
+ "model_path": "",
28
+ "device": "cuda",
29
+ "force_ori_type": False,
30
+ }
31
+ landmark478_cfg = {
32
+ "blaze_face_model_path": "",
33
+ "face_mesh_model_path": "",
34
+ "device": "cuda",
35
+ "force_ori_type": False,
36
+ "task_path": "",
37
+ }
38
+ appearance_extractor_cfg = {
39
+ "model_path": "",
40
+ "device": "cuda",
41
+ }
42
+ motion_extractor_cfg = {
43
+ "model_path": "",
44
+ "device": "cuda",
45
+ }
46
+ """
47
+
48
+
49
+ class Source2Info:
50
+ def __init__(
51
+ self,
52
+ insightface_det_cfg,
53
+ landmark106_cfg,
54
+ landmark203_cfg,
55
+ landmark478_cfg,
56
+ appearance_extractor_cfg,
57
+ motion_extractor_cfg,
58
+ ):
59
+ self.insightface_det = InsightFaceDet(**insightface_det_cfg)
60
+ self.landmark106 = Landmark106(**landmark106_cfg)
61
+ self.landmark203 = Landmark203(**landmark203_cfg)
62
+ self.landmark478 = Landmark478(**landmark478_cfg)
63
+
64
+ self.appearance_extractor = AppearanceExtractor(**appearance_extractor_cfg)
65
+ self.motion_extractor = MotionExtractor(**motion_extractor_cfg)
66
+
67
+ def _crop(self, img, last_lmk=None, **kwargs):
68
+ # img_rgb -> det->landmark106->landmark203->crop
69
+
70
+ if last_lmk is None: # det for first frame or image
71
+ det, _ = self.insightface_det(img)
72
+ boxes = det[np.argsort(-(det[:, 2] - det[:, 0]) * (det[:, 3] - det[:, 1]))]
73
+ if len(boxes) == 0:
74
+ return None
75
+ lmk_for_track = self.landmark106(img, boxes[0]) # 106
76
+ else: # track for video frames
77
+ lmk_for_track = last_lmk # 203
78
+
79
+ crop_dct = crop_image(
80
+ img,
81
+ lmk_for_track,
82
+ dsize=self.landmark203.dsize,
83
+ scale=1.5,
84
+ vy_ratio=-0.1,
85
+ pt_crop_flag=False,
86
+ )
87
+ lmk203 = self.landmark203(crop_dct["img_crop"], crop_dct["M_c2o"])
88
+
89
+ ret_dct = crop_image(
90
+ img,
91
+ lmk203,
92
+ dsize=512,
93
+ scale=kwargs.get("crop_scale", 2.3),
94
+ vx_ratio=kwargs.get("crop_vx_ratio", 0),
95
+ vy_ratio=kwargs.get("crop_vy_ratio", -0.125),
96
+ flag_do_rot=kwargs.get("crop_flag_do_rot", True),
97
+ pt_crop_flag=False,
98
+ )
99
+
100
+ img_crop = ret_dct["img_crop"]
101
+ M_c2o = ret_dct["M_c2o"]
102
+
103
+ return img_crop, M_c2o, lmk203
104
+
105
+ @staticmethod
106
+ def _img_crop_to_bchw256(img_crop):
107
+ rgb_256 = cv2.resize(img_crop, (256, 256), interpolation=cv2.INTER_AREA)
108
+ rgb_256_bchw = (rgb_256.astype(np.float32) / 255.0)[None].transpose(0, 3, 1, 2)
109
+ return rgb_256_bchw
110
+
111
+ def _get_kp_info(self, img):
112
+ # rgb_256_bchw_norm01
113
+ kp_info = self.motion_extractor(img)
114
+ return kp_info
115
+
116
+ def _get_f3d(self, img):
117
+ # rgb_256_bchw_norm01
118
+ fs = self.appearance_extractor(img)
119
+ return fs
120
+
121
+ def _get_eye_info(self, img):
122
+ # rgb uint8
123
+ lmk478 = self.landmark478(img) # [1, 478, 3]
124
+ attr = EyeAttrUtilsByMP(lmk478)
125
+ lr_open = attr.LR_open().reshape(-1, 2) # [1, 2]
126
+ lr_ball = attr.LR_ball_move().reshape(-1, 6) # [1, 3, 2] -> [1, 6]
127
+ return [lr_open, lr_ball]
128
+
129
+ def __call__(self, img, last_lmk=None, **kwargs):
130
+ """
131
+ img: rgb, uint8
132
+ last_lmk: last frame lmk203, for video tracking
133
+ kwargs: optional crop cfg
134
+ crop_scale: 2.3
135
+ crop_vx_ratio: 0
136
+ crop_vy_ratio: -0.125
137
+ crop_flag_do_rot: True
138
+ """
139
+ img_crop, M_c2o, lmk203 = self._crop(img, last_lmk=last_lmk, **kwargs)
140
+
141
+ eye_open, eye_ball = self._get_eye_info(img_crop)
142
+
143
+ rgb_256_bchw = self._img_crop_to_bchw256(img_crop)
144
+ kp_info = self._get_kp_info(rgb_256_bchw)
145
+ fs = self._get_f3d(rgb_256_bchw)
146
+
147
+ source_info = {
148
+ "x_s_info": kp_info,
149
+ "f_s": fs,
150
+ "M_c2o": M_c2o,
151
+ "eye_open": eye_open, # [1, 2]
152
+ "eye_ball": eye_ball, # [1, 6]
153
+ "lmk203": lmk203, # for track
154
+ }
155
+ return source_info
core/atomic_components/warp_f3d.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..models.warp_network import WarpNetwork
2
+
3
+
4
+ """
5
+ # __init__
6
+ warp_network_cfg = {
7
+ "model_path": "",
8
+ "device": "cuda",
9
+ }
10
+ """
11
+
12
+ class WarpF3D:
13
+ def __init__(
14
+ self,
15
+ warp_network_cfg,
16
+ ):
17
+ self.warp_net = WarpNetwork(**warp_network_cfg)
18
+
19
+ def __call__(self, f_s, x_s, x_d):
20
+ out = self.warp_net(f_s, x_s, x_d)
21
+ return out
22
+
core/atomic_components/wav2feat.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import numpy as np
3
+ import math
4
+
5
+ from ..aux_models.hubert_stream import HubertStreaming
6
+
7
+ """
8
+ wavlm_cfg = {
9
+ "model_path": "",
10
+ "device": "cuda",
11
+ "force_ori_type": False,
12
+ }
13
+ hubert_cfg = {
14
+ "model_path": "",
15
+ "device": "cuda",
16
+ "force_ori_type": False,
17
+ }
18
+ """
19
+
20
+
21
+ class Wav2Feat:
22
+ def __init__(self, w2f_cfg, w2f_type="hubert"):
23
+ self.w2f_type = w2f_type.lower()
24
+ if self.w2f_type == "hubert":
25
+ self.w2f = Wav2FeatHubert(hubert_cfg=w2f_cfg)
26
+ self.feat_dim = 1024
27
+ self.support_streaming = True
28
+ else:
29
+ raise ValueError(f"Unsupported w2f_type: {w2f_type}")
30
+
31
+ def __call__(
32
+ self,
33
+ audio,
34
+ sr=16000,
35
+ norm_mean_std=None, # for s2g
36
+ chunksize=(3, 5, 2), # for hubert
37
+ ):
38
+ if self.w2f_type == "hubert":
39
+ feat = self.w2f(audio, chunksize=chunksize)
40
+ elif self.w2f_type == "s2g":
41
+ feat = self.w2f(audio, sr=sr, norm_mean_std=norm_mean_std)
42
+ else:
43
+ raise ValueError(f"Unsupported w2f_type: {self.w2f_type}")
44
+ return feat
45
+
46
+ def wav2feat(
47
+ self,
48
+ audio,
49
+ sr=16000,
50
+ norm_mean_std=None, # for s2g
51
+ chunksize=(3, 5, 2),
52
+ ):
53
+ # for offline
54
+ if self.w2f_type == "hubert":
55
+ feat = self.w2f.wav2feat(audio, sr=sr, chunksize=chunksize)
56
+ elif self.w2f_type == "s2g":
57
+ feat = self.w2f(audio, sr=sr, norm_mean_std=norm_mean_std)
58
+ else:
59
+ raise ValueError(f"Unsupported w2f_type: {self.w2f_type}")
60
+ return feat
61
+
62
+
63
+ class Wav2FeatHubert:
64
+ def __init__(
65
+ self,
66
+ hubert_cfg,
67
+ ):
68
+ self.hubert = HubertStreaming(**hubert_cfg)
69
+
70
+ def __call__(self, audio_chunk, chunksize=(3, 5, 2)):
71
+ """
72
+ audio_chunk: int(sum(chunksize) * 0.04 * 16000) + 80 # 6480
73
+ """
74
+ valid_feat_s = - sum(chunksize[1:]) * 2 # -7
75
+ valid_feat_e = - chunksize[2] * 2 # -2
76
+
77
+ encoding_chunk = self.hubert(audio_chunk)
78
+ valid_encoding = encoding_chunk[valid_feat_s:valid_feat_e]
79
+ valid_feat = valid_encoding.reshape(chunksize[1], 2, 1024).mean(1) # [5, 1024]
80
+ return valid_feat
81
+
82
+ def wav2feat(self, audio, sr, chunksize=(3, 5, 2)):
83
+ # for offline
84
+ if sr != 16000:
85
+ audio_16k = librosa.resample(audio, orig_sr=sr, target_sr=16000)
86
+ else:
87
+ audio_16k = audio
88
+
89
+ num_f = math.ceil(len(audio_16k) / 16000 * 25)
90
+ split_len = int(sum(chunksize) * 0.04 * 16000) + 80 # 6480
91
+
92
+ speech_pad = np.concatenate([
93
+ np.zeros((split_len - int(sum(chunksize[1:]) * 0.04 * 16000),), dtype=audio_16k.dtype),
94
+ audio_16k,
95
+ np.zeros((split_len,), dtype=audio_16k.dtype),
96
+ ], 0)
97
+
98
+ i = 0
99
+ res_lst = []
100
+ while i < num_f:
101
+ sss = int(i * 0.04 * 16000)
102
+ eee = sss + split_len
103
+ audio_chunk = speech_pad[sss:eee]
104
+ valid_feat = self.__call__(audio_chunk, chunksize)
105
+ res_lst.append(valid_feat)
106
+ i += chunksize[1]
107
+
108
+ ret = np.concatenate(res_lst, 0)
109
+ ret = ret[:num_f]
110
+ return ret
core/atomic_components/writer.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import imageio
2
+ import os
3
+
4
+
5
+ class VideoWriterByImageIO:
6
+ def __init__(self, video_path, fps=25, **kwargs):
7
+ video_format = kwargs.get("format", "mp4") # default is mp4 format
8
+ codec = kwargs.get("vcodec", "libx264") # default is libx264 encoding
9
+ quality = kwargs.get("quality") # video quality
10
+ pixelformat = kwargs.get("pixelformat", "yuv420p") # video pixel format
11
+ macro_block_size = kwargs.get("macro_block_size", 2)
12
+ ffmpeg_params = ["-crf", str(kwargs.get("crf", 18))]
13
+
14
+ os.makedirs(os.path.dirname(video_path), exist_ok=True)
15
+
16
+ writer = imageio.get_writer(
17
+ video_path,
18
+ fps=fps,
19
+ format=video_format,
20
+ codec=codec,
21
+ quality=quality,
22
+ ffmpeg_params=ffmpeg_params,
23
+ pixelformat=pixelformat,
24
+ macro_block_size=macro_block_size,
25
+ )
26
+ self.writer = writer
27
+
28
+ def __call__(self, img, fmt="bgr"):
29
+ if fmt == "bgr":
30
+ frame = img[..., ::-1]
31
+ else:
32
+ frame = img
33
+ self.writer.append_data(frame)
34
+
35
+ def close(self):
36
+ self.writer.close()
core/aux_models/blaze_face.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ from ..utils.load_model import load_model
4
+
5
+
6
+ def intersect(box_a, box_b):
7
+ """We resize both arrays to [A,B,2] without new malloc:
8
+ [A,2] -> [A,1,2] -> [A,B,2]
9
+ [B,2] -> [1,B,2] -> [A,B,2]
10
+ Then we compute the area of intersect between box_a and box_b.
11
+ Args:
12
+ box_a: (array) bounding boxes, Shape: [A,4].
13
+ box_b: (array) bounding boxes, Shape: [B,4].
14
+ Return:
15
+ (array) intersection area, Shape: [A,B].
16
+ """
17
+ A = box_a.shape[0]
18
+ B = box_b.shape[0]
19
+ max_xy = np.minimum(
20
+ np.expand_dims(box_a[:, 2:], axis=1).repeat(B, axis=1),
21
+ np.expand_dims(box_b[:, 2:], axis=0).repeat(A, axis=0),
22
+ )
23
+ min_xy = np.maximum(
24
+ np.expand_dims(box_a[:, :2], axis=1).repeat(B, axis=1),
25
+ np.expand_dims(box_b[:, :2], axis=0).repeat(A, axis=0),
26
+ )
27
+ inter = np.clip((max_xy - min_xy), a_min=0, a_max=None)
28
+ return inter[:, :, 0] * inter[:, :, 1]
29
+
30
+
31
+ def jaccard(box_a, box_b):
32
+ """Compute the jaccard overlap of two sets of boxes. The jaccard overlap
33
+ is simply the intersection over union of two boxes. Here we operate on
34
+ ground truth boxes and default boxes.
35
+ E.g.:
36
+ A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B)
37
+ Args:
38
+ box_a: (array) Ground truth bounding boxes, Shape: [num_objects,4]
39
+ box_b: (array) Prior boxes from priorbox layers, Shape: [num_priors,4]
40
+ Return:
41
+ jaccard overlap: (array) Shape: [box_a.size(0), box_b.size(0)]
42
+ """
43
+ inter = intersect(box_a, box_b)
44
+ area_a = (
45
+ ((box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1]))
46
+ .reshape(-1, 1)
47
+ .repeat(box_b.shape[0], axis=1)
48
+ ) # [A,B]
49
+ area_b = (
50
+ ((box_b[:, 2] - box_b[:, 0]) * (box_b[:, 3] - box_b[:, 1]))
51
+ .reshape(1, -1)
52
+ .repeat(box_a.shape[0], axis=0)
53
+ ) # [A,B]
54
+ union = area_a + area_b - inter
55
+ return inter / union # [A,B]
56
+
57
+
58
+ def overlap_similarity(box, other_boxes):
59
+ """Computes the IOU between a bounding box and set of other boxes."""
60
+ box = np.expand_dims(box, axis=0) # Equivalent to unsqueeze(0) in PyTorch
61
+ iou = jaccard(box, other_boxes)
62
+ return np.squeeze(iou, axis=0) # Equivalent to squeeze(0) in PyTorch
63
+
64
+
65
+ class BlazeFace:
66
+ def __init__(self, model_path, device="cuda"):
67
+ self.anchor_options = {
68
+ "num_layers": 4,
69
+ "min_scale": 0.1484375,
70
+ "max_scale": 0.75,
71
+ "input_size_height": 128,
72
+ "input_size_width": 128,
73
+ "anchor_offset_x": 0.5,
74
+ "anchor_offset_y": 0.5,
75
+ "strides": [8, 16, 16, 16],
76
+ "aspect_ratios": [1.0],
77
+ "reduce_boxes_in_lowest_layer": False,
78
+ "interpolated_scale_aspect_ratio": 1.0,
79
+ "fixed_anchor_size": True,
80
+ }
81
+ self.num_classes = 1
82
+ self.num_anchors = 896
83
+ self.num_coords = 16
84
+ self.x_scale = 128.0
85
+ self.y_scale = 128.0
86
+ self.h_scale = 128.0
87
+ self.w_scale = 128.0
88
+ self.min_score_thresh = 0.5
89
+ self.min_suppression_threshold = 0.3
90
+ self.anchors = self.generate_anchors(self.anchor_options)
91
+ self.anchors = np.array(self.anchors)
92
+ assert len(self.anchors) == 896
93
+ self.model, self.model_type = load_model(model_path, device=device)
94
+ self.output_names = ["regressors", "classificators"]
95
+
96
+ def __call__(self, image: np.ndarray):
97
+ """
98
+ image: RGB image
99
+ """
100
+ image = cv2.resize(image, (128, 128))
101
+ image = image[np.newaxis, :, :, :].astype(np.float32)
102
+ image = image / 127.5 - 1.0
103
+ outputs = {}
104
+ if self.model_type == "onnx":
105
+ out_list = self.model.run(None, {"input": image})
106
+ for i, name in enumerate(self.output_names):
107
+ outputs[name] = out_list[i]
108
+ elif self.model_type == "tensorrt":
109
+ self.model.setup({"input": image})
110
+ self.model.infer()
111
+ for name in self.output_names:
112
+ outputs[name] = self.model.buffer[name][0]
113
+ else:
114
+ raise ValueError(f"Unsupported model type: {self.model_type}")
115
+ boxes = self.postprocess(outputs["regressors"], outputs["classificators"])
116
+ return boxes
117
+
118
+ def calculate_scale(self, min_scale, max_scale, stride_index, num_strides):
119
+ return min_scale + (max_scale - min_scale) * stride_index / (num_strides - 1.0)
120
+
121
+ def generate_anchors(self, options):
122
+ strides_size = len(options["strides"])
123
+ assert options["num_layers"] == strides_size
124
+
125
+ anchors = []
126
+ layer_id = 0
127
+ while layer_id < strides_size:
128
+ anchor_height = []
129
+ anchor_width = []
130
+ aspect_ratios = []
131
+ scales = []
132
+
133
+ # For same strides, we merge the anchors in the same order.
134
+ last_same_stride_layer = layer_id
135
+ while (last_same_stride_layer < strides_size) and (
136
+ options["strides"][last_same_stride_layer]
137
+ == options["strides"][layer_id]
138
+ ):
139
+ scale = self.calculate_scale(
140
+ options["min_scale"],
141
+ options["max_scale"],
142
+ last_same_stride_layer,
143
+ strides_size,
144
+ )
145
+
146
+ if (
147
+ last_same_stride_layer == 0
148
+ and options["reduce_boxes_in_lowest_layer"]
149
+ ):
150
+ # For first layer, it can be specified to use predefined anchors.
151
+ aspect_ratios.append(1.0)
152
+ aspect_ratios.append(2.0)
153
+ aspect_ratios.append(0.5)
154
+ scales.append(0.1)
155
+ scales.append(scale)
156
+ scales.append(scale)
157
+ else:
158
+ for aspect_ratio in options["aspect_ratios"]:
159
+ aspect_ratios.append(aspect_ratio)
160
+ scales.append(scale)
161
+
162
+ if options["interpolated_scale_aspect_ratio"] > 0.0:
163
+ scale_next = (
164
+ 1.0
165
+ if last_same_stride_layer == strides_size - 1
166
+ else self.calculate_scale(
167
+ options["min_scale"],
168
+ options["max_scale"],
169
+ last_same_stride_layer + 1,
170
+ strides_size,
171
+ )
172
+ )
173
+ scales.append(np.sqrt(scale * scale_next))
174
+ aspect_ratios.append(options["interpolated_scale_aspect_ratio"])
175
+
176
+ last_same_stride_layer += 1
177
+
178
+ for i in range(len(aspect_ratios)):
179
+ ratio_sqrts = np.sqrt(aspect_ratios[i])
180
+ anchor_height.append(scales[i] / ratio_sqrts)
181
+ anchor_width.append(scales[i] * ratio_sqrts)
182
+
183
+ stride = options["strides"][layer_id]
184
+ feature_map_height = int(np.ceil(options["input_size_height"] / stride))
185
+ feature_map_width = int(np.ceil(options["input_size_width"] / stride))
186
+
187
+ for y in range(feature_map_height):
188
+ for x in range(feature_map_width):
189
+ for anchor_id in range(len(anchor_height)):
190
+ x_center = (x + options["anchor_offset_x"]) / feature_map_width
191
+ y_center = (y + options["anchor_offset_y"]) / feature_map_height
192
+
193
+ new_anchor = [x_center, y_center, 0, 0]
194
+ if options["fixed_anchor_size"]:
195
+ new_anchor[2] = 1.0
196
+ new_anchor[3] = 1.0
197
+ else:
198
+ new_anchor[2] = anchor_width[anchor_id]
199
+ new_anchor[3] = anchor_height[anchor_id]
200
+ anchors.append(new_anchor)
201
+
202
+ layer_id = last_same_stride_layer
203
+
204
+ return anchors
205
+
206
+ def _tensors_to_detections(self, raw_box_tensor, raw_score_tensor, anchors):
207
+ """The output of the neural network is a tensor of shape (b, 896, 16)
208
+ containing the bounding box regressor predictions, as well as a tensor
209
+ of shape (b, 896, 1) with the classification confidences.
210
+
211
+ This function converts these two "raw" tensors into proper detections.
212
+ Returns a list of (num_detections, 17) tensors, one for each image in
213
+ the batch.
214
+
215
+ This is based on the source code from:
216
+ mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.cc
217
+ mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.proto
218
+ """
219
+ assert raw_box_tensor.ndim == 3
220
+ assert raw_box_tensor.shape[1] == self.num_anchors
221
+ assert raw_box_tensor.shape[2] == self.num_coords
222
+
223
+ assert raw_score_tensor.ndim == 3
224
+ assert raw_score_tensor.shape[1] == self.num_anchors
225
+ assert raw_score_tensor.shape[2] == self.num_classes
226
+
227
+ assert raw_box_tensor.shape[0] == raw_score_tensor.shape[0]
228
+
229
+ detection_boxes = self._decode_boxes(raw_box_tensor, anchors)
230
+
231
+ raw_score_tensor = np.clip(raw_score_tensor, -50, 100)
232
+ detection_scores = 1 / (1 + np.exp(-raw_score_tensor))
233
+ mask = detection_scores >= self.min_score_thresh
234
+ mask = mask[0, :, 0]
235
+ boxes = detection_boxes[0, mask, :]
236
+ scores = detection_scores[0, mask, :]
237
+ return np.concatenate((boxes, scores), axis=-1)
238
+
239
+ def _decode_boxes(self, raw_boxes, anchors):
240
+ """Converts the predictions into actual coordinates using
241
+ the anchor boxes. Processes the entire batch at once.
242
+ """
243
+ boxes = np.zeros_like(raw_boxes)
244
+
245
+ x_center = raw_boxes[..., 0] / self.x_scale * anchors[:, 2] + anchors[:, 0]
246
+ y_center = raw_boxes[..., 1] / self.y_scale * anchors[:, 3] + anchors[:, 1]
247
+
248
+ w = raw_boxes[..., 2] / self.w_scale * anchors[:, 2]
249
+ h = raw_boxes[..., 3] / self.h_scale * anchors[:, 3]
250
+
251
+ boxes[..., 0] = self.x_scale * (x_center - w / 2.0) # xmin
252
+ boxes[..., 1] = self.y_scale * (y_center - h / 2.0) # ymin
253
+ boxes[..., 2] = self.w_scale * (x_center + w / 2.0) # xmax
254
+ boxes[..., 3] = self.h_scale * (y_center + h / 2.0) # ymax
255
+
256
+ for k in range(6):
257
+ offset = 4 + k * 2
258
+ keypoint_x = (
259
+ raw_boxes[..., offset] / self.x_scale * anchors[:, 2] + anchors[:, 0]
260
+ )
261
+ keypoint_y = (
262
+ raw_boxes[..., offset + 1] / self.y_scale * anchors[:, 3]
263
+ + anchors[:, 1]
264
+ )
265
+ boxes[..., offset] = keypoint_x
266
+ boxes[..., offset + 1] = keypoint_y
267
+
268
+ return boxes
269
+
270
+ def _weighted_non_max_suppression(self, detections):
271
+ """The alternative NMS method as mentioned in the BlazeFace paper:
272
+
273
+ "We replace the suppression algorithm with a blending strategy that
274
+ estimates the regression parameters of a bounding box as a weighted
275
+ mean between the overlapping predictions."
276
+
277
+ The original MediaPipe code assigns the score of the most confident
278
+ detection to the weighted detection, but we take the average score
279
+ of the overlapping detections.
280
+
281
+ The input detections should be a NumPy array of shape (count, 17).
282
+
283
+ Returns a list of NumPy arrays, one for each detected face.
284
+
285
+ This is based on the source code from:
286
+ mediapipe/calculators/util/non_max_suppression_calculator.cc
287
+ mediapipe/calculators/util/non_max_suppression_calculator.proto
288
+ """
289
+ if len(detections) == 0:
290
+ return []
291
+
292
+ output_detections = []
293
+
294
+ # Sort the detections from highest to lowest score.
295
+ remaining = np.argsort(detections[:, 16])[::-1]
296
+
297
+ while len(remaining) > 0:
298
+ detection = detections[remaining[0]]
299
+
300
+ # Compute the overlap between the first box and the other
301
+ # remaining boxes. (Note that the other_boxes also include
302
+ # the first_box.)
303
+ first_box = detection[:4]
304
+ other_boxes = detections[remaining, :4]
305
+ ious = overlap_similarity(first_box, other_boxes)
306
+
307
+ # If two detections don't overlap enough, they are considered
308
+ # to be from different faces.
309
+ mask = ious > self.min_suppression_threshold
310
+ overlapping = remaining[mask]
311
+ remaining = remaining[~mask]
312
+
313
+ # Take an average of the coordinates from the overlapping
314
+ # detections, weighted by their confidence scores.
315
+ weighted_detection = detection.copy()
316
+ if len(overlapping) > 1:
317
+ coordinates = detections[overlapping, :16]
318
+ scores = detections[overlapping, 16:17]
319
+ total_score = scores.sum()
320
+ weighted = (coordinates * scores).sum(axis=0) / total_score
321
+ weighted_detection[:16] = weighted
322
+ weighted_detection[16] = total_score / len(overlapping)
323
+
324
+ output_detections.append(weighted_detection)
325
+
326
+ return output_detections
327
+
328
+ def postprocess(self, raw_boxes, scores):
329
+ detections = self._tensors_to_detections(raw_boxes, scores, self.anchors)
330
+
331
+ detections = self._weighted_non_max_suppression(detections)
332
+ detections = np.array(detections)
333
+ return detections
334
+
335
+
336
+ if __name__ == "__main__":
337
+ import argparse
338
+
339
+ parser = argparse.ArgumentParser()
340
+ parser.add_argument("--model", type=str, default="")
341
+ parser.add_argument("--image", type=str, default=None)
342
+ args = parser.parse_args()
343
+
344
+ blaze_face = BlazeFace(args.model)
345
+ image = cv2.imread(args.image)
346
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
347
+ image = cv2.resize(image, (128, 128))
348
+ image = image[np.newaxis, :, :, :].astype(np.float32)
349
+ image = image / 127.5 - 1.0
350
+ boxes = blaze_face(image)
351
+ print(boxes)
core/aux_models/face_mesh.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ from ..utils.load_model import load_model
5
+
6
+
7
+ class FaceMesh:
8
+ def __init__(self, model_path, device="cuda"):
9
+ self.model, self.model_type = load_model(model_path, device=device)
10
+ self.input_size = (256, 256) # (w, h)
11
+ self.output_names = [
12
+ "Identity",
13
+ "Identity_1",
14
+ "Identity_2",
15
+ ] # Identity is the mesh
16
+
17
+ def project_landmarks(self, points, roi):
18
+ width, height = self.input_size
19
+ points /= (width, height, width)
20
+ sin, cos = np.sin(roi[4]), np.cos(roi[4])
21
+ matrix = np.array([[cos, sin, 0.0], [-sin, cos, 0.0], [1.0, 1.0, 1.0]])
22
+ points -= (0.5, 0.5, 0.0)
23
+ rotated = np.matmul(points * (1, 1, 0), matrix)
24
+ points *= (0, 0, 1)
25
+ points += rotated
26
+ points *= (roi[2], roi[3], roi[2])
27
+ points += (roi[0], roi[1], 0.0)
28
+ return points
29
+
30
+ def __call__(self, image, roi):
31
+ """
32
+ image: np.ndarray, RGB, (H, W, C), [0, 255]
33
+ roi: np.ndarray, (cx, cy, w, h, rotation), rotation in radian
34
+ """
35
+ cx, cy, w, h = roi[:4]
36
+ w_half, h_half = w / 2, h / 2
37
+ pts = [
38
+ (cx - w_half, cy - h_half),
39
+ (cx + w_half, cy - h_half),
40
+ (cx + w_half, cy + h_half),
41
+ (cx - w_half, cy + h_half),
42
+ ]
43
+ rotation = roi[4]
44
+ s, c = np.sin(rotation), np.cos(rotation)
45
+ t = np.array(pts) - (cx, cy)
46
+ r = np.array([[c, s], [-s, c]])
47
+ src_pts = np.matmul(t, r) + (cx, cy)
48
+ src_pts = src_pts.astype(np.float32)
49
+
50
+ dst_pts = np.array(
51
+ [
52
+ [0.0, 0.0],
53
+ [self.input_size[0], 0.0],
54
+ [self.input_size[0], self.input_size[1]],
55
+ [0.0, self.input_size[1]],
56
+ ]
57
+ ).astype(np.float32)
58
+ M = cv2.getPerspectiveTransform(src_pts, dst_pts)
59
+ roi_image = cv2.warpPerspective(
60
+ image, M, self.input_size, flags=cv2.INTER_LINEAR
61
+ )
62
+ # cv2.imwrite('test.jpg', cv2.cvtColor(roi_image, cv2.COLOR_RGB2BGR))
63
+ roi_image = roi_image / 255.0
64
+ roi_image = roi_image.astype(np.float32)
65
+ roi_image = roi_image[np.newaxis, :, :, :]
66
+
67
+ outputs = {}
68
+ if self.model_type == "onnx":
69
+ out_list = self.model.run(None, {"input": roi_image})
70
+ for i, name in enumerate(self.output_names):
71
+ outputs[name] = out_list[i]
72
+ elif self.model_type == "tensorrt":
73
+ self.model.setup({"input": roi_image})
74
+ self.model.infer()
75
+ for name in self.output_names:
76
+ outputs[name] = self.model.buffer[name][0]
77
+ else:
78
+ raise ValueError(f"Unsupported model type: {self.model_type}")
79
+ points = outputs["Identity"].reshape(1434 // 3, 3)
80
+ points = self.project_landmarks(points, roi)
81
+ return points
82
+
83
+
84
+ if __name__ == "__main__":
85
+ import argparse
86
+
87
+ parser = argparse.ArgumentParser()
88
+ parser.add_argument("--model", type=str, help="model path")
89
+ parser.add_argument("--image", type=str, help="image path")
90
+ parser.add_argument("--device", type=str, default="cuda", help="device")
91
+ args = parser.parse_args()
92
+
93
+ face_mesh = FaceMesh(args.model, args.device)
94
+ image = cv2.imread(args.image, cv2.IMREAD_COLOR)
95
+ image = cv2.resize(image, (256, 256))
96
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
97
+
98
+ face_mesh = FaceMesh(args.model, args.device)
99
+ roi = np.array([128, 128, 256, 256, np.pi / 2])
100
+ mesh = face_mesh(image, roi)
101
+ print(mesh.shape)
core/aux_models/hubert_stream.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..utils.load_model import load_model
2
+
3
+
4
+ class HubertStreaming:
5
+ def __init__(self, model_path, device="cuda", **kwargs):
6
+ kwargs["model_file"] = model_path
7
+ kwargs["module_name"] = "HubertStreamingONNX"
8
+ kwargs["package_name"] = "..aux_models.modules"
9
+
10
+ self.model, self.model_type = load_model(model_path, device=device, **kwargs)
11
+ self.device = device
12
+
13
+ def forward_chunk(self, audio_chunk):
14
+ if self.model_type == "onnx":
15
+ output = self.model.run(None, {"input_values": audio_chunk.reshape(1, -1)})[0]
16
+ elif self.model_type == "tensorrt":
17
+ self.model.setup({"input_values": audio_chunk.reshape(1, -1)})
18
+ self.model.infer()
19
+ output = self.model.buffer["encoding_out"][0]
20
+ else:
21
+ raise ValueError(f"Unsupported model type: {self.model_type}")
22
+ return output
23
+
24
+ def __call__(self, audio_chunk):
25
+ if self.model_type == "ori":
26
+ output = self.model.forward_chunk(audio_chunk)
27
+ else:
28
+ output = self.forward_chunk(audio_chunk)
29
+ return output
core/aux_models/insightface_det.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division
2
+ import numpy as np
3
+ import cv2
4
+
5
+ from ..utils.load_model import load_model
6
+
7
+
8
+ def distance2bbox(points, distance, max_shape=None):
9
+ """Decode distance prediction to bounding box.
10
+
11
+ Args:
12
+ points (Tensor): Shape (n, 2), [x, y].
13
+ distance (Tensor): Distance from the given point to 4
14
+ boundaries (left, top, right, bottom).
15
+ max_shape (tuple): Shape of the image.
16
+
17
+ Returns:
18
+ Tensor: Decoded bboxes.
19
+ """
20
+ x1 = points[:, 0] - distance[:, 0]
21
+ y1 = points[:, 1] - distance[:, 1]
22
+ x2 = points[:, 0] + distance[:, 2]
23
+ y2 = points[:, 1] + distance[:, 3]
24
+ if max_shape is not None:
25
+ x1 = x1.clamp(min=0, max=max_shape[1])
26
+ y1 = y1.clamp(min=0, max=max_shape[0])
27
+ x2 = x2.clamp(min=0, max=max_shape[1])
28
+ y2 = y2.clamp(min=0, max=max_shape[0])
29
+ return np.stack([x1, y1, x2, y2], axis=-1)
30
+
31
+
32
+ def distance2kps(points, distance, max_shape=None):
33
+ """Decode distance prediction to bounding box.
34
+
35
+ Args:
36
+ points (Tensor): Shape (n, 2), [x, y].
37
+ distance (Tensor): Distance from the given point to 4
38
+ boundaries (left, top, right, bottom).
39
+ max_shape (tuple): Shape of the image.
40
+
41
+ Returns:
42
+ Tensor: Decoded bboxes.
43
+ """
44
+ preds = []
45
+ for i in range(0, distance.shape[1], 2):
46
+ px = points[:, i%2] + distance[:, i]
47
+ py = points[:, i%2+1] + distance[:, i+1]
48
+ if max_shape is not None:
49
+ px = px.clamp(min=0, max=max_shape[1])
50
+ py = py.clamp(min=0, max=max_shape[0])
51
+ preds.append(px)
52
+ preds.append(py)
53
+ return np.stack(preds, axis=-1)
54
+
55
+
56
+ class InsightFaceDet:
57
+ def __init__(self, model_path, device="cuda", **kwargs):
58
+ kwargs["model_file"] = model_path
59
+ kwargs["module_name"] = "RetinaFace"
60
+ kwargs["package_name"] = "..aux_models.modules"
61
+
62
+ self.model, self.model_type = load_model(model_path, device=device, **kwargs)
63
+ self.device = device
64
+
65
+ if self.model_type != "ori":
66
+ self._init_vars()
67
+
68
+ def _init_vars(self):
69
+ self.center_cache = {}
70
+
71
+ self.nms_thresh = 0.4
72
+ self.det_thresh = 0.5
73
+
74
+ self.input_size = (512, 512)
75
+ self.input_mean = 127.5
76
+ self.input_std = 128.0
77
+ self._anchor_ratio = 1.0
78
+ self.fmc = 3
79
+ self._feat_stride_fpn = [8, 16, 32]
80
+ self._num_anchors = 2
81
+ self.use_kps = True
82
+
83
+ self.output_names = [
84
+ "scores1",
85
+ "scores2",
86
+ "scores3",
87
+ "boxes1",
88
+ "boxes2",
89
+ "boxes3",
90
+ "kps1",
91
+ "kps2",
92
+ "kps3",
93
+ ]
94
+
95
+ def _run_model(self, blob):
96
+ if self.model_type == "onnx":
97
+ net_outs = self.model.run(None, {"image": blob})
98
+ elif self.model_type == "tensorrt":
99
+ self.model.setup({"image": blob})
100
+ self.model.infer()
101
+ net_outs = [self.model.buffer[name][0] for name in self.output_names]
102
+ else:
103
+ raise ValueError(f"Unsupported model type: {self.model_type}")
104
+ return net_outs
105
+
106
+ def _forward(self, img, threshold):
107
+ """
108
+ img: np.ndarray, shape (h, w, 3)
109
+ """
110
+ scores_list = []
111
+ bboxes_list = []
112
+ kpss_list = []
113
+ input_size = tuple(img.shape[0:2][::-1])
114
+ blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
115
+ # (1, 3, 512, 512)
116
+ net_outs = self._run_model(blob)
117
+
118
+ input_height = blob.shape[2]
119
+ input_width = blob.shape[3]
120
+ fmc = self.fmc
121
+ for idx, stride in enumerate(self._feat_stride_fpn):
122
+ scores = net_outs[idx]
123
+ bbox_preds = net_outs[idx+fmc]
124
+ bbox_preds = bbox_preds * stride
125
+ if self.use_kps:
126
+ kps_preds = net_outs[idx+fmc*2] * stride
127
+ height = input_height // stride
128
+ width = input_width // stride
129
+ # K = height * width
130
+ key = (height, width, stride)
131
+ if key in self.center_cache:
132
+ anchor_centers = self.center_cache[key]
133
+ else:
134
+ #solution-3:
135
+ anchor_centers = np.stack(np.mgrid[:height, :width][::-1], axis=-1).astype(np.float32)
136
+ anchor_centers = (anchor_centers * stride).reshape( (-1, 2) )
137
+ if self._num_anchors>1:
138
+ anchor_centers = np.stack([anchor_centers]*self._num_anchors, axis=1).reshape( (-1,2) )
139
+ if len(self.center_cache)<100:
140
+ self.center_cache[key] = anchor_centers
141
+
142
+ pos_inds = np.where(scores>=threshold)[0]
143
+ bboxes = distance2bbox(anchor_centers, bbox_preds)
144
+ pos_scores = scores[pos_inds]
145
+ pos_bboxes = bboxes[pos_inds]
146
+ scores_list.append(pos_scores)
147
+ bboxes_list.append(pos_bboxes)
148
+ if self.use_kps:
149
+ kpss = distance2kps(anchor_centers, kps_preds)
150
+ kpss = kpss.reshape( (kpss.shape[0], -1, 2) )
151
+ pos_kpss = kpss[pos_inds]
152
+ kpss_list.append(pos_kpss)
153
+ return scores_list, bboxes_list, kpss_list
154
+
155
+ def detect(self, img, input_size=None, max_num=0, metric='default', det_thresh=None):
156
+ input_size = self.input_size if input_size is None else input_size
157
+ det_thresh = self.det_thresh if det_thresh is None else det_thresh
158
+
159
+ im_ratio = float(img.shape[0]) / img.shape[1]
160
+ model_ratio = float(input_size[1]) / input_size[0]
161
+ if im_ratio>model_ratio:
162
+ new_height = input_size[1]
163
+ new_width = int(new_height / im_ratio)
164
+ else:
165
+ new_width = input_size[0]
166
+ new_height = int(new_width * im_ratio)
167
+ det_scale = float(new_height) / img.shape[0]
168
+ resized_img = cv2.resize(img, (new_width, new_height))
169
+ det_img = np.zeros( (input_size[1], input_size[0], 3), dtype=np.uint8 )
170
+ det_img[:new_height, :new_width, :] = resized_img
171
+
172
+ scores_list, bboxes_list, kpss_list = self._forward(det_img, det_thresh)
173
+
174
+ scores = np.vstack(scores_list)
175
+ scores_ravel = scores.ravel()
176
+ order = scores_ravel.argsort()[::-1]
177
+ bboxes = np.vstack(bboxes_list) / det_scale
178
+ if self.use_kps:
179
+ kpss = np.vstack(kpss_list) / det_scale
180
+ pre_det = np.hstack((bboxes, scores)).astype(np.float32, copy=False)
181
+ pre_det = pre_det[order, :]
182
+ keep = self.nms(pre_det)
183
+ det = pre_det[keep, :]
184
+ if self.use_kps:
185
+ kpss = kpss[order,:,:]
186
+ kpss = kpss[keep,:,:]
187
+ else:
188
+ kpss = None
189
+ if max_num > 0 and det.shape[0] > max_num:
190
+ area = (det[:, 2] - det[:, 0]) * (det[:, 3] - det[:, 1])
191
+ img_center = img.shape[0] // 2, img.shape[1] // 2
192
+ offsets = np.vstack([
193
+ (det[:, 0] + det[:, 2]) / 2 - img_center[1],
194
+ (det[:, 1] + det[:, 3]) / 2 - img_center[0]
195
+ ])
196
+ offset_dist_squared = np.sum(np.power(offsets, 2.0), 0)
197
+ if metric=='max':
198
+ values = area
199
+ else:
200
+ values = area - offset_dist_squared * 2.0 # some extra weight on the centering
201
+ bindex = np.argsort(values)[::-1] # some extra weight on the centering
202
+ bindex = bindex[0:max_num]
203
+ det = det[bindex, :]
204
+ if kpss is not None:
205
+ kpss = kpss[bindex, :]
206
+ return det, kpss
207
+
208
+ def nms(self, dets):
209
+ thresh = self.nms_thresh
210
+ x1 = dets[:, 0]
211
+ y1 = dets[:, 1]
212
+ x2 = dets[:, 2]
213
+ y2 = dets[:, 3]
214
+ scores = dets[:, 4]
215
+
216
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
217
+ order = scores.argsort()[::-1]
218
+
219
+ keep = []
220
+ while order.size > 0:
221
+ i = order[0]
222
+ keep.append(i)
223
+ xx1 = np.maximum(x1[i], x1[order[1:]])
224
+ yy1 = np.maximum(y1[i], y1[order[1:]])
225
+ xx2 = np.minimum(x2[i], x2[order[1:]])
226
+ yy2 = np.minimum(y2[i], y2[order[1:]])
227
+
228
+ w = np.maximum(0.0, xx2 - xx1 + 1)
229
+ h = np.maximum(0.0, yy2 - yy1 + 1)
230
+ inter = w * h
231
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
232
+
233
+ inds = np.where(ovr <= thresh)[0]
234
+ order = order[inds + 1]
235
+
236
+ return keep
237
+
238
+ def __call__(self, img, **kwargs):
239
+ if self.model_type == "ori":
240
+ det, kpss = self.model.detect(img, **kwargs)
241
+ else:
242
+ det, kpss = self.detect(img, **kwargs)
243
+
244
+ return det, kpss
245
+
core/aux_models/insightface_landmark106.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division
2
+ import numpy as np
3
+ import torch
4
+ import cv2
5
+ from skimage import transform as trans
6
+
7
+ from ..utils.load_model import load_model
8
+
9
+
10
+ def transform(data, center, output_size, scale, rotation):
11
+ scale_ratio = scale
12
+ rot = float(rotation) * np.pi / 180.0
13
+
14
+ t1 = trans.SimilarityTransform(scale=scale_ratio)
15
+ cx = center[0] * scale_ratio
16
+ cy = center[1] * scale_ratio
17
+ t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy))
18
+ t3 = trans.SimilarityTransform(rotation=rot)
19
+ t4 = trans.SimilarityTransform(translation=(output_size / 2,
20
+ output_size / 2))
21
+ t = t1 + t2 + t3 + t4
22
+ M = t.params[0:2]
23
+ cropped = cv2.warpAffine(data,
24
+ M, (output_size, output_size),
25
+ borderValue=0.0)
26
+ return cropped, M
27
+
28
+
29
+ def trans_points2d(pts, M):
30
+ new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
31
+ for i in range(pts.shape[0]):
32
+ pt = pts[i]
33
+ new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32)
34
+ new_pt = np.dot(M, new_pt)
35
+ new_pts[i] = new_pt[0:2]
36
+
37
+ return new_pts
38
+
39
+
40
+ class Landmark106:
41
+ def __init__(self, model_path, device="cuda", **kwargs):
42
+ kwargs["model_file"] = model_path
43
+ kwargs["module_name"] = "Landmark106"
44
+ kwargs["package_name"] = "..aux_models.modules"
45
+
46
+ self.model, self.model_type = load_model(model_path, device=device, **kwargs)
47
+ self.device = device
48
+
49
+ if self.model_type != "ori":
50
+ self._init_vars()
51
+
52
+ def _init_vars(self):
53
+ self.input_mean = 0.0
54
+ self.input_std = 1.0
55
+ self.input_size = (192, 192)
56
+ self.lmk_num = 106
57
+
58
+ self.output_names = ["fc1"]
59
+
60
+ def _run_model(self, blob):
61
+ if self.model_type == "onnx":
62
+ pred = self.model.run(None, {"data": blob})[0]
63
+ elif self.model_type == "tensorrt":
64
+ self.model.setup({"data": blob})
65
+ self.model.infer()
66
+ pred = self.model.buffer[self.output_names[0]][0]
67
+ else:
68
+ raise ValueError(f"Unsupported model type: {self.model_type}")
69
+ return pred
70
+
71
+ def get(self, img, bbox):
72
+ w, h = (bbox[2] - bbox[0]), (bbox[3] - bbox[1])
73
+ center = (bbox[2] + bbox[0]) / 2, (bbox[3] + bbox[1]) / 2
74
+ rotate = 0
75
+ _scale = self.input_size[0] / (max(w, h)*1.5)
76
+
77
+ aimg, M = transform(img, center, self.input_size[0], _scale, rotate)
78
+ input_size = tuple(aimg.shape[0:2][::-1])
79
+
80
+ blob = cv2.dnn.blobFromImage(aimg, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
81
+
82
+ pred = self._run_model(blob)
83
+
84
+ pred = pred.reshape((-1, 2))
85
+ if self.lmk_num < pred.shape[0]:
86
+ pred = pred[self.lmk_num*-1:,:]
87
+ pred[:, 0:2] += 1
88
+ pred[:, 0:2] *= (self.input_size[0] // 2)
89
+
90
+ IM = cv2.invertAffineTransform(M)
91
+ pred = trans_points2d(pred, IM)
92
+ return pred
93
+
94
+ def __call__(self, img, bbox):
95
+ if self.model_type == "ori":
96
+ pred = self.model.get(img, bbox)
97
+ else:
98
+ pred = self.get(img, bbox)
99
+
100
+ return pred
core/aux_models/landmark203.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from ..utils.load_model import load_model
3
+
4
+
5
+ def _transform_pts(pts, M):
6
+ """ conduct similarity or affine transformation to the pts
7
+ pts: Nx2 ndarray
8
+ M: 2x3 matrix or 3x3 matrix
9
+ return: Nx2
10
+ """
11
+ return pts @ M[:2, :2].T + M[:2, 2]
12
+
13
+
14
+ class Landmark203:
15
+ def __init__(self, model_path, device="cuda", **kwargs):
16
+ kwargs["model_file"] = model_path
17
+ kwargs["module_name"] = "Landmark203"
18
+ kwargs["package_name"] = "..aux_models.modules"
19
+
20
+ self.model, self.model_type = load_model(model_path, device=device, **kwargs)
21
+ self.device = device
22
+
23
+ self.output_names = ["landmarks"]
24
+ self.dsize = 224
25
+
26
+ def _run_model(self, inp):
27
+ if self.model_type == "onnx":
28
+ out_pts = self.model.run(None, {"input": inp})[0]
29
+ elif self.model_type == "tensorrt":
30
+ self.model.setup({"input": inp})
31
+ self.model.infer()
32
+ out_pts = self.model.buffer[self.output_names[0]][0]
33
+ else:
34
+ raise ValueError(f"Unsupported model type: {self.model_type}")
35
+ return out_pts
36
+
37
+ def run(self, img_crop_rgb, M_c2o=None):
38
+ # img_crop_rgb: 224x224
39
+
40
+ inp = (img_crop_rgb.astype(np.float32) / 255.).transpose(2, 0, 1)[None, ...] # HxWx3 (BGR) -> 1x3xHxW (RGB!)
41
+
42
+ out_pts = self._run_model(inp)
43
+
44
+ # 2d landmarks 203 points
45
+ lmk = out_pts[0].reshape(-1, 2) * self.dsize # scale to 0-224
46
+ if M_c2o is not None:
47
+ lmk = _transform_pts(lmk, M=M_c2o)
48
+
49
+ return lmk
50
+
51
+ def __call__(self, img_crop_rgb, M_c2o=None):
52
+ if self.model_type == "ori":
53
+ lmk = self.model.run(img_crop_rgb, M_c2o)
54
+ else:
55
+ lmk = self.run(img_crop_rgb, M_c2o)
56
+
57
+ return lmk
58
+
core/aux_models/mediapipe_landmark478.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ import numpy as np
3
+
4
+ from ..utils.load_model import load_model
5
+ from .blaze_face import BlazeFace
6
+ from .face_mesh import FaceMesh
7
+
8
+
9
+ class SizeMode(Enum):
10
+ DEFAULT = 0
11
+ SQUARE_LONG = 1
12
+ SQUARE_SHORT = 2
13
+
14
+
15
+ def _select_roi_size(
16
+ bbox: np.ndarray, image_size, size_mode: SizeMode # x1, y1, x2, y2 # w,h
17
+ ):
18
+ """Return the size of an ROI based on bounding box, image size and mode"""
19
+ width, height = bbox[2] - bbox[0], bbox[3] - bbox[1]
20
+ image_width, image_height = image_size
21
+ if size_mode == SizeMode.SQUARE_LONG:
22
+ long_size = max(width, height)
23
+ width, height = long_size, long_size
24
+ elif size_mode == SizeMode.SQUARE_SHORT:
25
+ short_side = min(width, height)
26
+ width, height = short_side, short_side
27
+ return width, height
28
+
29
+
30
+ def bbox_to_roi(
31
+ bbox: np.ndarray,
32
+ image_size, # w,h
33
+ rotation_keypoints=None,
34
+ scale=(1.0, 1.0), # w, h
35
+ size_mode: SizeMode = SizeMode.SQUARE_LONG,
36
+ ):
37
+ PI = np.pi
38
+ TWO_PI = 2 * np.pi
39
+ # select ROI dimensions
40
+ width, height = _select_roi_size(bbox, image_size, size_mode)
41
+ scale_x, scale_y = scale
42
+ # calculate ROI size and -centre
43
+ width, height = width * scale_x, height * scale_y
44
+ cx = (bbox[0] + bbox[2]) / 2
45
+ cy = (bbox[1] + bbox[3]) / 2
46
+ # calculate rotation of required
47
+ if rotation_keypoints is None or len(rotation_keypoints) < 2:
48
+ return np.array([cx, cy, width, height, 0])
49
+ x0, y0 = rotation_keypoints[0]
50
+ x1, y1 = rotation_keypoints[1]
51
+ angle = -np.atan2(y0 - y1, x1 - x0)
52
+ # normalise to [0, 2*PI]
53
+ rotation = angle - TWO_PI * np.floor((angle + PI) / TWO_PI)
54
+ return np.array([cx, cy, width, height, rotation])
55
+
56
+
57
+ class Landmark478:
58
+ def __init__(self, blaze_face_model_path="", face_mesh_model_path="", device="cuda", **kwargs):
59
+ if kwargs.get("force_ori_type", False):
60
+ assert "task_path" in kwargs
61
+ kwargs["module_name"] = "Landmark478"
62
+ kwargs["package_name"] = "..aux_models.modules"
63
+ self.model, self.model_type = load_model("", device=device, **kwargs)
64
+ else:
65
+ self.blaze_face = BlazeFace(blaze_face_model_path, device)
66
+ self.face_mesh = FaceMesh(face_mesh_model_path, device)
67
+ self.model_type = ""
68
+
69
+ def get(self, image):
70
+ bboxes = self.blaze_face(image)
71
+ if len(bboxes) == 0:
72
+ return None
73
+ bbox = bboxes[0]
74
+ scale = (image.shape[1] / 128.0, image.shape[0] / 128.0)
75
+
76
+ # The first 4 numbers describe the bounding box corners:
77
+ #
78
+ # ymin, xmin, ymax, xmax
79
+ # These are normalized coordinates (between 0 and 1).
80
+ # The next 12 numbers are the x,y-coordinates of the 6 facial landmark keypoints:
81
+ #
82
+ # right_eye_x, right_eye_y
83
+ # left_eye_x, left_eye_y
84
+ # nose_x, nose_y
85
+ # mouth_x, mouth_y
86
+ # right_ear_x, right_ear_y
87
+ # left_ear_x, left_ear_y
88
+ # Tip: these labeled as seen from the perspective of the person, so their right is your left.
89
+ # The final number is the confidence score that this detection really is a face.
90
+
91
+ bbox[0] = bbox[0] * scale[1]
92
+ bbox[1] = bbox[1] * scale[0]
93
+ bbox[2] = bbox[2] * scale[1]
94
+ bbox[3] = bbox[3] * scale[0]
95
+ left_eye = (bbox[4], bbox[5])
96
+ right_eye = (bbox[6], bbox[7])
97
+
98
+ roi = bbox_to_roi(
99
+ bbox,
100
+ (image.shape[1], image.shape[0]),
101
+ rotation_keypoints=[left_eye, right_eye],
102
+ scale=(1.5, 1.5),
103
+ size_mode=SizeMode.SQUARE_LONG,
104
+ )
105
+
106
+ mesh = self.face_mesh(image, roi)
107
+ mesh = mesh / (image.shape[1], image.shape[0], image.shape[1])
108
+ return mesh
109
+
110
+ def __call__(self, image):
111
+ if self.model_type == "ori":
112
+ det = self.model.detect_from_npimage(image.copy())
113
+ lmk = self.model.mplmk_to_nplmk(det)
114
+ return lmk
115
+ else:
116
+ lmk = self.get(image)
117
+ lmk = lmk.reshape(1, -1, 3).astype(np.float32)
118
+ return lmk
core/aux_models/modules/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .retinaface import RetinaFace
2
+ from .landmark106 import Landmark106
3
+ from .landmark203 import Landmark203
4
+ from .landmark478 import Landmark478
5
+ from .hubert_stream import HubertStreamingONNX
core/aux_models/modules/hubert_stream.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import onnxruntime
3
+
4
+
5
+ class HubertStreamingONNX:
6
+ def __init__(self, model_file, device="cuda"):
7
+ if device == "cuda":
8
+ providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
9
+ else:
10
+ providers = ["CPUExecutionProvider"]
11
+
12
+ self.session = onnxruntime.InferenceSession(model_file, providers=providers)
13
+
14
+ def forward_chunk(self, input_values):
15
+ encoding_out = self.session.run(
16
+ None,
17
+ {"input_values": input_values.reshape(1, -1)}
18
+ )[0]
19
+ return encoding_out
20
+
21
+
core/aux_models/modules/landmark106.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # insightface
2
+ from __future__ import division
3
+ import onnxruntime
4
+ import cv2
5
+ import numpy as np
6
+ from skimage import transform as trans
7
+
8
+
9
+ def transform(data, center, output_size, scale, rotation):
10
+ scale_ratio = scale
11
+ rot = float(rotation) * np.pi / 180.0
12
+
13
+ t1 = trans.SimilarityTransform(scale=scale_ratio)
14
+ cx = center[0] * scale_ratio
15
+ cy = center[1] * scale_ratio
16
+ t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy))
17
+ t3 = trans.SimilarityTransform(rotation=rot)
18
+ t4 = trans.SimilarityTransform(translation=(output_size / 2,
19
+ output_size / 2))
20
+ t = t1 + t2 + t3 + t4
21
+ M = t.params[0:2]
22
+ cropped = cv2.warpAffine(data,
23
+ M, (output_size, output_size),
24
+ borderValue=0.0)
25
+ return cropped, M
26
+
27
+
28
+ def trans_points2d(pts, M):
29
+ new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
30
+ for i in range(pts.shape[0]):
31
+ pt = pts[i]
32
+ new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32)
33
+ new_pt = np.dot(M, new_pt)
34
+ new_pts[i] = new_pt[0:2]
35
+
36
+ return new_pts
37
+
38
+
39
+
40
+ class Landmark106:
41
+ def __init__(self, model_file, device="cuda"):
42
+ if device == "cuda":
43
+ providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
44
+ else:
45
+ providers = ["CPUExecutionProvider"]
46
+ self.session = onnxruntime.InferenceSession(model_file, providers=providers)
47
+
48
+ self.input_mean = 0.0
49
+ self.input_std = 1.0
50
+ self.input_size = (192, 192)
51
+ input_cfg = self.session.get_inputs()[0]
52
+ input_name = input_cfg.name
53
+ outputs = self.session.get_outputs()
54
+ output_names = []
55
+ for out in outputs:
56
+ output_names.append(out.name)
57
+ self.input_name = input_name
58
+ self.output_names = output_names
59
+ self.lmk_num = 106
60
+
61
+ def get(self, img, bbox):
62
+ w, h = (bbox[2] - bbox[0]), (bbox[3] - bbox[1])
63
+ center = (bbox[2] + bbox[0]) / 2, (bbox[3] + bbox[1]) / 2
64
+ rotate = 0
65
+ _scale = self.input_size[0] / (max(w, h)*1.5)
66
+
67
+ aimg, M = transform(img, center, self.input_size[0], _scale, rotate)
68
+ input_size = tuple(aimg.shape[0:2][::-1])
69
+
70
+ blob = cv2.dnn.blobFromImage(aimg, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
71
+
72
+ pred = self.session.run(self.output_names, {self.input_name : blob})[0][0]
73
+
74
+ pred = pred.reshape((-1, 2))
75
+ if self.lmk_num < pred.shape[0]:
76
+ pred = pred[self.lmk_num*-1:,:]
77
+ pred[:, 0:2] += 1
78
+ pred[:, 0:2] *= (self.input_size[0] // 2)
79
+
80
+ IM = cv2.invertAffineTransform(M)
81
+ pred = trans_points2d(pred, IM)
82
+ return pred
83
+
core/aux_models/modules/landmark203.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnxruntime
2
+ import numpy as np
3
+
4
+
5
+ def _transform_pts(pts, M):
6
+ """ conduct similarity or affine transformation to the pts
7
+ pts: Nx2 ndarray
8
+ M: 2x3 matrix or 3x3 matrix
9
+ return: Nx2
10
+ """
11
+ return pts @ M[:2, :2].T + M[:2, 2]
12
+
13
+
14
+ class Landmark203:
15
+ def __init__(self, model_file, device="cuda"):
16
+ if device == "cuda":
17
+ providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
18
+ else:
19
+ providers = ["CPUExecutionProvider"]
20
+ self.session = onnxruntime.InferenceSession(model_file, providers=providers)
21
+
22
+ self.dsize = 224
23
+
24
+ def _run(self, inp):
25
+ out = self.session.run(None, {'input': inp})
26
+ return out
27
+
28
+ def run(self, img_crop_rgb, M_c2o=None):
29
+ # img_crop_rgb: 224x224
30
+
31
+ inp = (img_crop_rgb.astype(np.float32) / 255.).transpose(2, 0, 1)[None, ...] # HxWx3 (BGR) -> 1x3xHxW (RGB!)
32
+
33
+ out_lst = self._run(inp)
34
+ out_pts = out_lst[2]
35
+
36
+ # 2d landmarks 203 points
37
+ lmk = out_pts[0].reshape(-1, 2) * self.dsize # scale to 0-224
38
+ if M_c2o is not None:
39
+ lmk = _transform_pts(lmk, M=M_c2o)
40
+
41
+ return lmk
42
+
core/aux_models/modules/landmark478.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import mediapipe as mp
3
+ from mediapipe.tasks.python import vision, BaseOptions
4
+
5
+
6
+ class Landmark478:
7
+ def __init__(self, task_path):
8
+ base_options = BaseOptions(model_asset_path=task_path)
9
+ options = vision.FaceLandmarkerOptions(
10
+ base_options=base_options,
11
+ output_face_blendshapes=True,
12
+ output_facial_transformation_matrixes=True,
13
+ num_faces=1,
14
+ )
15
+ detector = vision.FaceLandmarker.create_from_options(options)
16
+ self.detector = detector
17
+
18
+ def detect_from_imp(self, imp):
19
+ image = mp.Image.create_from_file(imp)
20
+ detection_result = self.detector.detect(image)
21
+ return detection_result
22
+
23
+ def detect_from_npimage(self, img):
24
+ image = mp.Image(image_format=mp.ImageFormat.SRGB, data=img)
25
+ detection_result = self.detector.detect(image)
26
+ return detection_result
27
+
28
+ @staticmethod
29
+ def mplmk_to_nplmk(results):
30
+ face_landmarks_list = results.face_landmarks
31
+ np_lms = []
32
+ for face_lms in face_landmarks_list:
33
+ lms = [[lm.x, lm.y, lm.z] for lm in face_lms]
34
+ np_lms.append(lms)
35
+ return np.array(np_lms).astype(np.float32)
core/aux_models/modules/retinaface.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # insightface
2
+ from __future__ import division
3
+ import onnxruntime
4
+ import cv2
5
+ import numpy as np
6
+
7
+
8
+ def distance2bbox(points, distance, max_shape=None):
9
+ """Decode distance prediction to bounding box.
10
+
11
+ Args:
12
+ points (Tensor): Shape (n, 2), [x, y].
13
+ distance (Tensor): Distance from the given point to 4
14
+ boundaries (left, top, right, bottom).
15
+ max_shape (tuple): Shape of the image.
16
+
17
+ Returns:
18
+ Tensor: Decoded bboxes.
19
+ """
20
+ x1 = points[:, 0] - distance[:, 0]
21
+ y1 = points[:, 1] - distance[:, 1]
22
+ x2 = points[:, 0] + distance[:, 2]
23
+ y2 = points[:, 1] + distance[:, 3]
24
+ if max_shape is not None:
25
+ x1 = x1.clamp(min=0, max=max_shape[1])
26
+ y1 = y1.clamp(min=0, max=max_shape[0])
27
+ x2 = x2.clamp(min=0, max=max_shape[1])
28
+ y2 = y2.clamp(min=0, max=max_shape[0])
29
+ return np.stack([x1, y1, x2, y2], axis=-1)
30
+
31
+
32
+ def distance2kps(points, distance, max_shape=None):
33
+ """Decode distance prediction to bounding box.
34
+
35
+ Args:
36
+ points (Tensor): Shape (n, 2), [x, y].
37
+ distance (Tensor): Distance from the given point to 4
38
+ boundaries (left, top, right, bottom).
39
+ max_shape (tuple): Shape of the image.
40
+
41
+ Returns:
42
+ Tensor: Decoded bboxes.
43
+ """
44
+ preds = []
45
+ for i in range(0, distance.shape[1], 2):
46
+ px = points[:, i%2] + distance[:, i]
47
+ py = points[:, i%2+1] + distance[:, i+1]
48
+ if max_shape is not None:
49
+ px = px.clamp(min=0, max=max_shape[1])
50
+ py = py.clamp(min=0, max=max_shape[0])
51
+ preds.append(px)
52
+ preds.append(py)
53
+ return np.stack(preds, axis=-1)
54
+
55
+
56
+ class RetinaFace:
57
+ def __init__(self, model_file, device="cuda"):
58
+ if device == "cuda":
59
+ providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
60
+ else:
61
+ providers = ["CPUExecutionProvider"]
62
+ self.session = onnxruntime.InferenceSession(model_file, providers=providers)
63
+
64
+ self.center_cache = {}
65
+ self.nms_thresh = 0.4
66
+ self.det_thresh = 0.5
67
+ self._init_vars()
68
+
69
+ def _init_vars(self):
70
+ self.input_size = (512, 512)
71
+ input_cfg = self.session.get_inputs()[0]
72
+ input_name = input_cfg.name
73
+ outputs = self.session.get_outputs()
74
+ output_names = []
75
+ for o in outputs:
76
+ output_names.append(o.name)
77
+ self.input_name = input_name
78
+ self.output_names = output_names
79
+ self.input_mean = 127.5
80
+ self.input_std = 128.0
81
+ self._anchor_ratio = 1.0
82
+ self.fmc = 3
83
+ self._feat_stride_fpn = [8, 16, 32]
84
+ self._num_anchors = 2
85
+ self.use_kps = True
86
+
87
+ def forward(self, img, threshold):
88
+ scores_list = []
89
+ bboxes_list = []
90
+ kpss_list = []
91
+ input_size = tuple(img.shape[0:2][::-1])
92
+ blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
93
+ net_outs = self.session.run(self.output_names, {self.input_name : blob})
94
+
95
+ input_height = blob.shape[2]
96
+ input_width = blob.shape[3]
97
+ fmc = self.fmc
98
+ for idx, stride in enumerate(self._feat_stride_fpn):
99
+ scores = net_outs[idx]
100
+ bbox_preds = net_outs[idx+fmc]
101
+ bbox_preds = bbox_preds * stride
102
+ if self.use_kps:
103
+ kps_preds = net_outs[idx+fmc*2] * stride
104
+ height = input_height // stride
105
+ width = input_width // stride
106
+ # K = height * width
107
+ key = (height, width, stride)
108
+ if key in self.center_cache:
109
+ anchor_centers = self.center_cache[key]
110
+ else:
111
+ #solution-3:
112
+ anchor_centers = np.stack(np.mgrid[:height, :width][::-1], axis=-1).astype(np.float32)
113
+ anchor_centers = (anchor_centers * stride).reshape( (-1, 2) )
114
+ if self._num_anchors>1:
115
+ anchor_centers = np.stack([anchor_centers]*self._num_anchors, axis=1).reshape( (-1,2) )
116
+ if len(self.center_cache)<100:
117
+ self.center_cache[key] = anchor_centers
118
+
119
+ pos_inds = np.where(scores>=threshold)[0]
120
+ bboxes = distance2bbox(anchor_centers, bbox_preds)
121
+ pos_scores = scores[pos_inds]
122
+ pos_bboxes = bboxes[pos_inds]
123
+ scores_list.append(pos_scores)
124
+ bboxes_list.append(pos_bboxes)
125
+ if self.use_kps:
126
+ kpss = distance2kps(anchor_centers, kps_preds)
127
+ kpss = kpss.reshape( (kpss.shape[0], -1, 2) )
128
+ pos_kpss = kpss[pos_inds]
129
+ kpss_list.append(pos_kpss)
130
+ return scores_list, bboxes_list, kpss_list
131
+
132
+
133
+ def detect(self, img, input_size=None, max_num=0, metric='default', det_thresh=None):
134
+ input_size = self.input_size if input_size is None else input_size
135
+ det_thresh = self.det_thresh if det_thresh is None else det_thresh
136
+
137
+ im_ratio = float(img.shape[0]) / img.shape[1]
138
+ model_ratio = float(input_size[1]) / input_size[0]
139
+ if im_ratio>model_ratio:
140
+ new_height = input_size[1]
141
+ new_width = int(new_height / im_ratio)
142
+ else:
143
+ new_width = input_size[0]
144
+ new_height = int(new_width * im_ratio)
145
+ det_scale = float(new_height) / img.shape[0]
146
+ resized_img = cv2.resize(img, (new_width, new_height))
147
+ det_img = np.zeros( (input_size[1], input_size[0], 3), dtype=np.uint8 )
148
+ det_img[:new_height, :new_width, :] = resized_img
149
+
150
+ scores_list, bboxes_list, kpss_list = self.forward(det_img, det_thresh)
151
+
152
+ scores = np.vstack(scores_list)
153
+ scores_ravel = scores.ravel()
154
+ order = scores_ravel.argsort()[::-1]
155
+ bboxes = np.vstack(bboxes_list) / det_scale
156
+ if self.use_kps:
157
+ kpss = np.vstack(kpss_list) / det_scale
158
+ pre_det = np.hstack((bboxes, scores)).astype(np.float32, copy=False)
159
+ pre_det = pre_det[order, :]
160
+ keep = self.nms(pre_det)
161
+ det = pre_det[keep, :]
162
+ if self.use_kps:
163
+ kpss = kpss[order,:,:]
164
+ kpss = kpss[keep,:,:]
165
+ else:
166
+ kpss = None
167
+ if max_num > 0 and det.shape[0] > max_num:
168
+ area = (det[:, 2] - det[:, 0]) * (det[:, 3] - det[:, 1])
169
+ img_center = img.shape[0] // 2, img.shape[1] // 2
170
+ offsets = np.vstack([
171
+ (det[:, 0] + det[:, 2]) / 2 - img_center[1],
172
+ (det[:, 1] + det[:, 3]) / 2 - img_center[0]
173
+ ])
174
+ offset_dist_squared = np.sum(np.power(offsets, 2.0), 0)
175
+ if metric=='max':
176
+ values = area
177
+ else:
178
+ values = area - offset_dist_squared * 2.0 # some extra weight on the centering
179
+ bindex = np.argsort(values)[::-1] # some extra weight on the centering
180
+ bindex = bindex[0:max_num]
181
+ det = det[bindex, :]
182
+ if kpss is not None:
183
+ kpss = kpss[bindex, :]
184
+ return det, kpss
185
+
186
+ def nms(self, dets):
187
+ thresh = self.nms_thresh
188
+ x1 = dets[:, 0]
189
+ y1 = dets[:, 1]
190
+ x2 = dets[:, 2]
191
+ y2 = dets[:, 3]
192
+ scores = dets[:, 4]
193
+
194
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
195
+ order = scores.argsort()[::-1]
196
+
197
+ keep = []
198
+ while order.size > 0:
199
+ i = order[0]
200
+ keep.append(i)
201
+ xx1 = np.maximum(x1[i], x1[order[1:]])
202
+ yy1 = np.maximum(y1[i], y1[order[1:]])
203
+ xx2 = np.minimum(x2[i], x2[order[1:]])
204
+ yy2 = np.minimum(y2[i], y2[order[1:]])
205
+
206
+ w = np.maximum(0.0, xx2 - xx1 + 1)
207
+ h = np.maximum(0.0, yy2 - yy1 + 1)
208
+ inter = w * h
209
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
210
+
211
+ inds = np.where(ovr <= thresh)[0]
212
+ order = order[inds + 1]
213
+
214
+ return keep
215
+
core/models/appearance_extractor.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from ..utils.load_model import load_model
4
+
5
+
6
+ class AppearanceExtractor:
7
+ def __init__(self, model_path, device="cuda"):
8
+ kwargs = {
9
+ "module_name": "AppearanceFeatureExtractor",
10
+ }
11
+ self.model, self.model_type = load_model(model_path, device=device, **kwargs)
12
+ self.device = device
13
+
14
+ def __call__(self, image):
15
+ """
16
+ image: np.ndarray, shape (1, 3, 256, 256), float32, range [0, 1]
17
+ """
18
+ if self.model_type == "onnx":
19
+ pred = self.model.run(None, {"image": image})[0]
20
+ elif self.model_type == "tensorrt":
21
+ self.model.setup({"image": image})
22
+ self.model.infer()
23
+ pred = self.model.buffer["pred"][0].copy()
24
+ elif self.model_type == 'pytorch':
25
+ with torch.no_grad(), torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=True):
26
+ pred = self.model(torch.from_numpy(image).to(self.device)).float().cpu().numpy()
27
+ else:
28
+ raise ValueError(f"Unsupported model type: {self.model_type}")
29
+ return pred
core/models/decoder.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from ..utils.load_model import load_model
4
+
5
+
6
+ class Decoder:
7
+ def __init__(self, model_path, device="cuda"):
8
+ kwargs = {
9
+ "module_name": "SPADEDecoder",
10
+ }
11
+ self.model, self.model_type = load_model(model_path, device=device, **kwargs)
12
+ self.device = device
13
+
14
+ def __call__(self, feature):
15
+
16
+ if self.model_type == "onnx":
17
+ pred = self.model.run(None, {"feature": feature})[0]
18
+ elif self.model_type == "tensorrt":
19
+ self.model.setup({"feature": feature})
20
+ self.model.infer()
21
+ pred = self.model.buffer["output"][0].copy()
22
+ elif self.model_type == 'pytorch':
23
+ with torch.no_grad(), torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=True):
24
+ pred = self.model(torch.from_numpy(feature).to(self.device)).float().cpu().numpy()
25
+ else:
26
+ raise ValueError(f"Unsupported model type: {self.model_type}")
27
+
28
+ pred = np.transpose(pred[0], [1, 2, 0]).clip(0, 1) * 255 # [h, w, c]
29
+
30
+ return pred
core/models/lmdm.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from ..utils.load_model import load_model
4
+
5
+
6
+ def make_beta(n_timestep, cosine_s=8e-3):
7
+ timesteps = (
8
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
9
+ )
10
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
11
+ alphas = torch.cos(alphas).pow(2)
12
+ alphas = alphas / alphas[0]
13
+ betas = 1 - alphas[1:] / alphas[:-1]
14
+ betas = np.clip(betas, a_min=0, a_max=0.999)
15
+ return betas.numpy()
16
+
17
+
18
+ class LMDM:
19
+ def __init__(self, model_path, device="cuda", **kwargs):
20
+ kwargs["module_name"] = "LMDM"
21
+
22
+ self.model, self.model_type = load_model(model_path, device=device, **kwargs)
23
+ self.device = device
24
+
25
+ self.motion_feat_dim = kwargs.get("motion_feat_dim", 265)
26
+ self.audio_feat_dim = kwargs.get("audio_feat_dim", 1024+35)
27
+ self.seq_frames = kwargs.get("seq_frames", 80)
28
+
29
+ if self.model_type == "pytorch":
30
+ pass
31
+ else:
32
+ self._init_np()
33
+
34
+ def setup(self, sampling_timesteps):
35
+ if self.model_type == "pytorch":
36
+ self.model.setup(sampling_timesteps)
37
+ else:
38
+ self._setup_np(sampling_timesteps)
39
+
40
+ def _init_np(self):
41
+ self.sampling_timesteps = None
42
+ self.n_timestep = 1000
43
+
44
+ betas = torch.Tensor(make_beta(n_timestep=self.n_timestep))
45
+ alphas = 1.0 - betas
46
+ self.alphas_cumprod = torch.cumprod(alphas, axis=0).cpu().numpy()
47
+
48
+ def _setup_np(self, sampling_timesteps=50):
49
+ if self.sampling_timesteps == sampling_timesteps:
50
+ return
51
+
52
+ self.sampling_timesteps = sampling_timesteps
53
+
54
+ total_timesteps = self.n_timestep
55
+ eta = 1
56
+ shape = (1, self.seq_frames, self.motion_feat_dim)
57
+
58
+ times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
59
+ times = list(reversed(times.int().tolist()))
60
+ self.time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
61
+
62
+ self.time_cond_list = []
63
+ self.alpha_next_sqrt_list = []
64
+ self.sigma_list = []
65
+ self.c_list = []
66
+ self.noise_list = []
67
+
68
+ for time, time_next in self.time_pairs:
69
+ time_cond = np.full((1,), time, dtype=np.int64)
70
+ self.time_cond_list.append(time_cond)
71
+ if time_next < 0:
72
+ continue
73
+
74
+ alpha = self.alphas_cumprod[time]
75
+ alpha_next = self.alphas_cumprod[time_next]
76
+
77
+ sigma = eta * np.sqrt((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha))
78
+ c = np.sqrt(1 - alpha_next - sigma ** 2)
79
+ noise = np.random.randn(*shape).astype(np.float32)
80
+
81
+ self.alpha_next_sqrt_list.append(np.sqrt(alpha_next))
82
+ self.sigma_list.append(sigma)
83
+ self.c_list.append(c)
84
+ self.noise_list.append(noise)
85
+
86
+ def _one_step(self, x, cond_frame, cond, time_cond):
87
+ if self.model_type == "onnx":
88
+ pred = self.model.run(None, {"x": x, "cond_frame": cond_frame, "cond": cond, "time_cond": time_cond})
89
+ pred_noise, x_start = pred[0], pred[1]
90
+ elif self.model_type == "tensorrt":
91
+ self.model.setup({"x": x, "cond_frame": cond_frame, "cond": cond, "time_cond": time_cond})
92
+ self.model.infer()
93
+ pred_noise, x_start = self.model.buffer["pred_noise"][0], self.model.buffer["x_start"][0]
94
+ elif self.model_type == "pytorch":
95
+ with torch.no_grad():
96
+ pred_noise, x_start = self.model(x, cond_frame, cond, time_cond)
97
+ else:
98
+ raise ValueError(f"Unsupported model type: {self.model_type}")
99
+
100
+ return pred_noise, x_start
101
+
102
+ def _call_np(self, kp_cond, aud_cond, sampling_timesteps):
103
+ self._setup_np(sampling_timesteps)
104
+
105
+ cond_frame = kp_cond
106
+ cond = aud_cond
107
+
108
+ x = np.random.randn(1, self.seq_frames, self.motion_feat_dim).astype(np.float32)
109
+
110
+ x_start = None
111
+ i = 0
112
+ for _, time_next in self.time_pairs:
113
+ time_cond = self.time_cond_list[i]
114
+ pred_noise, x_start = self._one_step(x, cond_frame, cond, time_cond)
115
+ if time_next < 0:
116
+ x = x_start
117
+ continue
118
+
119
+ alpha_next_sqrt = self.alpha_next_sqrt_list[i]
120
+ c = self.c_list[i]
121
+ sigma = self.sigma_list[i]
122
+ noise = self.noise_list[i]
123
+ x = x_start * alpha_next_sqrt + c * pred_noise + sigma * noise
124
+
125
+ i += 1
126
+
127
+ return x
128
+
129
+ def __call__(self, kp_cond, aud_cond, sampling_timesteps):
130
+ if self.model_type == "pytorch":
131
+ pred_kp_seq = self.model.ddim_sample(
132
+ torch.from_numpy(kp_cond).to(self.device),
133
+ torch.from_numpy(aud_cond).to(self.device),
134
+ sampling_timesteps,
135
+ ).cpu().numpy()
136
+ else:
137
+ pred_kp_seq = self._call_np(kp_cond, aud_cond, sampling_timesteps)
138
+ return pred_kp_seq
139
+
140
+
core/models/modules/LMDM.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Latent Motion Diffusion Model
2
+ import torch
3
+ import torch.nn as nn
4
+ from .lmdm_modules.model import MotionDecoder
5
+ from .lmdm_modules.utils import extract, make_beta_schedule
6
+
7
+
8
+ class LMDM(nn.Module):
9
+ def __init__(
10
+ self,
11
+ motion_feat_dim=265,
12
+ audio_feat_dim=1024+35,
13
+ seq_frames=80,
14
+ checkpoint='',
15
+ device='cuda',
16
+ clip_denoised=False, # clip denoised (-1,1)
17
+ multi_cond_frame=False,
18
+ ):
19
+ super().__init__()
20
+
21
+ self.motion_feat_dim = motion_feat_dim
22
+ self.audio_feat_dim = audio_feat_dim
23
+ self.seq_frames = seq_frames
24
+ self.device = device
25
+
26
+ self.n_timestep = 1000
27
+ self.clip_denoised = clip_denoised
28
+ self.guidance_weight = 2
29
+
30
+ self.model = MotionDecoder(
31
+ nfeats=motion_feat_dim,
32
+ seq_len=seq_frames,
33
+ latent_dim=512,
34
+ ff_size=1024,
35
+ num_layers=8,
36
+ num_heads=8,
37
+ dropout=0.1,
38
+ cond_feature_dim=audio_feat_dim,
39
+ multi_cond_frame=multi_cond_frame,
40
+ )
41
+
42
+ self.init_diff()
43
+
44
+ self.sampling_timesteps = None
45
+
46
+ def init_diff(self):
47
+ n_timestep = self.n_timestep
48
+ betas = torch.Tensor(
49
+ make_beta_schedule(schedule="cosine", n_timestep=n_timestep)
50
+ )
51
+ alphas = 1.0 - betas
52
+ alphas_cumprod = torch.cumprod(alphas, axis=0)
53
+
54
+ self.register_buffer("alphas_cumprod", alphas_cumprod)
55
+ self.register_buffer(
56
+ "sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1)
57
+ )
58
+ self.register_buffer("sqrt_recip1m_alphas_cumprod", torch.sqrt(1.0 / (1.0 - alphas_cumprod)))
59
+
60
+ def predict_noise_from_start(self, x_t, t, x0):
61
+ a = extract(self.sqrt_recip1m_alphas_cumprod, t, x_t.shape)
62
+ b = extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
63
+ return (a * x_t - x0 / b)
64
+
65
+ def maybe_clip(self, x):
66
+ if self.clip_denoised:
67
+ return torch.clamp(x, min=-1., max=1.)
68
+ else:
69
+ return x
70
+
71
+ def model_predictions(self, x, cond_frame, cond, t):
72
+ weight = self.guidance_weight
73
+ x_start = self.model.guided_forward(x, cond_frame, cond, t, weight)
74
+ x_start = self.maybe_clip(x_start)
75
+ pred_noise = self.predict_noise_from_start(x, t, x_start)
76
+ return pred_noise, x_start
77
+
78
+ @torch.no_grad()
79
+ def forward(self, x, cond_frame, cond, time_cond):
80
+ pred_noise, x_start = self.model_predictions(x, cond_frame, cond, time_cond)
81
+ return pred_noise, x_start
82
+
83
+ def load_model(self, ckpt_path):
84
+ checkpoint = torch.load(ckpt_path, map_location='cpu')
85
+ self.model.load_state_dict(checkpoint["model_state_dict"])
86
+ self.eval()
87
+ return self
88
+
89
+ def setup(self, sampling_timesteps=50):
90
+ if self.sampling_timesteps == sampling_timesteps:
91
+ return
92
+
93
+ self.sampling_timesteps = sampling_timesteps
94
+
95
+ total_timesteps = self.n_timestep
96
+ device = self.device
97
+ eta = 1
98
+ shape = (1, self.seq_frames, self.motion_feat_dim)
99
+
100
+ times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
101
+ times = list(reversed(times.int().tolist()))
102
+ self.time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
103
+
104
+ self.time_cond_list = []
105
+ self.alpha_next_sqrt_list = []
106
+ self.sigma_list = []
107
+ self.c_list = []
108
+ self.noise_list = []
109
+
110
+ for time, time_next in self.time_pairs:
111
+ time_cond = torch.full((1,), time, device=device, dtype=torch.long)
112
+ self.time_cond_list.append(time_cond)
113
+ if time_next < 0:
114
+ continue
115
+ alpha = self.alphas_cumprod[time]
116
+ alpha_next = self.alphas_cumprod[time_next]
117
+
118
+ sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
119
+ c = (1 - alpha_next - sigma ** 2).sqrt()
120
+ noise = torch.randn(shape, device=device)
121
+
122
+ self.alpha_next_sqrt_list.append(alpha_next.sqrt())
123
+ self.sigma_list.append(sigma)
124
+ self.c_list.append(c)
125
+ self.noise_list.append(noise)
126
+
127
+ @torch.no_grad()
128
+ def ddim_sample(self, kp_cond, aud_cond, sampling_timesteps):
129
+ self.setup(sampling_timesteps)
130
+
131
+ cond_frame = kp_cond
132
+ cond = aud_cond
133
+
134
+ shape = (1, self.seq_frames, self.motion_feat_dim)
135
+ x = torch.randn(shape, device=self.device)
136
+
137
+ x_start = None
138
+ i = 0
139
+ for _, time_next in self.time_pairs:
140
+ time_cond = self.time_cond_list[i]
141
+ pred_noise, x_start = self.model_predictions(x, cond_frame, cond, time_cond)
142
+ if time_next < 0:
143
+ x = x_start
144
+ continue
145
+
146
+ alpha_next_sqrt = self.alpha_next_sqrt_list[i]
147
+ c = self.c_list[i]
148
+ sigma = self.sigma_list[i]
149
+ noise = self.noise_list[i]
150
+ x = x_start * alpha_next_sqrt + c * pred_noise + sigma * noise
151
+
152
+ i += 1
153
+ return x # pred_kp_seq
154
+
core/models/modules/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .appearance_feature_extractor import AppearanceFeatureExtractor
2
+ from .motion_extractor import MotionExtractor
3
+ from .warping_network import WarpingNetwork
4
+ from .spade_generator import SPADEDecoder
5
+ from .stitching_network import StitchingNetwork
6
+ from .LMDM import LMDM
core/models/modules/appearance_feature_extractor.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ Appearance extractor(F) defined in paper, which maps the source image s to a 3D appearance feature volume.
5
+ """
6
+
7
+ import torch
8
+ from torch import nn
9
+ from .util import SameBlock2d, DownBlock2d, ResBlock3d
10
+
11
+
12
+ class AppearanceFeatureExtractor(nn.Module):
13
+
14
+ def __init__(
15
+ self,
16
+ image_channel=3,
17
+ block_expansion=64,
18
+ num_down_blocks=2,
19
+ max_features=512,
20
+ reshape_channel=32,
21
+ reshape_depth=16,
22
+ num_resblocks=6,
23
+ ):
24
+ super(AppearanceFeatureExtractor, self).__init__()
25
+ self.image_channel = image_channel
26
+ self.block_expansion = block_expansion
27
+ self.num_down_blocks = num_down_blocks
28
+ self.max_features = max_features
29
+ self.reshape_channel = reshape_channel
30
+ self.reshape_depth = reshape_depth
31
+
32
+ self.first = SameBlock2d(
33
+ image_channel, block_expansion, kernel_size=(3, 3), padding=(1, 1)
34
+ )
35
+
36
+ down_blocks = []
37
+ for i in range(num_down_blocks):
38
+ in_features = min(max_features, block_expansion * (2**i))
39
+ out_features = min(max_features, block_expansion * (2 ** (i + 1)))
40
+ down_blocks.append(
41
+ DownBlock2d(
42
+ in_features, out_features, kernel_size=(3, 3), padding=(1, 1)
43
+ )
44
+ )
45
+ self.down_blocks = nn.ModuleList(down_blocks)
46
+
47
+ self.second = nn.Conv2d(
48
+ in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1
49
+ )
50
+
51
+ self.resblocks_3d = torch.nn.Sequential()
52
+ for i in range(num_resblocks):
53
+ self.resblocks_3d.add_module(
54
+ "3dr" + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1)
55
+ )
56
+
57
+ def forward(self, source_image):
58
+ out = self.first(source_image) # Bx3x256x256 -> Bx64x256x256
59
+
60
+ for i in range(len(self.down_blocks)):
61
+ out = self.down_blocks[i](out)
62
+ out = self.second(out)
63
+ bs, c, h, w = out.shape # ->Bx512x64x64
64
+
65
+ f_s = out.view(
66
+ bs, self.reshape_channel, self.reshape_depth, h, w
67
+ ) # ->Bx32x16x64x64
68
+ f_s = self.resblocks_3d(f_s) # ->Bx32x16x64x64
69
+ return f_s
70
+
71
+ def load_model(self, ckpt_path):
72
+ self.load_state_dict(torch.load(ckpt_path, map_location=lambda storage, loc: storage))
73
+ self.eval()
74
+ return self
core/models/modules/convnextv2.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ This moudle is adapted to the ConvNeXtV2 version for the extraction of implicit keypoints, poses, and expression deformation.
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ # from timm.models.layers import trunc_normal_, DropPath
10
+ from .util import LayerNorm, DropPath, trunc_normal_, GRN
11
+
12
+ __all__ = ['convnextv2_tiny']
13
+
14
+
15
+ class Block(nn.Module):
16
+ """ ConvNeXtV2 Block.
17
+
18
+ Args:
19
+ dim (int): Number of input channels.
20
+ drop_path (float): Stochastic depth rate. Default: 0.0
21
+ """
22
+
23
+ def __init__(self, dim, drop_path=0.):
24
+ super().__init__()
25
+ self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
26
+ self.norm = LayerNorm(dim, eps=1e-6)
27
+ self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
28
+ self.act = nn.GELU()
29
+ self.grn = GRN(4 * dim)
30
+ self.pwconv2 = nn.Linear(4 * dim, dim)
31
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
32
+
33
+ def forward(self, x):
34
+ input = x
35
+ x = self.dwconv(x)
36
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
37
+ x = self.norm(x)
38
+ x = self.pwconv1(x)
39
+ x = self.act(x)
40
+ x = self.grn(x)
41
+ x = self.pwconv2(x)
42
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
43
+
44
+ x = input + self.drop_path(x)
45
+ return x
46
+
47
+
48
+ class ConvNeXtV2(nn.Module):
49
+ """ ConvNeXt V2
50
+
51
+ Args:
52
+ in_chans (int): Number of input image channels. Default: 3
53
+ num_classes (int): Number of classes for classification head. Default: 1000
54
+ depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
55
+ dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
56
+ drop_path_rate (float): Stochastic depth rate. Default: 0.
57
+ head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ in_chans=3,
63
+ depths=[3, 3, 9, 3],
64
+ dims=[96, 192, 384, 768],
65
+ drop_path_rate=0.,
66
+ **kwargs
67
+ ):
68
+ super().__init__()
69
+ self.depths = depths
70
+ self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
71
+ stem = nn.Sequential(
72
+ nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
73
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
74
+ )
75
+ self.downsample_layers.append(stem)
76
+ for i in range(3):
77
+ downsample_layer = nn.Sequential(
78
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
79
+ nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
80
+ )
81
+ self.downsample_layers.append(downsample_layer)
82
+
83
+ self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
84
+ dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
85
+ cur = 0
86
+ for i in range(4):
87
+ stage = nn.Sequential(
88
+ *[Block(dim=dims[i], drop_path=dp_rates[cur + j]) for j in range(depths[i])]
89
+ )
90
+ self.stages.append(stage)
91
+ cur += depths[i]
92
+
93
+ self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
94
+
95
+ # NOTE: the output semantic items
96
+ num_bins = kwargs.get('num_bins', 66)
97
+ num_kp = kwargs.get('num_kp', 24) # the number of implicit keypoints
98
+ self.fc_kp = nn.Linear(dims[-1], 3 * num_kp) # implicit keypoints
99
+
100
+ # print('dims[-1]: ', dims[-1])
101
+ self.fc_scale = nn.Linear(dims[-1], 1) # scale
102
+ self.fc_pitch = nn.Linear(dims[-1], num_bins) # pitch bins
103
+ self.fc_yaw = nn.Linear(dims[-1], num_bins) # yaw bins
104
+ self.fc_roll = nn.Linear(dims[-1], num_bins) # roll bins
105
+ self.fc_t = nn.Linear(dims[-1], 3) # translation
106
+ self.fc_exp = nn.Linear(dims[-1], 3 * num_kp) # expression / delta
107
+
108
+ def _init_weights(self, m):
109
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
110
+ trunc_normal_(m.weight, std=.02)
111
+ nn.init.constant_(m.bias, 0)
112
+
113
+ def forward_features(self, x):
114
+ for i in range(4):
115
+ x = self.downsample_layers[i](x)
116
+ x = self.stages[i](x)
117
+ return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)
118
+
119
+ def forward(self, x):
120
+ x = self.forward_features(x)
121
+
122
+ # implicit keypoints
123
+ kp = self.fc_kp(x)
124
+
125
+ # pose and expression deformation
126
+ pitch = self.fc_pitch(x)
127
+ yaw = self.fc_yaw(x)
128
+ roll = self.fc_roll(x)
129
+ t = self.fc_t(x)
130
+ exp = self.fc_exp(x)
131
+ scale = self.fc_scale(x)
132
+
133
+ # ret_dct = {
134
+ # 'pitch': pitch,
135
+ # 'yaw': yaw,
136
+ # 'roll': roll,
137
+ # 't': t,
138
+ # 'exp': exp,
139
+ # 'scale': scale,
140
+
141
+ # 'kp': kp, # canonical keypoint
142
+ # }
143
+
144
+ # return ret_dct
145
+ return pitch, yaw, roll, t, exp, scale, kp
146
+
147
+
148
+ def convnextv2_tiny(**kwargs):
149
+ model = ConvNeXtV2(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
150
+ return model
core/models/modules/dense_motion.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ The module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving
5
+ """
6
+
7
+ from torch import nn
8
+ import torch.nn.functional as F
9
+ import torch
10
+ from .util import Hourglass, make_coordinate_grid, kp2gaussian
11
+
12
+
13
+ class DenseMotionNetwork(nn.Module):
14
+ def __init__(self, block_expansion, num_blocks, max_features, num_kp, feature_channel, reshape_depth, compress, estimate_occlusion_map=True):
15
+ super(DenseMotionNetwork, self).__init__()
16
+ self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(compress+1), max_features=max_features, num_blocks=num_blocks) # ~60+G
17
+
18
+ self.mask = nn.Conv3d(self.hourglass.out_filters, num_kp + 1, kernel_size=7, padding=3) # 65G! NOTE: computation cost is large
19
+ self.compress = nn.Conv3d(feature_channel, compress, kernel_size=1) # 0.8G
20
+ self.norm = nn.BatchNorm3d(compress, affine=True)
21
+ self.num_kp = num_kp
22
+ self.flag_estimate_occlusion_map = estimate_occlusion_map
23
+
24
+ if self.flag_estimate_occlusion_map:
25
+ self.occlusion = nn.Conv2d(self.hourglass.out_filters*reshape_depth, 1, kernel_size=7, padding=3)
26
+ else:
27
+ self.occlusion = None
28
+
29
+ def create_sparse_motions(self, feature, kp_driving, kp_source):
30
+ bs, _, d, h, w = feature.shape # (bs, 4, 16, 64, 64)
31
+ identity_grid = make_coordinate_grid((d, h, w), ref=kp_source) # (16, 64, 64, 3)
32
+ identity_grid = identity_grid.view(1, 1, d, h, w, 3) # (1, 1, d=16, h=64, w=64, 3)
33
+ coordinate_grid = identity_grid - kp_driving.view(bs, self.num_kp, 1, 1, 1, 3)
34
+
35
+ k = coordinate_grid.shape[1]
36
+
37
+ # NOTE: there lacks an one-order flow
38
+ driving_to_source = coordinate_grid + kp_source.view(bs, self.num_kp, 1, 1, 1, 3) # (bs, num_kp, d, h, w, 3)
39
+
40
+ # adding background feature
41
+ identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1, 1)
42
+ sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1) # (bs, 1+num_kp, d, h, w, 3)
43
+ return sparse_motions
44
+
45
+ def create_deformed_feature(self, feature, sparse_motions):
46
+ bs, _, d, h, w = feature.shape
47
+ feature_repeat = feature.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp+1, 1, 1, 1, 1, 1) # (bs, num_kp+1, 1, c, d, h, w)
48
+ feature_repeat = feature_repeat.view(bs * (self.num_kp+1), -1, d, h, w) # (bs*(num_kp+1), c, d, h, w)
49
+ sparse_motions = sparse_motions.view((bs * (self.num_kp+1), d, h, w, -1)) # (bs*(num_kp+1), d, h, w, 3)
50
+ sparse_deformed = F.grid_sample(feature_repeat, sparse_motions, align_corners=False)
51
+ sparse_deformed = sparse_deformed.view((bs, self.num_kp+1, -1, d, h, w)) # (bs, num_kp+1, c, d, h, w)
52
+
53
+ return sparse_deformed
54
+
55
+ def create_heatmap_representations(self, feature, kp_driving, kp_source):
56
+ spatial_size = feature.shape[3:] # (d=16, h=64, w=64)
57
+ gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=0.01) # (bs, num_kp, d, h, w)
58
+ gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=0.01) # (bs, num_kp, d, h, w)
59
+ heatmap = gaussian_driving - gaussian_source # (bs, num_kp, d, h, w)
60
+
61
+ # adding background feature
62
+ zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1], spatial_size[2]).type(heatmap.dtype).to(heatmap.device)
63
+ heatmap = torch.cat([zeros, heatmap], dim=1)
64
+ heatmap = heatmap.unsqueeze(2) # (bs, 1+num_kp, 1, d, h, w)
65
+ return heatmap
66
+
67
+ def forward(self, feature, kp_driving, kp_source):
68
+ bs, _, d, h, w = feature.shape # (bs, 32, 16, 64, 64)
69
+
70
+ feature = self.compress(feature) # (bs, 4, 16, 64, 64)
71
+ feature = self.norm(feature) # (bs, 4, 16, 64, 64)
72
+ feature = F.relu(feature) # (bs, 4, 16, 64, 64)
73
+
74
+ out_dict = dict()
75
+
76
+ # 1. deform 3d feature
77
+ sparse_motion = self.create_sparse_motions(feature, kp_driving, kp_source) # (bs, 1+num_kp, d, h, w, 3)
78
+ deformed_feature = self.create_deformed_feature(feature, sparse_motion) # (bs, 1+num_kp, c=4, d=16, h=64, w=64)
79
+
80
+ # 2. (bs, 1+num_kp, d, h, w)
81
+ heatmap = self.create_heatmap_representations(deformed_feature, kp_driving, kp_source) # (bs, 1+num_kp, 1, d, h, w)
82
+
83
+ input = torch.cat([heatmap, deformed_feature], dim=2) # (bs, 1+num_kp, c=5, d=16, h=64, w=64)
84
+ input = input.view(bs, -1, d, h, w) # (bs, (1+num_kp)*c=105, d=16, h=64, w=64)
85
+
86
+ prediction = self.hourglass(input)
87
+
88
+ mask = self.mask(prediction)
89
+ mask = F.softmax(mask, dim=1) # (bs, 1+num_kp, d=16, h=64, w=64)
90
+ out_dict['mask'] = mask
91
+ mask = mask.unsqueeze(2) # (bs, num_kp+1, 1, d, h, w)
92
+ sparse_motion = sparse_motion.permute(0, 1, 5, 2, 3, 4) # (bs, num_kp+1, 3, d, h, w)
93
+ deformation = (sparse_motion * mask).sum(dim=1) # (bs, 3, d, h, w) mask take effect in this place
94
+ deformation = deformation.permute(0, 2, 3, 4, 1) # (bs, d, h, w, 3)
95
+
96
+ out_dict['deformation'] = deformation
97
+
98
+ if self.flag_estimate_occlusion_map:
99
+ bs, _, d, h, w = prediction.shape
100
+ prediction_reshape = prediction.view(bs, -1, h, w)
101
+ occlusion_map = torch.sigmoid(self.occlusion(prediction_reshape)) # Bx1x64x64
102
+ out_dict['occlusion_map'] = occlusion_map
103
+
104
+ return out_dict
core/models/modules/lmdm_modules/model.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional, Union
2
+ import torch
3
+ import torch.nn as nn
4
+ from einops import rearrange
5
+ from einops.layers.torch import Rearrange
6
+ from torch import Tensor
7
+ from torch.nn import functional as F
8
+
9
+ from .rotary_embedding_torch import RotaryEmbedding
10
+ from .utils import PositionalEncoding, SinusoidalPosEmb, prob_mask_like
11
+
12
+
13
+ class DenseFiLM(nn.Module):
14
+ """Feature-wise linear modulation (FiLM) generator."""
15
+
16
+ def __init__(self, embed_channels):
17
+ super().__init__()
18
+ self.embed_channels = embed_channels
19
+ self.block = nn.Sequential(
20
+ nn.Mish(), nn.Linear(embed_channels, embed_channels * 2)
21
+ )
22
+
23
+ def forward(self, position):
24
+ pos_encoding = self.block(position)
25
+ pos_encoding = rearrange(pos_encoding, "b c -> b 1 c")
26
+ scale_shift = pos_encoding.chunk(2, dim=-1)
27
+ return scale_shift
28
+
29
+
30
+ def featurewise_affine(x, scale_shift):
31
+ scale, shift = scale_shift
32
+ return (scale + 1) * x + shift
33
+
34
+
35
+ class TransformerEncoderLayer(nn.Module):
36
+ def __init__(
37
+ self,
38
+ d_model: int,
39
+ nhead: int,
40
+ dim_feedforward: int = 2048,
41
+ dropout: float = 0.1,
42
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
43
+ layer_norm_eps: float = 1e-5,
44
+ batch_first: bool = False,
45
+ norm_first: bool = True,
46
+ device=None,
47
+ dtype=None,
48
+ rotary=None,
49
+ ) -> None:
50
+ super().__init__()
51
+ self.self_attn = nn.MultiheadAttention(
52
+ d_model, nhead, dropout=dropout, batch_first=batch_first
53
+ )
54
+ # Implementation of Feedforward model
55
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
56
+ self.dropout = nn.Dropout(dropout)
57
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
58
+
59
+ self.norm_first = norm_first
60
+ self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
61
+ self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
62
+ self.dropout1 = nn.Dropout(dropout)
63
+ self.dropout2 = nn.Dropout(dropout)
64
+ self.activation = activation
65
+
66
+ self.rotary = rotary
67
+ self.use_rotary = rotary is not None
68
+
69
+ def forward(
70
+ self,
71
+ src: Tensor,
72
+ src_mask: Optional[Tensor] = None,
73
+ src_key_padding_mask: Optional[Tensor] = None,
74
+ ) -> Tensor:
75
+ x = src
76
+ if self.norm_first:
77
+ x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
78
+ x = x + self._ff_block(self.norm2(x))
79
+ else:
80
+ x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
81
+ x = self.norm2(x + self._ff_block(x))
82
+
83
+ return x
84
+
85
+ # self-attention block
86
+ def _sa_block(
87
+ self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]
88
+ ) -> Tensor:
89
+ qk = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x
90
+ x = self.self_attn(
91
+ qk,
92
+ qk,
93
+ x,
94
+ attn_mask=attn_mask,
95
+ key_padding_mask=key_padding_mask,
96
+ need_weights=False,
97
+ )[0]
98
+ return self.dropout1(x)
99
+
100
+ # feed forward block
101
+ def _ff_block(self, x: Tensor) -> Tensor:
102
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
103
+ return self.dropout2(x)
104
+
105
+
106
+ class FiLMTransformerDecoderLayer(nn.Module):
107
+ def __init__(
108
+ self,
109
+ d_model: int,
110
+ nhead: int,
111
+ dim_feedforward=2048,
112
+ dropout=0.1,
113
+ activation=F.relu,
114
+ layer_norm_eps=1e-5,
115
+ batch_first=False,
116
+ norm_first=True,
117
+ device=None,
118
+ dtype=None,
119
+ rotary=None,
120
+ ):
121
+ super().__init__()
122
+ self.self_attn = nn.MultiheadAttention(
123
+ d_model, nhead, dropout=dropout, batch_first=batch_first
124
+ )
125
+ self.multihead_attn = nn.MultiheadAttention(
126
+ d_model, nhead, dropout=dropout, batch_first=batch_first
127
+ )
128
+ # Feedforward
129
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
130
+ self.dropout = nn.Dropout(dropout)
131
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
132
+
133
+ self.norm_first = norm_first
134
+ self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
135
+ self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
136
+ self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps)
137
+ self.dropout1 = nn.Dropout(dropout)
138
+ self.dropout2 = nn.Dropout(dropout)
139
+ self.dropout3 = nn.Dropout(dropout)
140
+ self.activation = activation
141
+
142
+ self.film1 = DenseFiLM(d_model)
143
+ self.film2 = DenseFiLM(d_model)
144
+ self.film3 = DenseFiLM(d_model)
145
+
146
+ self.rotary = rotary
147
+ self.use_rotary = rotary is not None
148
+
149
+ # x, cond, t
150
+ def forward(
151
+ self,
152
+ tgt,
153
+ memory,
154
+ t,
155
+ tgt_mask=None,
156
+ memory_mask=None,
157
+ tgt_key_padding_mask=None,
158
+ memory_key_padding_mask=None,
159
+ ):
160
+ x = tgt
161
+ if self.norm_first:
162
+ # self-attention -> film -> residual
163
+ x_1 = self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask)
164
+ x = x + featurewise_affine(x_1, self.film1(t))
165
+ # cross-attention -> film -> residual
166
+ x_2 = self._mha_block(
167
+ self.norm2(x), memory, memory_mask, memory_key_padding_mask
168
+ )
169
+ x = x + featurewise_affine(x_2, self.film2(t))
170
+ # feedforward -> film -> residual
171
+ x_3 = self._ff_block(self.norm3(x))
172
+ x = x + featurewise_affine(x_3, self.film3(t))
173
+ else:
174
+ x = self.norm1(
175
+ x
176
+ + featurewise_affine(
177
+ self._sa_block(x, tgt_mask, tgt_key_padding_mask), self.film1(t)
178
+ )
179
+ )
180
+ x = self.norm2(
181
+ x
182
+ + featurewise_affine(
183
+ self._mha_block(x, memory, memory_mask, memory_key_padding_mask),
184
+ self.film2(t),
185
+ )
186
+ )
187
+ x = self.norm3(x + featurewise_affine(self._ff_block(x), self.film3(t)))
188
+ return x
189
+
190
+ # self-attention block
191
+ # qkv
192
+ def _sa_block(self, x, attn_mask, key_padding_mask):
193
+ qk = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x
194
+ x = self.self_attn(
195
+ qk,
196
+ qk,
197
+ x,
198
+ attn_mask=attn_mask,
199
+ key_padding_mask=key_padding_mask,
200
+ need_weights=False,
201
+ )[0]
202
+ return self.dropout1(x)
203
+
204
+ # multihead attention block
205
+ # qkv
206
+ def _mha_block(self, x, mem, attn_mask, key_padding_mask):
207
+ q = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x
208
+ k = self.rotary.rotate_queries_or_keys(mem) if self.use_rotary else mem
209
+ x = self.multihead_attn(
210
+ q,
211
+ k,
212
+ mem,
213
+ attn_mask=attn_mask,
214
+ key_padding_mask=key_padding_mask,
215
+ need_weights=False,
216
+ )[0]
217
+ return self.dropout2(x)
218
+
219
+ # feed forward block
220
+ def _ff_block(self, x):
221
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
222
+ return self.dropout3(x)
223
+
224
+
225
+ class DecoderLayerStack(nn.Module):
226
+ def __init__(self, stack):
227
+ super().__init__()
228
+ self.stack = stack
229
+
230
+ def forward(self, x, cond, t):
231
+ for layer in self.stack:
232
+ x = layer(x, cond, t)
233
+ return x
234
+
235
+
236
+ class MotionDecoder(nn.Module):
237
+ def __init__(
238
+ self,
239
+ nfeats: int,
240
+ seq_len: int = 100, # 4 seconds, 25 fps
241
+ latent_dim: int = 256,
242
+ ff_size: int = 1024,
243
+ num_layers: int = 4,
244
+ num_heads: int = 4,
245
+ dropout: float = 0.1,
246
+ cond_feature_dim: int = 4800,
247
+ activation: Callable[[Tensor], Tensor] = F.gelu,
248
+ use_rotary=True,
249
+ multi_cond_frame=False,
250
+ **kwargs
251
+ ) -> None:
252
+
253
+ super().__init__()
254
+
255
+ self.multi_cond_frame = multi_cond_frame
256
+
257
+ output_feats = nfeats
258
+
259
+ # positional embeddings
260
+ self.rotary = None
261
+ self.abs_pos_encoding = nn.Identity()
262
+ # if rotary, replace absolute embedding with a rotary embedding instance (absolute becomes an identity)
263
+ if use_rotary:
264
+ self.rotary = RotaryEmbedding(dim=latent_dim)
265
+ else:
266
+ self.abs_pos_encoding = PositionalEncoding(
267
+ latent_dim, dropout, batch_first=True
268
+ )
269
+
270
+ # time embedding processing
271
+ self.time_mlp = nn.Sequential(
272
+ SinusoidalPosEmb(latent_dim), # learned?
273
+ nn.Linear(latent_dim, latent_dim * 4),
274
+ nn.Mish(),
275
+ )
276
+
277
+ self.to_time_cond = nn.Sequential(nn.Linear(latent_dim * 4, latent_dim),)
278
+
279
+ self.to_time_tokens = nn.Sequential(
280
+ nn.Linear(latent_dim * 4, latent_dim * 2), # 2 time tokens
281
+ Rearrange("b (r d) -> b r d", r=2),
282
+ )
283
+
284
+ # null embeddings for guidance dropout
285
+ self.null_cond_embed = nn.Parameter(torch.randn(1, seq_len, latent_dim))
286
+ self.null_cond_hidden = nn.Parameter(torch.randn(1, latent_dim))
287
+
288
+ self.norm_cond = nn.LayerNorm(latent_dim)
289
+
290
+ # input projection
291
+ if self.multi_cond_frame:
292
+ self.input_projection = nn.Linear(nfeats * 2 + 1, latent_dim)
293
+ else:
294
+ self.input_projection = nn.Linear(nfeats * 2, latent_dim)
295
+ self.cond_encoder = nn.Sequential()
296
+ for _ in range(2):
297
+ self.cond_encoder.append(
298
+ TransformerEncoderLayer(
299
+ d_model=latent_dim,
300
+ nhead=num_heads,
301
+ dim_feedforward=ff_size,
302
+ dropout=dropout,
303
+ activation=activation,
304
+ batch_first=True,
305
+ rotary=self.rotary,
306
+ )
307
+ )
308
+ # conditional projection
309
+ self.cond_projection = nn.Linear(cond_feature_dim, latent_dim)
310
+ self.non_attn_cond_projection = nn.Sequential(
311
+ nn.LayerNorm(latent_dim),
312
+ nn.Linear(latent_dim, latent_dim),
313
+ nn.SiLU(),
314
+ nn.Linear(latent_dim, latent_dim),
315
+ )
316
+ # decoder
317
+ decoderstack = nn.ModuleList([])
318
+ for _ in range(num_layers):
319
+ decoderstack.append(
320
+ FiLMTransformerDecoderLayer(
321
+ latent_dim,
322
+ num_heads,
323
+ dim_feedforward=ff_size,
324
+ dropout=dropout,
325
+ activation=activation,
326
+ batch_first=True,
327
+ rotary=self.rotary,
328
+ )
329
+ )
330
+
331
+ self.seqTransDecoder = DecoderLayerStack(decoderstack)
332
+
333
+ self.final_layer = nn.Linear(latent_dim, output_feats)
334
+
335
+ self.epsilon = 0.00001
336
+
337
+ def guided_forward(self, x, cond_frame, cond_embed, times, guidance_weight):
338
+ unc = self.forward(x, cond_frame, cond_embed, times, cond_drop_prob=1)
339
+ conditioned = self.forward(x, cond_frame, cond_embed, times, cond_drop_prob=0)
340
+
341
+ return unc + (conditioned - unc) * guidance_weight
342
+
343
+ def forward(
344
+ self, x: Tensor, cond_frame: Tensor, cond_embed: Tensor, times: Tensor, cond_drop_prob: float = 0.0
345
+ ):
346
+ batch_size, device = x.shape[0], x.device
347
+
348
+ # concat last frame, project to latent space
349
+ # cond_frame: [b, dim] | [b, n, dim+1]
350
+ if self.multi_cond_frame:
351
+ # [b, n, dim+1] (+1 mask)
352
+ x = torch.cat([x, cond_frame], dim=-1)
353
+ else:
354
+ # [b, dim]
355
+ x = torch.cat([x, cond_frame.unsqueeze(1).repeat(1, x.shape[1], 1)], dim=-1)
356
+ x = self.input_projection(x)
357
+ # add the positional embeddings of the input sequence to provide temporal information
358
+ x = self.abs_pos_encoding(x)
359
+
360
+ # create audio conditional embedding with conditional dropout
361
+ keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device=device)
362
+ keep_mask_embed = rearrange(keep_mask, "b -> b 1 1")
363
+ keep_mask_hidden = rearrange(keep_mask, "b -> b 1")
364
+
365
+ cond_tokens = self.cond_projection(cond_embed)
366
+ # encode tokens
367
+ cond_tokens = self.abs_pos_encoding(cond_tokens)
368
+ cond_tokens = self.cond_encoder(cond_tokens)
369
+
370
+ null_cond_embed = self.null_cond_embed.to(cond_tokens.dtype)
371
+ cond_tokens = torch.where(keep_mask_embed, cond_tokens, null_cond_embed)
372
+
373
+ mean_pooled_cond_tokens = cond_tokens.mean(dim=-2)
374
+ cond_hidden = self.non_attn_cond_projection(mean_pooled_cond_tokens)
375
+
376
+ # create the diffusion timestep embedding, add the extra audio projection
377
+ t_hidden = self.time_mlp(times)
378
+
379
+ # project to attention and FiLM conditioning
380
+ t = self.to_time_cond(t_hidden)
381
+ t_tokens = self.to_time_tokens(t_hidden)
382
+
383
+ # FiLM conditioning
384
+ null_cond_hidden = self.null_cond_hidden.to(t.dtype)
385
+ cond_hidden = torch.where(keep_mask_hidden, cond_hidden, null_cond_hidden)
386
+ t += cond_hidden
387
+
388
+ # cross-attention conditioning
389
+ c = torch.cat((cond_tokens, t_tokens), dim=-2)
390
+ cond_tokens = self.norm_cond(c)
391
+
392
+ # Pass through the transformer decoder
393
+ # attending to the conditional embedding
394
+ output = self.seqTransDecoder(x, cond_tokens, t)
395
+
396
+ output = self.final_layer(output)
397
+
398
+ return output
core/models/modules/lmdm_modules/rotary_embedding_torch.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inspect import isfunction
2
+ from math import log, pi
3
+
4
+ import torch
5
+ from einops import rearrange, repeat
6
+ from torch import einsum, nn
7
+
8
+ # helper functions
9
+
10
+
11
+ def exists(val):
12
+ return val is not None
13
+
14
+
15
+ def broadcat(tensors, dim=-1):
16
+ num_tensors = len(tensors)
17
+ shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
18
+ assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
19
+ shape_len = list(shape_lens)[0]
20
+
21
+ dim = (dim + shape_len) if dim < 0 else dim
22
+ dims = list(zip(*map(lambda t: list(t.shape), tensors)))
23
+
24
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
25
+ assert all(
26
+ [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
27
+ ), "invalid dimensions for broadcastable concatentation"
28
+ max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
29
+ expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
30
+ expanded_dims.insert(dim, (dim, dims[dim]))
31
+ expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
32
+ tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
33
+ return torch.cat(tensors, dim=dim)
34
+
35
+
36
+ # rotary embedding helper functions
37
+
38
+
39
+ def rotate_half(x):
40
+ x = rearrange(x, "... (d r) -> ... d r", r=2)
41
+ x1, x2 = x.unbind(dim=-1)
42
+ x = torch.stack((-x2, x1), dim=-1)
43
+ return rearrange(x, "... d r -> ... (d r)")
44
+
45
+
46
+ def apply_rotary_emb(freqs, t, start_index=0):
47
+ freqs = freqs.to(t)
48
+ rot_dim = freqs.shape[-1]
49
+ end_index = start_index + rot_dim
50
+ assert (
51
+ rot_dim <= t.shape[-1]
52
+ ), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
53
+ t_left, t, t_right = (
54
+ t[..., :start_index],
55
+ t[..., start_index:end_index],
56
+ t[..., end_index:],
57
+ )
58
+ t = (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
59
+ return torch.cat((t_left, t, t_right), dim=-1)
60
+
61
+
62
+ # learned rotation helpers
63
+
64
+
65
+ def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None):
66
+ if exists(freq_ranges):
67
+ rotations = einsum("..., f -> ... f", rotations, freq_ranges)
68
+ rotations = rearrange(rotations, "... r f -> ... (r f)")
69
+
70
+ rotations = repeat(rotations, "... n -> ... (n r)", r=2)
71
+ return apply_rotary_emb(rotations, t, start_index=start_index)
72
+
73
+
74
+ # classes
75
+
76
+
77
+ class RotaryEmbedding(nn.Module):
78
+ def __init__(
79
+ self,
80
+ dim,
81
+ custom_freqs=None,
82
+ freqs_for="lang",
83
+ theta=10000,
84
+ max_freq=10,
85
+ num_freqs=1,
86
+ learned_freq=False,
87
+ ):
88
+ super().__init__()
89
+ if exists(custom_freqs):
90
+ freqs = custom_freqs
91
+ elif freqs_for == "lang":
92
+ freqs = 1.0 / (
93
+ theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
94
+ )
95
+ elif freqs_for == "pixel":
96
+ freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
97
+ elif freqs_for == "constant":
98
+ freqs = torch.ones(num_freqs).float()
99
+ else:
100
+ raise ValueError(f"unknown modality {freqs_for}")
101
+
102
+ self.cache = dict()
103
+
104
+ if learned_freq:
105
+ self.freqs = nn.Parameter(freqs)
106
+ else:
107
+ self.register_buffer("freqs", freqs)
108
+
109
+ def rotate_queries_or_keys(self, t, seq_dim=-2):
110
+ device = t.device
111
+ seq_len = t.shape[seq_dim]
112
+ freqs = self.forward(
113
+ lambda: torch.arange(seq_len, device=device), cache_key=seq_len
114
+ )
115
+ return apply_rotary_emb(freqs, t)
116
+
117
+ def forward(self, t, cache_key=None):
118
+ if exists(cache_key) and cache_key in self.cache:
119
+ return self.cache[cache_key]
120
+
121
+ if isfunction(t):
122
+ t = t()
123
+
124
+ freqs = self.freqs
125
+
126
+ freqs = torch.einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
127
+ freqs = repeat(freqs, "... n -> ... (n r)", r=2)
128
+
129
+ if exists(cache_key):
130
+ self.cache[cache_key] = freqs
131
+
132
+ return freqs
core/models/modules/lmdm_modules/utils.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+
6
+
7
+ # absolute positional embedding used for vanilla transformer sequential data
8
+ class PositionalEncoding(nn.Module):
9
+ def __init__(self, d_model, dropout=0.1, max_len=500, batch_first=False):
10
+ super().__init__()
11
+ self.batch_first = batch_first
12
+
13
+ self.dropout = nn.Dropout(p=dropout)
14
+
15
+ pe = torch.zeros(max_len, d_model)
16
+ position = torch.arange(0, max_len).unsqueeze(1)
17
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
18
+ pe[:, 0::2] = torch.sin(position * div_term)
19
+ pe[:, 1::2] = torch.cos(position * div_term)
20
+ pe = pe.unsqueeze(0).transpose(0, 1)
21
+
22
+ self.register_buffer("pe", pe)
23
+
24
+ def forward(self, x):
25
+ if self.batch_first:
26
+ x = x + self.pe.permute(1, 0, 2)[:, : x.shape[1], :]
27
+ else:
28
+ x = x + self.pe[: x.shape[0], :]
29
+ return self.dropout(x)
30
+
31
+
32
+ # very similar positional embedding used for diffusion timesteps
33
+ class SinusoidalPosEmb(nn.Module):
34
+ def __init__(self, dim):
35
+ super().__init__()
36
+ self.dim = dim
37
+
38
+ def forward(self, x):
39
+ device = x.device
40
+ half_dim = self.dim // 2
41
+ emb = math.log(10000) / (half_dim - 1)
42
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
43
+ emb = x[:, None] * emb[None, :]
44
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
45
+ return emb
46
+
47
+
48
+ # dropout mask
49
+ def prob_mask_like(shape, prob, device):
50
+ if prob == 1:
51
+ return torch.ones(shape, device=device, dtype=torch.bool)
52
+ elif prob == 0:
53
+ return torch.zeros(shape, device=device, dtype=torch.bool)
54
+ else:
55
+ return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob
56
+
57
+
58
+ def extract(a, t, x_shape):
59
+ b, *_ = t.shape
60
+ out = a.gather(-1, t)
61
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
62
+
63
+
64
+ def make_beta_schedule(
65
+ schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
66
+ ):
67
+ if schedule == "linear":
68
+ betas = (
69
+ torch.linspace(
70
+ linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64
71
+ )
72
+ ** 2
73
+ )
74
+
75
+ elif schedule == "cosine":
76
+ timesteps = (
77
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
78
+ )
79
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
80
+ alphas = torch.cos(alphas).pow(2)
81
+ alphas = alphas / alphas[0]
82
+ betas = 1 - alphas[1:] / alphas[:-1]
83
+ betas = np.clip(betas, a_min=0, a_max=0.999)
84
+
85
+ elif schedule == "sqrt_linear":
86
+ betas = torch.linspace(
87
+ linear_start, linear_end, n_timestep, dtype=torch.float64
88
+ )
89
+ elif schedule == "sqrt":
90
+ betas = (
91
+ torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
92
+ ** 0.5
93
+ )
94
+ else:
95
+ raise ValueError(f"schedule '{schedule}' unknown.")
96
+ return betas.numpy()
core/models/modules/motion_extractor.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ Motion extractor(M), which directly predicts the canonical keypoints, head pose and expression deformation of the input image
5
+ """
6
+
7
+ from torch import nn
8
+ import torch
9
+
10
+ from .convnextv2 import convnextv2_tiny
11
+
12
+
13
+ class MotionExtractor(nn.Module):
14
+ def __init__(self, num_kp=21, backbone="convnextv2_tiny"):
15
+ super(MotionExtractor, self).__init__()
16
+ self.detector = convnextv2_tiny(num_kp=num_kp, backbone=backbone)
17
+
18
+ def forward(self, x):
19
+ out = self.detector(x)
20
+ return out # pitch, yaw, roll, t, exp, scale, kp
21
+
22
+ def load_model(self, ckpt_path):
23
+ self.load_state_dict(torch.load(ckpt_path, map_location=lambda storage, loc: storage))
24
+ self.eval()
25
+ return self
core/models/modules/spade_generator.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ Spade decoder(G) defined in the paper, which input the warped feature to generate the animated image.
5
+ """
6
+
7
+ import torch
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+ from .util import SPADEResnetBlock
11
+
12
+
13
+ class SPADEDecoder(nn.Module):
14
+ def __init__(
15
+ self,
16
+ upscale=2,
17
+ max_features=512,
18
+ block_expansion=64,
19
+ out_channels=64,
20
+ num_down_blocks=2,
21
+ ):
22
+ for i in range(num_down_blocks):
23
+ input_channels = min(max_features, block_expansion * (2 ** (i + 1)))
24
+ self.upscale = upscale
25
+ super().__init__()
26
+ norm_G = "spadespectralinstance"
27
+ label_num_channels = input_channels # 256
28
+
29
+ self.fc = nn.Conv2d(input_channels, 2 * input_channels, 3, padding=1)
30
+ self.G_middle_0 = SPADEResnetBlock(
31
+ 2 * input_channels, 2 * input_channels, norm_G, label_num_channels
32
+ )
33
+ self.G_middle_1 = SPADEResnetBlock(
34
+ 2 * input_channels, 2 * input_channels, norm_G, label_num_channels
35
+ )
36
+ self.G_middle_2 = SPADEResnetBlock(
37
+ 2 * input_channels, 2 * input_channels, norm_G, label_num_channels
38
+ )
39
+ self.G_middle_3 = SPADEResnetBlock(
40
+ 2 * input_channels, 2 * input_channels, norm_G, label_num_channels
41
+ )
42
+ self.G_middle_4 = SPADEResnetBlock(
43
+ 2 * input_channels, 2 * input_channels, norm_G, label_num_channels
44
+ )
45
+ self.G_middle_5 = SPADEResnetBlock(
46
+ 2 * input_channels, 2 * input_channels, norm_G, label_num_channels
47
+ )
48
+ self.up_0 = SPADEResnetBlock(
49
+ 2 * input_channels, input_channels, norm_G, label_num_channels
50
+ )
51
+ self.up_1 = SPADEResnetBlock(
52
+ input_channels, out_channels, norm_G, label_num_channels
53
+ )
54
+ self.up = nn.Upsample(scale_factor=2)
55
+
56
+ if self.upscale is None or self.upscale <= 1:
57
+ self.conv_img = nn.Conv2d(out_channels, 3, 3, padding=1)
58
+ else:
59
+ self.conv_img = nn.Sequential(
60
+ nn.Conv2d(out_channels, 3 * (2 * 2), kernel_size=3, padding=1),
61
+ nn.PixelShuffle(upscale_factor=2),
62
+ )
63
+
64
+ def forward(self, feature):
65
+ seg = feature # Bx256x64x64
66
+ x = self.fc(feature) # Bx512x64x64
67
+ x = self.G_middle_0(x, seg)
68
+ x = self.G_middle_1(x, seg)
69
+ x = self.G_middle_2(x, seg)
70
+ x = self.G_middle_3(x, seg)
71
+ x = self.G_middle_4(x, seg)
72
+ x = self.G_middle_5(x, seg)
73
+
74
+ x = self.up(x) # Bx512x64x64 -> Bx512x128x128
75
+ x = self.up_0(x, seg) # Bx512x128x128 -> Bx256x128x128
76
+ x = self.up(x) # Bx256x128x128 -> Bx256x256x256
77
+ x = self.up_1(x, seg) # Bx256x256x256 -> Bx64x256x256
78
+
79
+ x = self.conv_img(F.leaky_relu(x, 2e-1)) # Bx64x256x256 -> Bx3xHxW
80
+ x = torch.sigmoid(x) # Bx3xHxW
81
+
82
+ return x
83
+
84
+ def load_model(self, ckpt_path):
85
+ self.load_state_dict(torch.load(ckpt_path, map_location=lambda storage, loc: storage))
86
+ self.eval()
87
+ return self
core/models/modules/stitching_network.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ Stitching module(S) and two retargeting modules(R) defined in the paper.
5
+
6
+ - The stitching module pastes the animated portrait back into the original image space without pixel misalignment, such as in
7
+ the stitching region.
8
+
9
+ - The eyes retargeting module is designed to address the issue of incomplete eye closure during cross-id reenactment, especially
10
+ when a person with small eyes drives a person with larger eyes.
11
+
12
+ - The lip retargeting module is designed similarly to the eye retargeting module, and can also normalize the input by ensuring that
13
+ the lips are in a closed state, which facilitates better animation driving.
14
+ """
15
+ import torch
16
+ from torch import nn
17
+
18
+
19
+ def remove_ddp_dumplicate_key(state_dict):
20
+ from collections import OrderedDict
21
+ state_dict_new = OrderedDict()
22
+ for key in state_dict.keys():
23
+ state_dict_new[key.replace('module.', '')] = state_dict[key]
24
+ return state_dict_new
25
+
26
+
27
+ class StitchingNetwork(nn.Module):
28
+ def __init__(self, input_size=126, hidden_sizes=[128, 128, 64], output_size=65):
29
+ super(StitchingNetwork, self).__init__()
30
+ layers = []
31
+ for i in range(len(hidden_sizes)):
32
+ if i == 0:
33
+ layers.append(nn.Linear(input_size, hidden_sizes[i]))
34
+ else:
35
+ layers.append(nn.Linear(hidden_sizes[i - 1], hidden_sizes[i]))
36
+ layers.append(nn.ReLU(inplace=True))
37
+ layers.append(nn.Linear(hidden_sizes[-1], output_size))
38
+ self.mlp = nn.Sequential(*layers)
39
+
40
+ def _forward(self, x):
41
+ return self.mlp(x)
42
+
43
+ def load_model(self, ckpt_path):
44
+ checkpoint = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
45
+ self.load_state_dict(remove_ddp_dumplicate_key(checkpoint['retarget_shoulder']))
46
+ self.eval()
47
+ return self
48
+
49
+ def stitching(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
50
+ """ conduct the stitching
51
+ kp_source: Bxnum_kpx3
52
+ kp_driving: Bxnum_kpx3
53
+ """
54
+ bs, num_kp = kp_source.shape[:2]
55
+ kp_driving_new = kp_driving.clone()
56
+ delta = self._forward(torch.cat([kp_source.view(bs, -1), kp_driving_new.view(bs, -1)], dim=1))
57
+ delta_exp = delta[..., :3*num_kp].reshape(bs, num_kp, 3) # 1x20x3
58
+ delta_tx_ty = delta[..., 3*num_kp:3*num_kp+2].reshape(bs, 1, 2) # 1x1x2
59
+ kp_driving_new += delta_exp
60
+ kp_driving_new[..., :2] += delta_tx_ty
61
+ return kp_driving_new
62
+
63
+ def forward(self, kp_source, kp_driving):
64
+ out = self.stitching(kp_source, kp_driving)
65
+ return out
core/models/modules/util.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ This file defines various neural network modules and utility functions, including convolutional and residual blocks,
5
+ normalizations, and functions for spatial transformation and tensor manipulation.
6
+ """
7
+
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+ import torch
11
+ import torch.nn.utils.spectral_norm as spectral_norm
12
+ import math
13
+ import warnings
14
+ import collections.abc
15
+ from itertools import repeat
16
+
17
+ def kp2gaussian(kp, spatial_size, kp_variance):
18
+ """
19
+ Transform a keypoint into gaussian like representation
20
+ """
21
+ mean = kp
22
+
23
+ coordinate_grid = make_coordinate_grid(spatial_size, mean)
24
+ number_of_leading_dimensions = len(mean.shape) - 1
25
+ shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape
26
+ coordinate_grid = coordinate_grid.view(*shape)
27
+ repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 1)
28
+ coordinate_grid = coordinate_grid.repeat(*repeats)
29
+
30
+ # Preprocess kp shape
31
+ shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 3)
32
+ mean = mean.view(*shape)
33
+
34
+ mean_sub = (coordinate_grid - mean)
35
+
36
+ out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance)
37
+
38
+ return out
39
+
40
+
41
+ def make_coordinate_grid(spatial_size, ref, **kwargs):
42
+ d, h, w = spatial_size
43
+ x = torch.arange(w).type(ref.dtype).to(ref.device)
44
+ y = torch.arange(h).type(ref.dtype).to(ref.device)
45
+ z = torch.arange(d).type(ref.dtype).to(ref.device)
46
+
47
+ # NOTE: must be right-down-in
48
+ x = (2 * (x / (w - 1)) - 1) # the x axis faces to the right
49
+ y = (2 * (y / (h - 1)) - 1) # the y axis faces to the bottom
50
+ z = (2 * (z / (d - 1)) - 1) # the z axis faces to the inner
51
+
52
+ yy = y.view(1, -1, 1).repeat(d, 1, w)
53
+ xx = x.view(1, 1, -1).repeat(d, h, 1)
54
+ zz = z.view(-1, 1, 1).repeat(1, h, w)
55
+
56
+ meshed = torch.cat([xx.unsqueeze_(3), yy.unsqueeze_(3), zz.unsqueeze_(3)], 3)
57
+
58
+ return meshed
59
+
60
+
61
+ class ConvT2d(nn.Module):
62
+ """
63
+ Upsampling block for use in decoder.
64
+ """
65
+
66
+ def __init__(self, in_features, out_features, kernel_size=3, stride=2, padding=1, output_padding=1):
67
+ super(ConvT2d, self).__init__()
68
+
69
+ self.convT = nn.ConvTranspose2d(in_features, out_features, kernel_size=kernel_size, stride=stride,
70
+ padding=padding, output_padding=output_padding)
71
+ self.norm = nn.InstanceNorm2d(out_features)
72
+
73
+ def forward(self, x):
74
+ out = self.convT(x)
75
+ out = self.norm(out)
76
+ out = F.leaky_relu(out)
77
+ return out
78
+
79
+
80
+ class ResBlock3d(nn.Module):
81
+ """
82
+ Res block, preserve spatial resolution.
83
+ """
84
+
85
+ def __init__(self, in_features, kernel_size, padding):
86
+ super(ResBlock3d, self).__init__()
87
+ self.conv1 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, padding=padding)
88
+ self.conv2 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, padding=padding)
89
+ self.norm1 = nn.BatchNorm3d(in_features, affine=True)
90
+ self.norm2 = nn.BatchNorm3d(in_features, affine=True)
91
+
92
+ def forward(self, x):
93
+ out = self.norm1(x)
94
+ out = F.relu(out)
95
+ out = self.conv1(out)
96
+ out = self.norm2(out)
97
+ out = F.relu(out)
98
+ out = self.conv2(out)
99
+ out += x
100
+ return out
101
+
102
+
103
+ class UpBlock3d(nn.Module):
104
+ """
105
+ Upsampling block for use in decoder.
106
+ """
107
+
108
+ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
109
+ super(UpBlock3d, self).__init__()
110
+
111
+ self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
112
+ padding=padding, groups=groups)
113
+ self.norm = nn.BatchNorm3d(out_features, affine=True)
114
+
115
+ def forward(self, x):
116
+ out = F.interpolate(x, scale_factor=(1, 2, 2))
117
+ out = self.conv(out)
118
+ out = self.norm(out)
119
+ out = F.relu(out)
120
+ return out
121
+
122
+
123
+ class DownBlock2d(nn.Module):
124
+ """
125
+ Downsampling block for use in encoder.
126
+ """
127
+
128
+ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
129
+ super(DownBlock2d, self).__init__()
130
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups)
131
+ self.norm = nn.BatchNorm2d(out_features, affine=True)
132
+ self.pool = nn.AvgPool2d(kernel_size=(2, 2))
133
+
134
+ def forward(self, x):
135
+ out = self.conv(x)
136
+ out = self.norm(out)
137
+ out = F.relu(out)
138
+ out = self.pool(out)
139
+ return out
140
+
141
+
142
+ class DownBlock3d(nn.Module):
143
+ """
144
+ Downsampling block for use in encoder.
145
+ """
146
+
147
+ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
148
+ super(DownBlock3d, self).__init__()
149
+ '''
150
+ self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
151
+ padding=padding, groups=groups, stride=(1, 2, 2))
152
+ '''
153
+ self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
154
+ padding=padding, groups=groups)
155
+ self.norm = nn.BatchNorm3d(out_features, affine=True)
156
+ self.pool = nn.AvgPool3d(kernel_size=(1, 2, 2))
157
+
158
+ def forward(self, x):
159
+ out = self.conv(x)
160
+ out = self.norm(out)
161
+ out = F.relu(out)
162
+ out = self.pool(out)
163
+ return out
164
+
165
+
166
+ class SameBlock2d(nn.Module):
167
+ """
168
+ Simple block, preserve spatial resolution.
169
+ """
170
+
171
+ def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1, lrelu=False):
172
+ super(SameBlock2d, self).__init__()
173
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups)
174
+ self.norm = nn.BatchNorm2d(out_features, affine=True)
175
+ if lrelu:
176
+ self.ac = nn.LeakyReLU()
177
+ else:
178
+ self.ac = nn.ReLU()
179
+
180
+ def forward(self, x):
181
+ out = self.conv(x)
182
+ out = self.norm(out)
183
+ out = self.ac(out)
184
+ return out
185
+
186
+
187
+ class Encoder(nn.Module):
188
+ """
189
+ Hourglass Encoder
190
+ """
191
+
192
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
193
+ super(Encoder, self).__init__()
194
+
195
+ down_blocks = []
196
+ for i in range(num_blocks):
197
+ down_blocks.append(DownBlock3d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)), min(max_features, block_expansion * (2 ** (i + 1))), kernel_size=3, padding=1))
198
+ self.down_blocks = nn.ModuleList(down_blocks)
199
+
200
+ def forward(self, x):
201
+ outs = [x]
202
+ for down_block in self.down_blocks:
203
+ outs.append(down_block(outs[-1]))
204
+ return outs
205
+
206
+
207
+ class Decoder(nn.Module):
208
+ """
209
+ Hourglass Decoder
210
+ """
211
+
212
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
213
+ super(Decoder, self).__init__()
214
+
215
+ up_blocks = []
216
+
217
+ for i in range(num_blocks)[::-1]:
218
+ in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1)))
219
+ out_filters = min(max_features, block_expansion * (2 ** i))
220
+ up_blocks.append(UpBlock3d(in_filters, out_filters, kernel_size=3, padding=1))
221
+
222
+ self.up_blocks = nn.ModuleList(up_blocks)
223
+ self.out_filters = block_expansion + in_features
224
+
225
+ self.conv = nn.Conv3d(in_channels=self.out_filters, out_channels=self.out_filters, kernel_size=3, padding=1)
226
+ self.norm = nn.BatchNorm3d(self.out_filters, affine=True)
227
+
228
+ def forward(self, x):
229
+ out = x.pop()
230
+ for up_block in self.up_blocks:
231
+ out = up_block(out)
232
+ skip = x.pop()
233
+ out = torch.cat([out, skip], dim=1)
234
+ out = self.conv(out)
235
+ out = self.norm(out)
236
+ out = F.relu(out)
237
+ return out
238
+
239
+
240
+ class Hourglass(nn.Module):
241
+ """
242
+ Hourglass architecture.
243
+ """
244
+
245
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
246
+ super(Hourglass, self).__init__()
247
+ self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features)
248
+ self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features)
249
+ self.out_filters = self.decoder.out_filters
250
+
251
+ def forward(self, x):
252
+ return self.decoder(self.encoder(x))
253
+
254
+
255
+ class SPADE(nn.Module):
256
+ def __init__(self, norm_nc, label_nc):
257
+ super().__init__()
258
+
259
+ self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
260
+ nhidden = 128
261
+
262
+ self.mlp_shared = nn.Sequential(
263
+ nn.Conv2d(label_nc, nhidden, kernel_size=3, padding=1),
264
+ nn.ReLU())
265
+ self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
266
+ self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
267
+
268
+ def forward(self, x, segmap):
269
+ normalized = self.param_free_norm(x)
270
+ segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
271
+ actv = self.mlp_shared(segmap)
272
+ gamma = self.mlp_gamma(actv)
273
+ beta = self.mlp_beta(actv)
274
+ out = normalized * (1 + gamma) + beta
275
+ return out
276
+
277
+
278
+ class SPADEResnetBlock(nn.Module):
279
+ def __init__(self, fin, fout, norm_G, label_nc, use_se=False, dilation=1):
280
+ super().__init__()
281
+ # Attributes
282
+ self.learned_shortcut = (fin != fout)
283
+ fmiddle = min(fin, fout)
284
+ self.use_se = use_se
285
+ # create conv layers
286
+ self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation)
287
+ self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation)
288
+ if self.learned_shortcut:
289
+ self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
290
+ # apply spectral norm if specified
291
+ if 'spectral' in norm_G:
292
+ self.conv_0 = spectral_norm(self.conv_0)
293
+ self.conv_1 = spectral_norm(self.conv_1)
294
+ if self.learned_shortcut:
295
+ self.conv_s = spectral_norm(self.conv_s)
296
+ # define normalization layers
297
+ self.norm_0 = SPADE(fin, label_nc)
298
+ self.norm_1 = SPADE(fmiddle, label_nc)
299
+ if self.learned_shortcut:
300
+ self.norm_s = SPADE(fin, label_nc)
301
+
302
+ def forward(self, x, seg1):
303
+ x_s = self.shortcut(x, seg1)
304
+ dx = self.conv_0(self.actvn(self.norm_0(x, seg1)))
305
+ dx = self.conv_1(self.actvn(self.norm_1(dx, seg1)))
306
+ out = x_s + dx
307
+ return out
308
+
309
+ def shortcut(self, x, seg1):
310
+ if self.learned_shortcut:
311
+ x_s = self.conv_s(self.norm_s(x, seg1))
312
+ else:
313
+ x_s = x
314
+ return x_s
315
+
316
+ def actvn(self, x):
317
+ return F.leaky_relu(x, 2e-1)
318
+
319
+
320
+ def filter_state_dict(state_dict, remove_name='fc'):
321
+ new_state_dict = {}
322
+ for key in state_dict:
323
+ if remove_name in key:
324
+ continue
325
+ new_state_dict[key] = state_dict[key]
326
+ return new_state_dict
327
+
328
+
329
+ class GRN(nn.Module):
330
+ """ GRN (Global Response Normalization) layer
331
+ """
332
+
333
+ def __init__(self, dim):
334
+ super().__init__()
335
+ self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
336
+ self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
337
+
338
+ def forward(self, x):
339
+ Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
340
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
341
+ return self.gamma * (x * Nx) + self.beta + x
342
+
343
+
344
+ class LayerNorm(nn.Module):
345
+ r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
346
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
347
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
348
+ with shape (batch_size, channels, height, width).
349
+ """
350
+
351
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
352
+ super().__init__()
353
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
354
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
355
+ self.eps = eps
356
+ self.data_format = data_format
357
+ if self.data_format not in ["channels_last", "channels_first"]:
358
+ raise NotImplementedError
359
+ self.normalized_shape = (normalized_shape, )
360
+
361
+ def forward(self, x):
362
+ if self.data_format == "channels_last":
363
+ return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
364
+ elif self.data_format == "channels_first":
365
+ u = x.mean(1, keepdim=True)
366
+ s = (x - u).pow(2).mean(1, keepdim=True)
367
+ x = (x - u) / torch.sqrt(s + self.eps)
368
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
369
+ return x
370
+
371
+
372
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
373
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
374
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
375
+ def norm_cdf(x):
376
+ # Computes standard normal cumulative distribution function
377
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
378
+
379
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
380
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
381
+ "The distribution of values may be incorrect.",
382
+ stacklevel=2)
383
+
384
+ with torch.no_grad():
385
+ # Values are generated by using a truncated uniform distribution and
386
+ # then using the inverse CDF for the normal distribution.
387
+ # Get upper and lower cdf values
388
+ l = norm_cdf((a - mean) / std)
389
+ u = norm_cdf((b - mean) / std)
390
+
391
+ # Uniformly fill tensor with values from [l, u], then translate to
392
+ # [2l-1, 2u-1].
393
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
394
+
395
+ # Use inverse cdf transform for normal distribution to get truncated
396
+ # standard normal
397
+ tensor.erfinv_()
398
+
399
+ # Transform to proper mean, std
400
+ tensor.mul_(std * math.sqrt(2.))
401
+ tensor.add_(mean)
402
+
403
+ # Clamp to ensure it's in the proper range
404
+ tensor.clamp_(min=a, max=b)
405
+ return tensor
406
+
407
+
408
+ def drop_path(x, drop_prob=0., training=False, scale_by_keep=True):
409
+ """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
410
+
411
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
412
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
413
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
414
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
415
+ 'survival rate' as the argument.
416
+
417
+ """
418
+ if drop_prob == 0. or not training:
419
+ return x
420
+ keep_prob = 1 - drop_prob
421
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
422
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
423
+ if keep_prob > 0.0 and scale_by_keep:
424
+ random_tensor.div_(keep_prob)
425
+ return x * random_tensor
426
+
427
+
428
+ class DropPath(nn.Module):
429
+ """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
430
+ """
431
+
432
+ def __init__(self, drop_prob=None, scale_by_keep=True):
433
+ super(DropPath, self).__init__()
434
+ self.drop_prob = drop_prob
435
+ self.scale_by_keep = scale_by_keep
436
+
437
+ def forward(self, x):
438
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
439
+
440
+
441
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
442
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
443
+
444
+ # From PyTorch internals
445
+ def _ntuple(n):
446
+ def parse(x):
447
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
448
+ return tuple(x)
449
+ return tuple(repeat(x, n))
450
+ return parse
451
+
452
+ to_2tuple = _ntuple(2)
core/models/modules/warping_network.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ """
4
+ Warping field estimator(W) defined in the paper, which generates a warping field using the implicit
5
+ keypoint representations x_s and x_d, and employs this flow field to warp the source feature volume f_s.
6
+ """
7
+ import torch
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+ from .util import SameBlock2d
11
+ from .dense_motion import DenseMotionNetwork
12
+
13
+
14
+ class WarpingNetwork(nn.Module):
15
+ def __init__(
16
+ self,
17
+ num_kp=21,
18
+ block_expansion=64,
19
+ max_features=512,
20
+ num_down_blocks=2,
21
+ reshape_channel=32,
22
+ estimate_occlusion_map=True,
23
+ **kwargs
24
+ ):
25
+ super(WarpingNetwork, self).__init__()
26
+
27
+ self.upscale = kwargs.get('upscale', 1)
28
+ self.flag_use_occlusion_map = kwargs.get('flag_use_occlusion_map', True)
29
+
30
+ dense_motion_params = {
31
+ "block_expansion": 32,
32
+ "max_features": 1024,
33
+ "num_blocks": 5,
34
+ "reshape_depth": 16,
35
+ "compress": 4,
36
+ }
37
+
38
+ self.dense_motion_network = DenseMotionNetwork(
39
+ num_kp=num_kp,
40
+ feature_channel=reshape_channel,
41
+ estimate_occlusion_map=estimate_occlusion_map,
42
+ **dense_motion_params
43
+ )
44
+
45
+ self.third = SameBlock2d(max_features, block_expansion * (2 ** num_down_blocks), kernel_size=(3, 3), padding=(1, 1), lrelu=True)
46
+ self.fourth = nn.Conv2d(in_channels=block_expansion * (2 ** num_down_blocks), out_channels=block_expansion * (2 ** num_down_blocks), kernel_size=1, stride=1)
47
+
48
+ self.estimate_occlusion_map = estimate_occlusion_map
49
+
50
+ def deform_input(self, inp, deformation):
51
+ return F.grid_sample(inp, deformation, align_corners=False)
52
+
53
+ def forward(self, feature_3d, kp_source, kp_driving):
54
+ # Feature warper, Transforming feature representation according to deformation and occlusion
55
+ dense_motion = self.dense_motion_network(
56
+ feature=feature_3d, kp_driving=kp_driving, kp_source=kp_source
57
+ )
58
+ if 'occlusion_map' in dense_motion:
59
+ occlusion_map = dense_motion['occlusion_map'] # Bx1x64x64
60
+ else:
61
+ occlusion_map = None
62
+
63
+ deformation = dense_motion['deformation'] # Bx16x64x64x3
64
+ out = self.deform_input(feature_3d, deformation) # Bx32x16x64x64
65
+
66
+ bs, c, d, h, w = out.shape # Bx32x16x64x64
67
+ out = out.view(bs, c * d, h, w) # -> Bx512x64x64
68
+ out = self.third(out) # -> Bx256x64x64
69
+ out = self.fourth(out) # -> Bx256x64x64
70
+
71
+ if self.flag_use_occlusion_map and (occlusion_map is not None):
72
+ out = out * occlusion_map
73
+
74
+ # ret_dct = {
75
+ # 'occlusion_map': occlusion_map,
76
+ # 'deformation': deformation,
77
+ # 'out': out,
78
+ # }
79
+
80
+ # return ret_dct
81
+
82
+ return out
83
+
84
+ def load_model(self, ckpt_path):
85
+ self.load_state_dict(torch.load(ckpt_path, map_location=lambda storage, loc: storage))
86
+ self.eval()
87
+ return self
core/models/motion_extractor.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from ..utils.load_model import load_model
4
+
5
+
6
+ class MotionExtractor:
7
+ def __init__(self, model_path, device="cuda"):
8
+ kwargs = {
9
+ "module_name": "MotionExtractor",
10
+ }
11
+ self.model, self.model_type = load_model(model_path, device=device, **kwargs)
12
+ self.device = device
13
+
14
+ self.output_names = [
15
+ "pitch",
16
+ "yaw",
17
+ "roll",
18
+ "t",
19
+ "exp",
20
+ "scale",
21
+ "kp",
22
+ ]
23
+
24
+ def __call__(self, image):
25
+ """
26
+ image: np.ndarray, shape (1, 3, 256, 256), RGB, 0-1
27
+ """
28
+ outputs = {}
29
+ if self.model_type == "onnx":
30
+ out_list = self.model.run(None, {"image": image})
31
+ for i, name in enumerate(self.output_names):
32
+ outputs[name] = out_list[i]
33
+ elif self.model_type == "tensorrt":
34
+ self.model.setup({"image": image})
35
+ self.model.infer()
36
+ for name in self.output_names:
37
+ outputs[name] = self.model.buffer[name][0].copy()
38
+ elif self.model_type == "pytorch":
39
+ with torch.no_grad(), torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=True):
40
+ pred = self.model(torch.from_numpy(image).to(self.device))
41
+ for i, name in enumerate(self.output_names):
42
+ outputs[name] = pred[i].float().cpu().numpy()
43
+ else:
44
+ raise ValueError(f"Unsupported model type: {self.model_type}")
45
+ outputs["exp"] = outputs["exp"].reshape(1, -1)
46
+ outputs["kp"] = outputs["kp"].reshape(1, -1)
47
+ return outputs
48
+
49
+
core/models/stitch_network.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from ..utils.load_model import load_model
4
+
5
+
6
+ class StitchNetwork:
7
+ def __init__(self, model_path, device="cuda"):
8
+ kwargs = {
9
+ "module_name": "StitchingNetwork",
10
+ }
11
+ self.model, self.model_type = load_model(model_path, device=device, **kwargs)
12
+ self.device = device
13
+
14
+ def __call__(self, kp_source, kp_driving):
15
+ if self.model_type == "onnx":
16
+ pred = self.model.run(None, {"kp_source": kp_source, "kp_driving": kp_driving})[0]
17
+ elif self.model_type == "tensorrt":
18
+ self.model.setup({"kp_source": kp_source, "kp_driving": kp_driving})
19
+ self.model.infer()
20
+ pred = self.model.buffer["out"][0].copy()
21
+ elif self.model_type == 'pytorch':
22
+ with torch.no_grad():
23
+ pred = self.model(
24
+ torch.from_numpy(kp_source).to(self.device),
25
+ torch.from_numpy(kp_driving).to(self.device)
26
+ ).cpu().numpy()
27
+ else:
28
+ raise ValueError(f"Unsupported model type: {self.model_type}")
29
+
30
+ return pred
core/models/warp_network.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from ..utils.load_model import load_model
4
+
5
+
6
+ class WarpNetwork:
7
+ def __init__(self, model_path, device="cuda"):
8
+ kwargs = {
9
+ "module_name": "WarpingNetwork",
10
+ }
11
+ self.model, self.model_type = load_model(model_path, device=device, **kwargs)
12
+ self.device = device
13
+
14
+ def __call__(self, feature_3d, kp_source, kp_driving):
15
+ """
16
+ feature_3d: np.ndarray, shape (1, 32, 16, 64, 64)
17
+ kp_source | kp_driving: np.ndarray, shape (1, 21, 3)
18
+ """
19
+ if self.model_type == "onnx":
20
+ pred = self.model.run(None, {"feature_3d": feature_3d, "kp_source": kp_source, "kp_driving": kp_driving})[0]
21
+ elif self.model_type == "tensorrt":
22
+ self.model.setup({"feature_3d": feature_3d, "kp_source": kp_source, "kp_driving": kp_driving})
23
+ self.model.infer()
24
+ pred = self.model.buffer["out"][0].copy()
25
+ elif self.model_type == 'pytorch':
26
+ with torch.no_grad(), torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=True):
27
+ pred = self.model(
28
+ torch.from_numpy(feature_3d).to(self.device),
29
+ torch.from_numpy(kp_source).to(self.device),
30
+ torch.from_numpy(kp_driving).to(self.device)
31
+ ).float().cpu().numpy()
32
+ else:
33
+ raise ValueError(f"Unsupported model type: {self.model_type}")
34
+
35
+ return pred
core/utils/blend/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ import pyximport
2
+ pyximport.install()
3
+
4
+ from .blend import blend_images_cy
core/utils/blend/blend.pyx ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #cython: language_level=3
2
+ import numpy as np
3
+ cimport numpy as np
4
+
5
+ cdef extern from "blend_impl.h":
6
+ void _blend_images_cy_impl(
7
+ const float* mask_warped,
8
+ const float* frame_warped,
9
+ const unsigned char* frame_rgb,
10
+ const int height,
11
+ const int width,
12
+ unsigned char* result
13
+ ) noexcept nogil
14
+
15
+ def blend_images_cy(
16
+ np.ndarray[np.float32_t, ndim=2] mask_warped,
17
+ np.ndarray[np.float32_t, ndim=3] frame_warped,
18
+ np.ndarray[np.uint8_t, ndim=3] frame_rgb,
19
+ np.ndarray[np.uint8_t, ndim=3] result
20
+ ):
21
+ cdef int h = mask_warped.shape[0]
22
+ cdef int w = mask_warped.shape[1]
23
+
24
+ if not mask_warped.flags['C_CONTIGUOUS']:
25
+ mask_warped = np.ascontiguousarray(mask_warped)
26
+ if not frame_warped.flags['C_CONTIGUOUS']:
27
+ frame_warped = np.ascontiguousarray(frame_warped)
28
+ if not frame_rgb.flags['C_CONTIGUOUS']:
29
+ frame_rgb = np.ascontiguousarray(frame_rgb)
30
+
31
+ with nogil:
32
+ _blend_images_cy_impl(
33
+ <const float*>mask_warped.data,
34
+ <const float*>frame_warped.data,
35
+ <const unsigned char*>frame_rgb.data,
36
+ h, w,
37
+ <unsigned char*>result.data
38
+ )
core/utils/blend/blend.pyxbld ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+
4
+ def make_ext(modname, pyxfilename):
5
+ from distutils.extension import Extension
6
+
7
+ return Extension(name=modname,
8
+ sources=[pyxfilename, os.path.join(os.path.dirname(pyxfilename), "blend_impl.c")],
9
+ include_dirs=[np.get_include(), os.path.dirname(pyxfilename)],
10
+ extra_compile_args=["-O3", "-std=c99", "-march=native", "-ffast-math"],
11
+ )