elungky commited on
Commit
28451f7
·
0 Parent(s):

Initial commit for new Space - pre-built Docker image

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .flake8 +10 -0
  2. .gitattributes +44 -0
  3. .gitignore +247 -0
  4. .gitmodules +27 -0
  5. .pre-commit-config.yaml +55 -0
  6. ATTRIBUTIONS.md +0 -0
  7. CONTRIBUTING.md +51 -0
  8. INSTALL.md +48 -0
  9. LICENSE +201 -0
  10. README.md +248 -0
  11. assets/demo_1.gif +3 -0
  12. assets/demo_2.gif +3 -0
  13. assets/demo_3.gif +3 -0
  14. assets/demo_dynamic.gif +3 -0
  15. assets/diffusion/000000.png +3 -0
  16. assets/diffusion/000001.png +3 -0
  17. assets/diffusion/000002.png +3 -0
  18. assets/diffusion/000003.png +3 -0
  19. assets/diffusion/000004.png +3 -0
  20. assets/diffusion/000005.png +3 -0
  21. assets/diffusion/000006.png +3 -0
  22. assets/diffusion/000007.png +3 -0
  23. assets/diffusion/000008.png +3 -0
  24. assets/diffusion/000009.png +3 -0
  25. assets/diffusion/000010.png +3 -0
  26. assets/diffusion/000011.png +3 -0
  27. assets/diffusion/000012.png +3 -0
  28. assets/diffusion/000013.png +3 -0
  29. assets/diffusion/000014.png +3 -0
  30. assets/diffusion/000015.png +3 -0
  31. checkpoints/README.md +4 -0
  32. cosmos-predict1.yaml +29 -0
  33. cosmos_predict1/__init__.py +14 -0
  34. cosmos_predict1/autoregressive/__init__.py +14 -0
  35. cosmos_predict1/autoregressive/callbacks/video_sampling_teacher_forcing.py +352 -0
  36. cosmos_predict1/autoregressive/configs/__init__.py +14 -0
  37. cosmos_predict1/autoregressive/configs/base/__init__.py +14 -0
  38. cosmos_predict1/autoregressive/configs/base/callbacks.py +33 -0
  39. cosmos_predict1/autoregressive/configs/base/dataloader.py +72 -0
  40. cosmos_predict1/autoregressive/configs/base/dataset.py +39 -0
  41. cosmos_predict1/autoregressive/configs/base/model.py +318 -0
  42. cosmos_predict1/autoregressive/configs/base/model_config.py +718 -0
  43. cosmos_predict1/autoregressive/configs/base/model_parallel.py +33 -0
  44. cosmos_predict1/autoregressive/configs/base/optim.py +86 -0
  45. cosmos_predict1/autoregressive/configs/base/tokenizer.py +139 -0
  46. cosmos_predict1/autoregressive/configs/config.py +111 -0
  47. cosmos_predict1/autoregressive/configs/experiment/video2video/__init__.py +0 -0
  48. cosmos_predict1/autoregressive/configs/experiment/video2video/basic.py +163 -0
  49. cosmos_predict1/autoregressive/configs/inference/inference_config.py +102 -0
  50. cosmos_predict1/autoregressive/configs/registry.py +89 -0
.flake8 ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ [flake8]
2
+ enable-extensions = G
3
+ select = B,C,E,F,G,P,SIM1,T4,W,B9
4
+ max-line-length = 120
5
+ # C408 ignored because we like the dict keyword argument syntax
6
+ # E501 is not flexible enough, we're using B950 instead
7
+ ignore =
8
+ E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,E226,E265
9
+ exclude =
10
+ third_party
.gitattributes ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ <<<<<<< HEAD
37
+ assets/*.gif filter=lfs diff=lfs merge=lfs -text
38
+ *.gif filter=lfs diff=lfs merge=lfs -text
39
+ *.png filter=lfs diff=lfs merge=lfs -text
40
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
41
+ =======
42
+ >>>>>>> 0453ffbfce197070bb0c254a11ef21f15d1ad986
43
+ transformer_engine_torch-1.12.0+cu121-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
44
+ transformer_engine.whl filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Misc
17
+ outputs/
18
+ checkpoints/*
19
+ !checkpoints/README.md
20
+ datasets/*
21
+ !datasets/README.md
22
+ apex/
23
+
24
+ # Data types
25
+ *.jit
26
+ *.pt
27
+ *.hdr
28
+ *.webp
29
+ *.pgm
30
+ *.tiff
31
+ *.tif
32
+ *.tar
33
+ *.tar.gz
34
+ *.gz
35
+ *.pkl
36
+ *.pt
37
+ *.bin
38
+ *.pickle
39
+ *.txt
40
+
41
+ # Other uncheckable file types
42
+ *.zip
43
+ *.exe
44
+ *.dll
45
+ *.swp
46
+ *.vscode
47
+ *.DS_Store
48
+ *.pyc
49
+ *Thumbs.db
50
+ *.patch
51
+
52
+ # Credential information that should never be checked in
53
+ credentials
54
+ *.secret
55
+
56
+ # ------------------------ BELOW IS AUTO-GENERATED FOR PYTHON REPOS ------------------------
57
+
58
+ # Byte-compiled / optimized / DLL files
59
+ **/__pycache__/
60
+ *.py[cod]
61
+ *$py.class
62
+
63
+ # C extensions
64
+ *.so
65
+
66
+ # Distribution / packaging
67
+ .Python
68
+ build/
69
+ develop-eggs/
70
+ dist/
71
+ downloads/
72
+ eggs/
73
+ .eggs/
74
+ lib/
75
+ lib64/
76
+ parts/
77
+ results/
78
+ sdist/
79
+ var/
80
+ wheels/
81
+ share/python-wheels/
82
+ *.egg-info/
83
+ .installed.config
84
+ *.egg
85
+ MANIFEST
86
+
87
+ # PyInstaller
88
+ # Usually these files are written by a python script from a template
89
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
90
+ *.manifest
91
+ *.spec
92
+
93
+ # Installer logs
94
+ pip-log.txt
95
+ pip-delete-this-directory.txt
96
+
97
+ # Unit test / coverage reports
98
+ htmlcov/
99
+ .tox/
100
+ .nox/
101
+ .coverage
102
+ .coverage.*
103
+ .cache
104
+ nosetests.xml
105
+ coverage.xml
106
+ *.cover
107
+ *.py,cover
108
+ .hypothesis/
109
+ .pytest_cache/
110
+ cover/
111
+
112
+ # Translations
113
+ *.mo
114
+ *.pot
115
+
116
+ # Django stuff:
117
+ *.log
118
+ local_settings.py
119
+ db.sqlite3
120
+ db.sqlite3-journal
121
+
122
+ # Flask stuff:
123
+ instance/
124
+ .webassets-cache
125
+
126
+ # Scrapy stuff:
127
+ .scrapy
128
+
129
+ # Sphinx documentation
130
+ docs/_build/
131
+
132
+ # PyBuilder
133
+ .pybuilder/
134
+ target/
135
+
136
+ # Third party
137
+ # Jupyter Notebook
138
+ .ipynb_checkpoints
139
+
140
+ # IPython
141
+ profile_default/
142
+ ipython_config.py
143
+
144
+ # pyenv
145
+ # For a library or package, you might want to ignore these files since the code is
146
+ # intended to run in multiple environments; otherwise, check them in:
147
+ # .python-version
148
+
149
+ # pipenv
150
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
151
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
152
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
153
+ # install all needed dependencies.
154
+ #Pipfile.lock
155
+
156
+ # poetry
157
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
158
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
159
+ # commonly ignored for libraries.
160
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
161
+ #poetry.lock
162
+
163
+ # pdm
164
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
165
+ #pdm.lock
166
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
167
+ # in version control.
168
+ # https://pdm.fming.dev/#use-with-ide
169
+ .pdm.toml
170
+
171
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
172
+ __pypackages__/
173
+
174
+ # Celery stuff
175
+ celerybeat-schedule
176
+ celerybeat.pid
177
+
178
+ # SageMath parsed files
179
+ *.sage.py
180
+
181
+ # Environments
182
+ .env
183
+ .venv
184
+ env/
185
+ venv/
186
+ ENV/
187
+ env.bak/
188
+ venv.bak/
189
+
190
+ # Spyder project settings
191
+ .spyderproject
192
+ .spyproject
193
+
194
+ # Rope project settings
195
+ .ropeproject
196
+
197
+ # mkdocs documentation
198
+ /site
199
+
200
+ # mypy
201
+ .mypy_cache/
202
+ .dmypy.json
203
+ dmypy.json
204
+
205
+ # Pyre type checker
206
+ .pyre/
207
+
208
+ # pytype static type analyzer
209
+ .pytype/
210
+
211
+ # Cython debug symbols
212
+ cython_debug/
213
+
214
+ # ruff
215
+ .ruff_cache
216
+
217
+ # PyCharm
218
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
219
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
220
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
221
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
222
+ #.idea/
223
+ CLIP
224
+ .devcontainer/devcontainer.json
225
+
226
+ # Coverage
227
+ .coverage
228
+ coverage.xml
229
+
230
+ # JUnit Reports
231
+ report.xml
232
+
233
+ # CI-CD
234
+ temp/
235
+ envs.txt
236
+ manifest.json
237
+
238
+
239
+ # locks and t5 temp files
240
+ *.locks*
241
+ *.no_exist*
242
+ *models--t5*
243
+
244
+ # OneLogger
245
+ wandb/
246
+ onelogger.err
247
+ onelogger.log
.gitmodules ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [submodule "gui/dependencies/pybind11"]
2
+ path = gui/dependencies/pybind11
3
+ url = https://github.com/Tom94/pybind11
4
+ [submodule "gui/dependencies/glfw"]
5
+ path = gui/dependencies/glfw
6
+ url = https://github.com/Tom94/glfw
7
+ [submodule "gui/dependencies/args"]
8
+ path = gui/dependencies/args
9
+ url = https://github.com/Taywee/args
10
+ [submodule "gui/dependencies/tinylogger"]
11
+ path = gui/dependencies/tinylogger
12
+ url = https://github.com/Tom94/tinylogger
13
+ [submodule "gui/dependencies/imgui"]
14
+ path = gui/dependencies/imgui
15
+ url = https://github.com/ocornut/imgui.git
16
+ [submodule "gui/dependencies/dlss"]
17
+ path = gui/dependencies/dlss
18
+ url = https://github.com/NVIDIA/DLSS
19
+ [submodule "gui/dependencies/OpenXR-SDK"]
20
+ path = gui/dependencies/OpenXR-SDK
21
+ url = https://github.com/KhronosGroup/OpenXR-SDK.git
22
+ [submodule "gui/dependencies/zlib"]
23
+ path = gui/dependencies/zlib
24
+ url = https://github.com/Tom94/zlib
25
+ [submodule "gui/dependencies/fmt"]
26
+ path = gui/dependencies/fmt
27
+ url = https://github.com/fmtlib/fmt
.pre-commit-config.yaml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ default_language_version:
17
+ python: python3.10
18
+ repos:
19
+ - repo: https://github.com/pycqa/flake8
20
+ rev: 6.0.0
21
+ hooks:
22
+ - id: flake8
23
+ args:
24
+ - --max-line-length=120
25
+ - --ignore=E501,F401,E203,E402,E265,E741,F841,F821,F811,W503,E231,E225,E702
26
+ exclude: ^dist/|^third_party/
27
+
28
+ - repo: https://github.com/psf/black
29
+ rev: 23.12.1
30
+ hooks:
31
+ - id: black
32
+ args: [--line-length=120]
33
+ exclude: ^dist/|^third_party/
34
+
35
+ - repo: https://github.com/timothycrosley/isort
36
+ rev: 5.12.0
37
+ hooks:
38
+ - id: isort
39
+ args: [--line-length=120]
40
+
41
+ - repo: https://github.com/MarcoGorelli/absolufy-imports
42
+ rev: v0.3.1
43
+ hooks:
44
+ - id: absolufy-imports
45
+
46
+ - repo: https://github.com/pre-commit/pre-commit-hooks
47
+ rev: v4.0.1
48
+ hooks:
49
+ - id: trailing-whitespace
50
+ exclude: ^tests/.*/fixtures/.*
51
+ args: [--markdown-linebreak-ext=md]
52
+ - id: end-of-file-fixer
53
+ exclude: ^tests/.*/fixtures/.*
54
+ - id: check-added-large-files
55
+ args: ['--maxkb=2000']
ATTRIBUTIONS.md ADDED
The diff for this file is too large to render. See raw diff
 
