Spaces:
Build error
Build error
Commit
·
28451f7
0
Parent(s):
Initial commit for new Space - pre-built Docker image
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .flake8 +10 -0
- .gitattributes +44 -0
- .gitignore +247 -0
- .gitmodules +27 -0
- .pre-commit-config.yaml +55 -0
- ATTRIBUTIONS.md +0 -0
- CONTRIBUTING.md +51 -0
- INSTALL.md +48 -0
- LICENSE +201 -0
- README.md +248 -0
- assets/demo_1.gif +3 -0
- assets/demo_2.gif +3 -0
- assets/demo_3.gif +3 -0
- assets/demo_dynamic.gif +3 -0
- assets/diffusion/000000.png +3 -0
- assets/diffusion/000001.png +3 -0
- assets/diffusion/000002.png +3 -0
- assets/diffusion/000003.png +3 -0
- assets/diffusion/000004.png +3 -0
- assets/diffusion/000005.png +3 -0
- assets/diffusion/000006.png +3 -0
- assets/diffusion/000007.png +3 -0
- assets/diffusion/000008.png +3 -0
- assets/diffusion/000009.png +3 -0
- assets/diffusion/000010.png +3 -0
- assets/diffusion/000011.png +3 -0
- assets/diffusion/000012.png +3 -0
- assets/diffusion/000013.png +3 -0
- assets/diffusion/000014.png +3 -0
- assets/diffusion/000015.png +3 -0
- checkpoints/README.md +4 -0
- cosmos-predict1.yaml +29 -0
- cosmos_predict1/__init__.py +14 -0
- cosmos_predict1/autoregressive/__init__.py +14 -0
- cosmos_predict1/autoregressive/callbacks/video_sampling_teacher_forcing.py +352 -0
- cosmos_predict1/autoregressive/configs/__init__.py +14 -0
- cosmos_predict1/autoregressive/configs/base/__init__.py +14 -0
- cosmos_predict1/autoregressive/configs/base/callbacks.py +33 -0
- cosmos_predict1/autoregressive/configs/base/dataloader.py +72 -0
- cosmos_predict1/autoregressive/configs/base/dataset.py +39 -0
- cosmos_predict1/autoregressive/configs/base/model.py +318 -0
- cosmos_predict1/autoregressive/configs/base/model_config.py +718 -0
- cosmos_predict1/autoregressive/configs/base/model_parallel.py +33 -0
- cosmos_predict1/autoregressive/configs/base/optim.py +86 -0
- cosmos_predict1/autoregressive/configs/base/tokenizer.py +139 -0
- cosmos_predict1/autoregressive/configs/config.py +111 -0
- cosmos_predict1/autoregressive/configs/experiment/video2video/__init__.py +0 -0
- cosmos_predict1/autoregressive/configs/experiment/video2video/basic.py +163 -0
- cosmos_predict1/autoregressive/configs/inference/inference_config.py +102 -0
- cosmos_predict1/autoregressive/configs/registry.py +89 -0
.flake8
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[flake8]
|
2 |
+
enable-extensions = G
|
3 |
+
select = B,C,E,F,G,P,SIM1,T4,W,B9
|
4 |
+
max-line-length = 120
|
5 |
+
# C408 ignored because we like the dict keyword argument syntax
|
6 |
+
# E501 is not flexible enough, we're using B950 instead
|
7 |
+
ignore =
|
8 |
+
E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,E226,E265
|
9 |
+
exclude =
|
10 |
+
third_party
|
.gitattributes
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
<<<<<<< HEAD
|
37 |
+
assets/*.gif filter=lfs diff=lfs merge=lfs -text
|
38 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
39 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
41 |
+
=======
|
42 |
+
>>>>>>> 0453ffbfce197070bb0c254a11ef21f15d1ad986
|
43 |
+
transformer_engine_torch-1.12.0+cu121-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
44 |
+
transformer_engine.whl filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# Misc
|
17 |
+
outputs/
|
18 |
+
checkpoints/*
|
19 |
+
!checkpoints/README.md
|
20 |
+
datasets/*
|
21 |
+
!datasets/README.md
|
22 |
+
apex/
|
23 |
+
|
24 |
+
# Data types
|
25 |
+
*.jit
|
26 |
+
*.pt
|
27 |
+
*.hdr
|
28 |
+
*.webp
|
29 |
+
*.pgm
|
30 |
+
*.tiff
|
31 |
+
*.tif
|
32 |
+
*.tar
|
33 |
+
*.tar.gz
|
34 |
+
*.gz
|
35 |
+
*.pkl
|
36 |
+
*.pt
|
37 |
+
*.bin
|
38 |
+
*.pickle
|
39 |
+
*.txt
|
40 |
+
|
41 |
+
# Other uncheckable file types
|
42 |
+
*.zip
|
43 |
+
*.exe
|
44 |
+
*.dll
|
45 |
+
*.swp
|
46 |
+
*.vscode
|
47 |
+
*.DS_Store
|
48 |
+
*.pyc
|
49 |
+
*Thumbs.db
|
50 |
+
*.patch
|
51 |
+
|
52 |
+
# Credential information that should never be checked in
|
53 |
+
credentials
|
54 |
+
*.secret
|
55 |
+
|
56 |
+
# ------------------------ BELOW IS AUTO-GENERATED FOR PYTHON REPOS ------------------------
|
57 |
+
|
58 |
+
# Byte-compiled / optimized / DLL files
|
59 |
+
**/__pycache__/
|
60 |
+
*.py[cod]
|
61 |
+
*$py.class
|
62 |
+
|
63 |
+
# C extensions
|
64 |
+
*.so
|
65 |
+
|
66 |
+
# Distribution / packaging
|
67 |
+
.Python
|
68 |
+
build/
|
69 |
+
develop-eggs/
|
70 |
+
dist/
|
71 |
+
downloads/
|
72 |
+
eggs/
|
73 |
+
.eggs/
|
74 |
+
lib/
|
75 |
+
lib64/
|
76 |
+
parts/
|
77 |
+
results/
|
78 |
+
sdist/
|
79 |
+
var/
|
80 |
+
wheels/
|
81 |
+
share/python-wheels/
|
82 |
+
*.egg-info/
|
83 |
+
.installed.config
|
84 |
+
*.egg
|
85 |
+
MANIFEST
|
86 |
+
|
87 |
+
# PyInstaller
|
88 |
+
# Usually these files are written by a python script from a template
|
89 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
90 |
+
*.manifest
|
91 |
+
*.spec
|
92 |
+
|
93 |
+
# Installer logs
|
94 |
+
pip-log.txt
|
95 |
+
pip-delete-this-directory.txt
|
96 |
+
|
97 |
+
# Unit test / coverage reports
|
98 |
+
htmlcov/
|
99 |
+
.tox/
|
100 |
+
.nox/
|
101 |
+
.coverage
|
102 |
+
.coverage.*
|
103 |
+
.cache
|
104 |
+
nosetests.xml
|
105 |
+
coverage.xml
|
106 |
+
*.cover
|
107 |
+
*.py,cover
|
108 |
+
.hypothesis/
|
109 |
+
.pytest_cache/
|
110 |
+
cover/
|
111 |
+
|
112 |
+
# Translations
|
113 |
+
*.mo
|
114 |
+
*.pot
|
115 |
+
|
116 |
+
# Django stuff:
|
117 |
+
*.log
|
118 |
+
local_settings.py
|
119 |
+
db.sqlite3
|
120 |
+
db.sqlite3-journal
|
121 |
+
|
122 |
+
# Flask stuff:
|
123 |
+
instance/
|
124 |
+
.webassets-cache
|
125 |
+
|
126 |
+
# Scrapy stuff:
|
127 |
+
.scrapy
|
128 |
+
|
129 |
+
# Sphinx documentation
|
130 |
+
docs/_build/
|
131 |
+
|
132 |
+
# PyBuilder
|
133 |
+
.pybuilder/
|
134 |
+
target/
|
135 |
+
|
136 |
+
# Third party
|
137 |
+
# Jupyter Notebook
|
138 |
+
.ipynb_checkpoints
|
139 |
+
|
140 |
+
# IPython
|
141 |
+
profile_default/
|
142 |
+
ipython_config.py
|
143 |
+
|
144 |
+
# pyenv
|
145 |
+
# For a library or package, you might want to ignore these files since the code is
|
146 |
+
# intended to run in multiple environments; otherwise, check them in:
|
147 |
+
# .python-version
|
148 |
+
|
149 |
+
# pipenv
|
150 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
151 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
152 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
153 |
+
# install all needed dependencies.
|
154 |
+
#Pipfile.lock
|
155 |
+
|
156 |
+
# poetry
|
157 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
158 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
159 |
+
# commonly ignored for libraries.
|
160 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
161 |
+
#poetry.lock
|
162 |
+
|
163 |
+
# pdm
|
164 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
165 |
+
#pdm.lock
|
166 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
167 |
+
# in version control.
|
168 |
+
# https://pdm.fming.dev/#use-with-ide
|
169 |
+
.pdm.toml
|
170 |
+
|
171 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
172 |
+
__pypackages__/
|
173 |
+
|
174 |
+
# Celery stuff
|
175 |
+
celerybeat-schedule
|
176 |
+
celerybeat.pid
|
177 |
+
|
178 |
+
# SageMath parsed files
|
179 |
+
*.sage.py
|
180 |
+
|
181 |
+
# Environments
|
182 |
+
.env
|
183 |
+
.venv
|
184 |
+
env/
|
185 |
+
venv/
|
186 |
+
ENV/
|
187 |
+
env.bak/
|
188 |
+
venv.bak/
|
189 |
+
|
190 |
+
# Spyder project settings
|
191 |
+
.spyderproject
|
192 |
+
.spyproject
|
193 |
+
|
194 |
+
# Rope project settings
|
195 |
+
.ropeproject
|
196 |
+
|
197 |
+
# mkdocs documentation
|
198 |
+
/site
|
199 |
+
|
200 |
+
# mypy
|
201 |
+
.mypy_cache/
|
202 |
+
.dmypy.json
|
203 |
+
dmypy.json
|
204 |
+
|
205 |
+
# Pyre type checker
|
206 |
+
.pyre/
|
207 |
+
|
208 |
+
# pytype static type analyzer
|
209 |
+
.pytype/
|
210 |
+
|
211 |
+
# Cython debug symbols
|
212 |
+
cython_debug/
|
213 |
+
|
214 |
+
# ruff
|
215 |
+
.ruff_cache
|
216 |
+
|
217 |
+
# PyCharm
|
218 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
219 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
220 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
221 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
222 |
+
#.idea/
|
223 |
+
CLIP
|
224 |
+
.devcontainer/devcontainer.json
|
225 |
+
|
226 |
+
# Coverage
|
227 |
+
.coverage
|
228 |
+
coverage.xml
|
229 |
+
|
230 |
+
# JUnit Reports
|
231 |
+
report.xml
|
232 |
+
|
233 |
+
# CI-CD
|
234 |
+
temp/
|
235 |
+
envs.txt
|
236 |
+
manifest.json
|
237 |
+
|
238 |
+
|
239 |
+
# locks and t5 temp files
|
240 |
+
*.locks*
|
241 |
+
*.no_exist*
|
242 |
+
*models--t5*
|
243 |
+
|
244 |
+
# OneLogger
|
245 |
+
wandb/
|
246 |
+
onelogger.err
|
247 |
+
onelogger.log
|
.gitmodules
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[submodule "gui/dependencies/pybind11"]
|
2 |
+
path = gui/dependencies/pybind11
|
3 |
+
url = https://github.com/Tom94/pybind11
|
4 |
+
[submodule "gui/dependencies/glfw"]
|
5 |
+
path = gui/dependencies/glfw
|
6 |
+
url = https://github.com/Tom94/glfw
|
7 |
+
[submodule "gui/dependencies/args"]
|
8 |
+
path = gui/dependencies/args
|
9 |
+
url = https://github.com/Taywee/args
|
10 |
+
[submodule "gui/dependencies/tinylogger"]
|
11 |
+
path = gui/dependencies/tinylogger
|
12 |
+
url = https://github.com/Tom94/tinylogger
|
13 |
+
[submodule "gui/dependencies/imgui"]
|
14 |
+
path = gui/dependencies/imgui
|
15 |
+
url = https://github.com/ocornut/imgui.git
|
16 |
+
[submodule "gui/dependencies/dlss"]
|
17 |
+
path = gui/dependencies/dlss
|
18 |
+
url = https://github.com/NVIDIA/DLSS
|
19 |
+
[submodule "gui/dependencies/OpenXR-SDK"]
|
20 |
+
path = gui/dependencies/OpenXR-SDK
|
21 |
+
url = https://github.com/KhronosGroup/OpenXR-SDK.git
|
22 |
+
[submodule "gui/dependencies/zlib"]
|
23 |
+
path = gui/dependencies/zlib
|
24 |
+
url = https://github.com/Tom94/zlib
|
25 |
+
[submodule "gui/dependencies/fmt"]
|
26 |
+
path = gui/dependencies/fmt
|
27 |
+
url = https://github.com/fmtlib/fmt
|
.pre-commit-config.yaml
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
default_language_version:
|
17 |
+
python: python3.10
|
18 |
+
repos:
|
19 |
+
- repo: https://github.com/pycqa/flake8
|
20 |
+
rev: 6.0.0
|
21 |
+
hooks:
|
22 |
+
- id: flake8
|
23 |
+
args:
|
24 |
+
- --max-line-length=120
|
25 |
+
- --ignore=E501,F401,E203,E402,E265,E741,F841,F821,F811,W503,E231,E225,E702
|
26 |
+
exclude: ^dist/|^third_party/
|
27 |
+
|
28 |
+
- repo: https://github.com/psf/black
|
29 |
+
rev: 23.12.1
|
30 |
+
hooks:
|
31 |
+
- id: black
|
32 |
+
args: [--line-length=120]
|
33 |
+
exclude: ^dist/|^third_party/
|
34 |
+
|
35 |
+
- repo: https://github.com/timothycrosley/isort
|
36 |
+
rev: 5.12.0
|
37 |
+
hooks:
|
38 |
+
- id: isort
|
39 |
+
args: [--line-length=120]
|
40 |
+
|
41 |
+
- repo: https://github.com/MarcoGorelli/absolufy-imports
|
42 |
+
rev: v0.3.1
|
43 |
+
hooks:
|
44 |
+
- id: absolufy-imports
|
45 |
+
|
46 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
47 |
+
rev: v4.0.1
|
48 |
+
hooks:
|
49 |
+
- id: trailing-whitespace
|
50 |
+
exclude: ^tests/.*/fixtures/.*
|
51 |
+
args: [--markdown-linebreak-ext=md]
|
52 |
+
- id: end-of-file-fixer
|
53 |
+
exclude: ^tests/.*/fixtures/.*
|
54 |
+
- id: check-added-large-files
|
55 |
+
args: ['--maxkb=2000']
|
ATTRIBUTIONS.md
ADDED
The diff for this file is too large to render.
See raw diff
|
|
CONTRIBUTING.md
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# How to Contribute
|
2 |
+
|
3 |
+
We'd love to receive your patches and contributions. Please keep your PRs as draft until such time that you would like us to review them.
|
4 |
+
|
5 |
+
## Code Reviews
|
6 |
+
|
7 |
+
All submissions, including submissions by project members, require review. We use GitHub pull requests for this purpose. Consult
|
8 |
+
[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more information on using pull requests.
|
9 |
+
|
10 |
+
## Signing Your Work
|
11 |
+
|
12 |
+
* We require that all contributors "sign-off" on their commits. This certifies that the contribution is your original work, or you have rights to submit it under the same license, or a compatible license.
|
13 |
+
|
14 |
+
* Any contribution which contains commits that are not Signed-Off will not be accepted.
|
15 |
+
|
16 |
+
* To sign off on a commit you simply use the `--signoff` (or `-s`) option when committing your changes:
|
17 |
+
```bash
|
18 |
+
$ git commit -s -m "Add cool feature."
|
19 |
+
```
|
20 |
+
This will append the following to your commit message:
|
21 |
+
```
|
22 |
+
Signed-off-by: Your Name <your@email.com>
|
23 |
+
```
|
24 |
+
|
25 |
+
* Full text of the DCO:
|
26 |
+
|
27 |
+
```
|
28 |
+
Developer Certificate of Origin
|
29 |
+
Version 1.1
|
30 |
+
|
31 |
+
Copyright (C) 2004, 2006 The Linux Foundation and its contributors.
|
32 |
+
1 Letterman Drive
|
33 |
+
Suite D4700
|
34 |
+
San Francisco, CA, 94129
|
35 |
+
|
36 |
+
Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed.
|
37 |
+
```
|
38 |
+
|
39 |
+
```
|
40 |
+
Developer's Certificate of Origin 1.1
|
41 |
+
|
42 |
+
By making a contribution to this project, I certify that:
|
43 |
+
|
44 |
+
(a) The contribution was created in whole or in part by me and I have the right to submit it under the open source license indicated in the file; or
|
45 |
+
|
46 |
+
(b) The contribution is based upon previous work that, to the best of my knowledge, is covered under an appropriate open source license and I have the right under that license to submit that work with modifications, whether created in whole or in part by me, under the same open source license (unless I am permitted to submit under a different license), as indicated in the file; or
|
47 |
+
|
48 |
+
(c) The contribution was provided directly to me by some other person who certified (a), (b) or (c) and I have not modified it.
|
49 |
+
|
50 |
+
(d) I understand and agree that this project and the contribution are public and that a record of the contribution (including all personal information I submit with it, including my sign-off) is maintained indefinitely and may be redistributed consistent with this project or the open source license(s) involved.
|
51 |
+
```
|
INSTALL.md
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Environment setup
|
2 |
+
|
3 |
+
Cosmos runs only on Linux systems. We have tested the installation with Ubuntu 24.04, 22.04, and 20.04.
|
4 |
+
Cosmos requires the Python version to be `3.10.x`. Please also make sure you have `conda` installed ([instructions](https://docs.conda.io/projects/conda/en/latest/user-guide/install/index.html)).
|
5 |
+
|
6 |
+
### Inference
|
7 |
+
|
8 |
+
The below commands creates the `cosmos-predict1` conda environment and installs the dependencies for inference:
|
9 |
+
```bash
|
10 |
+
# Create the cosmos-predict1 conda environment.
|
11 |
+
conda env create --file cosmos-predict1.yaml
|
12 |
+
# Activate the cosmos-predict1 conda environment.
|
13 |
+
conda activate cosmos-predict1
|
14 |
+
# Install the dependencies.
|
15 |
+
pip install -r requirements.txt
|
16 |
+
# Patch Transformer engine linking issues in conda environments.
|
17 |
+
ln -sf $CONDA_PREFIX/lib/python3.10/site-packages/nvidia/*/include/* $CONDA_PREFIX/include/
|
18 |
+
ln -sf $CONDA_PREFIX/lib/python3.10/site-packages/nvidia/*/include/* $CONDA_PREFIX/include/python3.10
|
19 |
+
# Install Transformer engine.
|
20 |
+
pip install transformer-engine[pytorch]==1.12.0
|
21 |
+
# Install Apex for inference.
|
22 |
+
git clone https://github.com/NVIDIA/apex
|
23 |
+
CUDA_HOME=$CONDA_PREFIX pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./apex
|
24 |
+
# Install MoGe for inference.
|
25 |
+
pip install git+https://github.com/microsoft/MoGe.git
|
26 |
+
```
|
27 |
+
|
28 |
+
* Alternatively, if you are more familiar with a containerized environment, you can build the dockerfile and run it to get an environment with all the packages pre-installed.
|
29 |
+
This requires docker to be already present on your system with the [Nvidia Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) installed.
|
30 |
+
|
31 |
+
```bash
|
32 |
+
docker build -f Dockerfile . -t nvcr.io/$USER/cosmos-predict1:latest
|
33 |
+
```
|
34 |
+
|
35 |
+
Note: In case you encounter permission issues while mounting local files inside the docker, you can share the folders from your current directory to all users (including docker) using this helpful alias `alias share='sudo chown -R ${USER}:users $PWD && sudo chmod g+w $PWD'` before running the docker.
|
36 |
+
|
37 |
+
|
38 |
+
You can test the environment setup for inference with
|
39 |
+
```bash
|
40 |
+
CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python scripts/test_environment.py
|
41 |
+
```
|
42 |
+
|
43 |
+
### Post-training
|
44 |
+
|
45 |
+
|
46 |
+
🛠️ *Under construction* 👷
|
47 |
+
|
48 |
+
Stay tuned!
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: GEN3C Project (from DGX Station)
|
3 |
+
emoji: 🫁
|
4 |
+
colorFrom: green
|
5 |
+
colorTo: blue
|
6 |
+
sdk: docker
|
7 |
+
image: elungky/gen3c:latest
|
8 |
+
# app_port: 7860 # Remove or comment this line as the image handles the port
|
9 |
+
---
|
10 |
+
|
11 |
+
# GEN3C: 3D-Informed World-Consistent Video Generation with Precise Camera Control
|
12 |
+
|
13 |
+
<!-- Note: this video is hosted by GitHub and gets embedded automatically when viewing in the GitHub UI -->
|
14 |
+
|
15 |
+
https://github.com/user-attachments/assets/247e1719-9f8f-4504-bfa3-f9706bd8682d
|
16 |
+
|
17 |
+
|
18 |
+
**GEN3C: 3D-Informed World-Consistent Video Generation with Precise Camera Control**<br>
|
19 |
+
[Xuanchi Ren*](https://xuanchiren.com/),
|
20 |
+
[Tianchang Shen*](https://www.cs.toronto.edu/~shenti11/),
|
21 |
+
[Jiahui Huang](https://huangjh-pub.github.io/),
|
22 |
+
[Huan Ling](https://www.cs.toronto.edu/~linghuan/),
|
23 |
+
[Yifan Lu](https://yifanlu0227.github.io/),
|
24 |
+
[Merlin Nimier-David](https://merlin.nimierdavid.fr/),
|
25 |
+
[Thomas Müller](https://research.nvidia.com/person/thomas-muller),
|
26 |
+
[Alexander Keller](https://research.nvidia.com/person/alex-keller),
|
27 |
+
[Sanja Fidler](https://www.cs.toronto.edu/~fidler/),
|
28 |
+
[Jun Gao](https://www.cs.toronto.edu/~jungao/) <br>
|
29 |
+
\* indicates equal contribution <br>
|
30 |
+
**[Paper](https://arxiv.org/pdf/2503.03751), [Project Page](https://research.nvidia.com/labs/toronto-ai/GEN3C/), [HuggingFace](https://huggingface.co/collections/nvidia/gen3c-683f3f9540a8f9c98cf46a8d)**
|
31 |
+
|
32 |
+
Abstract: We present GEN3C, a generative video model with precise Camera Control and
|
33 |
+
temporal 3D Consistency. Prior video models already generate realistic videos,
|
34 |
+
but they tend to leverage little 3D information, leading to inconsistencies,
|
35 |
+
such as objects popping in and out of existence. Camera control, if implemented
|
36 |
+
at all, is imprecise, because camera parameters are mere inputs to the neural
|
37 |
+
network which must then infer how the video depends on the camera. In contrast,
|
38 |
+
GEN3C is guided by a 3D cache: point clouds obtained by predicting the
|
39 |
+
pixel-wise depth of seed images or previously generated frames. When generating
|
40 |
+
the next frames, GEN3C is conditioned on the 2D renderings of the 3D cache with
|
41 |
+
the new camera trajectory provided by the user. Crucially, this means that
|
42 |
+
GEN3C neither has to remember what it previously generated nor does it have to
|
43 |
+
infer the image structure from the camera pose. The model, instead, can focus
|
44 |
+
all its generative power on previously unobserved regions, as well as advancing
|
45 |
+
the scene state to the next frame. Our results demonstrate more precise camera
|
46 |
+
control than prior work, as well as state-of-the-art results in sparse-view
|
47 |
+
novel view synthesis, even in challenging settings such as driving scenes and
|
48 |
+
monocular dynamic video. Results are best viewed in videos.
|
49 |
+
|
50 |
+
For business inquiries, please visit our website and submit the form: [NVIDIA Research Licensing](https://www.nvidia.com/en-us/research/inquiries/).
|
51 |
+
For any other questions related to the model, please contact Xuanchi, Tianchang or Jun.
|
52 |
+
|
53 |
+
## News
|
54 |
+
- 2025-06-06 Code and model released! In a future update, we plan to include the pipeline for jointly predicting depth and camera pose from video, as well as a driving-finetuned model. Stay tuned!
|
55 |
+
|
56 |
+
## Installation
|
57 |
+
Please follow the "Inference" section in [INSTALL.md](INSTALL.md) to set up your environment.
|
58 |
+
|
59 |
+
## Inference
|
60 |
+
|
61 |
+
### Download checkpoints
|
62 |
+
1. Generate a [Hugging Face](https://huggingface.co/settings/tokens) access token (if you haven't done so already). Set the access token to `Read` permission (default is `Fine-grained`).
|
63 |
+
|
64 |
+
2. Log in to Hugging Face with the access token:
|
65 |
+
```bash
|
66 |
+
huggingface-cli login
|
67 |
+
```
|
68 |
+
|
69 |
+
3. Download the GEN3C model weights from [Hugging Face](https://huggingface.co/nvidia/GEN3C-Cosmos-7B):
|
70 |
+
```bash
|
71 |
+
CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python scripts/download_gen3c_checkpoints.py --checkpoint_dir checkpoints
|
72 |
+
```
|
73 |
+
|
74 |
+
### Interactive GUI usage
|
75 |
+
|
76 |
+
<div align="center">
|
77 |
+
<img src="gui/assets/gui_preview.webp" alt="GEN3C interactive GUI" width="1080px"/>
|
78 |
+
</div>
|
79 |
+
|
80 |
+
GEN3C can be used through an interactive GUI, allowing to visualize the inputs in 3D, author arbitrary camera trajectories, and start inference from a single window.
|
81 |
+
Please see the [dedicated instructions](gui/README.md).
|
82 |
+
|
83 |
+
|
84 |
+
### Command-line usage
|
85 |
+
GEN3C supports both images and videos as input. Below are examples of running GEN3C on single images and videos with predefined camera trajectory patterns.
|
86 |
+
|
87 |
+
### Example 1: Single Image to Video Generation
|
88 |
+
|
89 |
+
#### Single GPU
|
90 |
+
Generate a 121-frame video from a single image:
|
91 |
+
```bash
|
92 |
+
CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python cosmos_predict1/diffusion/inference/gen3c_single_image.py \
|
93 |
+
--checkpoint_dir checkpoints \
|
94 |
+
--input_image_path assets/diffusion/000000.png \
|
95 |
+
--video_save_name test_single_image \
|
96 |
+
--guidance 1 \
|
97 |
+
--foreground_masking
|
98 |
+
```
|
99 |
+
|
100 |
+
#### Multi-GPU (8 GPUs)
|
101 |
+
```bash
|
102 |
+
NUM_GPUS=8
|
103 |
+
CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) torchrun --nproc_per_node=${NUM_GPUS} cosmos_predict1/diffusion/inference/gen3c_single_image.py \
|
104 |
+
--checkpoint_dir checkpoints \
|
105 |
+
--input_image_path assets/diffusion/000000.png \
|
106 |
+
--video_save_name test_single_image_multigpu \
|
107 |
+
--num_gpus ${NUM_GPUS} \
|
108 |
+
--guidance 1 \
|
109 |
+
--foreground_masking
|
110 |
+
```
|
111 |
+
|
112 |
+
#### Additional Options
|
113 |
+
- To generate longer videos autoregressively, specify the number of frames using `--num_video_frames`. The number of frames must follow the pattern: 121 * N - 1 (e.g., 241, 361, etc.)
|
114 |
+
- To save buffer images alongside the output video, add the `--save_buffer` flag
|
115 |
+
- You can control camera trajectories using `--trajectory`, `--camera_rotation`, and `--movement_distance` arguments. See the "Camera Movement Options" section below for details.
|
116 |
+
|
117 |
+
#### Camera Movement Options
|
118 |
+
|
119 |
+
##### Trajectory Types
|
120 |
+
The `--trajectory` argument controls the path the camera takes during video generation. Available options:
|
121 |
+
|
122 |
+
| Option | Description |
|
123 |
+
|--------|-------------|
|
124 |
+
| `left` | Camera moves to the left (default) |
|
125 |
+
| `right` | Camera moves to the right |
|
126 |
+
| `up` | Camera moves upward |
|
127 |
+
| `down` | Camera moves downward |
|
128 |
+
| `zoom_in` | Camera moves closer to the scene |
|
129 |
+
| `zoom_out` | Camera moves away from the scene |
|
130 |
+
| `clockwise` | Camera moves in a clockwise circular path |
|
131 |
+
| `counterclockwise` | Camera moves in a counterclockwise circular path |
|
132 |
+
|
133 |
+
##### Camera Rotation Modes
|
134 |
+
The `--camera_rotation` argument controls how the camera rotates during movement. Available options:
|
135 |
+
|
136 |
+
| Option | Description |
|
137 |
+
|--------|-------------|
|
138 |
+
| `center_facing` | Camera always rotates to look at the (estimated) center of the scene (default) |
|
139 |
+
| `no_rotation` | Camera maintains its original orientation while moving |
|
140 |
+
| `trajectory_aligned` | Camera rotates to align with the direction of movement |
|
141 |
+
|
142 |
+
##### Movement Distance
|
143 |
+
The `--movement_distance` argument controls how far the camera moves from its initial position. The default value is 0.3. A larger value will result in more dramatic camera movement, while a smaller value will create more subtle movement.
|
144 |
+
|
145 |
+
##### GPU Memory Requirements
|
146 |
+
|
147 |
+
We have tested GEN3C only on H100 and A100 GPUs. For GPUs with limited memory, you can fully offload all models by appending the following flags to your command:
|
148 |
+
|
149 |
+
```bash
|
150 |
+
--offload_diffusion_transformer \
|
151 |
+
--offload_tokenizer \
|
152 |
+
--offload_text_encoder_model \
|
153 |
+
--offload_prompt_upsampler \
|
154 |
+
--offload_guardrail_models \
|
155 |
+
--disable_guardrail \
|
156 |
+
--disable_prompt_encoder
|
157 |
+
```
|
158 |
+
Maximum observed memory during inference with full offloading: ~43GB. Note: Memory usage may vary depending on system specifications and is provided for reference only.
|
159 |
+
|
160 |
+
|
161 |
+
### Example 2: Video to Video Generation
|
162 |
+
For video input, GEN3C requires additional depth information, camera intrinsics, and extrinsics. These can be obtained using your choice of SLAM packages. For testing purposes, we provide example data.
|
163 |
+
|
164 |
+
First, you need to download the test samples:
|
165 |
+
```bash
|
166 |
+
# Download test samples from Hugging Face
|
167 |
+
huggingface-cli download nvidia/GEN3C-Testing-Example --repo-type dataset --local-dir assets/diffusion/dynamic_video_samples
|
168 |
+
```
|
169 |
+
|
170 |
+
#### Single GPU
|
171 |
+
```bash
|
172 |
+
CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python cosmos_predict1/diffusion/inference/gen3c_dynamic.py \
|
173 |
+
--checkpoint_dir checkpoints \
|
174 |
+
--input_image_path assets/diffusion/dynamic_video_samples/batch_0000 \
|
175 |
+
--video_save_name test_dynamic_video \
|
176 |
+
--guidance 1
|
177 |
+
```
|
178 |
+
|
179 |
+
#### Multi-GPU (8 GPUs)
|
180 |
+
```bash
|
181 |
+
NUM_GPUS=8
|
182 |
+
CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) torchrun --nproc_per_node=${NUM_GPUS} cosmos_predict1/diffusion/inference/gen3c_dynamic.py \
|
183 |
+
--checkpoint_dir checkpoints \
|
184 |
+
--input_image_path assets/diffusion/dynamic_video_samples/batch_0000 \
|
185 |
+
--video_save_name test_dynamic_video_multigpu \
|
186 |
+
--num_gpus ${NUM_GPUS} \
|
187 |
+
--guidance 1
|
188 |
+
```
|
189 |
+
|
190 |
+
## Gallery
|
191 |
+
|
192 |
+
- **GEN3C** can be easily applied to video/scene creation from a single image
|
193 |
+
<div align="center">
|
194 |
+
<img src="assets/demo_3.gif" alt="" width="1100" />
|
195 |
+
</div>
|
196 |
+
|
197 |
+
- ... or sparse-view images (we use 5 images here)
|
198 |
+
<div align="center">
|
199 |
+
<img src="assets/demo_2.gif" alt="" width="1100" />
|
200 |
+
</div>
|
201 |
+
|
202 |
+
|
203 |
+
- .. and dynamic videos
|
204 |
+
<div align="center">
|
205 |
+
<img src="assets/demo_dynamic.gif" alt="" width="1100" />
|
206 |
+
</div>
|
207 |
+
|
208 |
+
## Acknowledgement
|
209 |
+
Our model is based on [NVIDIA Cosmos](https://github.com/NVIDIA/Cosmos) and [Stable Video Diffusion](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid).
|
210 |
+
|
211 |
+
We are also grateful to several other open-source repositories that we drew inspiration from or built upon during the development of our pipeline:
|
212 |
+
- [MoGe](https://github.com/microsoft/MoGe)
|
213 |
+
- [TrajectoryCrafter](https://github.com/TrajectoryCrafter/TrajectoryCrafter)
|
214 |
+
- [DimensionX](https://github.com/wenqsun/DimensionX)
|
215 |
+
- [Depth Anything V2](https://github.com/DepthAnything/Depth-Anything-V2)
|
216 |
+
- [Video Depth Anything](https://github.com/DepthAnything/Video-Depth-Anything)
|
217 |
+
|
218 |
+
## Citation
|
219 |
+
```
|
220 |
+
@inproceedings{ren2025gen3c,
|
221 |
+
title={GEN3C: 3D-Informed World-Consistent Video Generation with Precise Camera Control},
|
222 |
+
author={Ren, Xuanchi and Shen, Tianchang and Huang, Jiahui and Ling, Huan and
|
223 |
+
Lu, Yifan and Nimier-David, Merlin and Müller, Thomas and Keller, Alexander and
|
224 |
+
Fidler, Sanja and Gao, Jun},
|
225 |
+
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
|
226 |
+
year={2025}
|
227 |
+
}
|
228 |
+
```
|
229 |
+
|
230 |
+
## License and Contact
|
231 |
+
|
232 |
+
This project will download and install additional third-party open source software projects. Review the license terms of these open source projects before use.
|
233 |
+
|
234 |
+
|
235 |
+
GEN3C source code is released under the [Apache 2 License](https://www.apache.org/licenses/LICENSE-2.0).
|
236 |
+
|
237 |
+
GEN3C models are released under the [NVIDIA Open Model License](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). For a custom license, please visit our website and submit the form: [NVIDIA Research Licensing](https://www.nvidia.com/en-us/research/inquiries/).
|
238 |
+
=======
|
239 |
+
title: Gen3c
|
240 |
+
emoji: 🌍
|
241 |
+
colorFrom: indigo
|
242 |
+
colorTo: blue
|
243 |
+
sdk: docker
|
244 |
+
pinned: false
|
245 |
+
---
|
246 |
+
|
247 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
248 |
+
>>>>>>> 0453ffbfce197070bb0c254a11ef21f15d1ad986
|
assets/demo_1.gif
ADDED
![]() |
Git LFS Details
|
assets/demo_2.gif
ADDED
![]() |
Git LFS Details
|
assets/demo_3.gif
ADDED
![]() |
Git LFS Details
|
assets/demo_dynamic.gif
ADDED
![]() |
Git LFS Details
|
assets/diffusion/000000.png
ADDED
![]() |
Git LFS Details
|
assets/diffusion/000001.png
ADDED
![]() |
Git LFS Details
|
assets/diffusion/000002.png
ADDED
![]() |
Git LFS Details
|
assets/diffusion/000003.png
ADDED
![]() |
Git LFS Details
|
assets/diffusion/000004.png
ADDED
![]() |
Git LFS Details
|
assets/diffusion/000005.png
ADDED
![]() |
Git LFS Details
|
assets/diffusion/000006.png
ADDED
![]() |
Git LFS Details
|
assets/diffusion/000007.png
ADDED
![]() |
Git LFS Details
|
assets/diffusion/000008.png
ADDED
![]() |
Git LFS Details
|
assets/diffusion/000009.png
ADDED
![]() |
Git LFS Details
|
assets/diffusion/000010.png
ADDED
![]() |
Git LFS Details
|
assets/diffusion/000011.png
ADDED
![]() |
Git LFS Details
|
assets/diffusion/000012.png
ADDED
![]() |
Git LFS Details
|
assets/diffusion/000013.png
ADDED
![]() |
Git LFS Details
|
assets/diffusion/000014.png
ADDED
![]() |
Git LFS Details
|
assets/diffusion/000015.png
ADDED
![]() |
Git LFS Details
|
checkpoints/README.md
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
### Checkpoint directory
|
3 |
+
|
4 |
+
Model checkpoints will be downloaded to this directory.
|
cosmos-predict1.yaml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# conda env create --file cosmos-predict1.yaml
|
17 |
+
name: cosmos-predict1
|
18 |
+
channels:
|
19 |
+
- conda-forge
|
20 |
+
dependencies:
|
21 |
+
- python=3.10
|
22 |
+
- pip=25.0
|
23 |
+
- cmake
|
24 |
+
- ninja
|
25 |
+
- gcc=12.4.0
|
26 |
+
- gxx=12.4.0
|
27 |
+
- cuda=12.4
|
28 |
+
- cuda-nvcc=12.4
|
29 |
+
- cuda-toolkit=12.4
|
cosmos_predict1/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
cosmos_predict1/autoregressive/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
cosmos_predict1/autoregressive/callbacks/video_sampling_teacher_forcing.py
ADDED
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import glob
|
17 |
+
import math
|
18 |
+
import os
|
19 |
+
from typing import Optional
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
import torch
|
23 |
+
import torchvision
|
24 |
+
import torchvision.transforms.functional as torchvision_F
|
25 |
+
import wandb
|
26 |
+
from einops import rearrange
|
27 |
+
from megatron.core import parallel_state
|
28 |
+
from torch.distributed import get_process_group_ranks
|
29 |
+
|
30 |
+
from cosmos_predict1.autoregressive.utils.parallel import (
|
31 |
+
broadcast_data_batch_in_tp_cp_group,
|
32 |
+
gather_batch_from_cp_ranks,
|
33 |
+
get_batch_on_this_cp_rank,
|
34 |
+
)
|
35 |
+
from cosmos_predict1.callbacks.every_n import EveryN
|
36 |
+
from cosmos_predict1.utils import distributed, log, misc
|
37 |
+
from cosmos_predict1.utils.model import Model
|
38 |
+
from cosmos_predict1.utils.trainer import Trainer
|
39 |
+
|
40 |
+
|
41 |
+
def resize_image(image: torch.Tensor, resize_factor=0.5) -> torch.Tensor:
|
42 |
+
_, _, h, w = image.shape
|
43 |
+
new_h, new_w = int(resize_factor * h), int(resize_factor * w)
|
44 |
+
return torchvision_F.resize(image, (new_h, new_w))
|
45 |
+
|
46 |
+
|
47 |
+
class VideoSamplingTeacherForcing(EveryN):
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
every_n: int,
|
51 |
+
step_size: int = 1,
|
52 |
+
video_latent_shape: list = [6, 24, 40],
|
53 |
+
num_frames_to_display: int = 4,
|
54 |
+
save_folder: Optional[str] = None,
|
55 |
+
num_file_to_log: int = 8,
|
56 |
+
):
|
57 |
+
r"""
|
58 |
+
This callback enables us to perform teacher forcing inference on the training data.
|
59 |
+
By teacher forcing, we mean providing ground truth video tokens as inputs, and simply asking the model
|
60 |
+
to predict the next tokens. The predicted next tokens are then visualized. This does not perform
|
61 |
+
autoregressive sampling.
|
62 |
+
We also upload the downsampled video frames to wandb. Downsampling is needed for wandb to work fast.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
every_n (int): Call this callback every_n steps
|
66 |
+
step_size (int): Number of steps taken for gradient accumulation. Global iteration number is
|
67 |
+
iteration // self.step_size
|
68 |
+
video_latent_shape (list): Shape of the video latent
|
69 |
+
num_frames_to_display (int): Number of frames to subsample for displaying in wandb
|
70 |
+
save_folder (str): Name of the local folder to save the video
|
71 |
+
num_file_to_log (int): Number of files to upload to wandb
|
72 |
+
"""
|
73 |
+
super().__init__(every_n, step_size)
|
74 |
+
self.save_folder = save_folder if save_folder else self.__class__.__name__
|
75 |
+
self.video_latent_shape = video_latent_shape
|
76 |
+
self.num_frames_to_display = num_frames_to_display
|
77 |
+
self.num_file_to_log = num_file_to_log
|
78 |
+
self.rank = distributed.get_rank()
|
79 |
+
|
80 |
+
def on_train_start(self, model: Model, iteration: int = 0) -> None:
|
81 |
+
config_job = self.config.job
|
82 |
+
self.local_dir = f"{config_job.path_local}/{self.save_folder}"
|
83 |
+
if self.rank == 0:
|
84 |
+
os.makedirs(self.local_dir, exist_ok=True)
|
85 |
+
log.info(f"Video Teacher-Forcing Callback: local_dir: {self.local_dir}")
|
86 |
+
|
87 |
+
@torch.inference_mode()
|
88 |
+
def every_n_impl(
|
89 |
+
self,
|
90 |
+
trainer: Trainer,
|
91 |
+
model: Model,
|
92 |
+
data_batch: dict[str, torch.Tensor],
|
93 |
+
output_batch: dict[str, torch.Tensor],
|
94 |
+
loss: torch.Tensor,
|
95 |
+
iteration: int,
|
96 |
+
) -> None:
|
97 |
+
# Tokenize the data
|
98 |
+
|
99 |
+
broadcast_data_batch_in_tp_cp_group(data_batch)
|
100 |
+
|
101 |
+
input_vid = data_batch[model.tokenizer.tokenizer_config.video_tokenizer.data_key]
|
102 |
+
|
103 |
+
dataset_name = data_batch.get("dataset_name", None)
|
104 |
+
if dataset_name is not None and dataset_name.startswith("image"):
|
105 |
+
# we disable the callback if the input video is an image batch
|
106 |
+
log.info(f"dataset_name is {dataset_name}, skip this callback")
|
107 |
+
return
|
108 |
+
|
109 |
+
# get the caption
|
110 |
+
captions = data_batch.get("caption", None)
|
111 |
+
|
112 |
+
# get the context embedding and mask
|
113 |
+
context = data_batch.get("context", None)
|
114 |
+
context_mask = data_batch.get("context_mask", None)
|
115 |
+
if context is not None:
|
116 |
+
context = misc.to(context, "cuda").detach().clone()
|
117 |
+
if context_mask is not None:
|
118 |
+
context_mask = misc.to(context_mask, "cuda").detach().clone()
|
119 |
+
# get the action
|
120 |
+
action = data_batch.get("action", None)
|
121 |
+
if action is not None:
|
122 |
+
action = misc.to(action, "cuda").detach().clone()
|
123 |
+
|
124 |
+
# Input tokens
|
125 |
+
tokens, _ = model.tokenizer.tokenize(data_batch)
|
126 |
+
tokens = misc.to(tokens, "cuda").detach().clone()
|
127 |
+
skip_save_file = False
|
128 |
+
if parallel_state.get_context_parallel_world_size() > 1:
|
129 |
+
cp_group = parallel_state.get_context_parallel_group()
|
130 |
+
if self.rank != min(get_process_group_ranks(cp_group)):
|
131 |
+
skip_save_file = True
|
132 |
+
tokens = get_batch_on_this_cp_rank(tokens)
|
133 |
+
if parallel_state.get_tensor_model_parallel_world_size() > 1:
|
134 |
+
# Turn on TP
|
135 |
+
tp_group = parallel_state.get_tensor_model_parallel_group()
|
136 |
+
if self.rank != min(get_process_group_ranks(tp_group)):
|
137 |
+
skip_save_file = True
|
138 |
+
tokens_encoded_in_train = output_batch["encode_tokens"].detach()
|
139 |
+
percent_token_diff = (tokens != tokens_encoded_in_train).float().mean()
|
140 |
+
percent_token_diff = distributed.dist_reduce_tensor(percent_token_diff)
|
141 |
+
|
142 |
+
input_tokens = tokens
|
143 |
+
|
144 |
+
num_tokens_to_generate = np.prod(self.video_latent_shape)
|
145 |
+
|
146 |
+
# Do a forward pass
|
147 |
+
logits = model.model.forward(
|
148 |
+
tokens,
|
149 |
+
input_pos=None,
|
150 |
+
context=context,
|
151 |
+
context_mask=context_mask,
|
152 |
+
action=action,
|
153 |
+
)
|
154 |
+
if parallel_state.get_context_parallel_world_size() > 1:
|
155 |
+
logits = gather_batch_from_cp_ranks(logits)
|
156 |
+
input_tokens = gather_batch_from_cp_ranks(input_tokens)
|
157 |
+
|
158 |
+
# Start position for video tokens in the vocabulary
|
159 |
+
video_token_start = self.config.model.tokenizer_config.video_tokenizer.tokenizer_offset
|
160 |
+
video_vocab_size = self.config.model.tokenizer_config.video_tokenizer.vocab_size
|
161 |
+
|
162 |
+
# Clipping logits only to video tokens. We remove the text vocab predictions.
|
163 |
+
# This will ensure that the video tokens only correspond to the video part of the vocabulary.
|
164 |
+
logits = logits[:, :, video_token_start : video_token_start + video_vocab_size]
|
165 |
+
|
166 |
+
# Sample with argmax token. This should be good for teacher forcing experiment.
|
167 |
+
logits = logits.contiguous()
|
168 |
+
generations = torch.argmax(logits, dim=-1)
|
169 |
+
|
170 |
+
# For each video in the batch, subsample frames for display
|
171 |
+
batch_size = input_tokens.shape[0]
|
172 |
+
out_frames = []
|
173 |
+
out_videos_gen = []
|
174 |
+
out_videos_rec = []
|
175 |
+
out_videos_gt = []
|
176 |
+
# log the accuracy of teacher-forcing
|
177 |
+
acc = []
|
178 |
+
loss_list = []
|
179 |
+
|
180 |
+
for sample_num in range(batch_size):
|
181 |
+
# Subsample the generations to the video part.
|
182 |
+
# This corresponds to the part from begin of video to end of video.
|
183 |
+
bov_token = model.tokenizer.video_special_tokens["<|begin_of_video|>"]
|
184 |
+
bov_index = input_tokens[sample_num] == bov_token
|
185 |
+
use_special_token = sum(bov_index) != 0
|
186 |
+
if use_special_token:
|
187 |
+
bov_index = bov_index.nonzero().item()
|
188 |
+
# generations: <bov> real_token1 real_token2, ... real_token7680; total 7680
|
189 |
+
# gen_video_tokens: real_token1 real_token2, ..., real_token7680; total 7680
|
190 |
+
# for vis: real_token1 real_token2, ..., real_token7680; total 7680
|
191 |
+
# for accuracy: real_token1 real_token2, ..., real_token7680; total 7680
|
192 |
+
gen_video_tokens = generations[sample_num][bov_index : bov_index + num_tokens_to_generate]
|
193 |
+
gen_video_tokens_vis = gen_video_tokens
|
194 |
+
gen_video_tokens_acc = gen_video_tokens
|
195 |
+
logits_loss = logits[sample_num][bov_index : bov_index + num_tokens_to_generate]
|
196 |
+
else:
|
197 |
+
# generations: real_token1 real_token2, ... real_token7680
|
198 |
+
# gen_video_tokens: real_token2 real_token3, ..., real_token7680; total 7679
|
199 |
+
# We need different tokens for vis and accuracy compute
|
200 |
+
# for acc: real_token2 real_token3, ..., real_token7680; total 7679
|
201 |
+
# for vis: pad_token (real_token2, ..., real_token7680); total 1 + 7679
|
202 |
+
gen_video_tokens = generations[sample_num][
|
203 |
+
: num_tokens_to_generate - 1
|
204 |
+
] # remove the last token since there is no gt
|
205 |
+
# Since the first token is not predicted, we need to add the gt first token to make sure the shape is correct
|
206 |
+
gen_video_tokens_vis = torch.cat([input_tokens[sample_num][0:1], gen_video_tokens])
|
207 |
+
gen_video_tokens_acc = gen_video_tokens
|
208 |
+
logits_loss = logits[sample_num][: num_tokens_to_generate - 1]
|
209 |
+
|
210 |
+
# Rearrange the video to a spatial tensor
|
211 |
+
gen_video_tokens_vis_BTHW = rearrange(
|
212 |
+
gen_video_tokens_vis.unsqueeze(0),
|
213 |
+
"B (T H W) -> B T H W",
|
214 |
+
T=self.video_latent_shape[0],
|
215 |
+
H=self.video_latent_shape[1],
|
216 |
+
W=self.video_latent_shape[2],
|
217 |
+
)
|
218 |
+
|
219 |
+
# for real videos, we need to skip the bov and eov tokens for decoding
|
220 |
+
if use_special_token:
|
221 |
+
# input_tokens: <bov> real_token1 real_token2 ... <eov> <eov> ...
|
222 |
+
# real_video_tokens: real_token1 real_token2 ... real_token7680; total 7680
|
223 |
+
# for vis: real_token1 real_token2 ... real_token7680; total 7680
|
224 |
+
# for accuracy: real_token1 real_token2 ... real_token7680; total 7680; we include real_token1 since the output prediction also includes it, see gen_video_tokens_acc above
|
225 |
+
real_video_tokens = (
|
226 |
+
input_tokens[sample_num][bov_index + 1 : bov_index + num_tokens_to_generate + 1] - video_token_start
|
227 |
+
)
|
228 |
+
real_video_tokens_vis = real_video_tokens
|
229 |
+
real_video_tokens_acc = real_video_tokens
|
230 |
+
else:
|
231 |
+
# input_tokens: real_token1 real_token2,... real_token7680; total 7680
|
232 |
+
# real_video_tokens: real_token1 real_token2,... real_token7680; total 7680
|
233 |
+
# for acc: gt start from real_token2, real_token3; total 7679, remove the first token since it is not predicted
|
234 |
+
# for vis: gt start from real_token1, real_token2; total 7680
|
235 |
+
real_video_tokens = input_tokens[sample_num][:num_tokens_to_generate] - video_token_start
|
236 |
+
real_video_tokens_vis = real_video_tokens
|
237 |
+
real_video_tokens_acc = real_video_tokens[1:].flatten()
|
238 |
+
|
239 |
+
real_video_tokens_vis_BTHW = rearrange(
|
240 |
+
real_video_tokens_vis.unsqueeze(0),
|
241 |
+
"B (T H W) -> B T H W",
|
242 |
+
T=self.video_latent_shape[0],
|
243 |
+
H=self.video_latent_shape[1],
|
244 |
+
W=self.video_latent_shape[2],
|
245 |
+
)
|
246 |
+
# Calculate accuracy
|
247 |
+
correct_predictions = (gen_video_tokens_acc == real_video_tokens_acc).float()
|
248 |
+
labels = real_video_tokens_acc.clone()
|
249 |
+
|
250 |
+
if model.config.ignore_first_num_tokens > 0:
|
251 |
+
labels[: model.config.ignore_first_num_tokens] = model.tokenizer.ignore_index
|
252 |
+
select_index = labels != model.tokenizer.ignore_index
|
253 |
+
correct_predictions = correct_predictions[select_index]
|
254 |
+
|
255 |
+
loss = torch.nn.functional.cross_entropy(
|
256 |
+
logits_loss, labels, ignore_index=model.tokenizer.ignore_index, reduction="none"
|
257 |
+
)
|
258 |
+
acc.append(correct_predictions.mean() * 100.0)
|
259 |
+
loss_list.append(loss.mean())
|
260 |
+
|
261 |
+
# Decode the predicted latents
|
262 |
+
if model.tokenizer.tokenizer_config.video_tokenizer.temporal_overlap == 0:
|
263 |
+
vid_decoded = model.tokenizer.video_tokenizer.decode(gen_video_tokens_vis_BTHW.cuda())
|
264 |
+
else:
|
265 |
+
vid_decoded = model.tokenizer.video_tokenizer.decode_with_overlap(
|
266 |
+
gen_video_tokens_vis_BTHW.cuda(),
|
267 |
+
temporal_overlap=model.tokenizer.tokenizer_config.video_tokenizer.temporal_overlap,
|
268 |
+
)
|
269 |
+
# normalize decoded images from [-1, 1] to [0, 1], and clip value
|
270 |
+
vid_decoded = (vid_decoded * 0.5 + 0.5).clamp_(0, 1)
|
271 |
+
vid_decoded = vid_decoded[0]
|
272 |
+
|
273 |
+
# Decode the GT latents
|
274 |
+
if model.tokenizer.tokenizer_config.video_tokenizer.temporal_overlap == 0:
|
275 |
+
vid_rec = model.tokenizer.video_tokenizer.decode(real_video_tokens_vis_BTHW.cuda())
|
276 |
+
else:
|
277 |
+
vid_rec = model.tokenizer.video_tokenizer.decode_with_overlap(
|
278 |
+
real_video_tokens_vis_BTHW.cuda(),
|
279 |
+
temporal_overlap=model.tokenizer.tokenizer_config.video_tokenizer.temporal_overlap,
|
280 |
+
)
|
281 |
+
# normalize decoded image from [-1, 1] to [0, 1], and clip value
|
282 |
+
vid_rec = (vid_rec * 0.5 + 0.5).clamp_(0, 1)
|
283 |
+
vid_rec = vid_rec[0]
|
284 |
+
|
285 |
+
vid_input = input_vid[sample_num] # [-1, 1], input_vid shape: [B, C, L, H, W]
|
286 |
+
vid_input = (vid_input * 0.5 + 0.5).clamp_(0, 1).cuda() # Convert to [0, 1], [C, L, H, W]
|
287 |
+
|
288 |
+
# Subsample real and generated video frames
|
289 |
+
input_video_frames = vid_input.transpose(0, 1) # [L, C, H, W]
|
290 |
+
rec_video_frames = vid_rec.transpose(0, 1)
|
291 |
+
gen_video_frames = vid_decoded.transpose(0, 1)
|
292 |
+
out_videos_gen.append(gen_video_frames)
|
293 |
+
out_videos_rec.append(rec_video_frames)
|
294 |
+
out_videos_gt.append(input_video_frames)
|
295 |
+
|
296 |
+
stride = math.ceil(rec_video_frames.shape[0] / self.num_frames_to_display)
|
297 |
+
|
298 |
+
input_video_frames_subsampled = resize_image(input_video_frames[0::stride], resize_factor=0.5)
|
299 |
+
input_video_frames_subsampled = torchvision.utils.make_grid(
|
300 |
+
input_video_frames_subsampled, nrow=input_video_frames_subsampled.shape[0]
|
301 |
+
)
|
302 |
+
|
303 |
+
gt_video_frames_subsampled = resize_image(rec_video_frames[0::stride], resize_factor=0.5)
|
304 |
+
gt_video_frames_subsampled = torchvision.utils.make_grid(
|
305 |
+
gt_video_frames_subsampled, nrow=gt_video_frames_subsampled.shape[0]
|
306 |
+
)
|
307 |
+
gen_video_frames_subsampled = resize_image(gen_video_frames[0::stride], resize_factor=0.5)
|
308 |
+
gen_video_frames_subsampled = torchvision.utils.make_grid(
|
309 |
+
gen_video_frames_subsampled, nrow=gen_video_frames_subsampled.shape[0]
|
310 |
+
)
|
311 |
+
|
312 |
+
out_frames.append(input_video_frames_subsampled)
|
313 |
+
out_frames.append(gt_video_frames_subsampled)
|
314 |
+
out_frames.append(gen_video_frames_subsampled)
|
315 |
+
|
316 |
+
scaled_num_rank_to_log = (
|
317 |
+
self.num_file_to_log
|
318 |
+
* parallel_state.get_context_parallel_world_size()
|
319 |
+
* parallel_state.get_tensor_model_parallel_world_size()
|
320 |
+
)
|
321 |
+
if self.rank < scaled_num_rank_to_log and not skip_save_file:
|
322 |
+
local_path = f"{self.local_dir}/vid_teacher_forcing_iter_{iteration:09d}_{self.rank:04d}.jpg"
|
323 |
+
out_image_grid = torchvision.utils.make_grid(out_frames, nrow=1, padding=0, normalize=False)
|
324 |
+
os.makedirs(os.path.dirname(local_path), exist_ok=True)
|
325 |
+
torchvision.utils.save_image(out_image_grid, local_path)
|
326 |
+
|
327 |
+
# Log to wandb
|
328 |
+
avg_acc = distributed.dist_reduce_tensor(torch.stack(acc).mean()).item()
|
329 |
+
avg_loss = distributed.dist_reduce_tensor(torch.stack(loss_list).mean()).item()
|
330 |
+
log_info = ""
|
331 |
+
if "acc" in output_batch:
|
332 |
+
log_info = f"train acc: {(output_batch['acc'].mean().item()):.6f}%"
|
333 |
+
if percent_token_diff is not None:
|
334 |
+
log_info += f"; percent_token_diff_train_val: {percent_token_diff.item() * 100:.6f}%"
|
335 |
+
log.info(
|
336 |
+
f"Eval iteration {iteration} teacher-forcing accuracy: {avg_acc:.6f}%, loss: {avg_loss:.4f}; {log_info}"
|
337 |
+
)
|
338 |
+
if self.rank == 0 and wandb.run:
|
339 |
+
local_files = glob.glob(f"{self.local_dir}/vid_teacher_forcing_iter_{iteration:09d}_*.jpg")
|
340 |
+
local_files = sorted(local_files)[: self.num_file_to_log]
|
341 |
+
if captions is None:
|
342 |
+
captions = ["vid_frames_teacher_forcing"] * len(local_files)
|
343 |
+
for local_path, caption in zip(local_files, captions):
|
344 |
+
wandb.log(
|
345 |
+
{"frames": [wandb.Image(local_path, caption=caption)]},
|
346 |
+
step=iteration,
|
347 |
+
)
|
348 |
+
|
349 |
+
wandb.log({"eval/teacher_forcing_acc": avg_acc}, step=iteration)
|
350 |
+
wandb.log({"eval/teacher_forcing_loss": avg_loss}, step=iteration)
|
351 |
+
if percent_token_diff is not None:
|
352 |
+
wandb.log({"eval/percent_token_diff_train_val": percent_token_diff.item() * 100}, step=iteration)
|
cosmos_predict1/autoregressive/configs/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
cosmos_predict1/autoregressive/configs/base/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
cosmos_predict1/autoregressive/configs/base/callbacks.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from cosmos_predict1.autoregressive.callbacks.video_sampling_teacher_forcing import VideoSamplingTeacherForcing
|
17 |
+
from cosmos_predict1.callbacks.grad_clip import GradClip
|
18 |
+
from cosmos_predict1.utils.callback import ProgressBarCallback
|
19 |
+
from cosmos_predict1.utils.lazy_config import LazyCall as L
|
20 |
+
|
21 |
+
BASIC_CALLBACKS = dict(
|
22 |
+
progress_bar=L(ProgressBarCallback)(),
|
23 |
+
grad_clip=L(GradClip)(clip_norm=1.0, fsdp_enabled="${model.model_config.fsdp_enabled}", model_key="model"),
|
24 |
+
)
|
25 |
+
|
26 |
+
VIDEO_TEACHER_FORCING_CALLBACK = dict(
|
27 |
+
vid_sampling_tf=L(VideoSamplingTeacherForcing)(
|
28 |
+
every_n=500,
|
29 |
+
video_latent_shape="${model.model_config.video_latent_shape}",
|
30 |
+
num_frames_to_display=4,
|
31 |
+
save_folder="video_sampling_teacher_forcing",
|
32 |
+
)
|
33 |
+
)
|
cosmos_predict1/autoregressive/configs/base/dataloader.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from megatron.core import parallel_state
|
17 |
+
from torch.utils.data import DataLoader, DistributedSampler
|
18 |
+
|
19 |
+
from cosmos_predict1.autoregressive.configs.base.dataset import VideoDatasetConfig
|
20 |
+
from cosmos_predict1.autoregressive.datasets.video_dataset import VideoDataset
|
21 |
+
from cosmos_predict1.utils import log
|
22 |
+
from cosmos_predict1.utils.lazy_config import LazyCall as L
|
23 |
+
|
24 |
+
DATALOADER_OPTIONS = {}
|
25 |
+
|
26 |
+
|
27 |
+
def get_sampler(dataset):
|
28 |
+
return DistributedSampler(
|
29 |
+
dataset,
|
30 |
+
num_replicas=parallel_state.get_data_parallel_world_size(),
|
31 |
+
rank=parallel_state.get_data_parallel_rank(),
|
32 |
+
shuffle=True,
|
33 |
+
seed=0,
|
34 |
+
)
|
35 |
+
|
36 |
+
|
37 |
+
def dataloader_register(key):
|
38 |
+
log.info(f"registering dataloader {key}...")
|
39 |
+
|
40 |
+
def decorator(func):
|
41 |
+
DATALOADER_OPTIONS[key] = func
|
42 |
+
return func
|
43 |
+
|
44 |
+
return decorator
|
45 |
+
|
46 |
+
|
47 |
+
@dataloader_register("tealrobot_video")
|
48 |
+
def get_tealrobot_video(
|
49 |
+
batch_size: int = 1,
|
50 |
+
dataset_dir: str = "datasets/cosmos_nemo_assets/videos/",
|
51 |
+
sequence_interval: int = 1,
|
52 |
+
num_frames: int = 33,
|
53 |
+
video_size: list[int, int] = [640, 848],
|
54 |
+
start_frame_interval: int = 1,
|
55 |
+
):
|
56 |
+
dataset = L(VideoDataset)(
|
57 |
+
config=VideoDatasetConfig(
|
58 |
+
dataset_dir=dataset_dir,
|
59 |
+
sequence_interval=sequence_interval,
|
60 |
+
num_frames=num_frames,
|
61 |
+
video_size=video_size,
|
62 |
+
start_frame_interval=start_frame_interval,
|
63 |
+
)
|
64 |
+
)
|
65 |
+
return L(DataLoader)(
|
66 |
+
dataset=dataset,
|
67 |
+
sampler=L(get_sampler)(dataset=dataset),
|
68 |
+
batch_size=batch_size,
|
69 |
+
drop_last=True,
|
70 |
+
pin_memory=True,
|
71 |
+
num_workers=8,
|
72 |
+
)
|
cosmos_predict1/autoregressive/configs/base/dataset.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Dataset config class."""
|
17 |
+
|
18 |
+
import attrs
|
19 |
+
|
20 |
+
from cosmos_predict1.utils.config import make_freezable
|
21 |
+
|
22 |
+
|
23 |
+
@make_freezable
|
24 |
+
@attrs.define(slots=False)
|
25 |
+
class VideoDatasetConfig:
|
26 |
+
"""
|
27 |
+
Args:
|
28 |
+
dataset_dir (str): Base path to the dataset directory
|
29 |
+
sequence_interval (int): Interval between sampled frames in a sequence
|
30 |
+
num_frames (int): Number of frames to load per sequence
|
31 |
+
video_size (list): Target size [H,W] for video frames
|
32 |
+
start_frame_interval (int): Interval between starting frames of sequences
|
33 |
+
"""
|
34 |
+
|
35 |
+
dataset_dir: str = "datasets/cosmos_nemo_assets/videos/"
|
36 |
+
sequence_interval: int = 1
|
37 |
+
num_frames: int = 33
|
38 |
+
video_size: list[int, int] = [640, 848]
|
39 |
+
start_frame_interval: int = 1
|
cosmos_predict1/autoregressive/configs/base/model.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from typing import Optional
|
17 |
+
|
18 |
+
import attrs
|
19 |
+
|
20 |
+
from cosmos_predict1.autoregressive.configs.base.tokenizer import TokenizerConfig
|
21 |
+
from cosmos_predict1.utils import config
|
22 |
+
|
23 |
+
_ACTION_DIM = 8
|
24 |
+
from cosmos_predict1.utils.lazy_config import LazyDict
|
25 |
+
|
26 |
+
|
27 |
+
@attrs.define
|
28 |
+
class ModelConfig:
|
29 |
+
"""
|
30 |
+
A class to hold model configuration arguments.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
dim (int): The dimensionality of the input and output of each transformer block.
|
34 |
+
n_layers (int): Number of layers in the transformer.
|
35 |
+
n_heads (int): Number of attention heads.
|
36 |
+
n_kv_heads (Optional[int]): Number of key-value heads. If None, defaults to n_heads. Note: this is equivalent to
|
37 |
+
`num_gqa_groups` in TransformerEngine, where GQA means Grouped Query Attention.
|
38 |
+
head_dim (Optional[int]): Dimensionality of each head. If None, defaults to dim // n_heads.
|
39 |
+
vocab_size (int): Vocabulary size.
|
40 |
+
ffn_hidden_size (int): Hidden size for feedforward network.
|
41 |
+
norm_eps (float): Epsilon value for normalization.
|
42 |
+
rope_theta (float): Theta value for rotary positional embeddings.
|
43 |
+
apply_abs_pos_emb (bool): Whether to apply absolute position embeddings.
|
44 |
+
max_batch_size (int): Maximum batch size for inference.
|
45 |
+
max_seq_len (int): Maximum sequence length for input text.
|
46 |
+
fuse_qkv (bool): Whether to fuse QKV in attention. Defaults to True.
|
47 |
+
causal_mask (bool): Whether to use causal mask. Defaults to True.
|
48 |
+
norm_type (str): Type of normalization layer. Choices: "rmsnorm", "fused_rmsnorm", "layernorm", "np_layernorm".
|
49 |
+
precision (str): Data type for the model.
|
50 |
+
use_qk_normalization (bool): Whether to enable QK normalization.
|
51 |
+
tensor_model_parallel_size (int): Tensor model parallel size. Defaults to 1.
|
52 |
+
ckpt_dir (str): Checkpoint directory.
|
53 |
+
ckpt_path (str): Checkpoint path.
|
54 |
+
apply_yarn (Optional[bool]): Whether to apply YaRN (long-context extension).
|
55 |
+
yarn_scale (Optional[float]): Scale factor for YaRN.
|
56 |
+
yarn_beta_fast (Optional[int]): Beta fast variable for YaRN (i.e., low_freq_factor in Llama 3.1 RoPE scaling code)
|
57 |
+
yarn_beta_slow (Optional[int]): Beta slow variable for YaRN (i.e., high_freq_factor in Llama 3.1 RoPE scaling code)
|
58 |
+
original_seq_len (Optional[int]): Original sequence length.
|
59 |
+
vision_encoder (Optional[str]): Vision encoder name.
|
60 |
+
mm_projector (Optional[str]): Multi-modal projector name.
|
61 |
+
vision_encoder_in_channels (Optional[int]): Number of channels in the input image for the vision encoder. Default is 3, you can specify to int larger than 3. E.g. if you have 4-channel images with the last channel as the alpha channel, set this to 4.
|
62 |
+
rope_dim (Optional[str]): Dimensionality of the RoPE. Choices: "1D", "3D".
|
63 |
+
pytorch_rope_version (Optional[str]): Version of the PyTorch RoPE implementation. Choices: "v1", "v2".
|
64 |
+
original_latent_shape (Optional[list]): Original shape of the latent tensor needed for rope extension.
|
65 |
+
pad_to_multiple_of (Optional[int]): Pad the position embedding to a multiple of this value.
|
66 |
+
vision_encoder_in_channels (Optional[int]): Number of channels in the input image for the vision encoder. Default is 3.
|
67 |
+
insert_cross_attn (bool): Whether to insert the cross-attention layers after each multi-head self-attention (MSA) layer.
|
68 |
+
insert_cross_attn_every_k_layers (int): Insert cross-attention layers every k TransformerLayers.
|
69 |
+
context_dim (Optional[int]): The dimensionality of cross-attention embedding, e.g., T5 embed feature dim.
|
70 |
+
num_video_frames (Optional[int]): Number of video frames.
|
71 |
+
video_height (Optional[int]): Raw video pixel height dimension.
|
72 |
+
video_width (Optional[int]): Raw video pixel width dimension.
|
73 |
+
video_latent_shape (Optional[list]): Video tokenizer output dimension, in (T,H,W).
|
74 |
+
"""
|
75 |
+
|
76 |
+
dim: int = attrs.field(default=4096)
|
77 |
+
n_layers: int = attrs.field(default=32)
|
78 |
+
n_heads: int = attrs.field(default=32)
|
79 |
+
n_kv_heads: Optional[int] = attrs.field(default=8)
|
80 |
+
head_dim: Optional[int] = attrs.field(default=None)
|
81 |
+
vocab_size: int = attrs.field(default=128256)
|
82 |
+
ffn_hidden_size: int = attrs.field(default=14336)
|
83 |
+
norm_eps: float = attrs.field(default=1e-5)
|
84 |
+
rope_theta: float = attrs.field(default=500000)
|
85 |
+
apply_abs_pos_emb: bool = attrs.field(default=False)
|
86 |
+
max_batch_size: int = attrs.field(default=1)
|
87 |
+
max_seq_len: int = attrs.field(default=8192)
|
88 |
+
fuse_qkv: bool = attrs.field(default=False)
|
89 |
+
causal_mask: bool = attrs.field(default=True)
|
90 |
+
norm_type: str = attrs.field(default="rmsnorm")
|
91 |
+
precision: str = attrs.field(default="bfloat16")
|
92 |
+
use_qk_normalization: bool = False
|
93 |
+
tokenizer: Optional[TokenizerConfig] = None
|
94 |
+
tensor_model_parallel_size: int = attrs.field(default=1)
|
95 |
+
ckpt_dir: Optional[str] = attrs.field(default=None)
|
96 |
+
ckpt_path: Optional[str] = attrs.field(
|
97 |
+
default=None
|
98 |
+
) # If not None, load the model from this path instead of ckpt_dir
|
99 |
+
apply_yarn: Optional[bool] = attrs.field(default=False)
|
100 |
+
yarn_scale: Optional[float] = attrs.field(default=None)
|
101 |
+
yarn_beta_fast: Optional[int] = attrs.field(default=None)
|
102 |
+
yarn_beta_slow: Optional[int] = attrs.field(default=None)
|
103 |
+
original_seq_len: Optional[int] = attrs.field(default=None)
|
104 |
+
vision_encoder: Optional[str] = attrs.field(default=None)
|
105 |
+
vision_encoder_in_channels: Optional[int] = attrs.field(default=3)
|
106 |
+
mm_projector: Optional[str] = attrs.field(default=None)
|
107 |
+
rope_dim: Optional[str] = attrs.field(default="1D")
|
108 |
+
pytorch_rope_version: Optional[str] = attrs.field(default="v2")
|
109 |
+
original_latent_shape: Optional[list] = None
|
110 |
+
pad_to_multiple_of: Optional[int] = None
|
111 |
+
vision_encoder_in_channels: Optional[int] = attrs.field(default=3)
|
112 |
+
insert_cross_attn: bool = False
|
113 |
+
insert_cross_attn_every_k_layers: int = 1
|
114 |
+
context_dim: Optional[int] = attrs.field(default=1024)
|
115 |
+
# For video training
|
116 |
+
num_video_frames: Optional[int] = None
|
117 |
+
# Raw video pixel dimension
|
118 |
+
video_height: Optional[int] = None
|
119 |
+
video_width: Optional[int] = None
|
120 |
+
# Video tokenizer output dimension, in (T,H,W), it's computed by num_video_frames/temporal_compress_factor, video_height/spatial_compression_fact, video_width/spatial_compression_fact
|
121 |
+
video_latent_shape: Optional[list] = None
|
122 |
+
|
123 |
+
def __getitem__(self, item):
|
124 |
+
return getattr(self, item)
|
125 |
+
|
126 |
+
|
127 |
+
@attrs.define
|
128 |
+
class TrainingModelConfig:
|
129 |
+
"""
|
130 |
+
A class to hold model configuration arguments.
|
131 |
+
|
132 |
+
Args:
|
133 |
+
dim (int): The dimensionality of the input and output of each transformer block.
|
134 |
+
n_layers (int): Number of layers in the transformer.
|
135 |
+
n_heads (int): Number of attention heads.
|
136 |
+
n_kv_heads (Optional[int]): Number of key-value heads. If None, defaults to n_heads. Note: this is equivalent to
|
137 |
+
`num_gqa_groups` in TransformerEngine, where GQA means Grouped Query Attention.
|
138 |
+
head_dim (Optional[int]): Dimensionality of each head. If None, defaults to dim // n_heads.
|
139 |
+
vocab_size (int): Vocabulary size.
|
140 |
+
multiple_of (int): Ensures the hidden layer size is a multiple of this value for SwiGLU activation.
|
141 |
+
ffn_dim_multiplier (Optional[float]): Multiplier for feedforward network dimension.
|
142 |
+
ffn_hidden_size (Optional[int]): Hidden size for feedforward network. If None, use ffn_dim_multiplier to compute it.
|
143 |
+
norm_eps (float): Epsilon value for normalization.
|
144 |
+
rope_theta (float): Theta value for rotary positional embeddings.
|
145 |
+
apply_abs_pos_emb (bool): Whether to apply absolute position embeddings.
|
146 |
+
max_batch_size (int): Maximum batch size for inference (determines KV cache size).
|
147 |
+
max_seq_len (int): Maximum sequence length for input text (determines KV cache size).
|
148 |
+
fuse_qkv (bool): Whether to fuse QKV in attention. Flag for the pytorch backend.
|
149 |
+
causal_mask (bool): Whether to use causal mask. Defaults to True.
|
150 |
+
flash_attn (bool): Whether to use Flash attention.
|
151 |
+
norm_type (str): Type of normalization layer. Choices: "rmsnorm", "fused_rmsnorm", "layernorm", "np_layernorm".
|
152 |
+
backend (str): Backend for the model.
|
153 |
+
precision (str): Data type for the model.
|
154 |
+
ema (config.EMAConfig): Configuration for exponential moving average.
|
155 |
+
embedding_dropout(float): Dropout rate for the embedding layer.
|
156 |
+
attention_dropout(float): Dropout rate for attention.
|
157 |
+
hidden_dropout(float): Dropout after the attention and feed-forward layers (following TransformerEngine's
|
158 |
+
implementation in its TransformerLayer class).
|
159 |
+
use_qk_normalization (bool): Whether to enable QK normalization.
|
160 |
+
inference (bool): Whether the model is used for inference.
|
161 |
+
act_ckpt_enabled (bool): Whether to enable activation checkpointing.
|
162 |
+
fsdp_enabled (bool): Whether to enable FSDP.
|
163 |
+
fsdp (LazyDict): Configuration for FSDP.
|
164 |
+
ckpt_dir (str): Checkpoint directory.
|
165 |
+
ckpt_path (str): Checkpoint path.
|
166 |
+
cache_dir (str): Cache directory.
|
167 |
+
apply_yarn (Optional[bool]): Whether to apply YaRN (long-context extension).
|
168 |
+
yarn_scale (Optional[float]): Scale factor for YaRN.
|
169 |
+
yarn_beta_fast (Optional[int]): Beta fast variable for YaRN (i.e., low_freq_factor in Llama 3.1 RoPE scaling code)
|
170 |
+
yarn_beta_slow (Optional[int]): Beta slow variable for YaRN (i.e., high_freq_factor in Llama 3.1 RoPE scaling code)
|
171 |
+
original_seq_len (Optional[int]): Original sequence length.
|
172 |
+
depth_init (bool): If `True`, then each transformer block init uses its layer ID, and if `False`, each uses the
|
173 |
+
total number of transformer blocks. Defaults to `True` (following the TorchTitan implementation of Llama3).
|
174 |
+
context_parallel_size (int): Context parallel size. Defaults to 1.
|
175 |
+
tensor_model_parallel_size (int): Tensor model parallel size. Defaults to 1.
|
176 |
+
sequence_parallel (bool): Whether to use sequence parallelism. Defaults to False.
|
177 |
+
set_parallel_mode (bool): It is a boolean flag used by TransformerEngine to handle Tensor Parallelism.
|
178 |
+
Essentially, it is equivalent to `tensor_model_parallel_size > 1`. Defaults to `False`.
|
179 |
+
attention_tp (bool): Whether to use tensor parallelism for attention layers.
|
180 |
+
mm_projector (Optional[str]): Multimodal projector used for vision-language modeling. Defaults to None.
|
181 |
+
Choices: "identity", "linear", "mlp", "mlp_downsample".
|
182 |
+
video_latent_shape (Optional[list]): Shape of the video latent tensor. [T, H, W]
|
183 |
+
image_latent_shape (Optional[list]): Shape of the image latent tensor. [H, W]
|
184 |
+
num_video_frames (Optional[int]): Number of video frames.
|
185 |
+
rope_dim (Optional[str]): Dimensionality of the RoPE. Choices: "1D", "2D", "3D".
|
186 |
+
pytorch_rope_version (Optional[str]): Version of the RoPE for the `pytorch` backend. "v1" is the Llama implementation, and "v2" is HuggingFace/TransformerEngine implementation.
|
187 |
+
original_latent_shape (Optional[list]): Original shape of the latent tensor needed for rope extension.
|
188 |
+
pad_to_multiple_of (Optional[int]): Pad the position embedding to a multiple of this value.
|
189 |
+
peft_last_n_layers (Optional[int]): Number of last few layers to fine-tune in Parameter Efficient Fine-Tuning (PEFT). When this and peft_every_n_layers are both 0, it means all layers are fine-tuned (FFT).
|
190 |
+
peft_every_n_layers (Optional[int]): In Parameter Efficient Fine-Tuning (PEFT), every n layers are unfrozen and can be trained (in flamingo style). When this and peft_last_n_layers are both 0,
|
191 |
+
it means all layers are fine-tuned (FFT). For example, for a 40 layer model, n=8 means training layers 7, 15, 23, 31, 39, which includes the final layer.
|
192 |
+
It is advised to pick n such that the final layer is included.
|
193 |
+
freeze_vision_encoder (bool): Whether to freeze the vision encoder in vision-language model training. Defaults to False.
|
194 |
+
vision_encoder_in_channels (Optional[int]): Number of channels in the input image for the vision encoder. Default is 3, you can specify to int larger than 3. E.g. if you have 4-channel images with the last channel as the alpha channel, set this to 4.
|
195 |
+
insert_cross_attn (bool): Whether to insert the cross-attention layers after each multi-head self-attention (MSA) layer.
|
196 |
+
insert_cross_attn_every_k_layers (int): Insert cross-attention layers every k TransformerLayers.
|
197 |
+
context_dim (Optional[int]): The dimensionality of cross-attention embedding, e.g., T5 embed feature dim.
|
198 |
+
finetune_layers_with_cross_attn (bool): Whether to finetune Transformer layers w/ CA (cross-attn).
|
199 |
+
finetune_layers_without_cross_attn (bool): Whether to finetune Transformer layers w/o CA (cross-attn).
|
200 |
+
use_action_condition (bool): Whether to use the robot action condition.
|
201 |
+
action_embedding_mode (Optional[str]): The mode of the robot action embedding. Choices: "matrix", "mlp".
|
202 |
+
action_dim (Optional[int]): The dimensionality of the raw robot action tensor (e.g., 7 for DROID, [Δx, Δy, Δz, rx, ry, rz, gripper_open]).
|
203 |
+
action_embedding_dim (Optional[int]): The dimensionality of the robot action embedding.
|
204 |
+
group_causal_mask_mode (Optional[str]): The mode of the group causal mask. Choices: "causal", "group_diagonal".
|
205 |
+
sync_1d_parameters (bool): Whether to synchronize layernorm parameters (1D) across tensor parallel ranks (default True).
|
206 |
+
Note: this is to ensure all TP-ranks have the same layernorm parameters.
|
207 |
+
z_loss_coeff (float): The coefficient for the z-loss.
|
208 |
+
insert_medusa_head (bool): Whether to insert the Medusa head.
|
209 |
+
ft_medusa_option (str): Options on which layers to finetune, choices like:
|
210 |
+
"fft": fully fine-tune both medusa heads and all LLM backbone;
|
211 |
+
"head": fine-tune medusa heads;
|
212 |
+
"head_out": fine-tune medusa heads, and the output layer;
|
213 |
+
"head_out_last_k_layer": fine-tune medusa heads, the output layer, and the last k layer(s) of the LLM backbone.
|
214 |
+
medusa_num_heads (int): Number of heads in the Medusa head.
|
215 |
+
medusa_num_layers (int): Number of layers in the Medusa head.
|
216 |
+
medusa_concat_heads (bool): Whether to concatenate multiple medusa heads into fused matrix, only applicable when medusa_num_layers = 1.
|
217 |
+
zero_init_cross_attn_proj (bool): Whether to initialize the cross-attn proj layer with zeros (default False).
|
218 |
+
concat_action_to_context (bool): Whether to concatenate the action embedding to the context (default False).
|
219 |
+
"""
|
220 |
+
|
221 |
+
dim: int = attrs.field(default=4096)
|
222 |
+
n_layers: int = attrs.field(default=32)
|
223 |
+
n_heads: int = attrs.field(default=32)
|
224 |
+
n_kv_heads: Optional[int] = attrs.field(default=8)
|
225 |
+
head_dim: Optional[int] = attrs.field(default=None)
|
226 |
+
vocab_size: int = attrs.field(default=128256)
|
227 |
+
multiple_of: int = attrs.field(default=1024) # make SwiGLU hidden layer size multiple of large power of 2
|
228 |
+
ffn_dim_multiplier: Optional[float] = attrs.field(default=1.3)
|
229 |
+
ffn_hidden_size: Optional[int] = attrs.field(default=None)
|
230 |
+
norm_eps: float = attrs.field(default=1e-5)
|
231 |
+
rope_theta: float = attrs.field(default=500000)
|
232 |
+
apply_abs_pos_emb: bool = attrs.field(default=False)
|
233 |
+
max_batch_size: int = attrs.field(default=1)
|
234 |
+
max_seq_len: int = attrs.field(default=8192)
|
235 |
+
fuse_qkv: bool = attrs.field(default=False)
|
236 |
+
causal_mask: bool = attrs.field(default=True)
|
237 |
+
flash_attn: bool = attrs.field(default=True)
|
238 |
+
norm_type: str = attrs.field(default="rmsnorm")
|
239 |
+
backend: str = attrs.field(default="pytorch")
|
240 |
+
precision: str = attrs.field(default="bfloat16")
|
241 |
+
ema: config.EMAConfig = config.EMAConfig(enabled=False)
|
242 |
+
embedding_dropout: float = 0.0
|
243 |
+
attention_dropout: float = 0.0
|
244 |
+
hidden_dropout: float = 0.0
|
245 |
+
use_qk_normalization: bool = False
|
246 |
+
tokenizer: Optional[TokenizerConfig] = None
|
247 |
+
inference: bool = False
|
248 |
+
act_ckpt_enabled: bool = False
|
249 |
+
fsdp_enabled: bool = False
|
250 |
+
context_parallel_size: int = attrs.field(default=1)
|
251 |
+
tensor_model_parallel_size: int = attrs.field(default=1)
|
252 |
+
sequence_parallel: bool = attrs.field(default=False)
|
253 |
+
set_parallel_mode: bool = attrs.field(default=False)
|
254 |
+
fsdp: LazyDict = LazyDict(
|
255 |
+
dict(
|
256 |
+
policy="auto", # choices: ["size", "auto"]
|
257 |
+
min_num_params=1024, # Used as policy == "size"
|
258 |
+
sharding_strategy="hybrid", # Choices: ["full", "hybrid"]. "full" means sharding_group_size = world_size
|
259 |
+
sharding_group_size=8, # If None, defaults to min(world_size, 8). Recommends 8 for training on 8-GPU nodes.
|
260 |
+
)
|
261 |
+
)
|
262 |
+
ckpt_dir: Optional[str] = attrs.field(default="")
|
263 |
+
ckpt_path: Optional[str] = attrs.field(
|
264 |
+
default=None
|
265 |
+
) # If not None, load the model from this path instead of ckpt_dir
|
266 |
+
cache_dir: Optional[str] = attrs.field(default="/project/cosmos/ar/cache")
|
267 |
+
apply_yarn: Optional[bool] = attrs.field(default=False)
|
268 |
+
yarn_scale: Optional[float] = attrs.field(default=None)
|
269 |
+
yarn_beta_fast: Optional[int] = attrs.field(default=None)
|
270 |
+
yarn_beta_slow: Optional[int] = attrs.field(default=None)
|
271 |
+
original_seq_len: Optional[int] = attrs.field(default=None)
|
272 |
+
depth_init: bool = attrs.field(default=True)
|
273 |
+
ignore_first_num_tokens: int = 0
|
274 |
+
z_loss_coeff: float = 1e-4
|
275 |
+
attention_tp: bool = False
|
276 |
+
vision_encoder: Optional[str] = attrs.field(default=None)
|
277 |
+
mm_projector: Optional[str] = attrs.field(default=None)
|
278 |
+
rope_dim: Optional[str] = attrs.field(default="1D")
|
279 |
+
pytorch_rope_version: Optional[str] = attrs.field(default="v2")
|
280 |
+
original_latent_shape: Optional[list] = None
|
281 |
+
pad_to_multiple_of: Optional[int] = None
|
282 |
+
peft_last_n_layers: Optional[int] = attrs.field(default=0)
|
283 |
+
peft_every_n_layers: Optional[int] = attrs.field(default=0)
|
284 |
+
freeze_vision_encoder: bool = False
|
285 |
+
vision_encoder_in_channels: Optional[int] = attrs.field(default=3)
|
286 |
+
insert_cross_attn: bool = False
|
287 |
+
insert_cross_attn_every_k_layers: int = 1
|
288 |
+
context_dim: Optional[int] = attrs.field(default=1024)
|
289 |
+
finetune_layers_with_cross_attn: bool = False
|
290 |
+
finetune_layers_without_cross_attn: bool = False
|
291 |
+
use_action_condition: bool = False
|
292 |
+
action_embedding_mode: Optional[str] = attrs.field(default="mlp")
|
293 |
+
action_dim: Optional[int] = attrs.field(default=_ACTION_DIM)
|
294 |
+
action_embedding_dim: Optional[int] = attrs.field(default=1024)
|
295 |
+
group_causal_mask_mode: Optional[str] = attrs.field(default=None)
|
296 |
+
sync_1d_parameters: bool = True
|
297 |
+
# hyper-parameters for the medusa head configs
|
298 |
+
insert_medusa_head: bool = False
|
299 |
+
ft_medusa_option: str = "fft"
|
300 |
+
medusa_num_heads: int = 7
|
301 |
+
medusa_num_layers: int = 1
|
302 |
+
medusa_concat_heads: bool = True
|
303 |
+
# For video training
|
304 |
+
num_video_frames: Optional[int] = None
|
305 |
+
# Raw video pixel dimension
|
306 |
+
video_height: Optional[int] = None
|
307 |
+
video_width: Optional[int] = None
|
308 |
+
# Video tokenizer output dimension, in (T,H,W), it's computed by num_video_frames/temporal_compress_factor, video_height/spatial_compression_fact, video_width/spatial_compression_fact
|
309 |
+
video_latent_shape: Optional[list] = None
|
310 |
+
# For image training
|
311 |
+
image_latent_shape: Optional[list] = None
|
312 |
+
# For robot training (action)
|
313 |
+
zero_init_cross_attn_proj: bool = False
|
314 |
+
# For robot training (action)
|
315 |
+
concat_action_to_context: bool = False
|
316 |
+
|
317 |
+
def __getitem__(self, item):
|
318 |
+
return getattr(self, item)
|
cosmos_predict1/autoregressive/configs/base/model_config.py
ADDED
@@ -0,0 +1,718 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import copy
|
17 |
+
from typing import Callable, List, Optional
|
18 |
+
|
19 |
+
import torch
|
20 |
+
from megatron.core import ModelParallelConfig
|
21 |
+
|
22 |
+
from cosmos_predict1.autoregressive.configs.base.model import ModelConfig, TrainingModelConfig
|
23 |
+
from cosmos_predict1.autoregressive.configs.base.tokenizer import (
|
24 |
+
TextTokenizerConfig,
|
25 |
+
TokenizerConfig,
|
26 |
+
VideoTokenizerConfig,
|
27 |
+
create_discrete_video_fsq_tokenizer_state_dict_config,
|
28 |
+
)
|
29 |
+
from cosmos_predict1.autoregressive.tokenizer.image_text_tokenizer import ImageTextTokenizer
|
30 |
+
from cosmos_predict1.autoregressive.tokenizer.text_tokenizer import TextTokenizer
|
31 |
+
from cosmos_predict1.autoregressive.training.model import AutoRegressiveTrainingModel
|
32 |
+
from cosmos_predict1.utils import log
|
33 |
+
from cosmos_predict1.utils.config import EMAConfig
|
34 |
+
from cosmos_predict1.utils.lazy_config import LazyCall as L
|
35 |
+
|
36 |
+
# Common architecture specifications
|
37 |
+
BASE_CONFIG = {"n_kv_heads": 8, "norm_type": "rmsnorm", "norm_eps": 1e-5, "ffn_hidden_size": 14336}
|
38 |
+
COSMOS_ARCHITECTURES = {
|
39 |
+
"1b": {
|
40 |
+
"n_layers": 16,
|
41 |
+
"dim": 2048,
|
42 |
+
"n_heads": 32,
|
43 |
+
},
|
44 |
+
"4b": {
|
45 |
+
"n_layers": 16,
|
46 |
+
"dim": 4096,
|
47 |
+
"n_heads": 32,
|
48 |
+
},
|
49 |
+
"12b": {
|
50 |
+
"n_layers": 40,
|
51 |
+
"dim": 5120,
|
52 |
+
"n_heads": 32,
|
53 |
+
"head_dim": 128,
|
54 |
+
},
|
55 |
+
}
|
56 |
+
|
57 |
+
COSMOS_YARN_CONFIG = {
|
58 |
+
"original_latent_shape": [3, 40, 64],
|
59 |
+
"apply_yarn": True,
|
60 |
+
"yarn_beta_fast": 4,
|
61 |
+
"yarn_beta_slow": 1,
|
62 |
+
"yarn_scale": 2,
|
63 |
+
}
|
64 |
+
|
65 |
+
# Llama3 architecture specifications for different model sizes
|
66 |
+
LLAMA3_ARCHITECTURES = {
|
67 |
+
"8b": {
|
68 |
+
"n_layers": 32,
|
69 |
+
"dim": 4096,
|
70 |
+
"n_heads": 32,
|
71 |
+
"ffn_hidden_size": 14336,
|
72 |
+
},
|
73 |
+
}
|
74 |
+
# Llama3.1 uses YaRN for long context support (context of 128k tokens)
|
75 |
+
LLAMA_YARN_CONFIG = {
|
76 |
+
"apply_yarn": True,
|
77 |
+
"yarn_scale": 8,
|
78 |
+
"yarn_beta_fast": 4,
|
79 |
+
"yarn_beta_slow": 1,
|
80 |
+
}
|
81 |
+
|
82 |
+
# Mistral architecture specifications for different model sizes
|
83 |
+
MISTRAL_ARCHITECTURES = {
|
84 |
+
"12b": {
|
85 |
+
"n_layers": 40,
|
86 |
+
"dim": 5120,
|
87 |
+
"n_heads": 32,
|
88 |
+
"ffn_hidden_size": 14336,
|
89 |
+
"head_dim": 128,
|
90 |
+
},
|
91 |
+
}
|
92 |
+
|
93 |
+
PIXTRAL_VISION_ARCHITECTURES = {
|
94 |
+
"12b": {"vision_encoder": "pixtral-12b-vit", "mm_projector": "mlp"},
|
95 |
+
}
|
96 |
+
|
97 |
+
|
98 |
+
def get_model_arch_specs(model_size: str, model_family: str = "mistral", pretrained: bool = False) -> dict:
|
99 |
+
"""
|
100 |
+
Get the model architecture specifications for the given model size, model family and pretrained status.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
model_size (str): Model size. Choices: "1b", "3b", "4b", "7b", etc.
|
104 |
+
model_family (str): Model family. Choices: "llama", "llama3", "llama3.1", "mistral"
|
105 |
+
pretrained (bool): Whether to load pretrained weights.
|
106 |
+
|
107 |
+
Returns:
|
108 |
+
dict: A dictionary containing the model architecture specifications.
|
109 |
+
"""
|
110 |
+
arch_specs = copy.deepcopy(BASE_CONFIG)
|
111 |
+
model_size = model_size.lower()
|
112 |
+
if model_family.startswith("cosmos"):
|
113 |
+
arch_specs.update(COSMOS_ARCHITECTURES[model_size])
|
114 |
+
elif model_family.startswith("llama"):
|
115 |
+
arch_specs.update(LLAMA3_ARCHITECTURES[model_size])
|
116 |
+
elif model_family in ["mistral", "pixtral"]:
|
117 |
+
arch_specs.update(MISTRAL_ARCHITECTURES[model_size])
|
118 |
+
if model_family == "pixtral":
|
119 |
+
arch_specs.update(PIXTRAL_VISION_ARCHITECTURES[model_size])
|
120 |
+
else:
|
121 |
+
raise ValueError(f"Model family {model_family} is not supported.")
|
122 |
+
|
123 |
+
if pretrained:
|
124 |
+
if model_family == "cosmos":
|
125 |
+
if model_size == "12b":
|
126 |
+
arch_specs.update(COSMOS_YARN_CONFIG)
|
127 |
+
log.debug(f"Using YaRN for RoPE extension with config: {COSMOS_YARN_CONFIG}")
|
128 |
+
else:
|
129 |
+
pass
|
130 |
+
elif model_family in ["llama", "llama3"]:
|
131 |
+
pretrained_specs = {
|
132 |
+
"rope_theta": 500000,
|
133 |
+
"max_seq_len": 8192,
|
134 |
+
"vocab_size": 128256,
|
135 |
+
}
|
136 |
+
arch_specs.update(pretrained_specs)
|
137 |
+
elif model_family == "llama3.1":
|
138 |
+
pretrained_specs = {
|
139 |
+
"rope_theta": 500000,
|
140 |
+
"max_seq_len": 131072,
|
141 |
+
"original_seq_len": 8192,
|
142 |
+
"vocab_size": 128256,
|
143 |
+
**LLAMA_YARN_CONFIG,
|
144 |
+
}
|
145 |
+
arch_specs.update(pretrained_specs)
|
146 |
+
elif model_family == "mistral":
|
147 |
+
assert model_size == "12b", "We only support Mistral-Nemo-12B model."
|
148 |
+
pretrained_specs = {
|
149 |
+
"rope_theta": 1000000,
|
150 |
+
"max_seq_len": 128000,
|
151 |
+
"vocab_size": 131072,
|
152 |
+
}
|
153 |
+
arch_specs.update(pretrained_specs)
|
154 |
+
elif model_family == "pixtral":
|
155 |
+
assert model_size == "12b", "We only support Pixtral 12B model."
|
156 |
+
pretrained_specs = {"rope_theta": 1000000000, "max_seq_len": 128000, "vocab_size": 131072}
|
157 |
+
arch_specs.update(pretrained_specs)
|
158 |
+
else:
|
159 |
+
raise ValueError(f"Model family {model_family} doesn't have a pretrained config.")
|
160 |
+
|
161 |
+
return arch_specs
|
162 |
+
|
163 |
+
|
164 |
+
def create_text_model_config(
|
165 |
+
model_ckpt_path: str,
|
166 |
+
tokenizer_path: str,
|
167 |
+
tensor_model_parallel_size: int = 1,
|
168 |
+
model_family: str = "mistral",
|
169 |
+
model_size: str = "12b",
|
170 |
+
is_instruct_model: bool = True,
|
171 |
+
max_seq_len: int = None,
|
172 |
+
max_batch_size: int = 1,
|
173 |
+
rope_dim: str = "1D",
|
174 |
+
add_special_tokens: bool = True,
|
175 |
+
pytorch_rope_version: str = None,
|
176 |
+
) -> dict:
|
177 |
+
"""Create a text model for training or inference.
|
178 |
+
Args:
|
179 |
+
model_ckpt_path (str): Path to the model checkpoint.
|
180 |
+
tokenizer_path (str): Path to the tokenizer folder.
|
181 |
+
tensor_model_parallel_size (int): Number of tensor model parallel groups.
|
182 |
+
model_family (str): Model family. Choices: "llama", "llama3", "llama3.1", "mistral".
|
183 |
+
model_size (str): Model size. Choices: "1b", "3b", "4b", "7b", "8b", "72b", etc.
|
184 |
+
is_instruct_model (bool): Whether the model is an instruct model.
|
185 |
+
inference (bool): Whether to create the model for inference.
|
186 |
+
max_seq_len (int): Maximum sequence length.
|
187 |
+
max_batch_size (int): Maximum batch size.
|
188 |
+
rope_dim (str): RoPE dimension. Choices: "1D", "3D".
|
189 |
+
add_special_tokens (bool): Whether to add special tokens.
|
190 |
+
Returns:
|
191 |
+
dict: A dictionary containing the model configuration, which can be used to instantiate the model object.
|
192 |
+
"""
|
193 |
+
# Model size specific parameters
|
194 |
+
model_arch_specs = get_model_arch_specs(model_family=model_family, model_size=model_size, pretrained=True)
|
195 |
+
if max_seq_len is not None:
|
196 |
+
# Override the max_seq_len if provided
|
197 |
+
model_arch_specs["max_seq_len"] = max_seq_len
|
198 |
+
if pytorch_rope_version is not None:
|
199 |
+
model_arch_specs["pytorch_rope_version"] = pytorch_rope_version
|
200 |
+
model_config = ModelConfig(
|
201 |
+
max_batch_size=max_batch_size,
|
202 |
+
precision="bfloat16",
|
203 |
+
ckpt_path=model_ckpt_path,
|
204 |
+
use_qk_normalization=False,
|
205 |
+
tensor_model_parallel_size=tensor_model_parallel_size,
|
206 |
+
rope_dim=rope_dim,
|
207 |
+
**model_arch_specs,
|
208 |
+
)
|
209 |
+
|
210 |
+
tokenizer_config = TokenizerConfig(
|
211 |
+
text_tokenizer=TextTokenizerConfig(
|
212 |
+
config=L(TextTokenizer)(
|
213 |
+
model_family=model_family,
|
214 |
+
is_instruct_model=is_instruct_model,
|
215 |
+
local_path=tokenizer_path,
|
216 |
+
),
|
217 |
+
data_key="text",
|
218 |
+
tokenizer_offset=model_config.vocab_size,
|
219 |
+
tokenize_here=False,
|
220 |
+
vocab_size=model_config.vocab_size,
|
221 |
+
),
|
222 |
+
seq_len=model_config.max_seq_len,
|
223 |
+
training_type="text_only",
|
224 |
+
add_special_tokens=add_special_tokens,
|
225 |
+
)
|
226 |
+
return model_config, tokenizer_config
|
227 |
+
|
228 |
+
|
229 |
+
def create_vision_language_model_config(
|
230 |
+
model_ckpt_path: str,
|
231 |
+
tokenizer_ckpt_path: str,
|
232 |
+
tensor_model_parallel_size: int = 1,
|
233 |
+
model_family: str = "pixtral",
|
234 |
+
model_size: str = "12b",
|
235 |
+
is_instruct_model: bool = True,
|
236 |
+
max_batch_size: int = 1,
|
237 |
+
rope_dim: str = "1D",
|
238 |
+
add_special_tokens: bool = True,
|
239 |
+
max_seq_len: int = None,
|
240 |
+
vision_encoder_in_channels: int = 3,
|
241 |
+
fuse_qkv: bool = False,
|
242 |
+
pytorch_rope_version: str = None,
|
243 |
+
) -> dict:
|
244 |
+
"""Create a vision-language model for training or inference.
|
245 |
+
Args:
|
246 |
+
model_ckpt_path (str): Path to the model checkpoint.
|
247 |
+
tokenizer_ckpt_path (str): Path to the tokenizer checkpoint.
|
248 |
+
tensor_model_parallel_size (int): Number of tensor model parallel groups.
|
249 |
+
model_family (str): Model family. Choices: "pixtral".
|
250 |
+
model_size (str): Model size. Choices: "12b".
|
251 |
+
is_instruct_model (bool): Whether the model is an instruct model.
|
252 |
+
rope_dim (str): RoPE dimension. Choices: "1D".
|
253 |
+
add_special_tokens (bool): Whether to add special tokens.
|
254 |
+
max_seq_len (int): Maximum sequence length.
|
255 |
+
vision_encoder_in_channels (int): Number of channels in the input image for the vision encoder. Default is 3, you can specify to int larger than 3. E.g. if you have 4 channel images where last channel is binary mask, set this to 4.
|
256 |
+
fuse_qkv (bool): Whether to fuse the QKV linear layers.
|
257 |
+
Returns:
|
258 |
+
dict: A dictionary containing the model configuration, which can be used to instantiate the model object.
|
259 |
+
"""
|
260 |
+
# Model size specific parameters
|
261 |
+
model_arch_specs = get_model_arch_specs(model_family=model_family, model_size=model_size, pretrained=True)
|
262 |
+
if max_seq_len is not None:
|
263 |
+
# Override the max_seq_len if provided
|
264 |
+
model_arch_specs["max_seq_len"] = max_seq_len
|
265 |
+
if pytorch_rope_version is not None:
|
266 |
+
model_arch_specs["pytorch_rope_version"] = pytorch_rope_version
|
267 |
+
|
268 |
+
model_config = ModelConfig(
|
269 |
+
max_batch_size=max_batch_size,
|
270 |
+
precision="bfloat16",
|
271 |
+
ckpt_path=model_ckpt_path,
|
272 |
+
use_qk_normalization=False,
|
273 |
+
tensor_model_parallel_size=tensor_model_parallel_size,
|
274 |
+
rope_dim=rope_dim,
|
275 |
+
vision_encoder_in_channels=vision_encoder_in_channels,
|
276 |
+
fuse_qkv=fuse_qkv,
|
277 |
+
**model_arch_specs,
|
278 |
+
)
|
279 |
+
# Vision-language tokenizer
|
280 |
+
tokenizer_config = TokenizerConfig(
|
281 |
+
text_tokenizer=TextTokenizerConfig(
|
282 |
+
config=L(ImageTextTokenizer)(
|
283 |
+
model_family=model_family,
|
284 |
+
is_instruct_model=is_instruct_model,
|
285 |
+
image_processor_path=tokenizer_ckpt_path,
|
286 |
+
tokenizer_path=tokenizer_ckpt_path,
|
287 |
+
),
|
288 |
+
data_key="image_text_interleaved",
|
289 |
+
tokenizer_offset=model_config.vocab_size,
|
290 |
+
tokenize_here=False,
|
291 |
+
vocab_size=model_config.vocab_size,
|
292 |
+
),
|
293 |
+
seq_len=model_config.max_seq_len,
|
294 |
+
training_type="image_text_interleaved",
|
295 |
+
add_special_tokens=add_special_tokens,
|
296 |
+
)
|
297 |
+
return model_config, tokenizer_config
|
298 |
+
|
299 |
+
|
300 |
+
def create_video2world_model_config(
|
301 |
+
model_ckpt_path: str,
|
302 |
+
tokenizer_ckpt_path: str,
|
303 |
+
tensor_model_parallel_size: int = 1,
|
304 |
+
model_family: str = "cosmos",
|
305 |
+
model_size: str = "4b",
|
306 |
+
pixel_chunk_duration: int = 9,
|
307 |
+
num_video_frames: int = 36,
|
308 |
+
compression_ratio: List[int] = [8, 16, 16],
|
309 |
+
original_seq_len: int = 8192,
|
310 |
+
num_condition_latents_t: int = 1,
|
311 |
+
num_tokens_to_ignore: int = -1,
|
312 |
+
batch_size: int = 2,
|
313 |
+
video_tokenizer_config_creator: Callable = create_discrete_video_fsq_tokenizer_state_dict_config,
|
314 |
+
rope_dim: str = "3D",
|
315 |
+
add_special_tokens: bool = True,
|
316 |
+
video_height: int = 384,
|
317 |
+
video_width: int = 640,
|
318 |
+
use_qk_normalization: bool = True,
|
319 |
+
insert_cross_attn: bool = False,
|
320 |
+
insert_cross_attn_every_k_layers: int = 1,
|
321 |
+
context_dim: int = 1024,
|
322 |
+
training_type: str = "video_to_video",
|
323 |
+
pad_to_multiple_of: Optional[int] = 64,
|
324 |
+
vocab_size: int = 64000,
|
325 |
+
apply_abs_pos_emb: bool = False,
|
326 |
+
) -> dict:
|
327 |
+
"""Create a video-to-world model config.
|
328 |
+
Args:
|
329 |
+
tensor_model_parallel_size (int): Number of tensor model parallel groups.
|
330 |
+
model_family (str): Model family. Choices: "llama", "llama3", "llama3.1", "mistral".
|
331 |
+
model_size (str): Model size. Choices: "1b", "8b", "3b".
|
332 |
+
pixel_chunk_duration (int): Number of frames in each chunk.
|
333 |
+
num_video_frames (int): Number of video frames.
|
334 |
+
compression_ratio (List[int]): Compression ratio for the video frames. Choices: [8, 16, 16] or [4, 8, 8].
|
335 |
+
original_seq_len (int): Original sequence length.
|
336 |
+
apply_yarn (bool): Whether to apply YaRN for long context scaling.
|
337 |
+
yarn_beta_fast (Optional[int]): Fast beta for YaRN.
|
338 |
+
yarn_beta_slow (Optional[int]): Slow beta for YaRN.
|
339 |
+
yarn_scale (Optional[int]): Scale factor for ctx extension.
|
340 |
+
use_qk_normalization (bool): Whether to use Query-Key normalization.
|
341 |
+
training_type (str): Type of training task.
|
342 |
+
batch_size (int): Batch size.
|
343 |
+
video_tokenizer_config_creator (Callable): Method that takes "pixel_chunk_duration: int" and "version: str" as arguments and returns video tokenizer config
|
344 |
+
video_tokenizer_version (str): Version of the video tokenizer.
|
345 |
+
num_condition_latents_t (int): Number of conditioning latent channels
|
346 |
+
num_tokens_to_ignore (int) = Number of tokens to ignore. This takes the precedence
|
347 |
+
video_height (int): Height of the video frame. Defaults to 384.
|
348 |
+
video_width (int): Width of the video frame. Defaults to 640.
|
349 |
+
rope_dim (str): RoPE dimension. Choices: "1D", "3D".
|
350 |
+
add_special_tokens (bool): Whether to add special tokens, use False for 2D/3D RoPE.
|
351 |
+
pad_to_multiple_of (int): Pad the token sequence length to the nearest multiple of this number. Defaults to 64.
|
352 |
+
vocab_size (int): Vocabulary size.
|
353 |
+
apply_abs_pos_emb (bool): Whether to apply absolute positional embeddings.
|
354 |
+
Returns:
|
355 |
+
dict: A dictionary containing the model configuration representing the model object, can be instantiated.
|
356 |
+
"""
|
357 |
+
assert (
|
358 |
+
pixel_chunk_duration % compression_ratio[0] == 1
|
359 |
+
), f"pixel_chunk_duration({pixel_chunk_duration}) should be k*n + 1 (k={compression_ratio[0]})"
|
360 |
+
latent_chunk_duration = (pixel_chunk_duration - 1) // compression_ratio[0] + 1
|
361 |
+
latent_height = video_height // compression_ratio[1]
|
362 |
+
latent_width = video_width // compression_ratio[2]
|
363 |
+
# Do some math to compute the video latent shape and sequence length
|
364 |
+
assert (
|
365 |
+
num_video_frames % pixel_chunk_duration == 0
|
366 |
+
), f"num_video_frames {num_video_frames} should be divisible by pixel_chunk_duration {pixel_chunk_duration}"
|
367 |
+
video_latent_shape = [
|
368 |
+
num_video_frames // pixel_chunk_duration * latent_chunk_duration,
|
369 |
+
latent_height,
|
370 |
+
latent_width,
|
371 |
+
]
|
372 |
+
# product of video_latent_shape
|
373 |
+
num_token_video_latent = video_latent_shape[0] * video_latent_shape[1] * video_latent_shape[2]
|
374 |
+
if add_special_tokens:
|
375 |
+
seq_len = num_token_video_latent + 3 # Sequence length per batch, max_seq_len + 3
|
376 |
+
seq_len = (seq_len + 63) // 64 * 64 # Round up to multiple of 64
|
377 |
+
# for text to video, we need to add <bov> token to indicate the start of the video
|
378 |
+
elif training_type == "text_to_video":
|
379 |
+
seq_len = num_token_video_latent + 1
|
380 |
+
else:
|
381 |
+
seq_len = num_token_video_latent
|
382 |
+
|
383 |
+
if seq_len % pad_to_multiple_of != 0:
|
384 |
+
# Round up to the nearest multiple of pad_to_multiple_of
|
385 |
+
seq_len = ((seq_len + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of
|
386 |
+
|
387 |
+
# Model size specific parameters
|
388 |
+
model_arch_specs = get_model_arch_specs(model_family=model_family, model_size=model_size, pretrained=True)
|
389 |
+
|
390 |
+
# Whether skip the loss for first chunk or not, note the first token is already skipped when computing the loss
|
391 |
+
# If num_tokens_to_ignore is specified, use it.
|
392 |
+
# Else compute it from num_condition_latents_t
|
393 |
+
if num_tokens_to_ignore < 0:
|
394 |
+
num_tokens_to_ignore = latent_height * latent_width * num_condition_latents_t
|
395 |
+
if not add_special_tokens and num_condition_latents_t > 0:
|
396 |
+
# If there are no special tokens (bov), do a -1 so that you can compute the loss
|
397 |
+
# from the first token of the next chunk
|
398 |
+
num_tokens_to_ignore -= 1
|
399 |
+
|
400 |
+
model_config = ModelConfig(
|
401 |
+
video_height=video_height,
|
402 |
+
video_width=video_width,
|
403 |
+
max_seq_len=seq_len,
|
404 |
+
max_batch_size=batch_size,
|
405 |
+
precision="bfloat16",
|
406 |
+
ckpt_path=model_ckpt_path,
|
407 |
+
use_qk_normalization=use_qk_normalization,
|
408 |
+
vocab_size=64000,
|
409 |
+
original_seq_len=original_seq_len,
|
410 |
+
tensor_model_parallel_size=tensor_model_parallel_size,
|
411 |
+
video_latent_shape=video_latent_shape,
|
412 |
+
num_video_frames=num_video_frames,
|
413 |
+
rope_dim=rope_dim,
|
414 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
415 |
+
insert_cross_attn=insert_cross_attn,
|
416 |
+
insert_cross_attn_every_k_layers=insert_cross_attn_every_k_layers,
|
417 |
+
context_dim=context_dim,
|
418 |
+
apply_abs_pos_emb=apply_abs_pos_emb,
|
419 |
+
**model_arch_specs,
|
420 |
+
)
|
421 |
+
|
422 |
+
video_tokenizer_config = video_tokenizer_config_creator(
|
423 |
+
tokenizer_ckpt_path, pixel_chunk_duration, compression_ratio
|
424 |
+
)
|
425 |
+
tokenizer_config = TokenizerConfig(
|
426 |
+
text_tokenizer=None,
|
427 |
+
video_tokenizer=VideoTokenizerConfig(
|
428 |
+
config=video_tokenizer_config,
|
429 |
+
data_key="video",
|
430 |
+
tokenizer_offset=0, # Since there is no text embeddings in the model. Note this only apply when the model is trained from scratch. If we use text pretrained model, the offset will be vocab_size of text token.
|
431 |
+
tokenize_here=True,
|
432 |
+
max_seq_len=num_token_video_latent,
|
433 |
+
vocab_size=vocab_size,
|
434 |
+
),
|
435 |
+
seq_len=seq_len,
|
436 |
+
training_type=training_type,
|
437 |
+
add_special_tokens=add_special_tokens,
|
438 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
439 |
+
)
|
440 |
+
return model_config, tokenizer_config
|
441 |
+
|
442 |
+
|
443 |
+
def create_video2world_model(
|
444 |
+
tensor_model_parallel_size: int = 1,
|
445 |
+
context_parallel_size: int = 1,
|
446 |
+
shard_checkpoint: bool = False,
|
447 |
+
model_family: str = "cosmos",
|
448 |
+
model_size: str = "1b",
|
449 |
+
backend: str = "pytorch",
|
450 |
+
pixel_chunk_duration: int = 9,
|
451 |
+
num_video_frames: int = 36,
|
452 |
+
compression_ratio: List[int] = [8, 16, 16],
|
453 |
+
original_seq_len: int = 8192,
|
454 |
+
apply_yarn: bool = False,
|
455 |
+
yarn_beta_fast: Optional[int] = None,
|
456 |
+
yarn_beta_slow: Optional[int] = None,
|
457 |
+
yarn_scale: Optional[int] = None,
|
458 |
+
num_condition_latents_t: int = 1,
|
459 |
+
num_tokens_to_ignore: int = -1,
|
460 |
+
batch_size: int = 1,
|
461 |
+
fsdp_enabled: bool = False,
|
462 |
+
act_ckpt_enabled: bool = False,
|
463 |
+
video_tokenizer_config_creator: Callable = create_discrete_video_fsq_tokenizer_state_dict_config,
|
464 |
+
rope_dim: str = "3D",
|
465 |
+
add_special_tokens: bool = False,
|
466 |
+
video_height: int = 384,
|
467 |
+
video_width: int = 640,
|
468 |
+
original_latent_shape: Optional[List[int]] = None,
|
469 |
+
use_qk_normalization: bool = True,
|
470 |
+
sequence_parallel: bool = False,
|
471 |
+
insert_cross_attn: bool = False,
|
472 |
+
insert_cross_attn_every_k_layers: int = 1,
|
473 |
+
context_dim: int = 1024,
|
474 |
+
finetune_layers_with_cross_attn: bool = False,
|
475 |
+
finetune_layers_without_cross_attn: bool = False,
|
476 |
+
use_action_condition: bool = False,
|
477 |
+
action_embedding_mode: Optional[str] = "mlp",
|
478 |
+
action_dim: int = 8, # ACTION_DIM,
|
479 |
+
action_embedding_dim: int = 1024,
|
480 |
+
group_causal_mask_mode: Optional[str] = None,
|
481 |
+
training_type: str = "video_to_video",
|
482 |
+
pad_to_multiple_of: Optional[int] = 1,
|
483 |
+
z_loss_coeff: float = 1e-4,
|
484 |
+
temporal_overlap: int = 0,
|
485 |
+
embedding_dropout: float = 0.0,
|
486 |
+
insert_medusa_head: bool = False,
|
487 |
+
ft_medusa_option: str = "fft",
|
488 |
+
medusa_num_heads: int = 7,
|
489 |
+
medusa_num_layers: int = 1,
|
490 |
+
medusa_concat_heads: bool = True,
|
491 |
+
fuse_qkv: bool = False,
|
492 |
+
zero_init_cross_attn_proj: bool = False,
|
493 |
+
concat_action_to_context: bool = False,
|
494 |
+
tokenizer_ckpt_path: str = "checkpoints/Cosmos-1.0-Tokenizer-DV8x16x16/ema.jit",
|
495 |
+
) -> dict:
|
496 |
+
"""Create a video-to-video model for training.
|
497 |
+
Args:
|
498 |
+
tensor_model_parallel_size (int): Number of tensor model parallel groups.
|
499 |
+
context_parallel_size (int): Number of context parallel groups.
|
500 |
+
model_family (str): Model family. Choices: "llama", "llama3", "llama3.1", "mistral".
|
501 |
+
model_size (str): Model size. Choices: "1b", "8b", "3b".
|
502 |
+
backend (str): Backend for the model. Choices: "pytorch", "transformer_engine".
|
503 |
+
pixel_chunk_duration (int): Number of frames in each chunk.
|
504 |
+
num_video_frames (int): Number of video frames.
|
505 |
+
compression_ratio (List[int]): Compression ratio for the video frames. Choices: [8, 16, 16] or [4, 8, 8].
|
506 |
+
original_seq_len (int): Original sequence length.
|
507 |
+
apply_yarn (bool): Whether to apply YaRN for long context scaling.
|
508 |
+
yarn_beta_fast (Optional[int]): Fast beta for YaRN.
|
509 |
+
yarn_beta_slow (Optional[int]): Slow beta for YaRN.
|
510 |
+
yarn_scale (Optional[int]): Scale factor for ctx extension.
|
511 |
+
fsdp_enabled (bool): Whether Fully Sharded Data Parallel (FSDP) is enabled.
|
512 |
+
act_ckpt_enabled (bool): Whether activation checkpointing is enabled.
|
513 |
+
use_qk_normalization (bool): Whether to use Query-Key normalization.
|
514 |
+
training_type (str): Type of training task.
|
515 |
+
batch_size (int): Batch size.
|
516 |
+
video_tokenizer_config_creator (Callable): Method that takes "pixel_chunk_duration: int" and "version: str" as arguments and returns video tokenizer config
|
517 |
+
video_tokenizer_version (str): Version of the video tokenizer.
|
518 |
+
num_condition_latents_t (int): Number of conditioning latent channels
|
519 |
+
num_tokens_to_ignore (int) = Number of tokens to ignore. This takes the precedence
|
520 |
+
video_height (int): Height of the video frame. Defaults to 384.
|
521 |
+
video_width (int): Width of the video frame. Defaults to 640.
|
522 |
+
rope_dim (str): RoPE dimension. Choices: "1D", "2D", "3D".
|
523 |
+
add_special_tokens (bool): Whether to add special tokens, use False for 2D/3D RoPE.
|
524 |
+
original_latent_shape (list): Original latent shape before RoPE scaling.
|
525 |
+
sequence_parallel (bool): Whether to enable sequence parallelism.
|
526 |
+
insert_cross_attn (bool): Whether to insert the cross-attention layers after each multi-head self-attention (MSA) layer.
|
527 |
+
insert_cross_attn_every_k_layers (int): Insert cross-attention layers every k TransformerLayers.
|
528 |
+
context_dim (Optional[int]): The dimensionality of cross-attention embedding, e.g., T5 embed feature dim.
|
529 |
+
finetune_layers_with_cross_attn (bool): Whether to finetune Transformer layers w/ CA (cross-attn).
|
530 |
+
finetune_layers_without_cross_attn (bool): Whether to finetune Transformer layers w/o CA (cross-attn).
|
531 |
+
use_action_condition (bool): Whether to use action condition.
|
532 |
+
action_embedding_mode (Optional[str]): The mode of the robot action embedding. Choices: "matrix", "mlp".
|
533 |
+
action_dim (int): Dimension of the raw robot action tensor (e.g., 7 for DROID, [Δx, Δy, Δz, rx, ry, rz, gripper_open]).
|
534 |
+
action_embedding_dim (int): Dimension of the action embedding.
|
535 |
+
group_causal_mask_mode (Optional[str]): The mode of the group causal mask. Choices: "causal", "group_diagonal".
|
536 |
+
pad_to_multiple_of (int): Pad the token sequence length to the nearest multiple of this number. Defaults to 64.
|
537 |
+
z_loss_coeff (float): Coefficient for the z loss.
|
538 |
+
temporal_overlap (int): Temporal overlap in the latent space.
|
539 |
+
embedding_dropout (float): Dropout rate for the embeddings.
|
540 |
+
insert_medusa_head (bool): Whether to insert the Medusa head.
|
541 |
+
ft_medusa_option (str): Options on which layers to finetune, choices like:
|
542 |
+
"fft": fully fine-tune both medusa heads and all LLM backbone;
|
543 |
+
"head": fine-tune medusa heads;
|
544 |
+
"head_out": fine-tune medusa heads, and the output layer;
|
545 |
+
"head_out_last_k_layer": fine-tune medusa heads, the output layer, and the last k layer(s) of the LLM backbone.
|
546 |
+
medusa_num_heads (int): Number of heads in the Medusa head.
|
547 |
+
medusa_num_layers (int): Number of layers in the Medusa head.
|
548 |
+
medusa_concat_heads (bool): Whether to concatenate multiple medusa heads into fused matrix, only applicable when medusa_num_layers = 1.
|
549 |
+
fuse_qkv (bool): Whether to fuse the QKV linear layers.
|
550 |
+
zero_init_cross_attn_proj (bool): Whether to zero-initialize the cross-attention projection weights (default False).
|
551 |
+
concat_action_to_context (bool): Whether to concatenate the action embedding to the context (default False).
|
552 |
+
Returns:
|
553 |
+
dict: A dictionary containing the model configuration representing the model object, can be instantiated.
|
554 |
+
"""
|
555 |
+
assert (
|
556 |
+
pixel_chunk_duration % compression_ratio[0] == 1
|
557 |
+
), f"pixel_chunk_duration({pixel_chunk_duration}) should be k*n + 1 (k={compression_ratio[0]})"
|
558 |
+
latent_chunk_duration = (pixel_chunk_duration - 1) // compression_ratio[0] + 1
|
559 |
+
latent_height = video_height // compression_ratio[1]
|
560 |
+
latent_width = video_width // compression_ratio[2]
|
561 |
+
# Compute the video latent shape and sequence length
|
562 |
+
if temporal_overlap == 0:
|
563 |
+
assert (
|
564 |
+
num_video_frames % pixel_chunk_duration == 0
|
565 |
+
), f"num_video_frames {num_video_frames} should be divisible by pixel_chunk_duration {pixel_chunk_duration}"
|
566 |
+
video_latent_shape = [
|
567 |
+
num_video_frames // pixel_chunk_duration * latent_chunk_duration,
|
568 |
+
latent_height,
|
569 |
+
latent_width,
|
570 |
+
]
|
571 |
+
|
572 |
+
else:
|
573 |
+
# Calculate temporal overlap in the latent space
|
574 |
+
temporal_overlap_latent = temporal_overlap // compression_ratio[0]
|
575 |
+
|
576 |
+
# Calculate the effective number of latent chunks for the video
|
577 |
+
latent_chunks = (num_video_frames - temporal_overlap) // (pixel_chunk_duration - temporal_overlap)
|
578 |
+
|
579 |
+
# Compute the total duration of the latent chunks, accounting for overlap
|
580 |
+
effective_latent_duration = (
|
581 |
+
latent_chunk_duration - temporal_overlap_latent
|
582 |
+
) * latent_chunks + temporal_overlap_latent
|
583 |
+
|
584 |
+
# Define the shape of the video in the latent space
|
585 |
+
video_latent_shape = [
|
586 |
+
effective_latent_duration, # Temporal dimension
|
587 |
+
latent_height, # Height in the latent space
|
588 |
+
latent_width, # Width in the latent space
|
589 |
+
]
|
590 |
+
|
591 |
+
# product of video_latent_shape
|
592 |
+
num_token_video_latent = video_latent_shape[0] * video_latent_shape[1] * video_latent_shape[2]
|
593 |
+
if add_special_tokens:
|
594 |
+
seq_len = num_token_video_latent + 3 # Sequence length per batch, max_seq_len + 3
|
595 |
+
seq_len = (seq_len + 63) // 64 * 64 # Round up to multiple of 64
|
596 |
+
# for text to video, we need to add <bov> token to indicate the start of the video
|
597 |
+
elif training_type == "text_to_video":
|
598 |
+
seq_len = num_token_video_latent + 1
|
599 |
+
else:
|
600 |
+
seq_len = num_token_video_latent
|
601 |
+
|
602 |
+
if seq_len % pad_to_multiple_of != 0:
|
603 |
+
# Round up to the nearest multiple of pad_to_multiple_of
|
604 |
+
seq_len = ((seq_len + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of
|
605 |
+
|
606 |
+
# Model size specific parameters
|
607 |
+
model_arch_specs = get_model_arch_specs(model_family=model_family, model_size=model_size, pretrained=False)
|
608 |
+
|
609 |
+
inference = False # False for training, True for inference
|
610 |
+
# set_parallel_mode = True
|
611 |
+
set_parallel_mode = tensor_model_parallel_size > 1
|
612 |
+
attention_tp = True
|
613 |
+
|
614 |
+
if context_parallel_size > 1:
|
615 |
+
assert backend == "transformer_engine", "Context parallelism is only supported in transformer engine."
|
616 |
+
|
617 |
+
if tensor_model_parallel_size > 1:
|
618 |
+
assert set_parallel_mode, "Tensor model parallelism is only supported in parallel mode."
|
619 |
+
|
620 |
+
# Whether skip the loss for first chunk or not, note the first token is already skipped when computing the loss
|
621 |
+
# If num_tokens_to_ignore is specified, use it.
|
622 |
+
# Else compute it from num_condition_latents_t
|
623 |
+
if num_tokens_to_ignore < 0:
|
624 |
+
num_tokens_to_ignore = latent_height * latent_width * num_condition_latents_t
|
625 |
+
if not add_special_tokens and num_condition_latents_t > 0:
|
626 |
+
# If there are no special tokens (bov), do a -1 so that you can compute the loss
|
627 |
+
# from the first token of the next chunk
|
628 |
+
num_tokens_to_ignore -= 1
|
629 |
+
|
630 |
+
model_config = TrainingModelConfig(
|
631 |
+
video_height=video_height,
|
632 |
+
video_width=video_width,
|
633 |
+
max_seq_len=seq_len,
|
634 |
+
max_batch_size=batch_size,
|
635 |
+
inference=inference,
|
636 |
+
backend=backend,
|
637 |
+
precision="bfloat16",
|
638 |
+
ema=EMAConfig(enabled=False),
|
639 |
+
act_ckpt_enabled=act_ckpt_enabled,
|
640 |
+
fsdp_enabled=fsdp_enabled,
|
641 |
+
cache_dir=None,
|
642 |
+
ckpt_path="checkpoints/Cosmos-Predict1-4B/model.pt",
|
643 |
+
use_qk_normalization=use_qk_normalization,
|
644 |
+
vocab_size=64000,
|
645 |
+
ignore_first_num_tokens=num_tokens_to_ignore,
|
646 |
+
apply_yarn=apply_yarn,
|
647 |
+
yarn_beta_fast=yarn_beta_fast,
|
648 |
+
yarn_beta_slow=yarn_beta_slow,
|
649 |
+
original_seq_len=original_seq_len,
|
650 |
+
yarn_scale=yarn_scale,
|
651 |
+
context_parallel_size=context_parallel_size,
|
652 |
+
tensor_model_parallel_size=tensor_model_parallel_size,
|
653 |
+
set_parallel_mode=set_parallel_mode,
|
654 |
+
attention_tp=attention_tp,
|
655 |
+
video_latent_shape=video_latent_shape,
|
656 |
+
num_video_frames=num_video_frames,
|
657 |
+
rope_dim=rope_dim,
|
658 |
+
original_latent_shape=original_latent_shape,
|
659 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
660 |
+
sequence_parallel=sequence_parallel,
|
661 |
+
insert_cross_attn=insert_cross_attn,
|
662 |
+
insert_cross_attn_every_k_layers=insert_cross_attn_every_k_layers,
|
663 |
+
context_dim=context_dim,
|
664 |
+
finetune_layers_with_cross_attn=finetune_layers_with_cross_attn,
|
665 |
+
finetune_layers_without_cross_attn=finetune_layers_without_cross_attn,
|
666 |
+
use_action_condition=use_action_condition,
|
667 |
+
action_embedding_mode=action_embedding_mode,
|
668 |
+
action_dim=action_dim,
|
669 |
+
action_embedding_dim=action_embedding_dim,
|
670 |
+
group_causal_mask_mode=group_causal_mask_mode,
|
671 |
+
z_loss_coeff=z_loss_coeff,
|
672 |
+
embedding_dropout=embedding_dropout,
|
673 |
+
insert_medusa_head=insert_medusa_head,
|
674 |
+
ft_medusa_option=ft_medusa_option,
|
675 |
+
medusa_num_heads=medusa_num_heads,
|
676 |
+
medusa_num_layers=medusa_num_layers,
|
677 |
+
medusa_concat_heads=medusa_concat_heads,
|
678 |
+
fuse_qkv=fuse_qkv,
|
679 |
+
zero_init_cross_attn_proj=zero_init_cross_attn_proj,
|
680 |
+
concat_action_to_context=concat_action_to_context,
|
681 |
+
**model_arch_specs,
|
682 |
+
)
|
683 |
+
|
684 |
+
tokenizer_config = TokenizerConfig(
|
685 |
+
text_tokenizer=None,
|
686 |
+
video_tokenizer=VideoTokenizerConfig(
|
687 |
+
config=video_tokenizer_config_creator(
|
688 |
+
ckpt_path=tokenizer_ckpt_path, pixel_chunk_duration=pixel_chunk_duration
|
689 |
+
),
|
690 |
+
data_key="video",
|
691 |
+
tokenizer_offset=0,
|
692 |
+
vocab_size=64000,
|
693 |
+
tokenize_here=True,
|
694 |
+
max_seq_len=num_token_video_latent,
|
695 |
+
temporal_overlap=temporal_overlap,
|
696 |
+
),
|
697 |
+
seq_len="${model.model_config.max_seq_len}",
|
698 |
+
training_type=training_type,
|
699 |
+
add_special_tokens=add_special_tokens,
|
700 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
701 |
+
)
|
702 |
+
|
703 |
+
model_parallel = ModelParallelConfig(
|
704 |
+
bf16=True,
|
705 |
+
params_dtype=getattr(torch, "bfloat16"),
|
706 |
+
)
|
707 |
+
model_parallel.tensor_model_parallel_size = "${model.model_config.tensor_model_parallel_size}"
|
708 |
+
model_parallel.context_parallel_size = "${model.model_config.context_parallel_size}"
|
709 |
+
model_parallel.sequence_parallel = "${model.model_config.sequence_parallel}"
|
710 |
+
return L(AutoRegressiveTrainingModel.build)(
|
711 |
+
seed=0,
|
712 |
+
train_from_scratch=True,
|
713 |
+
model_config=model_config,
|
714 |
+
fsdp_checkpointer=None,
|
715 |
+
tokenizer_config=tokenizer_config,
|
716 |
+
model_parallel=model_parallel,
|
717 |
+
shard_checkpoint=shard_checkpoint,
|
718 |
+
)
|
cosmos_predict1/autoregressive/configs/base/model_parallel.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import torch
|
17 |
+
from megatron.core import ModelParallelConfig
|
18 |
+
|
19 |
+
from cosmos_predict1.utils.lazy_config import LazyDict
|
20 |
+
|
21 |
+
|
22 |
+
def create_model_parallel_config():
|
23 |
+
model_parallel = ModelParallelConfig(bf16=True, params_dtype=getattr(torch, "bfloat16"))
|
24 |
+
model_parallel.tensor_model_parallel_size = "${model.model_parallel.tensor_model_parallel_size}"
|
25 |
+
model_parallel.context_parallel_size = "${model.model_parallel.context_parallel_size}"
|
26 |
+
model_parallel.sequence_parallel = "${model.model_parallel.sequence_parallel}"
|
27 |
+
MODEL_PARALLELS = LazyDict(
|
28 |
+
dict(
|
29 |
+
model_parallel_bf16=model_parallel,
|
30 |
+
),
|
31 |
+
flags={"allow_objects": True},
|
32 |
+
)
|
33 |
+
return MODEL_PARALLELS["model_parallel_bf16"]
|
cosmos_predict1/autoregressive/configs/base/optim.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import torch
|
17 |
+
|
18 |
+
from cosmos_predict1.utils.lazy_config import LazyCall as L
|
19 |
+
|
20 |
+
|
21 |
+
class LambdaLinearWarmupScheduler:
|
22 |
+
"""
|
23 |
+
A learning rate scheduler that implements linear warm-up and cool-down.
|
24 |
+
|
25 |
+
This scheduler provides three phases:
|
26 |
+
1. Warm-up: Learning rate linearly increases from 0 to 1.
|
27 |
+
2. Constant: Learning rate remains at 1.
|
28 |
+
3. Cool-down: Learning rate linearly decreases from 1 to 0.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
warmup_steps (int): Number of steps for the warm-up phase.
|
32 |
+
warmup_offset (int): Starts warmup from this offset.
|
33 |
+
max_iter (int, optional): Total number of iterations. Required if cooldown_steps is provided.
|
34 |
+
cooldown_steps (int, optional): Number of steps for the cool-down phase.
|
35 |
+
|
36 |
+
Raises:
|
37 |
+
ValueError: If cooldown_steps is provided without max_iter, or if an invalid step is given.
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(self, warmup_steps: int, warmup_offset: int = 0, max_iter: int = None, cooldown_steps: int = None):
|
41 |
+
self.warmup_steps = warmup_steps
|
42 |
+
self.warmup_offset = warmup_offset
|
43 |
+
self.max_iter = max_iter
|
44 |
+
self.cooldown_steps = cooldown_steps
|
45 |
+
|
46 |
+
if cooldown_steps is not None:
|
47 |
+
if max_iter is None:
|
48 |
+
raise ValueError("max_iter must be specified when cooldown_steps is provided")
|
49 |
+
self.cooldown_start = max_iter - cooldown_steps
|
50 |
+
else:
|
51 |
+
self.cooldown_start = None
|
52 |
+
|
53 |
+
def __call__(self, step):
|
54 |
+
# Warm-up phase
|
55 |
+
if step < self.warmup_offset:
|
56 |
+
return 0
|
57 |
+
|
58 |
+
if step < self.warmup_steps + self.warmup_offset:
|
59 |
+
return float(step - self.warmup_offset) / float(max(1, self.warmup_steps))
|
60 |
+
|
61 |
+
# Constant phase (no cool-down)
|
62 |
+
elif self.cooldown_steps is None:
|
63 |
+
return 1.0
|
64 |
+
|
65 |
+
# Constant phase (before cool-down starts)
|
66 |
+
elif step < self.cooldown_start:
|
67 |
+
return 1.0
|
68 |
+
|
69 |
+
# Cool-down phase
|
70 |
+
elif self.cooldown_start <= step < self.max_iter:
|
71 |
+
cooldown_progress = (step - self.cooldown_start) / self.cooldown_steps
|
72 |
+
return 1.0 - cooldown_progress
|
73 |
+
|
74 |
+
# After max_iter
|
75 |
+
elif step >= self.max_iter:
|
76 |
+
return 0.0
|
77 |
+
|
78 |
+
# Unexpected case
|
79 |
+
else:
|
80 |
+
raise ValueError(f"Invalid step {step}")
|
81 |
+
|
82 |
+
|
83 |
+
LambdaLinearLR = L(torch.optim.lr_scheduler.LambdaLR)(
|
84 |
+
optimizer=None,
|
85 |
+
lr_lambda=L(LambdaLinearWarmupScheduler)(warmup_steps=5000),
|
86 |
+
)
|
cosmos_predict1/autoregressive/configs/base/tokenizer.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from typing import Optional
|
17 |
+
|
18 |
+
import attrs
|
19 |
+
|
20 |
+
from cosmos_predict1.autoregressive.tokenizer.discrete_video import DiscreteVideoFSQStateDictTokenizer
|
21 |
+
from cosmos_predict1.autoregressive.tokenizer.networks import CausalDiscreteVideoTokenizer
|
22 |
+
from cosmos_predict1.utils.lazy_config import LazyCall as L
|
23 |
+
from cosmos_predict1.utils.lazy_config import LazyDict
|
24 |
+
|
25 |
+
|
26 |
+
def create_discrete_video_fsq_tokenizer_state_dict_config(
|
27 |
+
ckpt_path, pixel_chunk_duration=33, compression_ratio=[8, 16, 16]
|
28 |
+
) -> LazyDict:
|
29 |
+
CausalDiscreteFactorizedVideoTokenizerConfig: LazyDict = L(CausalDiscreteVideoTokenizer)(
|
30 |
+
# The new causal discrete tokenizer, that is at least 2x more efficient in memory and runtime.
|
31 |
+
# - It relies on fully 3D discrete wavelet transform
|
32 |
+
# - Uses a layer norm instead of a group norm
|
33 |
+
# - Factorizes full convolutions into spatial and temporal convolutions
|
34 |
+
# - Factorizes full attention into spatial and temporal attention
|
35 |
+
# - Strictly causal, with flexible temporal length at inference.
|
36 |
+
attn_resolutions=[32],
|
37 |
+
channels=128,
|
38 |
+
channels_mult=[2, 4, 4],
|
39 |
+
dropout=0.0,
|
40 |
+
in_channels=3,
|
41 |
+
num_res_blocks=2,
|
42 |
+
out_channels=3,
|
43 |
+
resolution=1024,
|
44 |
+
patch_size=4,
|
45 |
+
patch_method="haar",
|
46 |
+
z_channels=16,
|
47 |
+
z_factor=1,
|
48 |
+
num_groups=1,
|
49 |
+
legacy_mode=False,
|
50 |
+
spatial_compression=16,
|
51 |
+
temporal_compression=8,
|
52 |
+
embedding_dim=6,
|
53 |
+
levels=[8, 8, 8, 5, 5, 5],
|
54 |
+
name="CausalDiscreteFactorizedVideoTokenizer",
|
55 |
+
)
|
56 |
+
|
57 |
+
return L(DiscreteVideoFSQStateDictTokenizer)(
|
58 |
+
enc_fp=ckpt_path.replace("ema.jit", "encoder.jit"),
|
59 |
+
dec_fp=ckpt_path.replace("ema.jit", "decoder.jit"),
|
60 |
+
tokenizer_module=CausalDiscreteFactorizedVideoTokenizerConfig,
|
61 |
+
name="discrete_video_fsq",
|
62 |
+
latent_ch=6,
|
63 |
+
is_bf16=True,
|
64 |
+
pixel_chunk_duration=pixel_chunk_duration,
|
65 |
+
latent_chunk_duration=1 + (pixel_chunk_duration - 1) // compression_ratio[0],
|
66 |
+
max_enc_batch_size=8,
|
67 |
+
max_dec_batch_size=4,
|
68 |
+
levels=[8, 8, 8, 5, 5, 5],
|
69 |
+
compression_ratio=compression_ratio,
|
70 |
+
)
|
71 |
+
|
72 |
+
|
73 |
+
@attrs.define(slots=False)
|
74 |
+
class TextTokenizerConfig:
|
75 |
+
"""
|
76 |
+
Text tokenizer config
|
77 |
+
|
78 |
+
Args:
|
79 |
+
config: Config file to define the text tokenizer class.
|
80 |
+
data_key (str): The input key from data_dict that will be passed to the text tokenizer.
|
81 |
+
tokenize_here (bool): Whether to use the tokenizer to perform online tokenization.
|
82 |
+
tokenizer_offset (int): Offset that is added to the tokens.
|
83 |
+
vocab_size (int): Vocabulary size of the tokenizer.
|
84 |
+
"""
|
85 |
+
|
86 |
+
config: LazyDict
|
87 |
+
data_key: str = ""
|
88 |
+
tokenize_here: bool = False
|
89 |
+
tokenizer_offset: int = 0
|
90 |
+
vocab_size: int = 0
|
91 |
+
|
92 |
+
|
93 |
+
@attrs.define(slots=False)
|
94 |
+
class VideoTokenizerConfig:
|
95 |
+
"""
|
96 |
+
Video tokenizer config
|
97 |
+
|
98 |
+
Args:
|
99 |
+
config: Config file to define the video tokenizer class.
|
100 |
+
data_key (str): The input key from data_dict that will be passed to the video tokenizer.
|
101 |
+
tokenize_here (bool): Whether to use the tokenizer to perform online tokenization.
|
102 |
+
tokenizer_offset (int): Offset that is added to the tokens. In case of joint text-video tokenizers, we
|
103 |
+
add an offset to make sure that video tokens and text tokens don't overlap.
|
104 |
+
vocab_size (int): Vocabulary size of the tokenizer.
|
105 |
+
max_seq_len (int): Maximum token length for an input video.
|
106 |
+
temporal_overlap (int): Overlap between consecutive video chunks.
|
107 |
+
"""
|
108 |
+
|
109 |
+
config: LazyDict
|
110 |
+
data_key: str = ""
|
111 |
+
tokenize_here: bool = True
|
112 |
+
tokenizer_offset: int = 0
|
113 |
+
vocab_size: int = 0
|
114 |
+
max_seq_len: int = -1
|
115 |
+
temporal_overlap: int = 0
|
116 |
+
|
117 |
+
|
118 |
+
@attrs.define(slots=False)
|
119 |
+
class TokenizerConfig:
|
120 |
+
"""
|
121 |
+
Joint tokenizer config
|
122 |
+
|
123 |
+
Args:
|
124 |
+
text_tokenizer (TextTokenizerConfig): Text tokenizer config file
|
125 |
+
class_tokenizer (ClassTokenizerConfig): Class tokenizer config file
|
126 |
+
video_tokenizer (VideoTokenizerConfig): Video tokenizer config file
|
127 |
+
image_tokenizer (ImageTokenizerConfig): Image tokenizer config file
|
128 |
+
seq_len (int): Final token sequence length
|
129 |
+
training_type (str): Type of training we use. Supports ["text_only", "text_to_video", "class_to_image", "image_text_interleaved"]
|
130 |
+
add_special_tokens (bool): Whether to add special tokens to the output tokens
|
131 |
+
pad_to_multiple_of (int): Pad the token sequence length to the nearest multiple of this number. Defaults to 64.
|
132 |
+
"""
|
133 |
+
|
134 |
+
text_tokenizer: Optional[TextTokenizerConfig] = None
|
135 |
+
video_tokenizer: Optional[VideoTokenizerConfig] = None
|
136 |
+
seq_len: int = 4096
|
137 |
+
training_type: str = None
|
138 |
+
add_special_tokens: bool = True
|
139 |
+
pad_to_multiple_of: Optional[int] = 64
|
cosmos_predict1/autoregressive/configs/config.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Default config for cosmos_ar project."""
|
17 |
+
|
18 |
+
import os
|
19 |
+
from typing import Any, List
|
20 |
+
|
21 |
+
import attrs
|
22 |
+
|
23 |
+
from cosmos_predict1.autoregressive.configs.registry import register_configs
|
24 |
+
from cosmos_predict1.autoregressive.trainer import Trainer
|
25 |
+
from cosmos_predict1.utils import config, log
|
26 |
+
from cosmos_predict1.utils.config_helper import import_all_modules_from_package
|
27 |
+
|
28 |
+
|
29 |
+
@attrs.define(slots=False)
|
30 |
+
class Config(config.Config):
|
31 |
+
defaults: List[Any] = attrs.field(
|
32 |
+
factory=lambda: [
|
33 |
+
"_self_",
|
34 |
+
{"model": None},
|
35 |
+
{"data_train": "mock_video"},
|
36 |
+
{"data_val": None},
|
37 |
+
{"optimizer": "fused_adamw"},
|
38 |
+
{"scheduler": "warmup_cosine_lr"},
|
39 |
+
{"checkpoint": "local"},
|
40 |
+
{"callbacks": "basic"},
|
41 |
+
{"global_config": None},
|
42 |
+
{"experiment": None},
|
43 |
+
]
|
44 |
+
)
|
45 |
+
|
46 |
+
def validate(self) -> None:
|
47 |
+
"""Validate that the config has all required fields."""
|
48 |
+
assert self.job.project != "", "job.project is not set"
|
49 |
+
assert self.job.group != "", "job.group is not set"
|
50 |
+
assert self.job.name != "", "job.name is not set"
|
51 |
+
log.info("Validating config for cosmos_autoregressive job")
|
52 |
+
# FSDP config check
|
53 |
+
if self.model.model_config.fsdp_enabled:
|
54 |
+
assert self.trainer.distributed_parallelism == "fsdp"
|
55 |
+
else:
|
56 |
+
assert self.trainer.distributed_parallelism == "ddp"
|
57 |
+
|
58 |
+
# Transformer Engine config check
|
59 |
+
if self.model.model_config.backend == "transformer_engine":
|
60 |
+
assert (
|
61 |
+
"NVTE_FLASH_ATTN" in os.environ and os.environ["NVTE_FLASH_ATTN"] == "1"
|
62 |
+
) # Enable Flash attention for transformer engine
|
63 |
+
|
64 |
+
# TP, CP config check
|
65 |
+
if self.model_parallel is not None:
|
66 |
+
if self.model_parallel.context_parallel_size > 1:
|
67 |
+
assert (
|
68 |
+
self.model.model_config.backend == "transformer_engine"
|
69 |
+
), "Context parallelism is only supported in transformer engine."
|
70 |
+
|
71 |
+
if self.model_parallel.tensor_model_parallel_size > 1:
|
72 |
+
assert (
|
73 |
+
self.model.model_config.set_parallel_mode
|
74 |
+
), "Tensor model parallelism is only supported in parallel mode."
|
75 |
+
|
76 |
+
if self.model_parallel.sequence_parallel:
|
77 |
+
assert (
|
78 |
+
self.model_parallel.tensor_model_parallel_size > 1
|
79 |
+
), "Sequence parallelism is only supported in tensor model parallelism."
|
80 |
+
assert (
|
81 |
+
self.model.model_config.backend == "transformer_engine"
|
82 |
+
), "Sequence parallelism is only supported in transformer engine."
|
83 |
+
|
84 |
+
|
85 |
+
def make_config():
|
86 |
+
c = Config(
|
87 |
+
model=None,
|
88 |
+
optimizer=None,
|
89 |
+
scheduler=None,
|
90 |
+
dataloader_train=None,
|
91 |
+
dataloader_val=None,
|
92 |
+
checkpoint=None,
|
93 |
+
)
|
94 |
+
|
95 |
+
c.job.project = "cosmos_autoregressive"
|
96 |
+
c.job.group = "debug"
|
97 |
+
c.job.name = "default_${now:%Y-%m-%d}_${now:%H-%M-%S}"
|
98 |
+
|
99 |
+
c.trainer.type = Trainer
|
100 |
+
c.trainer.run_validation = True
|
101 |
+
|
102 |
+
c.trainer.seed = 0
|
103 |
+
c.trainer.max_iter = 10
|
104 |
+
c.trainer.logging_iter = 1
|
105 |
+
|
106 |
+
c.trainer.callbacks = None
|
107 |
+
register_configs()
|
108 |
+
# experiment config are defined in the experiment folder
|
109 |
+
# call import_all_modules_from_package to register them
|
110 |
+
import_all_modules_from_package("cosmos_predict1.autoregressive.configs.experiment")
|
111 |
+
return c
|
cosmos_predict1/autoregressive/configs/experiment/video2video/__init__.py
ADDED
File without changes
|
cosmos_predict1/autoregressive/configs/experiment/video2video/basic.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""
|
17 |
+
This file contains a basic configuration for video2video experiments.
|
18 |
+
"""
|
19 |
+
|
20 |
+
from hydra.core.config_store import ConfigStore
|
21 |
+
|
22 |
+
from cosmos_predict1.autoregressive.configs.base.model_config import create_video2world_model
|
23 |
+
from cosmos_predict1.autoregressive.configs.base.model_parallel import create_model_parallel_config
|
24 |
+
from cosmos_predict1.utils import log
|
25 |
+
from cosmos_predict1.utils.lazy_config import LazyDict
|
26 |
+
|
27 |
+
cs = ConfigStore.instance()
|
28 |
+
|
29 |
+
|
30 |
+
"""
|
31 |
+
Finetune 4B model with TP=1, pytorch backend, low resolution tealrobot data, frames 33, chunk 33.
|
32 |
+
Usage:
|
33 |
+
torchrun --nproc_per_node=1 -m cosmos_predict1.autoregressive.train --config=cosmos_predict1/autoregressive/configs/config.py -- experiment=base_4b_example_tealrobotsmall_tp1
|
34 |
+
"""
|
35 |
+
base_4b_example_tealrobotsmall_tp1: LazyDict = LazyDict(
|
36 |
+
dict(
|
37 |
+
defaults=[
|
38 |
+
{"override /data_train": "tealrobot_video_small"},
|
39 |
+
{
|
40 |
+
"override /callbacks": [
|
41 |
+
"basic",
|
42 |
+
"video_teacher_forcing",
|
43 |
+
]
|
44 |
+
},
|
45 |
+
{"override /checkpoint": "local"},
|
46 |
+
{"override /optimizer": "fused_adamw"},
|
47 |
+
{"override /scheduler": "warmup_cosine_lr"},
|
48 |
+
"_self_",
|
49 |
+
],
|
50 |
+
job=dict(
|
51 |
+
project="posttraining",
|
52 |
+
group="autoregressive_base",
|
53 |
+
name="base_4b_example_tealrobotsmall_tp1",
|
54 |
+
),
|
55 |
+
model=create_video2world_model(
|
56 |
+
model_size="4b",
|
57 |
+
model_family="cosmos",
|
58 |
+
backend="pytorch",
|
59 |
+
tensor_model_parallel_size=1,
|
60 |
+
batch_size=1,
|
61 |
+
pixel_chunk_duration=33,
|
62 |
+
num_video_frames=33,
|
63 |
+
video_height=384,
|
64 |
+
video_width=640,
|
65 |
+
tokenizer_ckpt_path="checkpoints/Cosmos-Tokenize1-DV8x16x16-720p/ema.jit",
|
66 |
+
add_special_tokens=False,
|
67 |
+
),
|
68 |
+
trainer=dict(
|
69 |
+
max_iter=50000,
|
70 |
+
grad_accum_iter=1,
|
71 |
+
grad_scaler_args=dict(enabled=False),
|
72 |
+
run_validation=False, # No need for validation as epoch <= 1
|
73 |
+
distributed_parallelism="ddp",
|
74 |
+
callbacks=dict(
|
75 |
+
vid_sampling_tf=dict(
|
76 |
+
every_n=500,
|
77 |
+
),
|
78 |
+
),
|
79 |
+
),
|
80 |
+
checkpoint=dict(
|
81 |
+
load_path="checkpoints/Cosmos-Predict1-4B/model.pt",
|
82 |
+
load_training_state=False,
|
83 |
+
strict_resume=True,
|
84 |
+
save_iter=1000,
|
85 |
+
),
|
86 |
+
model_parallel=create_model_parallel_config(),
|
87 |
+
),
|
88 |
+
)
|
89 |
+
|
90 |
+
|
91 |
+
"""
|
92 |
+
Finetune 4B model with TP=4, pytorch backend, high resolution tealrobot data, frame 33, chunk 33.
|
93 |
+
Usage:
|
94 |
+
torchrun --nproc_per_node=4 -m cosmos_predict1.autoregressive.train --config=cosmos_predict1/autoregressive/configs/config.py -- experiment=base_4b_example_tealrobot_tp4
|
95 |
+
"""
|
96 |
+
base_4b_example_tealrobot_tp4: LazyDict = LazyDict(
|
97 |
+
dict(
|
98 |
+
defaults=[
|
99 |
+
{"override /data_train": "tealrobot_video"},
|
100 |
+
{
|
101 |
+
"override /callbacks": [
|
102 |
+
"basic",
|
103 |
+
"video_teacher_forcing",
|
104 |
+
]
|
105 |
+
},
|
106 |
+
{"override /checkpoint": "local"},
|
107 |
+
{"override /optimizer": "fused_adamw"},
|
108 |
+
{"override /scheduler": "warmup_cosine_lr"},
|
109 |
+
"_self_",
|
110 |
+
],
|
111 |
+
job=dict(
|
112 |
+
project="posttraining",
|
113 |
+
group="autoregressive_base",
|
114 |
+
name="base_4b_example_tealrobot_tp4",
|
115 |
+
),
|
116 |
+
model=create_video2world_model(
|
117 |
+
model_size="4b",
|
118 |
+
model_family="cosmos",
|
119 |
+
backend="pytorch",
|
120 |
+
tensor_model_parallel_size=4,
|
121 |
+
batch_size=1,
|
122 |
+
pixel_chunk_duration=33,
|
123 |
+
num_video_frames=33,
|
124 |
+
video_height=640,
|
125 |
+
video_width=848,
|
126 |
+
tokenizer_ckpt_path="checkpoints/Cosmos-Tokenize1-DV8x16x16-720p/ema.jit",
|
127 |
+
add_special_tokens=False,
|
128 |
+
),
|
129 |
+
trainer=dict(
|
130 |
+
max_iter=50000,
|
131 |
+
grad_accum_iter=1,
|
132 |
+
grad_scaler_args=dict(enabled=False),
|
133 |
+
run_validation=False, # No need for validation as epoch <= 1
|
134 |
+
distributed_parallelism="ddp",
|
135 |
+
callbacks=dict(
|
136 |
+
vid_sampling_tf=dict(
|
137 |
+
every_n=500,
|
138 |
+
),
|
139 |
+
),
|
140 |
+
),
|
141 |
+
checkpoint=dict(
|
142 |
+
load_path="checkpoints/Cosmos-Predict1-4B/model.pt",
|
143 |
+
load_training_state=False,
|
144 |
+
strict_resume=False,
|
145 |
+
save_iter=1000,
|
146 |
+
),
|
147 |
+
model_parallel=create_model_parallel_config(),
|
148 |
+
),
|
149 |
+
)
|
150 |
+
|
151 |
+
|
152 |
+
def register_experiments(cs):
|
153 |
+
# Register the experiments
|
154 |
+
for _item in [
|
155 |
+
base_4b_example_tealrobotsmall_tp1,
|
156 |
+
base_4b_example_tealrobot_tp4,
|
157 |
+
]:
|
158 |
+
cs.store(
|
159 |
+
group="experiment",
|
160 |
+
package="_global_",
|
161 |
+
name=_item["job"]["name"],
|
162 |
+
node=_item,
|
163 |
+
)
|
cosmos_predict1/autoregressive/configs/inference/inference_config.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from typing import Any, List, Optional, Union
|
17 |
+
|
18 |
+
import attrs
|
19 |
+
|
20 |
+
from cosmos_predict1.autoregressive.configs.base.model import ModelConfig, TokenizerConfig
|
21 |
+
|
22 |
+
|
23 |
+
@attrs.define(slots=False)
|
24 |
+
class DataShapeConfig:
|
25 |
+
latent_shape: list = []
|
26 |
+
num_video_frames: Union[None, int] = None
|
27 |
+
height: Union[None, int] = None
|
28 |
+
width: Union[None, int] = None
|
29 |
+
|
30 |
+
|
31 |
+
@attrs.define(slots=False)
|
32 |
+
class SamplingConfig:
|
33 |
+
"""
|
34 |
+
Sampling config
|
35 |
+
Args:
|
36 |
+
temperature (float): Temperature value for controlling randomness in sampling. Defaults to 0.6.
|
37 |
+
top_p (float): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
|
38 |
+
logprobs (bool): Flag indicating whether to compute token log probabilities. Defaults to False.
|
39 |
+
echo (bool): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
|
40 |
+
|
41 |
+
"""
|
42 |
+
|
43 |
+
temperature: float = 0.6
|
44 |
+
top_k: int = None
|
45 |
+
top_p: float = 0.9
|
46 |
+
compile_prefill: bool = False
|
47 |
+
compile_sampling: bool = True
|
48 |
+
logprobs: bool = False
|
49 |
+
echo: bool = False
|
50 |
+
|
51 |
+
|
52 |
+
@attrs.define(slots=False)
|
53 |
+
class DiffusionDecoderSamplingConfig:
|
54 |
+
"""
|
55 |
+
Diffusion decoder sampling config
|
56 |
+
Args:
|
57 |
+
guidance (float): Guidance scale for the diffusion process. Controls how much the model follows the conditioning. Defaults to 0.8.
|
58 |
+
sigma_min (float): Minimum noise level for the diffusion process. Defaults to 0.02.
|
59 |
+
sigma (float): Initial noise level for the diffusion process. Defaults to 8.
|
60 |
+
num_steps (int): Number of denoising steps to perform. Defaults to 35.
|
61 |
+
overlap (int): Number of overlapping frames between video chunks during processing. Defaults to 2.
|
62 |
+
continuous_tokenizer_channel (int): Number of channels in the continuous tokenizer of diffusion decoder. Defaults to 16.
|
63 |
+
continuous_tokenizer_spatial_compression_ratio (int): Spatial compression ratio for the continuous tokenizer of diffusion decoder. Defaults to 8.
|
64 |
+
dd_train_num_video_frames (int): Number of video frames used during training for diffusion decoder. Defaults to 57.
|
65 |
+
"""
|
66 |
+
|
67 |
+
guidance: float = 1.8
|
68 |
+
sigma_min: float = 0.02
|
69 |
+
sigma: float = 8
|
70 |
+
num_steps: int = 15
|
71 |
+
overlap: int = 2
|
72 |
+
continuous_tokenizer_channel = 16
|
73 |
+
continuous_tokenizer_spatial_compression_ratio = 8
|
74 |
+
dd_train_num_video_frames: int = 57
|
75 |
+
max_iter: int = 99
|
76 |
+
fps: int = 24
|
77 |
+
|
78 |
+
|
79 |
+
@attrs.define(slots=False)
|
80 |
+
class InferenceConfig:
|
81 |
+
"""
|
82 |
+
Inference config
|
83 |
+
Args:
|
84 |
+
model_config (ModelConfig): Model config
|
85 |
+
tokenizer_config (TokenizerConfig): Tokenizer config
|
86 |
+
ckpt_path (str): Path to the checkpoint
|
87 |
+
latent_shape (list): Shape of the latent
|
88 |
+
"""
|
89 |
+
|
90 |
+
model_config: ModelConfig = None
|
91 |
+
tokenizer_config: TokenizerConfig = None
|
92 |
+
ckpt_path: str = ""
|
93 |
+
data_shape_config: DataShapeConfig = None
|
94 |
+
|
95 |
+
defaults: List[Any] = attrs.field(
|
96 |
+
factory=lambda: [
|
97 |
+
"_self_",
|
98 |
+
{"data_val": None},
|
99 |
+
{"data_shape_config": "video_shape_as_model_config"},
|
100 |
+
{"eval_job": None},
|
101 |
+
]
|
102 |
+
)
|
cosmos_predict1/autoregressive/configs/registry.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import torch
|
17 |
+
from hydra.core.config_store import ConfigStore
|
18 |
+
|
19 |
+
from cosmos_predict1.autoregressive.configs.base.callbacks import BASIC_CALLBACKS, VIDEO_TEACHER_FORCING_CALLBACK
|
20 |
+
from cosmos_predict1.autoregressive.configs.base.dataloader import get_tealrobot_video
|
21 |
+
from cosmos_predict1.autoregressive.configs.base.optim import LambdaLinearLR
|
22 |
+
from cosmos_predict1.autoregressive.configs.experiment.video2video.basic import register_experiments
|
23 |
+
from cosmos_predict1.utils import config, log
|
24 |
+
from cosmos_predict1.utils.lazy_config import LazyCall as L
|
25 |
+
from cosmos_predict1.utils.scheduler import WarmupCosineLR
|
26 |
+
|
27 |
+
|
28 |
+
def register_checkpoint(cs):
|
29 |
+
checkpoint_local = config.CheckpointConfig(
|
30 |
+
save_iter=5000,
|
31 |
+
broadcast_via_filesystem=True,
|
32 |
+
)
|
33 |
+
cs.store(group="checkpoint", package="checkpoint", name="local", node=checkpoint_local)
|
34 |
+
|
35 |
+
|
36 |
+
def register_callbacks(cs):
|
37 |
+
cs.store(group="callbacks", package="trainer.callbacks", name="basic", node=BASIC_CALLBACKS)
|
38 |
+
cs.store(
|
39 |
+
group="callbacks",
|
40 |
+
package="trainer.callbacks",
|
41 |
+
name="video_teacher_forcing",
|
42 |
+
node=VIDEO_TEACHER_FORCING_CALLBACK,
|
43 |
+
)
|
44 |
+
|
45 |
+
|
46 |
+
def register_scheduler(cs):
|
47 |
+
cs.store(
|
48 |
+
group="scheduler",
|
49 |
+
package="scheduler",
|
50 |
+
name="warmup_cosine_lr",
|
51 |
+
node=L(WarmupCosineLR)(optimizer=None, warmup_iters=5000, lr_decay_iters="${trainer.max_iter}", min_lr=1e-8),
|
52 |
+
)
|
53 |
+
cs.store(group="scheduler", package="scheduler", name="lambdalinear", node=LambdaLinearLR)
|
54 |
+
|
55 |
+
|
56 |
+
def register_optimizer(cs):
|
57 |
+
cs.store(
|
58 |
+
group="optimizer",
|
59 |
+
package="optimizer",
|
60 |
+
name="fused_adamw",
|
61 |
+
node=L(torch.optim.AdamW)(params=None, lr=1e-3, weight_decay=0.05, fused=True),
|
62 |
+
)
|
63 |
+
cs.store(
|
64 |
+
group="optimizer",
|
65 |
+
package="optimizer",
|
66 |
+
name="sgd",
|
67 |
+
node=L(torch.optim.SGD)(params=None, lr=5e-6, momentum=0.9),
|
68 |
+
)
|
69 |
+
|
70 |
+
|
71 |
+
def register_training_data(cs):
|
72 |
+
cs.store(
|
73 |
+
group="data_train",
|
74 |
+
package="dataloader_train",
|
75 |
+
name="tealrobot_video_small",
|
76 |
+
node=get_tealrobot_video(num_frames=33, video_size=[384, 640]),
|
77 |
+
)
|
78 |
+
cs.store(group="data_train", package="dataloader_train", name="tealrobot_video", node=get_tealrobot_video())
|
79 |
+
|
80 |
+
|
81 |
+
def register_configs():
|
82 |
+
log.info("Registering configs for autoregressive_base")
|
83 |
+
cs = ConfigStore.instance()
|
84 |
+
register_callbacks(cs)
|
85 |
+
register_checkpoint(cs)
|
86 |
+
register_optimizer(cs)
|
87 |
+
register_scheduler(cs)
|
88 |
+
register_training_data(cs)
|
89 |
+
register_experiments(cs)
|