CONTRIBUTING.md ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # How to Contribute
2
+
3
+ We'd love to receive your patches and contributions. Please keep your PRs as draft until such time that you would like us to review them.
4
+
5
+ ## Code Reviews
6
+
7
+ All submissions, including submissions by project members, require review. We use GitHub pull requests for this purpose. Consult
8
+ [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more information on using pull requests.
9
+
10
+ ## Signing Your Work
11
+
12
+ * We require that all contributors "sign-off" on their commits. This certifies that the contribution is your original work, or you have rights to submit it under the same license, or a compatible license.
13
+
14
+ * Any contribution which contains commits that are not Signed-Off will not be accepted.
15
+
16
+ * To sign off on a commit you simply use the `--signoff` (or `-s`) option when committing your changes:
17
+ ```bash
18
+ $ git commit -s -m "Add cool feature."
19
+ ```
20
+ This will append the following to your commit message:
21
+ ```
22
+ Signed-off-by: Your Name <your@email.com>
23
+ ```
24
+
25
+ * Full text of the DCO:
26
+
27
+ ```
28
+ Developer Certificate of Origin
29
+ Version 1.1
30
+
31
+ Copyright (C) 2004, 2006 The Linux Foundation and its contributors.
32
+ 1 Letterman Drive
33
+ Suite D4700
34
+ San Francisco, CA, 94129
35
+
36
+ Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed.
37
+ ```
38
+
39
+ ```
40
+ Developer's Certificate of Origin 1.1
41
+
42
+ By making a contribution to this project, I certify that:
43
+
44
+ (a) The contribution was created in whole or in part by me and I have the right to submit it under the open source license indicated in the file; or
45
+
46
+ (b) The contribution is based upon previous work that, to the best of my knowledge, is covered under an appropriate open source license and I have the right under that license to submit that work with modifications, whether created in whole or in part by me, under the same open source license (unless I am permitted to submit under a different license), as indicated in the file; or
47
+
48
+ (c) The contribution was provided directly to me by some other person who certified (a), (b) or (c) and I have not modified it.
49
+
50
+ (d) I understand and agree that this project and the contribution are public and that a record of the contribution (including all personal information I submit with it, including my sign-off) is maintained indefinitely and may be redistributed consistent with this project or the open source license(s) involved.
51
+ ```
INSTALL.md ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Environment setup
2
+
3
+ Cosmos runs only on Linux systems. We have tested the installation with Ubuntu 24.04, 22.04, and 20.04.
4
+ Cosmos requires the Python version to be `3.10.x`. Please also make sure you have `conda` installed ([instructions](https://docs.conda.io/projects/conda/en/latest/user-guide/install/index.html)).
5
+
6
+ ### Inference
7
+
8
+ The below commands creates the `cosmos-predict1` conda environment and installs the dependencies for inference:
9
+ ```bash
10
+ # Create the cosmos-predict1 conda environment.
11
+ conda env create --file cosmos-predict1.yaml
12
+ # Activate the cosmos-predict1 conda environment.
13
+ conda activate cosmos-predict1
14
+ # Install the dependencies.
15
+ pip install -r requirements.txt
16
+ # Patch Transformer engine linking issues in conda environments.
17
+ ln -sf $CONDA_PREFIX/lib/python3.10/site-packages/nvidia/*/include/* $CONDA_PREFIX/include/
18
+ ln -sf $CONDA_PREFIX/lib/python3.10/site-packages/nvidia/*/include/* $CONDA_PREFIX/include/python3.10
19
+ # Install Transformer engine.
20
+ pip install transformer-engine[pytorch]==1.12.0
21
+ # Install Apex for inference.
22
+ git clone https://github.com/NVIDIA/apex
23
+ CUDA_HOME=$CONDA_PREFIX pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./apex
24
+ # Install MoGe for inference.
25
+ pip install git+https://github.com/microsoft/MoGe.git
26
+ ```
27
+
28
+ * Alternatively, if you are more familiar with a containerized environment, you can build the dockerfile and run it to get an environment with all the packages pre-installed.
29
+ This requires docker to be already present on your system with the [Nvidia Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) installed.
30
+
31
+ ```bash
32
+ docker build -f Dockerfile . -t nvcr.io/$USER/cosmos-predict1:latest
33
+ ```
34
+
35
+ Note: In case you encounter permission issues while mounting local files inside the docker, you can share the folders from your current directory to all users (including docker) using this helpful alias `alias share='sudo chown -R ${USER}:users $PWD && sudo chmod g+w $PWD'` before running the docker.
36
+
37
+
38
+ You can test the environment setup for inference with
39
+ ```bash
40
+ CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python scripts/test_environment.py
41
+ ```
42
+
43
+ ### Post-training
44
+
45
+
46
+ 🛠️ *Under construction* 👷
47
+
48
+ Stay tuned!
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: GEN3C Project (from DGX Station)
3
+ emoji: 🫁
4
+ colorFrom: green
5
+ colorTo: blue
6
+ sdk: docker
7
+ image: elungky/gen3c:latest
8
+ # app_port: 7860 # Remove or comment this line as the image handles the port
9
+ ---
10
+
11
+ # GEN3C: 3D-Informed World-Consistent Video Generation with Precise Camera Control
12
+
13
+ <!-- Note: this video is hosted by GitHub and gets embedded automatically when viewing in the GitHub UI -->
14
+
15
+ https://github.com/user-attachments/assets/247e1719-9f8f-4504-bfa3-f9706bd8682d
16
+
17
+
18
+ **GEN3C: 3D-Informed World-Consistent Video Generation with Precise Camera Control**<br>
19
+ [Xuanchi Ren*](https://xuanchiren.com/),
20
+ [Tianchang Shen*](https://www.cs.toronto.edu/~shenti11/),
21
+ [Jiahui Huang](https://huangjh-pub.github.io/),
22
+ [Huan Ling](https://www.cs.toronto.edu/~linghuan/),
23
+ [Yifan Lu](https://yifanlu0227.github.io/),
24
+ [Merlin Nimier-David](https://merlin.nimierdavid.fr/),
25
+ [Thomas Müller](https://research.nvidia.com/person/thomas-muller),
26
+ [Alexander Keller](https://research.nvidia.com/person/alex-keller),
27
+ [Sanja Fidler](https://www.cs.toronto.edu/~fidler/),
28
+ [Jun Gao](https://www.cs.toronto.edu/~jungao/) <br>
29
+ \* indicates equal contribution <br>
30
+ **[Paper](https://arxiv.org/pdf/2503.03751), [Project Page](https://research.nvidia.com/labs/toronto-ai/GEN3C/), [HuggingFace](https://huggingface.co/collections/nvidia/gen3c-683f3f9540a8f9c98cf46a8d)**
31
+
32
+ Abstract: We present GEN3C, a generative video model with precise Camera Control and
33
+ temporal 3D Consistency. Prior video models already generate realistic videos,
34
+ but they tend to leverage little 3D information, leading to inconsistencies,
35
+ such as objects popping in and out of existence. Camera control, if implemented
36
+ at all, is imprecise, because camera parameters are mere inputs to the neural
37
+ network which must then infer how the video depends on the camera. In contrast,
38
+ GEN3C is guided by a 3D cache: point clouds obtained by predicting the
39
+ pixel-wise depth of seed images or previously generated frames. When generating
40
+ the next frames, GEN3C is conditioned on the 2D renderings of the 3D cache with
41
+ the new camera trajectory provided by the user. Crucially, this means that
42
+ GEN3C neither has to remember what it previously generated nor does it have to
43
+ infer the image structure from the camera pose. The model, instead, can focus
44
+ all its generative power on previously unobserved regions, as well as advancing
45
+ the scene state to the next frame. Our results demonstrate more precise camera
46
+ control than prior work, as well as state-of-the-art results in sparse-view
47
+ novel view synthesis, even in challenging settings such as driving scenes and
48
+ monocular dynamic video. Results are best viewed in videos.
49
+
50
+ For business inquiries, please visit our website and submit the form: [NVIDIA Research Licensing](https://www.nvidia.com/en-us/research/inquiries/).
51
+ For any other questions related to the model, please contact Xuanchi, Tianchang or Jun.
52
+
53
+ ## News
54
+ - 2025-06-06 Code and model released! In a future update, we plan to include the pipeline for jointly predicting depth and camera pose from video, as well as a driving-finetuned model. Stay tuned!
55
+
56
+ ## Installation
57
+ Please follow the "Inference" section in [INSTALL.md](INSTALL.md) to set up your environment.
58
+
59
+ ## Inference
60
+
61
+ ### Download checkpoints
62
+ 1. Generate a [Hugging Face](https://huggingface.co/settings/tokens) access token (if you haven't done so already). Set the access token to `Read` permission (default is `Fine-grained`).
63
+
64
+ 2. Log in to Hugging Face with the access token:
65
+ ```bash
66
+ huggingface-cli login
67
+ ```
68
+
69
+ 3. Download the GEN3C model weights from [Hugging Face](https://huggingface.co/nvidia/GEN3C-Cosmos-7B):
70
+ ```bash
71
+ CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python scripts/download_gen3c_checkpoints.py --checkpoint_dir checkpoints
72
+ ```
73
+
74
+ ### Interactive GUI usage
75
+
76
+ <div align="center">
77
+ <img src="gui/assets/gui_preview.webp" alt="GEN3C interactive GUI" width="1080px"/>
78
+ </div>
79
+
80
+ GEN3C can be used through an interactive GUI, allowing to visualize the inputs in 3D, author arbitrary camera trajectories, and start inference from a single window.
81
+ Please see the [dedicated instructions](gui/README.md).
82
+
83
+
84
+ ### Command-line usage
85
+ GEN3C supports both images and videos as input. Below are examples of running GEN3C on single images and videos with predefined camera trajectory patterns.
86
+
87
+ ### Example 1: Single Image to Video Generation
88
+
89
+ #### Single GPU
90
+ Generate a 121-frame video from a single image:
91
+ ```bash
92
+ CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python cosmos_predict1/diffusion/inference/gen3c_single_image.py \
93
+ --checkpoint_dir checkpoints \
94
+ --input_image_path assets/diffusion/000000.png \
95
+ --video_save_name test_single_image \
96
+ --guidance 1 \
97
+ --foreground_masking
98
+ ```
99
+
100
+ #### Multi-GPU (8 GPUs)
101
+ ```bash
102
+ NUM_GPUS=8
103
+ CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) torchrun --nproc_per_node=${NUM_GPUS} cosmos_predict1/diffusion/inference/gen3c_single_image.py \
104
+ --checkpoint_dir checkpoints \
105
+ --input_image_path assets/diffusion/000000.png \
106
+ --video_save_name test_single_image_multigpu \
107
+ --num_gpus ${NUM_GPUS} \
108
+ --guidance 1 \
109
+ --foreground_masking
110
+ ```
111
+
112
+ #### Additional Options
113
+ - To generate longer videos autoregressively, specify the number of frames using `--num_video_frames`. The number of frames must follow the pattern: 121 * N - 1 (e.g., 241, 361, etc.)
114
+ - To save buffer images alongside the output video, add the `--save_buffer` flag
115
+ - You can control camera trajectories using `--trajectory`, `--camera_rotation`, and `--movement_distance` arguments. See the "Camera Movement Options" section below for details.
116
+
117
+ #### Camera Movement Options
118
+
119
+ ##### Trajectory Types
120
+ The `--trajectory` argument controls the path the camera takes during video generation. Available options:
121
+
122
+ | Option | Description |
123
+ |--------|-------------|
124
+ | `left` | Camera moves to the left (default) |
125
+ | `right` | Camera moves to the right |
126
+ | `up` | Camera moves upward |
127
+ | `down` | Camera moves downward |
128
+ | `zoom_in` | Camera moves closer to the scene |
129
+ | `zoom_out` | Camera moves away from the scene |
130
+ | `clockwise` | Camera moves in a clockwise circular path |
131
+ | `counterclockwise` | Camera moves in a counterclockwise circular path |
132
+
133
+ ##### Camera Rotation Modes
134
+ The `--camera_rotation` argument controls how the camera rotates during movement. Available options:
135
+
136
+ | Option | Description |
137
+ |--------|-------------|
138
+ | `center_facing` | Camera always rotates to look at the (estimated) center of the scene (default) |
139
+ | `no_rotation` | Camera maintains its original orientation while moving |
140
+ | `trajectory_aligned` | Camera rotates to align with the direction of movement |
141
+
142
+ ##### Movement Distance
143
+ The `--movement_distance` argument controls how far the camera moves from its initial position. The default value is 0.3. A larger value will result in more dramatic camera movement, while a smaller value will create more subtle movement.
144
+
145
+ ##### GPU Memory Requirements
146
+
147
+ We have tested GEN3C only on H100 and A100 GPUs. For GPUs with limited memory, you can fully offload all models by appending the following flags to your command:
148
+
149
+ ```bash
150
+ --offload_diffusion_transformer \
151
+ --offload_tokenizer \
152
+ --offload_text_encoder_model \
153
+ --offload_prompt_upsampler \
154
+ --offload_guardrail_models \
155
+ --disable_guardrail \
156
+ --disable_prompt_encoder
157
+ ```
158
+ Maximum observed memory during inference with full offloading: ~43GB. Note: Memory usage may vary depending on system specifications and is provided for reference only.
159
+
160
+
161
+ ### Example 2: Video to Video Generation
162
+ For video input, GEN3C requires additional depth information, camera intrinsics, and extrinsics. These can be obtained using your choice of SLAM packages. For testing purposes, we provide example data.
163
+
164
+ First, you need to download the test samples:
165
+ ```bash
166
+ # Download test samples from Hugging Face
167
+ huggingface-cli download nvidia/GEN3C-Testing-Example --repo-type dataset --local-dir assets/diffusion/dynamic_video_samples
168
+ ```
169
+
170
+ #### Single GPU
171
+ ```bash
172
+ CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python cosmos_predict1/diffusion/inference/gen3c_dynamic.py \
173
+ --checkpoint_dir checkpoints \
174
+ --input_image_path assets/diffusion/dynamic_video_samples/batch_0000 \
175
+ --video_save_name test_dynamic_video \
176
+ --guidance 1
177
+ ```
178
+
179
+ #### Multi-GPU (8 GPUs)
180
+ ```bash
181
+ NUM_GPUS=8
182
+ CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) torchrun --nproc_per_node=${NUM_GPUS} cosmos_predict1/diffusion/inference/gen3c_dynamic.py \
183
+ --checkpoint_dir checkpoints \
184
+ --input_image_path assets/diffusion/dynamic_video_samples/batch_0000 \
185
+ --video_save_name test_dynamic_video_multigpu \
186
+ --num_gpus ${NUM_GPUS} \
187
+ --guidance 1
188
+ ```
189
+
190
+ ## Gallery
191
+
192
+ - **GEN3C** can be easily applied to video/scene creation from a single image
193
+ <div align="center">
194
+ <img src="assets/demo_3.gif" alt="" width="1100" />
195
+ </div>
196
+
197
+ - ... or sparse-view images (we use 5 images here)
198
+ <div align="center">
199
+ <img src="assets/demo_2.gif" alt="" width="1100" />
200
+ </div>
201
+
202
+
203
+ - .. and dynamic videos
204
+ <div align="center">
205
+ <img src="assets/demo_dynamic.gif" alt="" width="1100" />
206
+ </div>
207
+
208
+ ## Acknowledgement
209
+ Our model is based on [NVIDIA Cosmos](https://github.com/NVIDIA/Cosmos) and [Stable Video Diffusion](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid).
210
+
211
+ We are also grateful to several other open-source repositories that we drew inspiration from or built upon during the development of our pipeline:
212
+ - [MoGe](https://github.com/microsoft/MoGe)
213
+ - [TrajectoryCrafter](https://github.com/TrajectoryCrafter/TrajectoryCrafter)
214
+ - [DimensionX](https://github.com/wenqsun/DimensionX)
215
+ - [Depth Anything V2](https://github.com/DepthAnything/Depth-Anything-V2)
216
+ - [Video Depth Anything](https://github.com/DepthAnything/Video-Depth-Anything)
217
+
218
+ ## Citation
219
+ ```
220
+ @inproceedings{ren2025gen3c,
221
+ title={GEN3C: 3D-Informed World-Consistent Video Generation with Precise Camera Control},
222
+ author={Ren, Xuanchi and Shen, Tianchang and Huang, Jiahui and Ling, Huan and
223
+ Lu, Yifan and Nimier-David, Merlin and Müller, Thomas and Keller, Alexander and
224
+ Fidler, Sanja and Gao, Jun},
225
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
226
+ year={2025}
227
+ }
228
+ ```
229
+
230
+ ## License and Contact
231
+
232
+ This project will download and install additional third-party open source software projects. Review the license terms of these open source projects before use.
233
+
234
+
235
+ GEN3C source code is released under the [Apache 2 License](https://www.apache.org/licenses/LICENSE-2.0).
236
+
237
+ GEN3C models are released under the [NVIDIA Open Model License](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). For a custom license, please visit our website and submit the form: [NVIDIA Research Licensing](https://www.nvidia.com/en-us/research/inquiries/).
238
+ =======
239
+ title: Gen3c
240
+ emoji: 🌍
241
+ colorFrom: indigo
242
+ colorTo: blue
243
+ sdk: docker
244
+ pinned: false
245
+ ---
246
+
247
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
248
+ >>>>>>> 0453ffbfce197070bb0c254a11ef21f15d1ad986
assets/demo_1.gif ADDED

Git LFS Details

  • SHA256: e6162366c56277d084b05a37c617e2994ba75285d421e203556dcff08128b32b
  • Pointer size: 133 Bytes
  • Size of remote file: 14.7 MB
assets/demo_2.gif ADDED

Git LFS Details

  • SHA256: e765e71d3016c6e314b6403f82313a1df42f68f6fb0f9416f197d82e0710f27e
  • Pointer size: 133 Bytes
  • Size of remote file: 10.6 MB
assets/demo_3.gif ADDED

Git LFS Details

  • SHA256: 8c4cf4a4bf62daf03b25ac66c2c3693adbf7cd459e55d3481a65a9ff4a9d09d9
  • Pointer size: 133 Bytes
  • Size of remote file: 35.3 MB
assets/demo_dynamic.gif ADDED

Git LFS Details

  • SHA256: 174faba45ae701eaa432dd14de1297c0479b6c0b832adbc211cbb529fbec6c61
  • Pointer size: 133 Bytes
  • Size of remote file: 24.5 MB
assets/diffusion/000000.png ADDED

Git LFS Details

  • SHA256: b7e6eab7548c2ede900f8b504a5cef981e0cd0ec38af90dbea3f0db860e002c3
  • Pointer size: 132 Bytes
  • Size of remote file: 1.33 MB
assets/diffusion/000001.png ADDED

Git LFS Details

  • SHA256: abe310078829c9e1375ac30c7c270c84c8f68a09f3857bd35c7a5754f3326151
  • Pointer size: 132 Bytes
  • Size of remote file: 1.13 MB
assets/diffusion/000002.png ADDED

Git LFS Details

  • SHA256: 7ad89b53e9fafed0d8eefd1cfc7cc4889c5d2f510ed32d5247c5adab4cb0c622
  • Pointer size: 131 Bytes
  • Size of remote file: 789 kB
assets/diffusion/000003.png ADDED

Git LFS Details

  • SHA256: 22f39915f1b277e70683befbc18ac5859c65c3d389e4dbb5127a539a411fec54
  • Pointer size: 132 Bytes
  • Size of remote file: 1.11 MB
assets/diffusion/000004.png ADDED

Git LFS Details

  • SHA256: e2f957208849c0f86b89545734bb7b243868b574554cb6aeed248b04e7234ad4
  • Pointer size: 132 Bytes
  • Size of remote file: 1.26 MB
assets/diffusion/000005.png ADDED

Git LFS Details

  • SHA256: 267f6ae47d0e2aebda89fac5416bc0915855043131d0d8d8a4fc9506cabd4681
  • Pointer size: 132 Bytes
  • Size of remote file: 1.36 MB
assets/diffusion/000006.png ADDED

Git LFS Details

  • SHA256: 4b6fd098366bcd54bd21a5707ae6d9f78d74c2eefcfbb6919569c0d1741d837f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.21 MB
assets/diffusion/000007.png ADDED

Git LFS Details

  • SHA256: 334733b7428f9521e625a8b310770fbba3e4616ccbe0af625d07e2b065e6e9ad
  • Pointer size: 132 Bytes
  • Size of remote file: 1.15 MB
assets/diffusion/000008.png ADDED

Git LFS Details

  • SHA256: 7eae1abb3343c1e11f4e42172eba85eeed0fb2a5f7701a42e5003cf84f1696cd
  • Pointer size: 132 Bytes
  • Size of remote file: 1.68 MB
assets/diffusion/000009.png ADDED

Git LFS Details

  • SHA256: 2a5c5711d41f56bb307ef6020d0dffec9ce2297bda9ef9ae465237d8347adb34
  • Pointer size: 131 Bytes
  • Size of remote file: 603 kB
assets/diffusion/000010.png ADDED

Git LFS Details

  • SHA256: e4d32f1d1c6d427e421d6f4478d4c2c697cb0406a18ecc3b8ebeeb2a0cbba7f5
  • Pointer size: 132 Bytes
  • Size of remote file: 1.18 MB
assets/diffusion/000011.png ADDED

Git LFS Details

  • SHA256: e352d7435d3b313fcc47efd9bd0dc6e0dd5d5e8af8c50e965c57987bee1c94ec
  • Pointer size: 131 Bytes
  • Size of remote file: 944 kB
assets/diffusion/000012.png ADDED

Git LFS Details

  • SHA256: b672d43521890b2852976a0c12828ad16b9288277efff6c41189dc0c04c9c6e1
  • Pointer size: 132 Bytes
  • Size of remote file: 1.1 MB
assets/diffusion/000013.png ADDED

Git LFS Details

  • SHA256: eab3a655213eede094889bab94313e1cef142b811429bee9e0f3420c2b013105
  • Pointer size: 132 Bytes
  • Size of remote file: 1.24 MB
assets/diffusion/000014.png ADDED

Git LFS Details

  • SHA256: eb014db53082677aca35a3fc27daa1f306452c5cb7130a4ed6468cae144a0b63
  • Pointer size: 132 Bytes
  • Size of remote file: 1.35 MB
assets/diffusion/000015.png ADDED

Git LFS Details

  • SHA256: a6ac0d4e7eb6d4dbc3ae997fafc28721b716db092aaa52ede11e4d87b3e9b20d
  • Pointer size: 132 Bytes
  • Size of remote file: 1.49 MB
checkpoints/README.md ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+
2
+ ### Checkpoint directory
3
+
4
+ Model checkpoints will be downloaded to this directory.
cosmos-predict1.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # conda env create --file cosmos-predict1.yaml
17
+ name: cosmos-predict1
18
+ channels:
19
+ - conda-forge
20
+ dependencies:
21
+ - python=3.10
22
+ - pip=25.0
23
+ - cmake
24
+ - ninja
25
+ - gcc=12.4.0
26
+ - gxx=12.4.0
27
+ - cuda=12.4
28
+ - cuda-nvcc=12.4
29
+ - cuda-toolkit=12.4
cosmos_predict1/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
cosmos_predict1/autoregressive/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
cosmos_predict1/autoregressive/callbacks/video_sampling_teacher_forcing.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import glob
17
+ import math
18
+ import os
19
+ from typing import Optional
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torchvision
24
+ import torchvision.transforms.functional as torchvision_F
25
+ import wandb
26
+ from einops import rearrange
27
+ from megatron.core import parallel_state
28
+ from torch.distributed import get_process_group_ranks
29
+
30
+ from cosmos_predict1.autoregressive.utils.parallel import (
31
+ broadcast_data_batch_in_tp_cp_group,
32
+ gather_batch_from_cp_ranks,
33
+ get_batch_on_this_cp_rank,
34
+ )
35
+ from cosmos_predict1.callbacks.every_n import EveryN
36
+ from cosmos_predict1.utils import distributed, log, misc
37
+ from cosmos_predict1.utils.model import Model
38
+ from cosmos_predict1.utils.trainer import Trainer
39
+
40
+
41
+ def resize_image(image: torch.Tensor, resize_factor=0.5) -> torch.Tensor:
42
+ _, _, h, w = image.shape
43
+ new_h, new_w = int(resize_factor * h), int(resize_factor * w)
44
+ return torchvision_F.resize(image, (new_h, new_w))
45
+
46
+
47
+ class VideoSamplingTeacherForcing(EveryN):
48
+ def __init__(
49
+ self,
50
+ every_n: int,
51
+ step_size: int = 1,
52
+ video_latent_shape: list = [6, 24, 40],
53
+ num_frames_to_display: int = 4,
54
+ save_folder: Optional[str] = None,
55
+ num_file_to_log: int = 8,
56
+ ):
57
+ r"""
58
+ This callback enables us to perform teacher forcing inference on the training data.
59
+ By teacher forcing, we mean providing ground truth video tokens as inputs, and simply asking the model
60
+ to predict the next tokens. The predicted next tokens are then visualized. This does not perform
61
+ autoregressive sampling.
62
+ We also upload the downsampled video frames to wandb. Downsampling is needed for wandb to work fast.
63
+
64
+ Args:
65
+ every_n (int): Call this callback every_n steps
66
+ step_size (int): Number of steps taken for gradient accumulation. Global iteration number is
67
+ iteration // self.step_size
68
+ video_latent_shape (list): Shape of the video latent
69
+ num_frames_to_display (int): Number of frames to subsample for displaying in wandb
70
+ save_folder (str): Name of the local folder to save the video
71
+ num_file_to_log (int): Number of files to upload to wandb
72
+ """
73
+ super().__init__(every_n, step_size)
74
+ self.save_folder = save_folder if save_folder else self.__class__.__name__
75
+ self.video_latent_shape = video_latent_shape
76
+ self.num_frames_to_display = num_frames_to_display
77
+ self.num_file_to_log = num_file_to_log
78
+ self.rank = distributed.get_rank()
79
+
80
+ def on_train_start(self, model: Model, iteration: int = 0) -> None:
81
+ config_job = self.config.job
82
+ self.local_dir = f"{config_job.path_local}/{self.save_folder}"
83
+ if self.rank == 0:
84
+ os.makedirs(self.local_dir, exist_ok=True)
85
+ log.info(f"Video Teacher-Forcing Callback: local_dir: {self.local_dir}")
86
+
87
+ @torch.inference_mode()
88
+ def every_n_impl(
89
+ self,
90
+ trainer: Trainer,
91
+ model: Model,
92
+ data_batch: dict[str, torch.Tensor],
93
+ output_batch: dict[str, torch.Tensor],
94
+ loss: torch.Tensor,
95
+ iteration: int,
96
+ ) -> None:
97
+ # Tokenize the data
98
+
99
+ broadcast_data_batch_in_tp_cp_group(data_batch)
100
+
101
+ input_vid = data_batch[model.tokenizer.tokenizer_config.video_tokenizer.data_key]
102
+
103
+ dataset_name = data_batch.get("dataset_name", None)
104
+ if dataset_name is not None and dataset_name.startswith("image"):
105
+ # we disable the callback if the input video is an image batch
106
+ log.info(f"dataset_name is {dataset_name}, skip this callback")
107
+ return
108
+
109
+ # get the caption
110
+ captions = data_batch.get("caption", None)
111
+
112
+ # get the context embedding and mask
113
+ context = data_batch.get("context", None)
114
+ context_mask = data_batch.get("context_mask", None)
115
+ if context is not None:
116
+ context = misc.to(context, "cuda").detach().clone()
117
+ if context_mask is not None:
118
+ context_mask = misc.to(context_mask, "cuda").detach().clone()
119
+ # get the action
120
+ action = data_batch.get("action", None)
121
+ if action is not None:
122
+ action = misc.to(action, "cuda").detach().clone()
123
+
124
+ # Input tokens
125
+ tokens, _ = model.tokenizer.tokenize(data_batch)
126
+ tokens = misc.to(tokens, "cuda").detach().clone()
127
+ skip_save_file = False
128
+ if parallel_state.get_context_parallel_world_size() > 1:
129
+ cp_group = parallel_state.get_context_parallel_group()
130
+ if self.rank != min(get_process_group_ranks(cp_group)):
131
+ skip_save_file = True
132
+ tokens = get_batch_on_this_cp_rank(tokens)
133
+ if parallel_state.get_tensor_model_parallel_world_size() > 1:
134
+ # Turn on TP
135
+ tp_group = parallel_state.get_tensor_model_parallel_group()
136
+ if self.rank != min(get_process_group_ranks(tp_group)):
137
+ skip_save_file = True
138
+ tokens_encoded_in_train = output_batch["encode_tokens"].detach()
139
+ percent_token_diff = (tokens != tokens_encoded_in_train).float().mean()
140
+ percent_token_diff = distributed.dist_reduce_tensor(percent_token_diff)
141
+
142
+ input_tokens = tokens
143
+
144
+ num_tokens_to_generate = np.prod(self.video_latent_shape)
145
+
146
+ # Do a forward pass
147
+ logits = model.model.forward(
148
+ tokens,
149
+ input_pos=None,
150
+ context=context,
151
+ context_mask=context_mask,
152
+ action=action,
153
+ )
154
+ if parallel_state.get_context_parallel_world_size() > 1:
155
+ logits = gather_batch_from_cp_ranks(logits)
156
+ input_tokens = gather_batch_from_cp_ranks(input_tokens)
157
+
158
+ # Start position for video tokens in the vocabulary
159
+ video_token_start = self.config.model.tokenizer_config.video_tokenizer.tokenizer_offset
160
+ video_vocab_size = self.config.model.tokenizer_config.video_tokenizer.vocab_size
161
+
162
+ # Clipping logits only to video tokens. We remove the text vocab predictions.
163
+ # This will ensure that the video tokens only correspond to the video part of the vocabulary.
164
+ logits = logits[:, :, video_token_start : video_token_start + video_vocab_size]
165
+
166
+ # Sample with argmax token. This should be good for teacher forcing experiment.
167
+ logits = logits.contiguous()
168
+ generations = torch.argmax(logits, dim=-1)
169
+
170
+ # For each video in the batch, subsample frames for display
171
+ batch_size = input_tokens.shape[0]
172
+ out_frames = []
173
+ out_videos_gen = []
174
+ out_videos_rec = []
175
+ out_videos_gt = []
176
+ # log the accuracy of teacher-forcing
177
+ acc = []
178
+ loss_list = []
179
+
180
+ for sample_num in range(batch_size):
181
+ # Subsample the generations to the video part.
182
+ # This corresponds to the part from begin of video to end of video.
183
+ bov_token = model.tokenizer.video_special_tokens["<|begin_of_video|>"]
184
+ bov_index = input_tokens[sample_num] == bov_token
185
+ use_special_token = sum(bov_index) != 0
186
+ if use_special_token:
187
+ bov_index = bov_index.nonzero().item()
188
+ # generations: <bov> real_token1 real_token2, ... real_token7680; total 7680
189
+ # gen_video_tokens: real_token1 real_token2, ..., real_token7680; total 7680
190
+ # for vis: real_token1 real_token2, ..., real_token7680; total 7680
191
+ # for accuracy: real_token1 real_token2, ..., real_token7680; total 7680
192
+ gen_video_tokens = generations[sample_num][bov_index : bov_index + num_tokens_to_generate]
193
+ gen_video_tokens_vis = gen_video_tokens
194
+ gen_video_tokens_acc = gen_video_tokens
195
+ logits_loss = logits[sample_num][bov_index : bov_index + num_tokens_to_generate]
196
+ else:
197
+ # generations: real_token1 real_token2, ... real_token7680
198
+ # gen_video_tokens: real_token2 real_token3, ..., real_token7680; total 7679
199
+ # We need different tokens for vis and accuracy compute
200
+ # for acc: real_token2 real_token3, ..., real_token7680; total 7679
201
+ # for vis: pad_token (real_token2, ..., real_token7680); total 1 + 7679
202
+ gen_video_tokens = generations[sample_num][
203
+ : num_tokens_to_generate - 1
204
+ ] # remove the last token since there is no gt
205
+ # Since the first token is not predicted, we need to add the gt first token to make sure the shape is correct
206
+ gen_video_tokens_vis = torch.cat([input_tokens[sample_num][0:1], gen_video_tokens])
207
+ gen_video_tokens_acc = gen_video_tokens
208
+ logits_loss = logits[sample_num][: num_tokens_to_generate - 1]
209
+
210
+ # Rearrange the video to a spatial tensor
211
+ gen_video_tokens_vis_BTHW = rearrange(
212
+ gen_video_tokens_vis.unsqueeze(0),
213
+ "B (T H W) -> B T H W",
214
+ T=self.video_latent_shape[0],
215
+ H=self.video_latent_shape[1],
216
+ W=self.video_latent_shape[2],
217
+ )
218
+
219
+ # for real videos, we need to skip the bov and eov tokens for decoding
220
+ if use_special_token:
221
+ # input_tokens: <bov> real_token1 real_token2 ... <eov> <eov> ...
222
+ # real_video_tokens: real_token1 real_token2 ... real_token7680; total 7680
223
+ # for vis: real_token1 real_token2 ... real_token7680; total 7680
224
+ # for accuracy: real_token1 real_token2 ... real_token7680; total 7680; we include real_token1 since the output prediction also includes it, see gen_video_tokens_acc above
225
+ real_video_tokens = (
226
+ input_tokens[sample_num][bov_index + 1 : bov_index + num_tokens_to_generate + 1] - video_token_start
227
+ )
228
+ real_video_tokens_vis = real_video_tokens
229
+ real_video_tokens_acc = real_video_tokens
230
+ else:
231
+ # input_tokens: real_token1 real_token2,... real_token7680; total 7680
232
+ # real_video_tokens: real_token1 real_token2,... real_token7680; total 7680
233
+ # for acc: gt start from real_token2, real_token3; total 7679, remove the first token since it is not predicted
234
+ # for vis: gt start from real_token1, real_token2; total 7680
235
+ real_video_tokens = input_tokens[sample_num][:num_tokens_to_generate] - video_token_start
236
+ real_video_tokens_vis = real_video_tokens
237
+ real_video_tokens_acc = real_video_tokens[1:].flatten()
238
+
239
+ real_video_tokens_vis_BTHW = rearrange(
240
+ real_video_tokens_vis.unsqueeze(0),
241
+ "B (T H W) -> B T H W",
242
+ T=self.video_latent_shape[0],
243
+ H=self.video_latent_shape[1],
244
+ W=self.video_latent_shape[2],
245
+ )
246
+ # Calculate accuracy
247
+ correct_predictions = (gen_video_tokens_acc == real_video_tokens_acc).float()
248
+ labels = real_video_tokens_acc.clone()
249
+
250
+ if model.config.ignore_first_num_tokens > 0:
251
+ labels[: model.config.ignore_first_num_tokens] = model.tokenizer.ignore_index
252
+ select_index = labels != model.tokenizer.ignore_index
253
+ correct_predictions = correct_predictions[select_index]
254
+
255
+ loss = torch.nn.functional.cross_entropy(
256
+ logits_loss, labels, ignore_index=model.tokenizer.ignore_index, reduction="none"
257
+ )
258
+ acc.append(correct_predictions.mean() * 100.0)
259
+ loss_list.append(loss.mean())
260
+
261
+ # Decode the predicted latents
262
+ if model.tokenizer.tokenizer_config.video_tokenizer.temporal_overlap == 0:
263
+ vid_decoded = model.tokenizer.video_tokenizer.decode(gen_video_tokens_vis_BTHW.cuda())
264
+ else:
265
+ vid_decoded = model.tokenizer.video_tokenizer.decode_with_overlap(
266
+ gen_video_tokens_vis_BTHW.cuda(),
267
+ temporal_overlap=model.tokenizer.tokenizer_config.video_tokenizer.temporal_overlap,
268
+ )
269
+ # normalize decoded images from [-1, 1] to [0, 1], and clip value
270
+ vid_decoded = (vid_decoded * 0.5 + 0.5).clamp_(0, 1)
271
+ vid_decoded = vid_decoded[0]
272
+
273
+ # Decode the GT latents
274
+ if model.tokenizer.tokenizer_config.video_tokenizer.temporal_overlap == 0:
275
+ vid_rec = model.tokenizer.video_tokenizer.decode(real_video_tokens_vis_BTHW.cuda())
276
+ else:
277
+ vid_rec = model.tokenizer.video_tokenizer.decode_with_overlap(
278
+ real_video_tokens_vis_BTHW.cuda(),
279
+ temporal_overlap=model.tokenizer.tokenizer_config.video_tokenizer.temporal_overlap,
280
+ )
281
+ # normalize decoded image from [-1, 1] to [0, 1], and clip value
282
+ vid_rec = (vid_rec * 0.5 + 0.5).clamp_(0, 1)
283
+ vid_rec = vid_rec[0]
284
+
285
+ vid_input = input_vid[sample_num] # [-1, 1], input_vid shape: [B, C, L, H, W]
286
+ vid_input = (vid_input * 0.5 + 0.5).clamp_(0, 1).cuda() # Convert to [0, 1], [C, L, H, W]
287
+
288
+ # Subsample real and generated video frames
289
+ input_video_frames = vid_input.transpose(0, 1) # [L, C, H, W]
290
+ rec_video_frames = vid_rec.transpose(0, 1)
291
+ gen_video_frames = vid_decoded.transpose(0, 1)
292
+ out_videos_gen.append(gen_video_frames)
293
+ out_videos_rec.append(rec_video_frames)
294
+ out_videos_gt.append(input_video_frames)
295
+
296
+ stride = math.ceil(rec_video_frames.shape[0] / self.num_frames_to_display)
297
+
298
+ input_video_frames_subsampled = resize_image(input_video_frames[0::stride], resize_factor=0.5)
299
+ input_video_frames_subsampled = torchvision.utils.make_grid(
300
+ input_video_frames_subsampled, nrow=input_video_frames_subsampled.shape[0]
301
+ )
302
+
303
+ gt_video_frames_subsampled = resize_image(rec_video_frames[0::stride], resize_factor=0.5)
304
+ gt_video_frames_subsampled = torchvision.utils.make_grid(
305
+ gt_video_frames_subsampled, nrow=gt_video_frames_subsampled.shape[0]
306
+ )
307
+ gen_video_frames_subsampled = resize_image(gen_video_frames[0::stride], resize_factor=0.5)
308
+ gen_video_frames_subsampled = torchvision.utils.make_grid(
309
+ gen_video_frames_subsampled, nrow=gen_video_frames_subsampled.shape[0]
310
+ )
311
+
312
+ out_frames.append(input_video_frames_subsampled)
313
+ out_frames.append(gt_video_frames_subsampled)
314
+ out_frames.append(gen_video_frames_subsampled)
315
+
316
+ scaled_num_rank_to_log = (
317
+ self.num_file_to_log
318
+ * parallel_state.get_context_parallel_world_size()
319
+ * parallel_state.get_tensor_model_parallel_world_size()
320
+ )
321
+ if self.rank < scaled_num_rank_to_log and not skip_save_file:
322
+ local_path = f"{self.local_dir}/vid_teacher_forcing_iter_{iteration:09d}_{self.rank:04d}.jpg"
323
+ out_image_grid = torchvision.utils.make_grid(out_frames, nrow=1, padding=0, normalize=False)
324
+ os.makedirs(os.path.dirname(local_path), exist_ok=True)
325
+ torchvision.utils.save_image(out_image_grid, local_path)
326
+
327
+ # Log to wandb
328
+ avg_acc = distributed.dist_reduce_tensor(torch.stack(acc).mean()).item()
329
+ avg_loss = distributed.dist_reduce_tensor(torch.stack(loss_list).mean()).item()
330
+ log_info = ""
331
+ if "acc" in output_batch:
332
+ log_info = f"train acc: {(output_batch['acc'].mean().item()):.6f}%"
333
+ if percent_token_diff is not None:
334
+ log_info += f"; percent_token_diff_train_val: {percent_token_diff.item() * 100:.6f}%"
335
+ log.info(
336
+ f"Eval iteration {iteration} teacher-forcing accuracy: {avg_acc:.6f}%, loss: {avg_loss:.4f}; {log_info}"
337
+ )
338
+ if self.rank == 0 and wandb.run:
339
+ local_files = glob.glob(f"{self.local_dir}/vid_teacher_forcing_iter_{iteration:09d}_*.jpg")
340
+ local_files = sorted(local_files)[: self.num_file_to_log]
341
+ if captions is None:
342
+ captions = ["vid_frames_teacher_forcing"] * len(local_files)
343
+ for local_path, caption in zip(local_files, captions):
344
+ wandb.log(
345
+ {"frames": [wandb.Image(local_path, caption=caption)]},
346
+ step=iteration,
347
+ )
348
+
349
+ wandb.log({"eval/teacher_forcing_acc": avg_acc}, step=iteration)
350
+ wandb.log({"eval/teacher_forcing_loss": avg_loss}, step=iteration)
351
+ if percent_token_diff is not None:
352
+ wandb.log({"eval/percent_token_diff_train_val": percent_token_diff.item() * 100}, step=iteration)
cosmos_predict1/autoregressive/configs/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
cosmos_predict1/autoregressive/configs/base/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
cosmos_predict1/autoregressive/configs/base/callbacks.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from cosmos_predict1.autoregressive.callbacks.video_sampling_teacher_forcing import VideoSamplingTeacherForcing
17
+ from cosmos_predict1.callbacks.grad_clip import GradClip
18
+ from cosmos_predict1.utils.callback import ProgressBarCallback
19
+ from cosmos_predict1.utils.lazy_config import LazyCall as L
20
+
21
+ BASIC_CALLBACKS = dict(
22
+ progress_bar=L(ProgressBarCallback)(),
23
+ grad_clip=L(GradClip)(clip_norm=1.0, fsdp_enabled="${model.model_config.fsdp_enabled}", model_key="model"),
24
+ )
25
+
26
+ VIDEO_TEACHER_FORCING_CALLBACK = dict(
27
+ vid_sampling_tf=L(VideoSamplingTeacherForcing)(
28
+ every_n=500,
29
+ video_latent_shape="${model.model_config.video_latent_shape}",
30
+ num_frames_to_display=4,
31
+ save_folder="video_sampling_teacher_forcing",
32
+ )
33
+ )
cosmos_predict1/autoregressive/configs/base/dataloader.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from megatron.core import parallel_state
17
+ from torch.utils.data import DataLoader, DistributedSampler
18
+
19
+ from cosmos_predict1.autoregressive.configs.base.dataset import VideoDatasetConfig
20
+ from cosmos_predict1.autoregressive.datasets.video_dataset import VideoDataset
21
+ from cosmos_predict1.utils import log
22
+ from cosmos_predict1.utils.lazy_config import LazyCall as L
23
+
24
+ DATALOADER_OPTIONS = {}
25
+
26
+
27
+ def get_sampler(dataset):
28
+ return DistributedSampler(
29
+ dataset,
30
+ num_replicas=parallel_state.get_data_parallel_world_size(),
31
+ rank=parallel_state.get_data_parallel_rank(),
32
+ shuffle=True,
33
+ seed=0,
34
+ )
35
+
36
+
37
+ def dataloader_register(key):
38
+ log.info(f"registering dataloader {key}...")
39
+
40
+ def decorator(func):
41
+ DATALOADER_OPTIONS[key] = func
42
+ return func
43
+
44
+ return decorator
45
+
46
+
47
+ @dataloader_register("tealrobot_video")
48
+ def get_tealrobot_video(
49
+ batch_size: int = 1,
50
+ dataset_dir: str = "datasets/cosmos_nemo_assets/videos/",
51
+ sequence_interval: int = 1,
52
+ num_frames: int = 33,
53
+ video_size: list[int, int] = [640, 848],
54
+ start_frame_interval: int = 1,
55
+ ):
56
+ dataset = L(VideoDataset)(
57
+ config=VideoDatasetConfig(
58
+ dataset_dir=dataset_dir,
59
+ sequence_interval=sequence_interval,
60
+ num_frames=num_frames,
61
+ video_size=video_size,
62
+ start_frame_interval=start_frame_interval,
63
+ )
64
+ )
65
+ return L(DataLoader)(
66
+ dataset=dataset,
67
+ sampler=L(get_sampler)(dataset=dataset),
68
+ batch_size=batch_size,
69
+ drop_last=True,
70
+ pin_memory=True,
71
+ num_workers=8,
72
+ )
cosmos_predict1/autoregressive/configs/base/dataset.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Dataset config class."""
17
+
18
+ import attrs
19
+
20
+ from cosmos_predict1.utils.config import make_freezable
21
+
22
+
23
+ @make_freezable
24
+ @attrs.define(slots=False)
25
+ class VideoDatasetConfig:
26
+ """
27
+ Args:
28
+ dataset_dir (str): Base path to the dataset directory
29
+ sequence_interval (int): Interval between sampled frames in a sequence
30
+ num_frames (int): Number of frames to load per sequence
31
+ video_size (list): Target size [H,W] for video frames
32
+ start_frame_interval (int): Interval between starting frames of sequences
33
+ """
34
+
35
+ dataset_dir: str = "datasets/cosmos_nemo_assets/videos/"
36
+ sequence_interval: int = 1
37
+ num_frames: int = 33
38
+ video_size: list[int, int] = [640, 848]
39
+ start_frame_interval: int = 1
cosmos_predict1/autoregressive/configs/base/model.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Optional
17
+
18
+ import attrs
19
+
20
+ from cosmos_predict1.autoregressive.configs.base.tokenizer import TokenizerConfig
21
+ from cosmos_predict1.utils import config
22
+
23
+ _ACTION_DIM = 8
24
+ from cosmos_predict1.utils.lazy_config import LazyDict
25
+
26
+
27
+ @attrs.define
28
+ class ModelConfig:
29
+ """
30
+ A class to hold model configuration arguments.
31
+
32
+ Args:
33
+ dim (int): The dimensionality of the input and output of each transformer block.
34
+ n_layers (int): Number of layers in the transformer.
35
+ n_heads (int): Number of attention heads.
36
+ n_kv_heads (Optional[int]): Number of key-value heads. If None, defaults to n_heads. Note: this is equivalent to
37
+ `num_gqa_groups` in TransformerEngine, where GQA means Grouped Query Attention.
38
+ head_dim (Optional[int]): Dimensionality of each head. If None, defaults to dim // n_heads.
39
+ vocab_size (int): Vocabulary size.
40
+ ffn_hidden_size (int): Hidden size for feedforward network.
41
+ norm_eps (float): Epsilon value for normalization.
42
+ rope_theta (float): Theta value for rotary positional embeddings.
43
+ apply_abs_pos_emb (bool): Whether to apply absolute position embeddings.
44
+ max_batch_size (int): Maximum batch size for inference.
45
+ max_seq_len (int): Maximum sequence length for input text.
46
+ fuse_qkv (bool): Whether to fuse QKV in attention. Defaults to True.
47
+ causal_mask (bool): Whether to use causal mask. Defaults to True.
48
+ norm_type (str): Type of normalization layer. Choices: "rmsnorm", "fused_rmsnorm", "layernorm", "np_layernorm".
49
+ precision (str): Data type for the model.
50
+ use_qk_normalization (bool): Whether to enable QK normalization.
51
+ tensor_model_parallel_size (int): Tensor model parallel size. Defaults to 1.
52
+ ckpt_dir (str): Checkpoint directory.
53
+ ckpt_path (str): Checkpoint path.
54
+ apply_yarn (Optional[bool]): Whether to apply YaRN (long-context extension).
55
+ yarn_scale (Optional[float]): Scale factor for YaRN.
56
+ yarn_beta_fast (Optional[int]): Beta fast variable for YaRN (i.e., low_freq_factor in Llama 3.1 RoPE scaling code)
57
+ yarn_beta_slow (Optional[int]): Beta slow variable for YaRN (i.e., high_freq_factor in Llama 3.1 RoPE scaling code)
58
+ original_seq_len (Optional[int]): Original sequence length.
59
+ vision_encoder (Optional[str]): Vision encoder name.
60
+ mm_projector (Optional[str]): Multi-modal projector name.
61
+ vision_encoder_in_channels (Optional[int]): Number of channels in the input image for the vision encoder. Default is 3, you can specify to int larger than 3. E.g. if you have 4-channel images with the last channel as the alpha channel, set this to 4.
62
+ rope_dim (Optional[str]): Dimensionality of the RoPE. Choices: "1D", "3D".
63
+ pytorch_rope_version (Optional[str]): Version of the PyTorch RoPE implementation. Choices: "v1", "v2".
64
+ original_latent_shape (Optional[list]): Original shape of the latent tensor needed for rope extension.
65
+ pad_to_multiple_of (Optional[int]): Pad the position embedding to a multiple of this value.
66
+ vision_encoder_in_channels (Optional[int]): Number of channels in the input image for the vision encoder. Default is 3.
67
+ insert_cross_attn (bool): Whether to insert the cross-attention layers after each multi-head self-attention (MSA) layer.
68
+ insert_cross_attn_every_k_layers (int): Insert cross-attention layers every k TransformerLayers.
69
+ context_dim (Optional[int]): The dimensionality of cross-attention embedding, e.g., T5 embed feature dim.
70
+ num_video_frames (Optional[int]): Number of video frames.
71
+ video_height (Optional[int]): Raw video pixel height dimension.
72
+ video_width (Optional[int]): Raw video pixel width dimension.
73
+ video_latent_shape (Optional[list]): Video tokenizer output dimension, in (T,H,W).
74
+ """
75
+
76
+ dim: int = attrs.field(default=4096)
77
+ n_layers: int = attrs.field(default=32)
78
+ n_heads: int = attrs.field(default=32)
79
+ n_kv_heads: Optional[int] = attrs.field(default=8)
80
+ head_dim: Optional[int] = attrs.field(default=None)
81
+ vocab_size: int = attrs.field(default=128256)
82
+ ffn_hidden_size: int = attrs.field(default=14336)
83
+ norm_eps: float = attrs.field(default=1e-5)
84
+ rope_theta: float = attrs.field(default=500000)
85
+ apply_abs_pos_emb: bool = attrs.field(default=False)
86
+ max_batch_size: int = attrs.field(default=1)
87
+ max_seq_len: int = attrs.field(default=8192)
88
+ fuse_qkv: bool = attrs.field(default=False)
89
+ causal_mask: bool = attrs.field(default=True)
90
+ norm_type: str = attrs.field(default="rmsnorm")
91
+ precision: str = attrs.field(default="bfloat16")
92
+ use_qk_normalization: bool = False
93
+ tokenizer: Optional[TokenizerConfig] = None
94
+ tensor_model_parallel_size: int = attrs.field(default=1)
95
+ ckpt_dir: Optional[str] = attrs.field(default=None)
96
+ ckpt_path: Optional[str] = attrs.field(
97
+ default=None
98
+ ) # If not None, load the model from this path instead of ckpt_dir
99
+ apply_yarn: Optional[bool] = attrs.field(default=False)
100
+ yarn_scale: Optional[float] = attrs.field(default=None)
101
+ yarn_beta_fast: Optional[int] = attrs.field(default=None)
102
+ yarn_beta_slow: Optional[int] = attrs.field(default=None)
103
+ original_seq_len: Optional[int] = attrs.field(default=None)
104
+ vision_encoder: Optional[str] = attrs.field(default=None)
105
+ vision_encoder_in_channels: Optional[int] = attrs.field(default=3)
106
+ mm_projector: Optional[str] = attrs.field(default=None)
107
+ rope_dim: Optional[str] = attrs.field(default="1D")
108
+ pytorch_rope_version: Optional[str] = attrs.field(default="v2")
109
+ original_latent_shape: Optional[list] = None
110
+ pad_to_multiple_of: Optional[int] = None
111
+ vision_encoder_in_channels: Optional[int] = attrs.field(default=3)
112
+ insert_cross_attn: bool = False
113
+ insert_cross_attn_every_k_layers: int = 1
114
+ context_dim: Optional[int] = attrs.field(default=1024)
115
+ # For video training
116
+ num_video_frames: Optional[int] = None
117
+ # Raw video pixel dimension
118
+ video_height: Optional[int] = None
119
+ video_width: Optional[int] = None
120
+ # Video tokenizer output dimension, in (T,H,W), it's computed by num_video_frames/temporal_compress_factor, video_height/spatial_compression_fact, video_width/spatial_compression_fact
121
+ video_latent_shape: Optional[list] = None
122
+
123
+ def __getitem__(self, item):
124
+ return getattr(self, item)
125
+
126
+
127
+ @attrs.define
128
+ class TrainingModelConfig:
129
+ """
130
+ A class to hold model configuration arguments.
131
+
132
+ Args:
133
+ dim (int): The dimensionality of the input and output of each transformer block.
134
+ n_layers (int): Number of layers in the transformer.
135
+ n_heads (int): Number of attention heads.
136
+ n_kv_heads (Optional[int]): Number of key-value heads. If None, defaults to n_heads. Note: this is equivalent to
137
+ `num_gqa_groups` in TransformerEngine, where GQA means Grouped Query Attention.
138
+ head_dim (Optional[int]): Dimensionality of each head. If None, defaults to dim // n_heads.
139
+ vocab_size (int): Vocabulary size.
140
+ multiple_of (int): Ensures the hidden layer size is a multiple of this value for SwiGLU activation.
141
+ ffn_dim_multiplier (Optional[float]): Multiplier for feedforward network dimension.
142
+ ffn_hidden_size (Optional[int]): Hidden size for feedforward network. If None, use ffn_dim_multiplier to compute it.
143
+ norm_eps (float): Epsilon value for normalization.
144
+ rope_theta (float): Theta value for rotary positional embeddings.
145
+ apply_abs_pos_emb (bool): Whether to apply absolute position embeddings.
146
+ max_batch_size (int): Maximum batch size for inference (determines KV cache size).
147
+ max_seq_len (int): Maximum sequence length for input text (determines KV cache size).
148
+ fuse_qkv (bool): Whether to fuse QKV in attention. Flag for the pytorch backend.
149
+ causal_mask (bool): Whether to use causal mask. Defaults to True.
150
+ flash_attn (bool): Whether to use Flash attention.
151
+ norm_type (str): Type of normalization layer. Choices: "rmsnorm", "fused_rmsnorm", "layernorm", "np_layernorm".
152
+ backend (str): Backend for the model.
153
+ precision (str): Data type for the model.
154
+ ema (config.EMAConfig): Configuration for exponential moving average.
155
+ embedding_dropout(float): Dropout rate for the embedding layer.
156
+ attention_dropout(float): Dropout rate for attention.
157
+ hidden_dropout(float): Dropout after the attention and feed-forward layers (following TransformerEngine's
158
+ implementation in its TransformerLayer class).
159
+ use_qk_normalization (bool): Whether to enable QK normalization.
160
+ inference (bool): Whether the model is used for inference.
161
+ act_ckpt_enabled (bool): Whether to enable activation checkpointing.
162
+ fsdp_enabled (bool): Whether to enable FSDP.
163
+ fsdp (LazyDict): Configuration for FSDP.
164
+ ckpt_dir (str): Checkpoint directory.
165
+ ckpt_path (str): Checkpoint path.
166
+ cache_dir (str): Cache directory.
167
+ apply_yarn (Optional[bool]): Whether to apply YaRN (long-context extension).
168
+ yarn_scale (Optional[float]): Scale factor for YaRN.
169
+ yarn_beta_fast (Optional[int]): Beta fast variable for YaRN (i.e., low_freq_factor in Llama 3.1 RoPE scaling code)
170
+ yarn_beta_slow (Optional[int]): Beta slow variable for YaRN (i.e., high_freq_factor in Llama 3.1 RoPE scaling code)
171
+ original_seq_len (Optional[int]): Original sequence length.
172
+ depth_init (bool): If `True`, then each transformer block init uses its layer ID, and if `False`, each uses the
173
+ total number of transformer blocks. Defaults to `True` (following the TorchTitan implementation of Llama3).
174
+ context_parallel_size (int): Context parallel size. Defaults to 1.
175
+ tensor_model_parallel_size (int): Tensor model parallel size. Defaults to 1.
176
+ sequence_parallel (bool): Whether to use sequence parallelism. Defaults to False.
177
+ set_parallel_mode (bool): It is a boolean flag used by TransformerEngine to handle Tensor Parallelism.
178
+ Essentially, it is equivalent to `tensor_model_parallel_size > 1`. Defaults to `False`.
179
+ attention_tp (bool): Whether to use tensor parallelism for attention layers.
180
+ mm_projector (Optional[str]): Multimodal projector used for vision-language modeling. Defaults to None.
181
+ Choices: "identity", "linear", "mlp", "mlp_downsample".
182
+ video_latent_shape (Optional[list]): Shape of the video latent tensor. [T, H, W]
183
+ image_latent_shape (Optional[list]): Shape of the image latent tensor. [H, W]
184
+ num_video_frames (Optional[int]): Number of video frames.
185
+ rope_dim (Optional[str]): Dimensionality of the RoPE. Choices: "1D", "2D", "3D".
186
+ pytorch_rope_version (Optional[str]): Version of the RoPE for the `pytorch` backend. "v1" is the Llama implementation, and "v2" is HuggingFace/TransformerEngine implementation.
187
+ original_latent_shape (Optional[list]): Original shape of the latent tensor needed for rope extension.
188
+ pad_to_multiple_of (Optional[int]): Pad the position embedding to a multiple of this value.
189
+ peft_last_n_layers (Optional[int]): Number of last few layers to fine-tune in Parameter Efficient Fine-Tuning (PEFT). When this and peft_every_n_layers are both 0, it means all layers are fine-tuned (FFT).
190
+ peft_every_n_layers (Optional[int]): In Parameter Efficient Fine-Tuning (PEFT), every n layers are unfrozen and can be trained (in flamingo style). When this and peft_last_n_layers are both 0,
191
+ it means all layers are fine-tuned (FFT). For example, for a 40 layer model, n=8 means training layers 7, 15, 23, 31, 39, which includes the final layer.
192
+ It is advised to pick n such that the final layer is included.
193
+ freeze_vision_encoder (bool): Whether to freeze the vision encoder in vision-language model training. Defaults to False.
194
+ vision_encoder_in_channels (Optional[int]): Number of channels in the input image for the vision encoder. Default is 3, you can specify to int larger than 3. E.g. if you have 4-channel images with the last channel as the alpha channel, set this to 4.
195
+ insert_cross_attn (bool): Whether to insert the cross-attention layers after each multi-head self-attention (MSA) layer.
196
+ insert_cross_attn_every_k_layers (int): Insert cross-attention layers every k TransformerLayers.
197
+ context_dim (Optional[int]): The dimensionality of cross-attention embedding, e.g., T5 embed feature dim.
198
+ finetune_layers_with_cross_attn (bool): Whether to finetune Transformer layers w/ CA (cross-attn).
199
+ finetune_layers_without_cross_attn (bool): Whether to finetune Transformer layers w/o CA (cross-attn).
200
+ use_action_condition (bool): Whether to use the robot action condition.
201
+ action_embedding_mode (Optional[str]): The mode of the robot action embedding. Choices: "matrix", "mlp".
202
+ action_dim (Optional[int]): The dimensionality of the raw robot action tensor (e.g., 7 for DROID, [Δx, Δy, Δz, rx, ry, rz, gripper_open]).
203
+ action_embedding_dim (Optional[int]): The dimensionality of the robot action embedding.
204
+ group_causal_mask_mode (Optional[str]): The mode of the group causal mask. Choices: "causal", "group_diagonal".
205
+ sync_1d_parameters (bool): Whether to synchronize layernorm parameters (1D) across tensor parallel ranks (default True).
206
+ Note: this is to ensure all TP-ranks have the same layernorm parameters.
207
+ z_loss_coeff (float): The coefficient for the z-loss.
208
+ insert_medusa_head (bool): Whether to insert the Medusa head.
209
+ ft_medusa_option (str): Options on which layers to finetune, choices like:
210
+ "fft": fully fine-tune both medusa heads and all LLM backbone;
211
+ "head": fine-tune medusa heads;
212
+ "head_out": fine-tune medusa heads, and the output layer;
213
+ "head_out_last_k_layer": fine-tune medusa heads, the output layer, and the last k layer(s) of the LLM backbone.
214
+ medusa_num_heads (int): Number of heads in the Medusa head.
215
+ medusa_num_layers (int): Number of layers in the Medusa head.
216
+ medusa_concat_heads (bool): Whether to concatenate multiple medusa heads into fused matrix, only applicable when medusa_num_layers = 1.
217
+ zero_init_cross_attn_proj (bool): Whether to initialize the cross-attn proj layer with zeros (default False).
218
+ concat_action_to_context (bool): Whether to concatenate the action embedding to the context (default False).
219
+ """
220
+
221
+ dim: int = attrs.field(default=4096)
222
+ n_layers: int = attrs.field(default=32)
223
+ n_heads: int = attrs.field(default=32)
224
+ n_kv_heads: Optional[int] = attrs.field(default=8)
225
+ head_dim: Optional[int] = attrs.field(default=None)
226
+ vocab_size: int = attrs.field(default=128256)
227
+ multiple_of: int = attrs.field(default=1024) # make SwiGLU hidden layer size multiple of large power of 2
228
+ ffn_dim_multiplier: Optional[float] = attrs.field(default=1.3)
229
+ ffn_hidden_size: Optional[int] = attrs.field(default=None)
230
+ norm_eps: float = attrs.field(default=1e-5)
231
+ rope_theta: float = attrs.field(default=500000)
232
+ apply_abs_pos_emb: bool = attrs.field(default=False)
233
+ max_batch_size: int = attrs.field(default=1)
234
+ max_seq_len: int = attrs.field(default=8192)
235
+ fuse_qkv: bool = attrs.field(default=False)
236
+ causal_mask: bool = attrs.field(default=True)
237
+ flash_attn: bool = attrs.field(default=True)
238
+ norm_type: str = attrs.field(default="rmsnorm")
239
+ backend: str = attrs.field(default="pytorch")
240
+ precision: str = attrs.field(default="bfloat16")
241
+ ema: config.EMAConfig = config.EMAConfig(enabled=False)
242
+ embedding_dropout: float = 0.0
243
+ attention_dropout: float = 0.0
244
+ hidden_dropout: float = 0.0
245
+ use_qk_normalization: bool = False
246
+ tokenizer: Optional[TokenizerConfig] = None
247
+ inference: bool = False
248
+ act_ckpt_enabled: bool = False
249
+ fsdp_enabled: bool = False
250
+ context_parallel_size: int = attrs.field(default=1)
251
+ tensor_model_parallel_size: int = attrs.field(default=1)
252
+ sequence_parallel: bool = attrs.field(default=False)
253
+ set_parallel_mode: bool = attrs.field(default=False)
254
+ fsdp: LazyDict = LazyDict(
255
+ dict(
256
+ policy="auto", # choices: ["size", "auto"]
257
+ min_num_params=1024, # Used as policy == "size"
258
+ sharding_strategy="hybrid", # Choices: ["full", "hybrid"]. "full" means sharding_group_size = world_size
259
+ sharding_group_size=8, # If None, defaults to min(world_size, 8). Recommends 8 for training on 8-GPU nodes.
260
+ )
261
+ )
262
+ ckpt_dir: Optional[str] = attrs.field(default="")
263
+ ckpt_path: Optional[str] = attrs.field(
264
+ default=None
265
+ ) # If not None, load the model from this path instead of ckpt_dir
266
+ cache_dir: Optional[str] = attrs.field(default="/project/cosmos/ar/cache")
267
+ apply_yarn: Optional[bool] = attrs.field(default=False)
268
+ yarn_scale: Optional[float] = attrs.field(default=None)
269
+ yarn_beta_fast: Optional[int] = attrs.field(default=None)
270
+ yarn_beta_slow: Optional[int] = attrs.field(default=None)
271
+ original_seq_len: Optional[int] = attrs.field(default=None)
272
+ depth_init: bool = attrs.field(default=True)
273
+ ignore_first_num_tokens: int = 0
274
+ z_loss_coeff: float = 1e-4
275
+ attention_tp: bool = False
276
+ vision_encoder: Optional[str] = attrs.field(default=None)
277
+ mm_projector: Optional[str] = attrs.field(default=None)
278
+ rope_dim: Optional[str] = attrs.field(default="1D")
279
+ pytorch_rope_version: Optional[str] = attrs.field(default="v2")
280
+ original_latent_shape: Optional[list] = None
281
+ pad_to_multiple_of: Optional[int] = None
282
+ peft_last_n_layers: Optional[int] = attrs.field(default=0)
283
+ peft_every_n_layers: Optional[int] = attrs.field(default=0)
284
+ freeze_vision_encoder: bool = False
285
+ vision_encoder_in_channels: Optional[int] = attrs.field(default=3)
286
+ insert_cross_attn: bool = False
287
+ insert_cross_attn_every_k_layers: int = 1
288
+ context_dim: Optional[int] = attrs.field(default=1024)
289
+ finetune_layers_with_cross_attn: bool = False
290
+ finetune_layers_without_cross_attn: bool = False
291
+ use_action_condition: bool = False
292
+ action_embedding_mode: Optional[str] = attrs.field(default="mlp")
293
+ action_dim: Optional[int] = attrs.field(default=_ACTION_DIM)
294
+ action_embedding_dim: Optional[int] = attrs.field(default=1024)
295
+ group_causal_mask_mode: Optional[str] = attrs.field(default=None)
296
+ sync_1d_parameters: bool = True
297
+ # hyper-parameters for the medusa head configs
298
+ insert_medusa_head: bool = False
299
+ ft_medusa_option: str = "fft"
300
+ medusa_num_heads: int = 7
301
+ medusa_num_layers: int = 1
302
+ medusa_concat_heads: bool = True
303
+ # For video training
304
+ num_video_frames: Optional[int] = None
305
+ # Raw video pixel dimension
306
+ video_height: Optional[int] = None
307
+ video_width: Optional[int] = None
308
+ # Video tokenizer output dimension, in (T,H,W), it's computed by num_video_frames/temporal_compress_factor, video_height/spatial_compression_fact, video_width/spatial_compression_fact
309
+ video_latent_shape: Optional[list] = None
310
+ # For image training
311
+ image_latent_shape: Optional[list] = None
312
+ # For robot training (action)
313
+ zero_init_cross_attn_proj: bool = False
314
+ # For robot training (action)
315
+ concat_action_to_context: bool = False
316
+
317
+ def __getitem__(self, item):
318
+ return getattr(self, item)
cosmos_predict1/autoregressive/configs/base/model_config.py ADDED
@@ -0,0 +1,718 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import copy
17
+ from typing import Callable, List, Optional
18
+
19
+ import torch
20
+ from megatron.core import ModelParallelConfig
21
+
22
+ from cosmos_predict1.autoregressive.configs.base.model import ModelConfig, TrainingModelConfig
23
+ from cosmos_predict1.autoregressive.configs.base.tokenizer import (
24
+ TextTokenizerConfig,
25
+ TokenizerConfig,
26
+ VideoTokenizerConfig,
27
+ create_discrete_video_fsq_tokenizer_state_dict_config,
28
+ )
29
+ from cosmos_predict1.autoregressive.tokenizer.image_text_tokenizer import ImageTextTokenizer
30
+ from cosmos_predict1.autoregressive.tokenizer.text_tokenizer import TextTokenizer
31
+ from cosmos_predict1.autoregressive.training.model import AutoRegressiveTrainingModel
32
+ from cosmos_predict1.utils import log
33
+ from cosmos_predict1.utils.config import EMAConfig
34
+ from cosmos_predict1.utils.lazy_config import LazyCall as L
35
+
36
+ # Common architecture specifications
37
+ BASE_CONFIG = {"n_kv_heads": 8, "norm_type": "rmsnorm", "norm_eps": 1e-5, "ffn_hidden_size": 14336}
38
+ COSMOS_ARCHITECTURES = {
39
+ "1b": {
40
+ "n_layers": 16,
41
+ "dim": 2048,
42
+ "n_heads": 32,
43
+ },
44
+ "4b": {
45
+ "n_layers": 16,
46
+ "dim": 4096,
47
+ "n_heads": 32,
48
+ },
49
+ "12b": {
50
+ "n_layers": 40,
51
+ "dim": 5120,
52
+ "n_heads": 32,
53
+ "head_dim": 128,
54
+ },
55
+ }
56
+
57
+ COSMOS_YARN_CONFIG = {
58
+ "original_latent_shape": [3, 40, 64],
59
+ "apply_yarn": True,
60
+ "yarn_beta_fast": 4,
61
+ "yarn_beta_slow": 1,
62
+ "yarn_scale": 2,
63
+ }
64
+
65
+ # Llama3 architecture specifications for different model sizes
66
+ LLAMA3_ARCHITECTURES = {
67
+ "8b": {
68
+ "n_layers": 32,
69
+ "dim": 4096,
70
+ "n_heads": 32,
71
+ "ffn_hidden_size": 14336,
72
+ },
73
+ }
74
+ # Llama3.1 uses YaRN for long context support (context of 128k tokens)
75
+ LLAMA_YARN_CONFIG = {
76
+ "apply_yarn": True,
77
+ "yarn_scale": 8,
78
+ "yarn_beta_fast": 4,
79
+ "yarn_beta_slow": 1,
80
+ }
81
+
82
+ # Mistral architecture specifications for different model sizes
83
+ MISTRAL_ARCHITECTURES = {
84
+ "12b": {
85
+ "n_layers": 40,
86
+ "dim": 5120,
87
+ "n_heads": 32,
88
+ "ffn_hidden_size": 14336,
89
+ "head_dim": 128,
90
+ },
91
+ }
92
+
93
+ PIXTRAL_VISION_ARCHITECTURES = {
94
+ "12b": {"vision_encoder": "pixtral-12b-vit", "mm_projector": "mlp"},
95
+ }
96
+
97
+
98
+ def get_model_arch_specs(model_size: str, model_family: str = "mistral", pretrained: bool = False) -> dict:
99
+ """
100
+ Get the model architecture specifications for the given model size, model family and pretrained status.
101
+
102
+ Args:
103
+ model_size (str): Model size. Choices: "1b", "3b", "4b", "7b", etc.
104
+ model_family (str): Model family. Choices: "llama", "llama3", "llama3.1", "mistral"
105
+ pretrained (bool): Whether to load pretrained weights.
106
+
107
+ Returns:
108
+ dict: A dictionary containing the model architecture specifications.
109
+ """
110
+ arch_specs = copy.deepcopy(BASE_CONFIG)
111
+ model_size = model_size.lower()
112
+ if model_family.startswith("cosmos"):
113
+ arch_specs.update(COSMOS_ARCHITECTURES[model_size])
114
+ elif model_family.startswith("llama"):
115
+ arch_specs.update(LLAMA3_ARCHITECTURES[model_size])
116
+ elif model_family in ["mistral", "pixtral"]:
117
+ arch_specs.update(MISTRAL_ARCHITECTURES[model_size])
118
+ if model_family == "pixtral":
119
+ arch_specs.update(PIXTRAL_VISION_ARCHITECTURES[model_size])
120
+ else:
121
+ raise ValueError(f"Model family {model_family} is not supported.")
122
+
123
+ if pretrained:
124
+ if model_family == "cosmos":
125
+ if model_size == "12b":
126
+ arch_specs.update(COSMOS_YARN_CONFIG)
127
+ log.debug(f"Using YaRN for RoPE extension with config: {COSMOS_YARN_CONFIG}")
128
+ else:
129
+ pass
130
+ elif model_family in ["llama", "llama3"]:
131
+ pretrained_specs = {
132
+ "rope_theta": 500000,
133
+ "max_seq_len": 8192,
134
+ "vocab_size": 128256,
135
+ }
136
+ arch_specs.update(pretrained_specs)
137
+ elif model_family == "llama3.1":
138
+ pretrained_specs = {
139
+ "rope_theta": 500000,
140
+ "max_seq_len": 131072,
141
+ "original_seq_len": 8192,
142
+ "vocab_size": 128256,
143
+ **LLAMA_YARN_CONFIG,
144
+ }
145
+ arch_specs.update(pretrained_specs)
146
+ elif model_family == "mistral":
147
+ assert model_size == "12b", "We only support Mistral-Nemo-12B model."
148
+ pretrained_specs = {
149
+ "rope_theta": 1000000,
150
+ "max_seq_len": 128000,
151
+ "vocab_size": 131072,
152
+ }
153
+ arch_specs.update(pretrained_specs)
154
+ elif model_family == "pixtral":
155
+ assert model_size == "12b", "We only support Pixtral 12B model."
156
+ pretrained_specs = {"rope_theta": 1000000000, "max_seq_len": 128000, "vocab_size": 131072}
157
+ arch_specs.update(pretrained_specs)
158
+ else:
159
+ raise ValueError(f"Model family {model_family} doesn't have a pretrained config.")
160
+
161
+ return arch_specs
162
+
163
+
164
+ def create_text_model_config(
165
+ model_ckpt_path: str,
166
+ tokenizer_path: str,
167
+ tensor_model_parallel_size: int = 1,
168
+ model_family: str = "mistral",
169
+ model_size: str = "12b",
170
+ is_instruct_model: bool = True,
171
+ max_seq_len: int = None,
172
+ max_batch_size: int = 1,
173
+ rope_dim: str = "1D",
174
+ add_special_tokens: bool = True,
175
+ pytorch_rope_version: str = None,
176
+ ) -> dict:
177
+ """Create a text model for training or inference.
178
+ Args:
179
+ model_ckpt_path (str): Path to the model checkpoint.
180
+ tokenizer_path (str): Path to the tokenizer folder.
181
+ tensor_model_parallel_size (int): Number of tensor model parallel groups.
182
+ model_family (str): Model family. Choices: "llama", "llama3", "llama3.1", "mistral".
183
+ model_size (str): Model size. Choices: "1b", "3b", "4b", "7b", "8b", "72b", etc.
184
+ is_instruct_model (bool): Whether the model is an instruct model.
185
+ inference (bool): Whether to create the model for inference.
186
+ max_seq_len (int): Maximum sequence length.
187
+ max_batch_size (int): Maximum batch size.
188
+ rope_dim (str): RoPE dimension. Choices: "1D", "3D".
189
+ add_special_tokens (bool): Whether to add special tokens.
190
+ Returns:
191
+ dict: A dictionary containing the model configuration, which can be used to instantiate the model object.
192
+ """
193
+ # Model size specific parameters
194
+ model_arch_specs = get_model_arch_specs(model_family=model_family, model_size=model_size, pretrained=True)
195
+ if max_seq_len is not None:
196
+ # Override the max_seq_len if provided
197
+ model_arch_specs["max_seq_len"] = max_seq_len
198
+ if pytorch_rope_version is not None:
199
+ model_arch_specs["pytorch_rope_version"] = pytorch_rope_version
200
+ model_config = ModelConfig(
201
+ max_batch_size=max_batch_size,
202
+ precision="bfloat16",
203
+ ckpt_path=model_ckpt_path,
204
+ use_qk_normalization=False,
205
+ tensor_model_parallel_size=tensor_model_parallel_size,
206
+ rope_dim=rope_dim,
207
+ **model_arch_specs,
208
+ )
209
+
210
+ tokenizer_config = TokenizerConfig(
211
+ text_tokenizer=TextTokenizerConfig(
212
+ config=L(TextTokenizer)(
213
+ model_family=model_family,
214
+ is_instruct_model=is_instruct_model,
215
+ local_path=tokenizer_path,
216
+ ),
217
+ data_key="text",
218
+ tokenizer_offset=model_config.vocab_size,
219
+ tokenize_here=False,
220
+ vocab_size=model_config.vocab_size,
221
+ ),
222
+ seq_len=model_config.max_seq_len,
223
+ training_type="text_only",
224
+ add_special_tokens=add_special_tokens,
225
+ )
226
+ return model_config, tokenizer_config
227
+
228
+
229
+ def create_vision_language_model_config(
230
+ model_ckpt_path: str,
231
+ tokenizer_ckpt_path: str,
232
+ tensor_model_parallel_size: int = 1,
233
+ model_family: str = "pixtral",
234
+ model_size: str = "12b",
235
+ is_instruct_model: bool = True,
236
+ max_batch_size: int = 1,
237
+ rope_dim: str = "1D",
238
+ add_special_tokens: bool = True,
239
+ max_seq_len: int = None,
240
+ vision_encoder_in_channels: int = 3,
241
+ fuse_qkv: bool = False,
242
+ pytorch_rope_version: str = None,
243
+ ) -> dict:
244
+ """Create a vision-language model for training or inference.
245
+ Args:
246
+ model_ckpt_path (str): Path to the model checkpoint.
247
+ tokenizer_ckpt_path (str): Path to the tokenizer checkpoint.
248
+ tensor_model_parallel_size (int): Number of tensor model parallel groups.
249
+ model_family (str): Model family. Choices: "pixtral".
250
+ model_size (str): Model size. Choices: "12b".
251
+ is_instruct_model (bool): Whether the model is an instruct model.
252
+ rope_dim (str): RoPE dimension. Choices: "1D".
253
+ add_special_tokens (bool): Whether to add special tokens.
254
+ max_seq_len (int): Maximum sequence length.
255
+ vision_encoder_in_channels (int): Number of channels in the input image for the vision encoder. Default is 3, you can specify to int larger than 3. E.g. if you have 4 channel images where last channel is binary mask, set this to 4.
256
+ fuse_qkv (bool): Whether to fuse the QKV linear layers.
257
+ Returns:
258
+ dict: A dictionary containing the model configuration, which can be used to instantiate the model object.
259
+ """
260
+ # Model size specific parameters
261
+ model_arch_specs = get_model_arch_specs(model_family=model_family, model_size=model_size, pretrained=True)
262
+ if max_seq_len is not None:
263
+ # Override the max_seq_len if provided
264
+ model_arch_specs["max_seq_len"] = max_seq_len
265
+ if pytorch_rope_version is not None:
266
+ model_arch_specs["pytorch_rope_version"] = pytorch_rope_version
267
+
268
+ model_config = ModelConfig(
269
+ max_batch_size=max_batch_size,
270
+ precision="bfloat16",
271
+ ckpt_path=model_ckpt_path,
272
+ use_qk_normalization=False,
273
+ tensor_model_parallel_size=tensor_model_parallel_size,
274
+ rope_dim=rope_dim,
275
+ vision_encoder_in_channels=vision_encoder_in_channels,
276
+ fuse_qkv=fuse_qkv,
277
+ **model_arch_specs,
278
+ )
279
+ # Vision-language tokenizer
280
+ tokenizer_config = TokenizerConfig(
281
+ text_tokenizer=TextTokenizerConfig(
282
+ config=L(ImageTextTokenizer)(
283
+ model_family=model_family,
284
+ is_instruct_model=is_instruct_model,
285
+ image_processor_path=tokenizer_ckpt_path,
286
+ tokenizer_path=tokenizer_ckpt_path,
287
+ ),
288
+ data_key="image_text_interleaved",
289
+ tokenizer_offset=model_config.vocab_size,
290
+ tokenize_here=False,
291
+ vocab_size=model_config.vocab_size,
292
+ ),
293
+ seq_len=model_config.max_seq_len,
294
+ training_type="image_text_interleaved",
295
+ add_special_tokens=add_special_tokens,
296
+ )
297
+ return model_config, tokenizer_config
298
+
299
+
300
+ def create_video2world_model_config(
301
+ model_ckpt_path: str,
302
+ tokenizer_ckpt_path: str,
303
+ tensor_model_parallel_size: int = 1,
304
+ model_family: str = "cosmos",
305
+ model_size: str = "4b",
306
+ pixel_chunk_duration: int = 9,
307
+ num_video_frames: int = 36,
308
+ compression_ratio: List[int] = [8, 16, 16],
309
+ original_seq_len: int = 8192,
310
+ num_condition_latents_t: int = 1,
311
+ num_tokens_to_ignore: int = -1,
312
+ batch_size: int = 2,
313
+ video_tokenizer_config_creator: Callable = create_discrete_video_fsq_tokenizer_state_dict_config,
314
+ rope_dim: str = "3D",
315
+ add_special_tokens: bool = True,
316
+ video_height: int = 384,
317
+ video_width: int = 640,
318
+ use_qk_normalization: bool = True,
319
+ insert_cross_attn: bool = False,
320
+ insert_cross_attn_every_k_layers: int = 1,
321
+ context_dim: int = 1024,
322
+ training_type: str = "video_to_video",
323
+ pad_to_multiple_of: Optional[int] = 64,
324
+ vocab_size: int = 64000,
325
+ apply_abs_pos_emb: bool = False,
326
+ ) -> dict:
327
+ """Create a video-to-world model config.
328
+ Args:
329
+ tensor_model_parallel_size (int): Number of tensor model parallel groups.
330
+ model_family (str): Model family. Choices: "llama", "llama3", "llama3.1", "mistral".
331
+ model_size (str): Model size. Choices: "1b", "8b", "3b".
332
+ pixel_chunk_duration (int): Number of frames in each chunk.
333
+ num_video_frames (int): Number of video frames.
334
+ compression_ratio (List[int]): Compression ratio for the video frames. Choices: [8, 16, 16] or [4, 8, 8].
335
+ original_seq_len (int): Original sequence length.
336
+ apply_yarn (bool): Whether to apply YaRN for long context scaling.
337
+ yarn_beta_fast (Optional[int]): Fast beta for YaRN.
338
+ yarn_beta_slow (Optional[int]): Slow beta for YaRN.
339
+ yarn_scale (Optional[int]): Scale factor for ctx extension.
340
+ use_qk_normalization (bool): Whether to use Query-Key normalization.
341
+ training_type (str): Type of training task.
342
+ batch_size (int): Batch size.
343
+ video_tokenizer_config_creator (Callable): Method that takes "pixel_chunk_duration: int" and "version: str" as arguments and returns video tokenizer config
344
+ video_tokenizer_version (str): Version of the video tokenizer.
345
+ num_condition_latents_t (int): Number of conditioning latent channels
346
+ num_tokens_to_ignore (int) = Number of tokens to ignore. This takes the precedence
347
+ video_height (int): Height of the video frame. Defaults to 384.
348
+ video_width (int): Width of the video frame. Defaults to 640.
349
+ rope_dim (str): RoPE dimension. Choices: "1D", "3D".
350
+ add_special_tokens (bool): Whether to add special tokens, use False for 2D/3D RoPE.
351
+ pad_to_multiple_of (int): Pad the token sequence length to the nearest multiple of this number. Defaults to 64.
352
+ vocab_size (int): Vocabulary size.
353
+ apply_abs_pos_emb (bool): Whether to apply absolute positional embeddings.
354
+ Returns:
355
+ dict: A dictionary containing the model configuration representing the model object, can be instantiated.
356
+ """
357
+ assert (
358
+ pixel_chunk_duration % compression_ratio[0] == 1
359
+ ), f"pixel_chunk_duration({pixel_chunk_duration}) should be k*n + 1 (k={compression_ratio[0]})"
360
+ latent_chunk_duration = (pixel_chunk_duration - 1) // compression_ratio[0] + 1
361
+ latent_height = video_height // compression_ratio[1]
362
+ latent_width = video_width // compression_ratio[2]
363
+ # Do some math to compute the video latent shape and sequence length
364
+ assert (
365
+ num_video_frames % pixel_chunk_duration == 0
366
+ ), f"num_video_frames {num_video_frames} should be divisible by pixel_chunk_duration {pixel_chunk_duration}"
367
+ video_latent_shape = [
368
+ num_video_frames // pixel_chunk_duration * latent_chunk_duration,
369
+ latent_height,
370
+ latent_width,
371
+ ]
372
+ # product of video_latent_shape
373
+ num_token_video_latent = video_latent_shape[0] * video_latent_shape[1] * video_latent_shape[2]
374
+ if add_special_tokens:
375
+ seq_len = num_token_video_latent + 3 # Sequence length per batch, max_seq_len + 3
376
+ seq_len = (seq_len + 63) // 64 * 64 # Round up to multiple of 64
377
+ # for text to video, we need to add <bov> token to indicate the start of the video
378
+ elif training_type == "text_to_video":
379
+ seq_len = num_token_video_latent + 1
380
+ else:
381
+ seq_len = num_token_video_latent
382
+
383
+ if seq_len % pad_to_multiple_of != 0:
384
+ # Round up to the nearest multiple of pad_to_multiple_of
385
+ seq_len = ((seq_len + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of
386
+
387
+ # Model size specific parameters
388
+ model_arch_specs = get_model_arch_specs(model_family=model_family, model_size=model_size, pretrained=True)
389
+
390
+ # Whether skip the loss for first chunk or not, note the first token is already skipped when computing the loss
391
+ # If num_tokens_to_ignore is specified, use it.
392
+ # Else compute it from num_condition_latents_t
393
+ if num_tokens_to_ignore < 0:
394
+ num_tokens_to_ignore = latent_height * latent_width * num_condition_latents_t
395
+ if not add_special_tokens and num_condition_latents_t > 0:
396
+ # If there are no special tokens (bov), do a -1 so that you can compute the loss
397
+ # from the first token of the next chunk
398
+ num_tokens_to_ignore -= 1
399
+
400
+ model_config = ModelConfig(
401
+ video_height=video_height,
402
+ video_width=video_width,
403
+ max_seq_len=seq_len,
404
+ max_batch_size=batch_size,
405
+ precision="bfloat16",
406
+ ckpt_path=model_ckpt_path,
407
+ use_qk_normalization=use_qk_normalization,
408
+ vocab_size=64000,
409
+ original_seq_len=original_seq_len,
410
+ tensor_model_parallel_size=tensor_model_parallel_size,
411
+ video_latent_shape=video_latent_shape,
412
+ num_video_frames=num_video_frames,
413
+ rope_dim=rope_dim,
414
+ pad_to_multiple_of=pad_to_multiple_of,
415
+ insert_cross_attn=insert_cross_attn,
416
+ insert_cross_attn_every_k_layers=insert_cross_attn_every_k_layers,
417
+ context_dim=context_dim,
418
+ apply_abs_pos_emb=apply_abs_pos_emb,
419
+ **model_arch_specs,
420
+ )
421
+
422
+ video_tokenizer_config = video_tokenizer_config_creator(
423
+ tokenizer_ckpt_path, pixel_chunk_duration, compression_ratio
424
+ )
425
+ tokenizer_config = TokenizerConfig(
426
+ text_tokenizer=None,
427
+ video_tokenizer=VideoTokenizerConfig(
428
+ config=video_tokenizer_config,
429
+ data_key="video",
430
+ tokenizer_offset=0, # Since there is no text embeddings in the model. Note this only apply when the model is trained from scratch. If we use text pretrained model, the offset will be vocab_size of text token.
431
+ tokenize_here=True,
432
+ max_seq_len=num_token_video_latent,
433
+ vocab_size=vocab_size,
434
+ ),
435
+ seq_len=seq_len,
436
+ training_type=training_type,
437
+ add_special_tokens=add_special_tokens,
438
+ pad_to_multiple_of=pad_to_multiple_of,
439
+ )
440
+ return model_config, tokenizer_config
441
+
442
+
443
+ def create_video2world_model(
444
+ tensor_model_parallel_size: int = 1,
445
+ context_parallel_size: int = 1,
446
+ shard_checkpoint: bool = False,
447
+ model_family: str = "cosmos",
448
+ model_size: str = "1b",
449
+ backend: str = "pytorch",
450
+ pixel_chunk_duration: int = 9,
451
+ num_video_frames: int = 36,
452
+ compression_ratio: List[int] = [8, 16, 16],
453
+ original_seq_len: int = 8192,
454
+ apply_yarn: bool = False,
455
+ yarn_beta_fast: Optional[int] = None,
456
+ yarn_beta_slow: Optional[int] = None,
457
+ yarn_scale: Optional[int] = None,
458
+ num_condition_latents_t: int = 1,
459
+ num_tokens_to_ignore: int = -1,
460
+ batch_size: int = 1,
461
+ fsdp_enabled: bool = False,
462
+ act_ckpt_enabled: bool = False,
463
+ video_tokenizer_config_creator: Callable = create_discrete_video_fsq_tokenizer_state_dict_config,
464
+ rope_dim: str = "3D",
465
+ add_special_tokens: bool = False,
466
+ video_height: int = 384,
467
+ video_width: int = 640,
468
+ original_latent_shape: Optional[List[int]] = None,
469
+ use_qk_normalization: bool = True,
470
+ sequence_parallel: bool = False,
471
+ insert_cross_attn: bool = False,
472
+ insert_cross_attn_every_k_layers: int = 1,
473
+ context_dim: int = 1024,
474
+ finetune_layers_with_cross_attn: bool = False,
475
+ finetune_layers_without_cross_attn: bool = False,
476
+ use_action_condition: bool = False,
477
+ action_embedding_mode: Optional[str] = "mlp",
478
+ action_dim: int = 8, # ACTION_DIM,
479
+ action_embedding_dim: int = 1024,
480
+ group_causal_mask_mode: Optional[str] = None,
481
+ training_type: str = "video_to_video",
482
+ pad_to_multiple_of: Optional[int] = 1,
483
+ z_loss_coeff: float = 1e-4,
484
+ temporal_overlap: int = 0,
485
+ embedding_dropout: float = 0.0,
486
+ insert_medusa_head: bool = False,
487
+ ft_medusa_option: str = "fft",
488
+ medusa_num_heads: int = 7,
489
+ medusa_num_layers: int = 1,
490
+ medusa_concat_heads: bool = True,
491
+ fuse_qkv: bool = False,
492
+ zero_init_cross_attn_proj: bool = False,
493
+ concat_action_to_context: bool = False,
494
+ tokenizer_ckpt_path: str = "checkpoints/Cosmos-1.0-Tokenizer-DV8x16x16/ema.jit",
495
+ ) -> dict:
496
+ """Create a video-to-video model for training.
497
+ Args:
498
+ tensor_model_parallel_size (int): Number of tensor model parallel groups.
499
+ context_parallel_size (int): Number of context parallel groups.
500
+ model_family (str): Model family. Choices: "llama", "llama3", "llama3.1", "mistral".
501
+ model_size (str): Model size. Choices: "1b", "8b", "3b".
502
+ backend (str): Backend for the model. Choices: "pytorch", "transformer_engine".
503
+ pixel_chunk_duration (int): Number of frames in each chunk.
504
+ num_video_frames (int): Number of video frames.
505
+ compression_ratio (List[int]): Compression ratio for the video frames. Choices: [8, 16, 16] or [4, 8, 8].
506
+ original_seq_len (int): Original sequence length.
507
+ apply_yarn (bool): Whether to apply YaRN for long context scaling.
508
+ yarn_beta_fast (Optional[int]): Fast beta for YaRN.
509
+ yarn_beta_slow (Optional[int]): Slow beta for YaRN.
510
+ yarn_scale (Optional[int]): Scale factor for ctx extension.
511
+ fsdp_enabled (bool): Whether Fully Sharded Data Parallel (FSDP) is enabled.
512
+ act_ckpt_enabled (bool): Whether activation checkpointing is enabled.
513
+ use_qk_normalization (bool): Whether to use Query-Key normalization.
514
+ training_type (str): Type of training task.
515
+ batch_size (int): Batch size.
516
+ video_tokenizer_config_creator (Callable): Method that takes "pixel_chunk_duration: int" and "version: str" as arguments and returns video tokenizer config
517
+ video_tokenizer_version (str): Version of the video tokenizer.
518
+ num_condition_latents_t (int): Number of conditioning latent channels
519
+ num_tokens_to_ignore (int) = Number of tokens to ignore. This takes the precedence
520
+ video_height (int): Height of the video frame. Defaults to 384.
521
+ video_width (int): Width of the video frame. Defaults to 640.
522
+ rope_dim (str): RoPE dimension. Choices: "1D", "2D", "3D".
523
+ add_special_tokens (bool): Whether to add special tokens, use False for 2D/3D RoPE.
524
+ original_latent_shape (list): Original latent shape before RoPE scaling.
525
+ sequence_parallel (bool): Whether to enable sequence parallelism.
526
+ insert_cross_attn (bool): Whether to insert the cross-attention layers after each multi-head self-attention (MSA) layer.
527
+ insert_cross_attn_every_k_layers (int): Insert cross-attention layers every k TransformerLayers.
528
+ context_dim (Optional[int]): The dimensionality of cross-attention embedding, e.g., T5 embed feature dim.
529
+ finetune_layers_with_cross_attn (bool): Whether to finetune Transformer layers w/ CA (cross-attn).
530
+ finetune_layers_without_cross_attn (bool): Whether to finetune Transformer layers w/o CA (cross-attn).
531
+ use_action_condition (bool): Whether to use action condition.
532
+ action_embedding_mode (Optional[str]): The mode of the robot action embedding. Choices: "matrix", "mlp".
533
+ action_dim (int): Dimension of the raw robot action tensor (e.g., 7 for DROID, [Δx, Δy, Δz, rx, ry, rz, gripper_open]).
534
+ action_embedding_dim (int): Dimension of the action embedding.
535
+ group_causal_mask_mode (Optional[str]): The mode of the group causal mask. Choices: "causal", "group_diagonal".
536
+ pad_to_multiple_of (int): Pad the token sequence length to the nearest multiple of this number. Defaults to 64.
537
+ z_loss_coeff (float): Coefficient for the z loss.
538
+ temporal_overlap (int): Temporal overlap in the latent space.
539
+ embedding_dropout (float): Dropout rate for the embeddings.
540
+ insert_medusa_head (bool): Whether to insert the Medusa head.
541
+ ft_medusa_option (str): Options on which layers to finetune, choices like:
542
+ "fft": fully fine-tune both medusa heads and all LLM backbone;
543
+ "head": fine-tune medusa heads;
544
+ "head_out": fine-tune medusa heads, and the output layer;
545
+ "head_out_last_k_layer": fine-tune medusa heads, the output layer, and the last k layer(s) of the LLM backbone.
546
+ medusa_num_heads (int): Number of heads in the Medusa head.
547
+ medusa_num_layers (int): Number of layers in the Medusa head.
548
+ medusa_concat_heads (bool): Whether to concatenate multiple medusa heads into fused matrix, only applicable when medusa_num_layers = 1.
549
+ fuse_qkv (bool): Whether to fuse the QKV linear layers.
550
+ zero_init_cross_attn_proj (bool): Whether to zero-initialize the cross-attention projection weights (default False).
551
+ concat_action_to_context (bool): Whether to concatenate the action embedding to the context (default False).
552
+ Returns:
553
+ dict: A dictionary containing the model configuration representing the model object, can be instantiated.
554
+ """
555
+ assert (
556
+ pixel_chunk_duration % compression_ratio[0] == 1
557
+ ), f"pixel_chunk_duration({pixel_chunk_duration}) should be k*n + 1 (k={compression_ratio[0]})"
558
+ latent_chunk_duration = (pixel_chunk_duration - 1) // compression_ratio[0] + 1
559
+ latent_height = video_height // compression_ratio[1]
560
+ latent_width = video_width // compression_ratio[2]
561
+ # Compute the video latent shape and sequence length
562
+ if temporal_overlap == 0:
563
+ assert (
564
+ num_video_frames % pixel_chunk_duration == 0
565
+ ), f"num_video_frames {num_video_frames} should be divisible by pixel_chunk_duration {pixel_chunk_duration}"
566
+ video_latent_shape = [
567
+ num_video_frames // pixel_chunk_duration * latent_chunk_duration,
568
+ latent_height,
569
+ latent_width,
570
+ ]
571
+
572
+ else:
573
+ # Calculate temporal overlap in the latent space
574
+ temporal_overlap_latent = temporal_overlap // compression_ratio[0]
575
+
576
+ # Calculate the effective number of latent chunks for the video
577
+ latent_chunks = (num_video_frames - temporal_overlap) // (pixel_chunk_duration - temporal_overlap)
578
+
579
+ # Compute the total duration of the latent chunks, accounting for overlap
580
+ effective_latent_duration = (
581
+ latent_chunk_duration - temporal_overlap_latent
582
+ ) * latent_chunks + temporal_overlap_latent
583
+
584
+ # Define the shape of the video in the latent space
585
+ video_latent_shape = [
586
+ effective_latent_duration, # Temporal dimension
587
+ latent_height, # Height in the latent space
588
+ latent_width, # Width in the latent space
589
+ ]
590
+
591
+ # product of video_latent_shape
592
+ num_token_video_latent = video_latent_shape[0] * video_latent_shape[1] * video_latent_shape[2]
593
+ if add_special_tokens:
594
+ seq_len = num_token_video_latent + 3 # Sequence length per batch, max_seq_len + 3
595
+ seq_len = (seq_len + 63) // 64 * 64 # Round up to multiple of 64
596
+ # for text to video, we need to add <bov> token to indicate the start of the video
597
+ elif training_type == "text_to_video":
598
+ seq_len = num_token_video_latent + 1
599
+ else:
600
+ seq_len = num_token_video_latent
601
+
602
+ if seq_len % pad_to_multiple_of != 0:
603
+ # Round up to the nearest multiple of pad_to_multiple_of
604
+ seq_len = ((seq_len + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of
605
+
606
+ # Model size specific parameters
607
+ model_arch_specs = get_model_arch_specs(model_family=model_family, model_size=model_size, pretrained=False)
608
+
609
+ inference = False # False for training, True for inference
610
+ # set_parallel_mode = True
611
+ set_parallel_mode = tensor_model_parallel_size > 1
612
+ attention_tp = True
613
+
614
+ if context_parallel_size > 1:
615
+ assert backend == "transformer_engine", "Context parallelism is only supported in transformer engine."
616
+
617
+ if tensor_model_parallel_size > 1:
618
+ assert set_parallel_mode, "Tensor model parallelism is only supported in parallel mode."
619
+
620
+ # Whether skip the loss for first chunk or not, note the first token is already skipped when computing the loss
621
+ # If num_tokens_to_ignore is specified, use it.
622
+ # Else compute it from num_condition_latents_t
623
+ if num_tokens_to_ignore < 0:
624
+ num_tokens_to_ignore = latent_height * latent_width * num_condition_latents_t
625
+ if not add_special_tokens and num_condition_latents_t > 0:
626
+ # If there are no special tokens (bov), do a -1 so that you can compute the loss
627
+ # from the first token of the next chunk
628
+ num_tokens_to_ignore -= 1
629
+
630
+ model_config = TrainingModelConfig(
631
+ video_height=video_height,
632
+ video_width=video_width,
633
+ max_seq_len=seq_len,
634
+ max_batch_size=batch_size,
635
+ inference=inference,
636
+ backend=backend,
637
+ precision="bfloat16",
638
+ ema=EMAConfig(enabled=False),
639
+ act_ckpt_enabled=act_ckpt_enabled,
640
+ fsdp_enabled=fsdp_enabled,
641
+ cache_dir=None,
642
+ ckpt_path="checkpoints/Cosmos-Predict1-4B/model.pt",
643
+ use_qk_normalization=use_qk_normalization,
644
+ vocab_size=64000,
645
+ ignore_first_num_tokens=num_tokens_to_ignore,
646
+ apply_yarn=apply_yarn,
647
+ yarn_beta_fast=yarn_beta_fast,
648
+ yarn_beta_slow=yarn_beta_slow,
649
+ original_seq_len=original_seq_len,
650
+ yarn_scale=yarn_scale,
651
+ context_parallel_size=context_parallel_size,
652
+ tensor_model_parallel_size=tensor_model_parallel_size,
653
+ set_parallel_mode=set_parallel_mode,
654
+ attention_tp=attention_tp,
655
+ video_latent_shape=video_latent_shape,
656
+ num_video_frames=num_video_frames,
657
+ rope_dim=rope_dim,
658
+ original_latent_shape=original_latent_shape,
659
+ pad_to_multiple_of=pad_to_multiple_of,
660
+ sequence_parallel=sequence_parallel,
661
+ insert_cross_attn=insert_cross_attn,
662
+ insert_cross_attn_every_k_layers=insert_cross_attn_every_k_layers,
663
+ context_dim=context_dim,
664
+ finetune_layers_with_cross_attn=finetune_layers_with_cross_attn,
665
+ finetune_layers_without_cross_attn=finetune_layers_without_cross_attn,
666
+ use_action_condition=use_action_condition,
667
+ action_embedding_mode=action_embedding_mode,
668
+ action_dim=action_dim,
669
+ action_embedding_dim=action_embedding_dim,
670
+ group_causal_mask_mode=group_causal_mask_mode,
671
+ z_loss_coeff=z_loss_coeff,
672
+ embedding_dropout=embedding_dropout,
673
+ insert_medusa_head=insert_medusa_head,
674
+ ft_medusa_option=ft_medusa_option,
675
+ medusa_num_heads=medusa_num_heads,
676
+ medusa_num_layers=medusa_num_layers,
677
+ medusa_concat_heads=medusa_concat_heads,
678
+ fuse_qkv=fuse_qkv,
679
+ zero_init_cross_attn_proj=zero_init_cross_attn_proj,
680
+ concat_action_to_context=concat_action_to_context,
681
+ **model_arch_specs,
682
+ )
683
+
684
+ tokenizer_config = TokenizerConfig(
685
+ text_tokenizer=None,
686
+ video_tokenizer=VideoTokenizerConfig(
687
+ config=video_tokenizer_config_creator(
688
+ ckpt_path=tokenizer_ckpt_path, pixel_chunk_duration=pixel_chunk_duration
689
+ ),
690
+ data_key="video",
691
+ tokenizer_offset=0,
692
+ vocab_size=64000,
693
+ tokenize_here=True,
694
+ max_seq_len=num_token_video_latent,
695
+ temporal_overlap=temporal_overlap,
696
+ ),
697
+ seq_len="${model.model_config.max_seq_len}",
698
+ training_type=training_type,
699
+ add_special_tokens=add_special_tokens,
700
+ pad_to_multiple_of=pad_to_multiple_of,
701
+ )
702
+
703
+ model_parallel = ModelParallelConfig(
704
+ bf16=True,
705
+ params_dtype=getattr(torch, "bfloat16"),
706
+ )
707
+ model_parallel.tensor_model_parallel_size = "${model.model_config.tensor_model_parallel_size}"
708
+ model_parallel.context_parallel_size = "${model.model_config.context_parallel_size}"
709
+ model_parallel.sequence_parallel = "${model.model_config.sequence_parallel}"
710
+ return L(AutoRegressiveTrainingModel.build)(
711
+ seed=0,
712
+ train_from_scratch=True,
713
+ model_config=model_config,
714
+ fsdp_checkpointer=None,
715
+ tokenizer_config=tokenizer_config,
716
+ model_parallel=model_parallel,
717
+ shard_checkpoint=shard_checkpoint,
718
+ )
cosmos_predict1/autoregressive/configs/base/model_parallel.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ from megatron.core import ModelParallelConfig
18
+
19
+ from cosmos_predict1.utils.lazy_config import LazyDict
20
+
21
+
22
+ def create_model_parallel_config():
23
+ model_parallel = ModelParallelConfig(bf16=True, params_dtype=getattr(torch, "bfloat16"))
24
+ model_parallel.tensor_model_parallel_size = "${model.model_parallel.tensor_model_parallel_size}"
25
+ model_parallel.context_parallel_size = "${model.model_parallel.context_parallel_size}"
26
+ model_parallel.sequence_parallel = "${model.model_parallel.sequence_parallel}"
27
+ MODEL_PARALLELS = LazyDict(
28
+ dict(
29
+ model_parallel_bf16=model_parallel,
30
+ ),
31
+ flags={"allow_objects": True},
32
+ )
33
+ return MODEL_PARALLELS["model_parallel_bf16"]
cosmos_predict1/autoregressive/configs/base/optim.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+
18
+ from cosmos_predict1.utils.lazy_config import LazyCall as L
19
+
20
+
21
+ class LambdaLinearWarmupScheduler:
22
+ """
23
+ A learning rate scheduler that implements linear warm-up and cool-down.
24
+
25
+ This scheduler provides three phases:
26
+ 1. Warm-up: Learning rate linearly increases from 0 to 1.
27
+ 2. Constant: Learning rate remains at 1.
28
+ 3. Cool-down: Learning rate linearly decreases from 1 to 0.
29
+
30
+ Args:
31
+ warmup_steps (int): Number of steps for the warm-up phase.
32
+ warmup_offset (int): Starts warmup from this offset.
33
+ max_iter (int, optional): Total number of iterations. Required if cooldown_steps is provided.
34
+ cooldown_steps (int, optional): Number of steps for the cool-down phase.
35
+
36
+ Raises:
37
+ ValueError: If cooldown_steps is provided without max_iter, or if an invalid step is given.
38
+ """
39
+
40
+ def __init__(self, warmup_steps: int, warmup_offset: int = 0, max_iter: int = None, cooldown_steps: int = None):
41
+ self.warmup_steps = warmup_steps
42
+ self.warmup_offset = warmup_offset
43
+ self.max_iter = max_iter
44
+ self.cooldown_steps = cooldown_steps
45
+
46
+ if cooldown_steps is not None:
47
+ if max_iter is None:
48
+ raise ValueError("max_iter must be specified when cooldown_steps is provided")
49
+ self.cooldown_start = max_iter - cooldown_steps
50
+ else:
51
+ self.cooldown_start = None
52
+
53
+ def __call__(self, step):
54
+ # Warm-up phase
55
+ if step < self.warmup_offset:
56
+ return 0
57
+
58
+ if step < self.warmup_steps + self.warmup_offset:
59
+ return float(step - self.warmup_offset) / float(max(1, self.warmup_steps))
60
+
61
+ # Constant phase (no cool-down)
62
+ elif self.cooldown_steps is None:
63
+ return 1.0
64
+
65
+ # Constant phase (before cool-down starts)
66
+ elif step < self.cooldown_start:
67
+ return 1.0
68
+
69
+ # Cool-down phase
70
+ elif self.cooldown_start <= step < self.max_iter:
71
+ cooldown_progress = (step - self.cooldown_start) / self.cooldown_steps
72
+ return 1.0 - cooldown_progress
73
+
74
+ # After max_iter
75
+ elif step >= self.max_iter:
76
+ return 0.0
77
+
78
+ # Unexpected case
79
+ else:
80
+ raise ValueError(f"Invalid step {step}")
81
+
82
+
83
+ LambdaLinearLR = L(torch.optim.lr_scheduler.LambdaLR)(
84
+ optimizer=None,
85
+ lr_lambda=L(LambdaLinearWarmupScheduler)(warmup_steps=5000),
86
+ )
cosmos_predict1/autoregressive/configs/base/tokenizer.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Optional
17
+
18
+ import attrs
19
+
20
+ from cosmos_predict1.autoregressive.tokenizer.discrete_video import DiscreteVideoFSQStateDictTokenizer
21
+ from cosmos_predict1.autoregressive.tokenizer.networks import CausalDiscreteVideoTokenizer
22
+ from cosmos_predict1.utils.lazy_config import LazyCall as L
23
+ from cosmos_predict1.utils.lazy_config import LazyDict
24
+
25
+
26
+ def create_discrete_video_fsq_tokenizer_state_dict_config(
27
+ ckpt_path, pixel_chunk_duration=33, compression_ratio=[8, 16, 16]
28
+ ) -> LazyDict:
29
+ CausalDiscreteFactorizedVideoTokenizerConfig: LazyDict = L(CausalDiscreteVideoTokenizer)(
30
+ # The new causal discrete tokenizer, that is at least 2x more efficient in memory and runtime.
31
+ # - It relies on fully 3D discrete wavelet transform
32
+ # - Uses a layer norm instead of a group norm
33
+ # - Factorizes full convolutions into spatial and temporal convolutions
34
+ # - Factorizes full attention into spatial and temporal attention
35
+ # - Strictly causal, with flexible temporal length at inference.
36
+ attn_resolutions=[32],
37
+ channels=128,
38
+ channels_mult=[2, 4, 4],
39
+ dropout=0.0,
40
+ in_channels=3,
41
+ num_res_blocks=2,
42
+ out_channels=3,
43
+ resolution=1024,
44
+ patch_size=4,
45
+ patch_method="haar",
46
+ z_channels=16,
47
+ z_factor=1,
48
+ num_groups=1,
49
+ legacy_mode=False,
50
+ spatial_compression=16,
51
+ temporal_compression=8,
52
+ embedding_dim=6,
53
+ levels=[8, 8, 8, 5, 5, 5],
54
+ name="CausalDiscreteFactorizedVideoTokenizer",
55
+ )
56
+
57
+ return L(DiscreteVideoFSQStateDictTokenizer)(
58
+ enc_fp=ckpt_path.replace("ema.jit", "encoder.jit"),
59
+ dec_fp=ckpt_path.replace("ema.jit", "decoder.jit"),
60
+ tokenizer_module=CausalDiscreteFactorizedVideoTokenizerConfig,
61
+ name="discrete_video_fsq",
62
+ latent_ch=6,
63
+ is_bf16=True,
64
+ pixel_chunk_duration=pixel_chunk_duration,
65
+ latent_chunk_duration=1 + (pixel_chunk_duration - 1) // compression_ratio[0],
66
+ max_enc_batch_size=8,
67
+ max_dec_batch_size=4,
68
+ levels=[8, 8, 8, 5, 5, 5],
69
+ compression_ratio=compression_ratio,
70
+ )
71
+
72
+
73
+ @attrs.define(slots=False)
74
+ class TextTokenizerConfig:
75
+ """
76
+ Text tokenizer config
77
+
78
+ Args:
79
+ config: Config file to define the text tokenizer class.
80
+ data_key (str): The input key from data_dict that will be passed to the text tokenizer.
81
+ tokenize_here (bool): Whether to use the tokenizer to perform online tokenization.
82
+ tokenizer_offset (int): Offset that is added to the tokens.
83
+ vocab_size (int): Vocabulary size of the tokenizer.
84
+ """
85
+
86
+ config: LazyDict
87
+ data_key: str = ""
88
+ tokenize_here: bool = False
89
+ tokenizer_offset: int = 0
90
+ vocab_size: int = 0
91
+
92
+
93
+ @attrs.define(slots=False)
94
+ class VideoTokenizerConfig:
95
+ """
96
+ Video tokenizer config
97
+
98
+ Args:
99
+ config: Config file to define the video tokenizer class.
100
+ data_key (str): The input key from data_dict that will be passed to the video tokenizer.
101
+ tokenize_here (bool): Whether to use the tokenizer to perform online tokenization.
102
+ tokenizer_offset (int): Offset that is added to the tokens. In case of joint text-video tokenizers, we
103
+ add an offset to make sure that video tokens and text tokens don't overlap.
104
+ vocab_size (int): Vocabulary size of the tokenizer.
105
+ max_seq_len (int): Maximum token length for an input video.
106
+ temporal_overlap (int): Overlap between consecutive video chunks.
107
+ """
108
+
109
+ config: LazyDict
110
+ data_key: str = ""
111
+ tokenize_here: bool = True
112
+ tokenizer_offset: int = 0
113
+ vocab_size: int = 0
114
+ max_seq_len: int = -1
115
+ temporal_overlap: int = 0
116
+
117
+
118
+ @attrs.define(slots=False)
119
+ class TokenizerConfig:
120
+ """
121
+ Joint tokenizer config
122
+
123
+ Args:
124
+ text_tokenizer (TextTokenizerConfig): Text tokenizer config file
125
+ class_tokenizer (ClassTokenizerConfig): Class tokenizer config file
126
+ video_tokenizer (VideoTokenizerConfig): Video tokenizer config file
127
+ image_tokenizer (ImageTokenizerConfig): Image tokenizer config file
128
+ seq_len (int): Final token sequence length
129
+ training_type (str): Type of training we use. Supports ["text_only", "text_to_video", "class_to_image", "image_text_interleaved"]
130
+ add_special_tokens (bool): Whether to add special tokens to the output tokens
131
+ pad_to_multiple_of (int): Pad the token sequence length to the nearest multiple of this number. Defaults to 64.
132
+ """
133
+
134
+ text_tokenizer: Optional[TextTokenizerConfig] = None
135
+ video_tokenizer: Optional[VideoTokenizerConfig] = None
136
+ seq_len: int = 4096
137
+ training_type: str = None
138
+ add_special_tokens: bool = True
139
+ pad_to_multiple_of: Optional[int] = 64
cosmos_predict1/autoregressive/configs/config.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Default config for cosmos_ar project."""
17
+
18
+ import os
19
+ from typing import Any, List
20
+
21
+ import attrs
22
+
23
+ from cosmos_predict1.autoregressive.configs.registry import register_configs
24
+ from cosmos_predict1.autoregressive.trainer import Trainer
25
+ from cosmos_predict1.utils import config, log
26
+ from cosmos_predict1.utils.config_helper import import_all_modules_from_package
27
+
28
+
29
+ @attrs.define(slots=False)
30
+ class Config(config.Config):
31
+ defaults: List[Any] = attrs.field(
32
+ factory=lambda: [
33
+ "_self_",
34
+ {"model": None},
35
+ {"data_train": "mock_video"},
36
+ {"data_val": None},
37
+ {"optimizer": "fused_adamw"},
38
+ {"scheduler": "warmup_cosine_lr"},
39
+ {"checkpoint": "local"},
40
+ {"callbacks": "basic"},
41
+ {"global_config": None},
42
+ {"experiment": None},
43
+ ]
44
+ )
45
+
46
+ def validate(self) -> None:
47
+ """Validate that the config has all required fields."""
48
+ assert self.job.project != "", "job.project is not set"
49
+ assert self.job.group != "", "job.group is not set"
50
+ assert self.job.name != "", "job.name is not set"
51
+ log.info("Validating config for cosmos_autoregressive job")
52
+ # FSDP config check
53
+ if self.model.model_config.fsdp_enabled:
54
+ assert self.trainer.distributed_parallelism == "fsdp"
55
+ else:
56
+ assert self.trainer.distributed_parallelism == "ddp"
57
+
58
+ # Transformer Engine config check
59
+ if self.model.model_config.backend == "transformer_engine":
60
+ assert (
61
+ "NVTE_FLASH_ATTN" in os.environ and os.environ["NVTE_FLASH_ATTN"] == "1"
62
+ ) # Enable Flash attention for transformer engine
63
+
64
+ # TP, CP config check
65
+ if self.model_parallel is not None:
66
+ if self.model_parallel.context_parallel_size > 1:
67
+ assert (
68
+ self.model.model_config.backend == "transformer_engine"
69
+ ), "Context parallelism is only supported in transformer engine."
70
+
71
+ if self.model_parallel.tensor_model_parallel_size > 1:
72
+ assert (
73
+ self.model.model_config.set_parallel_mode
74
+ ), "Tensor model parallelism is only supported in parallel mode."
75
+
76
+ if self.model_parallel.sequence_parallel:
77
+ assert (
78
+ self.model_parallel.tensor_model_parallel_size > 1
79
+ ), "Sequence parallelism is only supported in tensor model parallelism."
80
+ assert (
81
+ self.model.model_config.backend == "transformer_engine"
82
+ ), "Sequence parallelism is only supported in transformer engine."
83
+
84
+
85
+ def make_config():
86
+ c = Config(
87
+ model=None,
88
+ optimizer=None,
89
+ scheduler=None,
90
+ dataloader_train=None,
91
+ dataloader_val=None,
92
+ checkpoint=None,
93
+ )
94
+
95
+ c.job.project = "cosmos_autoregressive"
96
+ c.job.group = "debug"
97
+ c.job.name = "default_${now:%Y-%m-%d}_${now:%H-%M-%S}"
98
+
99
+ c.trainer.type = Trainer
100
+ c.trainer.run_validation = True
101
+
102
+ c.trainer.seed = 0
103
+ c.trainer.max_iter = 10
104
+ c.trainer.logging_iter = 1
105
+
106
+ c.trainer.callbacks = None
107
+ register_configs()
108
+ # experiment config are defined in the experiment folder
109
+ # call import_all_modules_from_package to register them
110
+ import_all_modules_from_package("cosmos_predict1.autoregressive.configs.experiment")
111
+ return c
cosmos_predict1/autoregressive/configs/experiment/video2video/__init__.py ADDED
File without changes
cosmos_predict1/autoregressive/configs/experiment/video2video/basic.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """
17
+ This file contains a basic configuration for video2video experiments.
18
+ """
19
+
20
+ from hydra.core.config_store import ConfigStore
21
+
22
+ from cosmos_predict1.autoregressive.configs.base.model_config import create_video2world_model
23
+ from cosmos_predict1.autoregressive.configs.base.model_parallel import create_model_parallel_config
24
+ from cosmos_predict1.utils import log
25
+ from cosmos_predict1.utils.lazy_config import LazyDict
26
+
27
+ cs = ConfigStore.instance()
28
+
29
+
30
+ """
31
+ Finetune 4B model with TP=1, pytorch backend, low resolution tealrobot data, frames 33, chunk 33.
32
+ Usage:
33
+ torchrun --nproc_per_node=1 -m cosmos_predict1.autoregressive.train --config=cosmos_predict1/autoregressive/configs/config.py -- experiment=base_4b_example_tealrobotsmall_tp1
34
+ """
35
+ base_4b_example_tealrobotsmall_tp1: LazyDict = LazyDict(
36
+ dict(
37
+ defaults=[
38
+ {"override /data_train": "tealrobot_video_small"},
39
+ {
40
+ "override /callbacks": [
41
+ "basic",
42
+ "video_teacher_forcing",
43
+ ]
44
+ },
45
+ {"override /checkpoint": "local"},
46
+ {"override /optimizer": "fused_adamw"},
47
+ {"override /scheduler": "warmup_cosine_lr"},
48
+ "_self_",
49
+ ],
50
+ job=dict(
51
+ project="posttraining",
52
+ group="autoregressive_base",
53
+ name="base_4b_example_tealrobotsmall_tp1",
54
+ ),
55
+ model=create_video2world_model(
56
+ model_size="4b",
57
+ model_family="cosmos",
58
+ backend="pytorch",
59
+ tensor_model_parallel_size=1,
60
+ batch_size=1,
61
+ pixel_chunk_duration=33,
62
+ num_video_frames=33,
63
+ video_height=384,
64
+ video_width=640,
65
+ tokenizer_ckpt_path="checkpoints/Cosmos-Tokenize1-DV8x16x16-720p/ema.jit",
66
+ add_special_tokens=False,
67
+ ),
68
+ trainer=dict(
69
+ max_iter=50000,
70
+ grad_accum_iter=1,
71
+ grad_scaler_args=dict(enabled=False),
72
+ run_validation=False, # No need for validation as epoch <= 1
73
+ distributed_parallelism="ddp",
74
+ callbacks=dict(
75
+ vid_sampling_tf=dict(
76
+ every_n=500,
77
+ ),
78
+ ),
79
+ ),
80
+ checkpoint=dict(
81
+ load_path="checkpoints/Cosmos-Predict1-4B/model.pt",
82
+ load_training_state=False,
83
+ strict_resume=True,
84
+ save_iter=1000,
85
+ ),
86
+ model_parallel=create_model_parallel_config(),
87
+ ),
88
+ )
89
+
90
+
91
+ """
92
+ Finetune 4B model with TP=4, pytorch backend, high resolution tealrobot data, frame 33, chunk 33.
93
+ Usage:
94
+ torchrun --nproc_per_node=4 -m cosmos_predict1.autoregressive.train --config=cosmos_predict1/autoregressive/configs/config.py -- experiment=base_4b_example_tealrobot_tp4
95
+ """
96
+ base_4b_example_tealrobot_tp4: LazyDict = LazyDict(
97
+ dict(
98
+ defaults=[
99
+ {"override /data_train": "tealrobot_video"},
100
+ {
101
+ "override /callbacks": [
102
+ "basic",
103
+ "video_teacher_forcing",
104
+ ]
105
+ },
106
+ {"override /checkpoint": "local"},
107
+ {"override /optimizer": "fused_adamw"},
108
+ {"override /scheduler": "warmup_cosine_lr"},
109
+ "_self_",
110
+ ],
111
+ job=dict(
112
+ project="posttraining",
113
+ group="autoregressive_base",
114
+ name="base_4b_example_tealrobot_tp4",
115
+ ),
116
+ model=create_video2world_model(
117
+ model_size="4b",
118
+ model_family="cosmos",
119
+ backend="pytorch",
120
+ tensor_model_parallel_size=4,
121
+ batch_size=1,
122
+ pixel_chunk_duration=33,
123
+ num_video_frames=33,
124
+ video_height=640,
125
+ video_width=848,
126
+ tokenizer_ckpt_path="checkpoints/Cosmos-Tokenize1-DV8x16x16-720p/ema.jit",
127
+ add_special_tokens=False,
128
+ ),
129
+ trainer=dict(
130
+ max_iter=50000,
131
+ grad_accum_iter=1,
132
+ grad_scaler_args=dict(enabled=False),
133
+ run_validation=False, # No need for validation as epoch <= 1
134
+ distributed_parallelism="ddp",
135
+ callbacks=dict(
136
+ vid_sampling_tf=dict(
137
+ every_n=500,
138
+ ),
139
+ ),
140
+ ),
141
+ checkpoint=dict(
142
+ load_path="checkpoints/Cosmos-Predict1-4B/model.pt",
143
+ load_training_state=False,
144
+ strict_resume=False,
145
+ save_iter=1000,
146
+ ),
147
+ model_parallel=create_model_parallel_config(),
148
+ ),
149
+ )
150
+
151
+
152
+ def register_experiments(cs):
153
+ # Register the experiments
154
+ for _item in [
155
+ base_4b_example_tealrobotsmall_tp1,
156
+ base_4b_example_tealrobot_tp4,
157
+ ]:
158
+ cs.store(
159
+ group="experiment",
160
+ package="_global_",
161
+ name=_item["job"]["name"],
162
+ node=_item,
163
+ )
cosmos_predict1/autoregressive/configs/inference/inference_config.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any, List, Optional, Union
17
+
18
+ import attrs
19
+
20
+ from cosmos_predict1.autoregressive.configs.base.model import ModelConfig, TokenizerConfig
21
+
22
+
23
+ @attrs.define(slots=False)
24
+ class DataShapeConfig:
25
+ latent_shape: list = []
26
+ num_video_frames: Union[None, int] = None
27
+ height: Union[None, int] = None
28
+ width: Union[None, int] = None
29
+
30
+
31
+ @attrs.define(slots=False)
32
+ class SamplingConfig:
33
+ """
34
+ Sampling config
35
+ Args:
36
+ temperature (float): Temperature value for controlling randomness in sampling. Defaults to 0.6.
37
+ top_p (float): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
38
+ logprobs (bool): Flag indicating whether to compute token log probabilities. Defaults to False.
39
+ echo (bool): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
40
+
41
+ """
42
+
43
+ temperature: float = 0.6
44
+ top_k: int = None
45
+ top_p: float = 0.9
46
+ compile_prefill: bool = False
47
+ compile_sampling: bool = True
48
+ logprobs: bool = False
49
+ echo: bool = False
50
+
51
+
52
+ @attrs.define(slots=False)
53
+ class DiffusionDecoderSamplingConfig:
54
+ """
55
+ Diffusion decoder sampling config
56
+ Args:
57
+ guidance (float): Guidance scale for the diffusion process. Controls how much the model follows the conditioning. Defaults to 0.8.
58
+ sigma_min (float): Minimum noise level for the diffusion process. Defaults to 0.02.
59
+ sigma (float): Initial noise level for the diffusion process. Defaults to 8.
60
+ num_steps (int): Number of denoising steps to perform. Defaults to 35.
61
+ overlap (int): Number of overlapping frames between video chunks during processing. Defaults to 2.
62
+ continuous_tokenizer_channel (int): Number of channels in the continuous tokenizer of diffusion decoder. Defaults to 16.
63
+ continuous_tokenizer_spatial_compression_ratio (int): Spatial compression ratio for the continuous tokenizer of diffusion decoder. Defaults to 8.
64
+ dd_train_num_video_frames (int): Number of video frames used during training for diffusion decoder. Defaults to 57.
65
+ """
66
+
67
+ guidance: float = 1.8
68
+ sigma_min: float = 0.02
69
+ sigma: float = 8
70
+ num_steps: int = 15
71
+ overlap: int = 2
72
+ continuous_tokenizer_channel = 16
73
+ continuous_tokenizer_spatial_compression_ratio = 8
74
+ dd_train_num_video_frames: int = 57
75
+ max_iter: int = 99
76
+ fps: int = 24
77
+
78
+
79
+ @attrs.define(slots=False)
80
+ class InferenceConfig:
81
+ """
82
+ Inference config
83
+ Args:
84
+ model_config (ModelConfig): Model config
85
+ tokenizer_config (TokenizerConfig): Tokenizer config
86
+ ckpt_path (str): Path to the checkpoint
87
+ latent_shape (list): Shape of the latent
88
+ """
89
+
90
+ model_config: ModelConfig = None
91
+ tokenizer_config: TokenizerConfig = None
92
+ ckpt_path: str = ""
93
+ data_shape_config: DataShapeConfig = None
94
+
95
+ defaults: List[Any] = attrs.field(
96
+ factory=lambda: [
97
+ "_self_",
98
+ {"data_val": None},
99
+ {"data_shape_config": "video_shape_as_model_config"},
100
+ {"eval_job": None},
101
+ ]
102
+ )
cosmos_predict1/autoregressive/configs/registry.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ from hydra.core.config_store import ConfigStore
18
+
19
+ from cosmos_predict1.autoregressive.configs.base.callbacks import BASIC_CALLBACKS, VIDEO_TEACHER_FORCING_CALLBACK
20
+ from cosmos_predict1.autoregressive.configs.base.dataloader import get_tealrobot_video
21
+ from cosmos_predict1.autoregressive.configs.base.optim import LambdaLinearLR
22
+ from cosmos_predict1.autoregressive.configs.experiment.video2video.basic import register_experiments
23
+ from cosmos_predict1.utils import config, log
24
+ from cosmos_predict1.utils.lazy_config import LazyCall as L
25
+ from cosmos_predict1.utils.scheduler import WarmupCosineLR
26
+
27
+
28
+ def register_checkpoint(cs):
29
+ checkpoint_local = config.CheckpointConfig(
30
+ save_iter=5000,
31
+ broadcast_via_filesystem=True,
32
+ )
33
+ cs.store(group="checkpoint", package="checkpoint", name="local", node=checkpoint_local)
34
+
35
+
36
+ def register_callbacks(cs):
37
+ cs.store(group="callbacks", package="trainer.callbacks", name="basic", node=BASIC_CALLBACKS)
38
+ cs.store(
39
+ group="callbacks",
40
+ package="trainer.callbacks",
41
+ name="video_teacher_forcing",
42
+ node=VIDEO_TEACHER_FORCING_CALLBACK,
43
+ )
44
+
45
+
46
+ def register_scheduler(cs):
47
+ cs.store(
48
+ group="scheduler",
49
+ package="scheduler",
50
+ name="warmup_cosine_lr",
51
+ node=L(WarmupCosineLR)(optimizer=None, warmup_iters=5000, lr_decay_iters="${trainer.max_iter}", min_lr=1e-8),
52
+ )
53
+ cs.store(group="scheduler", package="scheduler", name="lambdalinear", node=LambdaLinearLR)
54
+
55
+
56
+ def register_optimizer(cs):
57
+ cs.store(
58
+ group="optimizer",
59
+ package="optimizer",
60
+ name="fused_adamw",
61
+ node=L(torch.optim.AdamW)(params=None, lr=1e-3, weight_decay=0.05, fused=True),
62
+ )
63
+ cs.store(
64
+ group="optimizer",
65
+ package="optimizer",
66
+ name="sgd",
67
+ node=L(torch.optim.SGD)(params=None, lr=5e-6, momentum=0.9),
68
+ )
69
+
70
+
71
+ def register_training_data(cs):
72
+ cs.store(
73
+ group="data_train",
74
+ package="dataloader_train",
75
+ name="tealrobot_video_small",
76
+ node=get_tealrobot_video(num_frames=33, video_size=[384, 640]),
77
+ )
78
+ cs.store(group="data_train", package="dataloader_train", name="tealrobot_video", node=get_tealrobot_video())
79
+
80
+
81
+ def register_configs():
82
+ log.info("Registering configs for autoregressive_base")
83
+ cs = ConfigStore.instance()
84
+ register_callbacks(cs)
85
+ register_checkpoint(cs)
86
+ register_optimizer(cs)
87
+ register_scheduler(cs)
88
+ register_training_data(cs)
89
+ register_experiments(cs)