Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
bcc039b
0
Parent(s):
Initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .github/workflows/black.yml +12 -0
- .github/workflows/isort.yml +10 -0
- .gitignore +168 -0
- .prettierrc +8 -0
- CODE_OF_CONDUCT.md +80 -0
- CONTRIBUTING.md +36 -0
- LICENSE +28 -0
- README.md +117 -0
- apps/__init__.py +0 -0
- apps/main/__init__.py +0 -0
- apps/main/configs/eval.yaml +35 -0
- apps/main/configs/llama_1B.yaml +87 -0
- apps/main/configs/llama_7B.yaml +95 -0
- apps/main/eval.py +354 -0
- apps/main/generate.py +463 -0
- apps/main/lingua_train.py +654 -0
- blt-figure.jpg +0 -0
- blt-figure.pdf +0 -0
- bytelatent/.DS_Store +0 -0
- bytelatent/__init__.py +3 -0
- bytelatent/args.py +199 -0
- bytelatent/base_transformer.py +585 -0
- bytelatent/checkpoint.py +311 -0
- bytelatent/configs/debug.yaml +110 -0
- bytelatent/constants.py +5 -0
- bytelatent/data/__init__.py +1 -0
- bytelatent/data/data_types.py +115 -0
- bytelatent/data/iterators/__init__.py +1 -0
- bytelatent/data/iterators/abstract_iterator.py +23 -0
- bytelatent/data/iterators/arrow_iterator.py +216 -0
- bytelatent/data/iterators/looping_iterator.py +36 -0
- bytelatent/data/iterators/multiprocess_iterator.py +243 -0
- bytelatent/data/iterators/packing_iterator.py +226 -0
- bytelatent/data/iterators/preprocess_iterator.py +111 -0
- bytelatent/data/iterators/sampling_iterator.py +66 -0
- bytelatent/data/iterators/sequence_iterator.py +122 -0
- bytelatent/data/iterators/test_arrow_iterator.py +89 -0
- bytelatent/data/iterators/test_iters.py +162 -0
- bytelatent/data/ngram_processor.py +146 -0
- bytelatent/data/patcher.py +609 -0
- bytelatent/distributed.py +478 -0
- bytelatent/entropy_model.py +36 -0
- bytelatent/float8.py +152 -0
- bytelatent/logger.py +129 -0
- bytelatent/metrics.py +232 -0
- bytelatent/model/__init__.py +1 -0
- bytelatent/model/blt.py +1064 -0
- bytelatent/model/local_models.py +356 -0
- bytelatent/model/transformer.py +199 -0
- bytelatent/model/utils.py +116 -0
.github/workflows/black.yml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Lint with Black
|
2 |
+
|
3 |
+
on: [push, pull_request]
|
4 |
+
|
5 |
+
jobs:
|
6 |
+
lint:
|
7 |
+
runs-on: ubuntu-latest
|
8 |
+
steps:
|
9 |
+
- uses: actions/checkout@v4
|
10 |
+
- uses: psf/black@stable
|
11 |
+
with:
|
12 |
+
version: "24.8.0"
|
.github/workflows/isort.yml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Lint with isort
|
2 |
+
|
3 |
+
on: [push, pull_request]
|
4 |
+
|
5 |
+
jobs:
|
6 |
+
lint:
|
7 |
+
runs-on: ubuntu-latest
|
8 |
+
steps:
|
9 |
+
- uses: actions/checkout@v4
|
10 |
+
- uses: isort/isort-action@master
|
.gitignore
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
110 |
+
.pdm.toml
|
111 |
+
.pdm-python
|
112 |
+
.pdm-build/
|
113 |
+
|
114 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
115 |
+
__pypackages__/
|
116 |
+
|
117 |
+
# Celery stuff
|
118 |
+
celerybeat-schedule
|
119 |
+
celerybeat.pid
|
120 |
+
|
121 |
+
# SageMath parsed files
|
122 |
+
*.sage.py
|
123 |
+
|
124 |
+
# Environments
|
125 |
+
.env
|
126 |
+
.venv
|
127 |
+
env/
|
128 |
+
venv/
|
129 |
+
ENV/
|
130 |
+
env.bak/
|
131 |
+
venv.bak/
|
132 |
+
|
133 |
+
# Spyder project settings
|
134 |
+
.spyderproject
|
135 |
+
.spyproject
|
136 |
+
|
137 |
+
# Rope project settings
|
138 |
+
.ropeproject
|
139 |
+
|
140 |
+
# mkdocs documentation
|
141 |
+
/site
|
142 |
+
|
143 |
+
# mypy
|
144 |
+
.mypy_cache/
|
145 |
+
.dmypy.json
|
146 |
+
dmypy.json
|
147 |
+
|
148 |
+
# Pyre type checker
|
149 |
+
.pyre/
|
150 |
+
|
151 |
+
# pytype static type analyzer
|
152 |
+
.pytype/
|
153 |
+
|
154 |
+
# Cython debug symbols
|
155 |
+
cython_debug/
|
156 |
+
|
157 |
+
# PyCharm
|
158 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
159 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
160 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
161 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
162 |
+
#.idea/
|
163 |
+
*.out
|
164 |
+
|
165 |
+
figures/
|
166 |
+
.vscode/
|
167 |
+
.DS_Store
|
168 |
+
|
.prettierrc
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"overrides": [
|
3 |
+
{
|
4 |
+
"files": "*.yaml",
|
5 |
+
"options": { "tabWidth": 2 }
|
6 |
+
}
|
7 |
+
]
|
8 |
+
}
|
CODE_OF_CONDUCT.md
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Code of Conduct
|
2 |
+
|
3 |
+
## Our Pledge
|
4 |
+
|
5 |
+
In the interest of fostering an open and welcoming environment, we as
|
6 |
+
contributors and maintainers pledge to make participation in our project and
|
7 |
+
our community a harassment-free experience for everyone, regardless of age, body
|
8 |
+
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
9 |
+
level of experience, education, socio-economic status, nationality, personal
|
10 |
+
appearance, race, religion, or sexual identity and orientation.
|
11 |
+
|
12 |
+
## Our Standards
|
13 |
+
|
14 |
+
Examples of behavior that contributes to creating a positive environment
|
15 |
+
include:
|
16 |
+
|
17 |
+
* Using welcoming and inclusive language
|
18 |
+
* Being respectful of differing viewpoints and experiences
|
19 |
+
* Gracefully accepting constructive criticism
|
20 |
+
* Focusing on what is best for the community
|
21 |
+
* Showing empathy towards other community members
|
22 |
+
|
23 |
+
Examples of unacceptable behavior by participants include:
|
24 |
+
|
25 |
+
* The use of sexualized language or imagery and unwelcome sexual attention or
|
26 |
+
advances
|
27 |
+
* Trolling, insulting/derogatory comments, and personal or political attacks
|
28 |
+
* Public or private harassment
|
29 |
+
* Publishing others' private information, such as a physical or electronic
|
30 |
+
address, without explicit permission
|
31 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
32 |
+
professional setting
|
33 |
+
|
34 |
+
## Our Responsibilities
|
35 |
+
|
36 |
+
Project maintainers are responsible for clarifying the standards of acceptable
|
37 |
+
behavior and are expected to take appropriate and fair corrective action in
|
38 |
+
response to any instances of unacceptable behavior.
|
39 |
+
|
40 |
+
Project maintainers have the right and responsibility to remove, edit, or
|
41 |
+
reject comments, commits, code, wiki edits, issues, and other contributions
|
42 |
+
that are not aligned to this Code of Conduct, or to ban temporarily or
|
43 |
+
permanently any contributor for other behaviors that they deem inappropriate,
|
44 |
+
threatening, offensive, or harmful.
|
45 |
+
|
46 |
+
## Scope
|
47 |
+
|
48 |
+
This Code of Conduct applies within all project spaces, and it also applies when
|
49 |
+
an individual is representing the project or its community in public spaces.
|
50 |
+
Examples of representing a project or community include using an official
|
51 |
+
project e-mail address, posting via an official social media account, or acting
|
52 |
+
as an appointed representative at an online or offline event. Representation of
|
53 |
+
a project may be further defined and clarified by project maintainers.
|
54 |
+
|
55 |
+
This Code of Conduct also applies outside the project spaces when there is a
|
56 |
+
reasonable belief that an individual's behavior may have a negative impact on
|
57 |
+
the project or its community.
|
58 |
+
|
59 |
+
## Enforcement
|
60 |
+
|
61 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
62 |
+
reported by contacting the project team at <opensource-conduct@meta.com>. All
|
63 |
+
complaints will be reviewed and investigated and will result in a response that
|
64 |
+
is deemed necessary and appropriate to the circumstances. The project team is
|
65 |
+
obligated to maintain confidentiality with regard to the reporter of an incident.
|
66 |
+
Further details of specific enforcement policies may be posted separately.
|
67 |
+
|
68 |
+
Project maintainers who do not follow or enforce the Code of Conduct in good
|
69 |
+
faith may face temporary or permanent repercussions as determined by other
|
70 |
+
members of the project's leadership.
|
71 |
+
|
72 |
+
## Attribution
|
73 |
+
|
74 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
75 |
+
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
76 |
+
|
77 |
+
[homepage]: https://www.contributor-covenant.org
|
78 |
+
|
79 |
+
For answers to common questions about this code of conduct, see
|
80 |
+
https://www.contributor-covenant.org/faq
|
CONTRIBUTING.md
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Contributing to
|
2 |
+
|
3 |
+
We want to make contributing to this project as easy and transparent as
|
4 |
+
possible.
|
5 |
+
|
6 |
+
## Pull Requests
|
7 |
+
|
8 |
+
We actively welcome your pull requests.
|
9 |
+
|
10 |
+
1. Fork the repo and create your branch from `main`.
|
11 |
+
2. If you've added code that should be tested, add tests.
|
12 |
+
3. If you've changed APIs, update the documentation.
|
13 |
+
4. Ensure the test suite passes.
|
14 |
+
5. Make sure your code lints.
|
15 |
+
6. If you haven't already, complete the Contributor License Agreement ("CLA").
|
16 |
+
|
17 |
+
## Contributor License Agreement ("CLA")
|
18 |
+
|
19 |
+
In order to accept your pull request, we need you to submit a CLA. You only need
|
20 |
+
to do this once to work on any of Meta's open source projects.
|
21 |
+
|
22 |
+
Complete your CLA here: <https://code.facebook.com/cla>
|
23 |
+
|
24 |
+
## Issues
|
25 |
+
|
26 |
+
We use GitHub issues to track public bugs. Please ensure your description is
|
27 |
+
clear and has sufficient instructions to be able to reproduce the issue.
|
28 |
+
|
29 |
+
Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe
|
30 |
+
disclosure of security bugs. In those cases, please go through the process
|
31 |
+
outlined on that page and do not file a public issue.
|
32 |
+
|
33 |
+
## License
|
34 |
+
|
35 |
+
By contributing to BLT, you agree that your contributions will be licensed
|
36 |
+
under the LICENSE file in the root directory of this source tree.
|
LICENSE
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
BSD 3-Clause License
|
2 |
+
|
3 |
+
Copyright 2024 Meta
|
4 |
+
|
5 |
+
Redistribution and use in source and binary forms, with or without modification,
|
6 |
+
are permitted provided that the following conditions are met:
|
7 |
+
|
8 |
+
1. Redistributions of source code must retain the above copyright notice,this list
|
9 |
+
of conditions and the following disclaimer.
|
10 |
+
|
11 |
+
2. Redistributions in binary form must reproduce the above copyright notice, this
|
12 |
+
list of conditions and the following disclaimer in the documentation
|
13 |
+
and/or other materials provided with the distribution.
|
14 |
+
|
15 |
+
3. Neither the name of the copyright holder nor the names of its contributors may
|
16 |
+
be used to endorse or promote products derived from this software without specific
|
17 |
+
prior written permission.
|
18 |
+
|
19 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY
|
20 |
+
EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
|
21 |
+
OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT
|
22 |
+
SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
|
23 |
+
INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
|
24 |
+
TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
|
25 |
+
BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
26 |
+
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
|
27 |
+
ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
|
28 |
+
DAMAGE.
|
README.md
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte Latent Transformer
|
2 |
+
|
3 |
+
This repository contains code for our paper: "Byte Latent Transformer: Patches Scale Better Than Tokens"
|
4 |
+
|
5 |
+
- [Paper Link](https://dl.fbaipublicfiles.com/blt/BLT__Patches_Scale_Better_Than_Tokens.pdf)
|
6 |
+
|
7 |
+
## Abstract
|
8 |
+
|
9 |
+
We introduce the Byte Latent Transformer architecture (BLTs), a new byte-level LLM architecture that
|
10 |
+
for the first time, matches tokenization-based LLM performance at scale, with significant improvements
|
11 |
+
in inference efficiency and robustness. BLT encodes bytes into dynamically sized patches, which serve
|
12 |
+
as the primary units of computation. Patches are segmented dynamically based on the entropy of the
|
13 |
+
next byte, allocating more compute and model capacity where there is more data complexity. The BLT
|
14 |
+
architecture includes new attention mechanisms to maximize the information flow between byte and
|
15 |
+
patch hidden representations and a new type of byte-sequence memory. We present the first scaling
|
16 |
+
study of byte-level models up to 8B parameters and 8T training bytes, showing for the first time
|
17 |
+
that we can train a model end-to-end at scale from bytes with no tokenization or other preprocessing.
|
18 |
+
Scaling trends reveal training and inference efficiency benefits from dynamically selecting very long
|
19 |
+
patches on average, along with qualitative improvements with reasoning and long tail generalization
|
20 |
+
from modeling byte-sequences.
|
21 |
+
|
22 |
+

|
23 |
+
|
24 |
+
## Development Status
|
25 |
+
|
26 |
+
We are actively updating the blt code to make it easier to reproduce our results.
|
27 |
+
Please file an issue and/or be patient while we make more of our code public!
|
28 |
+
|
29 |
+
## Quick start
|
30 |
+
|
31 |
+
The following commands launch a SLURM job that creates an environment for Meta Lingua.
|
32 |
+
The env creation should take around 5 minutes without counting downloads.
|
33 |
+
|
34 |
+
```bash
|
35 |
+
git clone https://github.com/facebookresearch/blt
|
36 |
+
cd blt
|
37 |
+
|
38 |
+
bash setup/create_env.sh
|
39 |
+
# or if you have access to a SLURM cluster
|
40 |
+
sbatch setup/create_env.sh
|
41 |
+
```
|
42 |
+
|
43 |
+
Once that is done your can activate the environment
|
44 |
+
|
45 |
+
```bash
|
46 |
+
conda activate blt_<date>
|
47 |
+
```
|
48 |
+
|
49 |
+
use the provided script to download and prepare data from huggingface (among `fineweb_edu`, `fineweb_edu_10bt`, or `dclm_baseline_1.0`).
|
50 |
+
This command will download the `fineweb_edu` and prepare it for training in the `./data` directory, specifying the amount of memory `terashuf` (the tool used to shuffle samples) will be allocated. By default, the number of chunks (`nchunks`) is 32. If you are running on fewer than 32 GPUs, it is recommended to set `nchunks` to 1 or to match `nchunks` with the number of GPUs (`nchunks` = NGPUs). See [here](https://github.com/facebookresearch/lingua/issues/55#issuecomment-2483643076) for more details.
|
51 |
+
|
52 |
+
```bash
|
53 |
+
python setup/download_prepare_hf_data.py fineweb_edu <MEMORY> --data_dir ./data --seed 42 --nchunks <NCHUNKS>
|
54 |
+
```
|
55 |
+
|
56 |
+
to download tokenizer (here llama3), use the folowing script:
|
57 |
+
|
58 |
+
```bash
|
59 |
+
python setup/download_tokenizer.py llama3 <SAVE_PATH> --api_key <HUGGINGFACE_TOKEN>
|
60 |
+
```
|
61 |
+
|
62 |
+
Now launch a debug job to check if everything works. **The provided configurations are templates, you need to adapt them for them to work (change `dump_dir`, `data.root_dir`, `data.tokenizer.path`, etc ...)**
|
63 |
+
|
64 |
+
```bash
|
65 |
+
# stool stands for SLURM tool !
|
66 |
+
python -m bytelatent.stool script=bytelatent.train config=apps/bytelatent/configs/debug.yaml nodes=1 partition=<partition>
|
67 |
+
# if you want to launch locally you can use torchrun
|
68 |
+
torchrun --nproc-per-node 8 -m bytelatent.train config=apps/bytelatent/configs/debug.yaml
|
69 |
+
# or you can also launch on 1 GPU
|
70 |
+
python -m bytelatent.train config=apps/bytelatent/configs/debug.yaml
|
71 |
+
```
|
72 |
+
|
73 |
+
When using `stool`, if a job crashes, it can be relaunched using sbatch:
|
74 |
+
|
75 |
+
```bash
|
76 |
+
sbatch path/to/dump_dir/submit.slurm
|
77 |
+
```
|
78 |
+
|
79 |
+
## Linting
|
80 |
+
|
81 |
+
To lint, run the following command
|
82 |
+
|
83 |
+
```
|
84 |
+
bash dev/lint.sh
|
85 |
+
```
|
86 |
+
|
87 |
+
## Citation
|
88 |
+
|
89 |
+
The BLT is partially based on Meta Lingua, so consider citing it in addition to our BLT paper if you re-use our work.
|
90 |
+
|
91 |
+
BLT Paper Citation (will be updated to arXiv soon)
|
92 |
+
|
93 |
+
```
|
94 |
+
@article{meta_blt,
|
95 |
+
author = {Artidoro Pagnoni, Ram Pasunuru, Pedro Rodriguez, John Nguyen, Benjamin Muller, Margaret Li, Chunting Zhou, Lili Yu, Jason Weston, Luke Zettlemoyer, Gargi Ghosh, Mike Lewis, Ari Holtzman†, Srinivasan Iyer},
|
96 |
+
title = {Byte Latent Transformer: Patches Scale Better Than Tokens},
|
97 |
+
url = {https://github.com/facebookresearch/blt},
|
98 |
+
year = {2024}
|
99 |
+
}
|
100 |
+
```
|
101 |
+
|
102 |
+
Lingua Code
|
103 |
+
|
104 |
+
```
|
105 |
+
@misc{meta_lingua,
|
106 |
+
author = {Mathurin Videau, Badr Youbi Idrissi, Daniel Haziza, Luca Wehrstedt, Jade Copet, Olivier Teytaud, David Lopez-Paz},
|
107 |
+
title = {{Meta Lingua}: A minimal {PyTorch LLM} training library},
|
108 |
+
url = {https://github.com/facebookresearch/lingua},
|
109 |
+
year = {2024}
|
110 |
+
}
|
111 |
+
```
|
112 |
+
|
113 |
+
## License
|
114 |
+
|
115 |
+
The BLT code is partially based on Meta Lingia.
|
116 |
+
|
117 |
+
Meta Lingua is licensed under BSD-3-Clause license. Refer to the LICENSE file in the top level directory.
|
apps/__init__.py
ADDED
File without changes
|
apps/main/__init__.py
ADDED
File without changes
|
apps/main/configs/eval.yaml
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: "debug_evals"
|
2 |
+
# ckpt_dir: !!CHANGETHIS!!
|
3 |
+
# dump_dir: !!CHANGETHIS!!
|
4 |
+
generator:
|
5 |
+
max_tokens: 8192
|
6 |
+
dtype: bf16
|
7 |
+
temperature: 1.0
|
8 |
+
top_p: 0.95
|
9 |
+
harness:
|
10 |
+
tasks:
|
11 |
+
- hellaswag
|
12 |
+
- task: boolq
|
13 |
+
dataset_kwargs:
|
14 |
+
trust_remote_code: true
|
15 |
+
- task: nq_open
|
16 |
+
num_fewshot: 5
|
17 |
+
- piqa
|
18 |
+
- task: social_iqa
|
19 |
+
dataset_kwargs:
|
20 |
+
trust_remote_code: true
|
21 |
+
- triviaqa
|
22 |
+
- winogrande
|
23 |
+
- openbookqa
|
24 |
+
- arc_easy
|
25 |
+
- arc_challenge
|
26 |
+
- race
|
27 |
+
- commonsense_qa
|
28 |
+
# - coqa
|
29 |
+
- copa
|
30 |
+
- gsm8k
|
31 |
+
- bbh
|
32 |
+
- mmlu
|
33 |
+
- mmlu_pro
|
34 |
+
validation:
|
35 |
+
max_steps: 1000
|
apps/main/configs/llama_1B.yaml
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# dump_dir: !!!CHANGE_THIS!!!
|
2 |
+
name: large_lm
|
3 |
+
steps: 60_000
|
4 |
+
probe_freq: null
|
5 |
+
seed: 777
|
6 |
+
|
7 |
+
optim:
|
8 |
+
lr: 3e-3
|
9 |
+
weight_decay: 0.033
|
10 |
+
warmup: 5000
|
11 |
+
lr_min_ratio: 0.000001
|
12 |
+
clip: 1.0
|
13 |
+
|
14 |
+
distributed:
|
15 |
+
fsdp_type: full_shard
|
16 |
+
compile: true
|
17 |
+
model_dtype: bf16
|
18 |
+
matmul_allow_tf32: false
|
19 |
+
selective_activation_checkpointing: false
|
20 |
+
tp_size: 1
|
21 |
+
|
22 |
+
model:
|
23 |
+
dim: 2048
|
24 |
+
n_layers: 25
|
25 |
+
n_heads: 16
|
26 |
+
|
27 |
+
data:
|
28 |
+
root_dir: data/shuffled
|
29 |
+
sources:
|
30 |
+
dclm_baseline_1.0: 100.0
|
31 |
+
batch_size: 4
|
32 |
+
prefetch_size: 1024
|
33 |
+
seq_len: 4096
|
34 |
+
n_views: 2
|
35 |
+
load_async: true
|
36 |
+
add_bos: true
|
37 |
+
add_eos: true
|
38 |
+
tokenizer:
|
39 |
+
name: tiktoken
|
40 |
+
path: tokenizers/cl_toplang_128k.tiktoken
|
41 |
+
|
42 |
+
profiling:
|
43 |
+
run: true
|
44 |
+
mem_warmup: 0
|
45 |
+
mem_steps: 4
|
46 |
+
profile_warmup: 100
|
47 |
+
profile_steps: 4
|
48 |
+
|
49 |
+
checkpoint:
|
50 |
+
dump:
|
51 |
+
every: 2500
|
52 |
+
keep: 3
|
53 |
+
eval:
|
54 |
+
every: 5000
|
55 |
+
keep: -1
|
56 |
+
|
57 |
+
logging:
|
58 |
+
freq: 1
|
59 |
+
|
60 |
+
async_eval_gpus: 8
|
61 |
+
eval:
|
62 |
+
harness:
|
63 |
+
tasks:
|
64 |
+
- hellaswag
|
65 |
+
- task: boolq
|
66 |
+
dataset_kwargs:
|
67 |
+
trust_remote_code: true
|
68 |
+
- piqa
|
69 |
+
- task: social_iqa
|
70 |
+
dataset_kwargs:
|
71 |
+
trust_remote_code: true
|
72 |
+
- winogrande
|
73 |
+
- openbookqa
|
74 |
+
- arc_easy
|
75 |
+
- arc_challenge
|
76 |
+
- race
|
77 |
+
- commonsense_qa
|
78 |
+
- copa
|
79 |
+
# - coqa
|
80 |
+
# - task: nq_open
|
81 |
+
# num_fewshot: 5
|
82 |
+
# - triviaqa
|
83 |
+
validation:
|
84 |
+
max_steps: 1000
|
85 |
+
generator:
|
86 |
+
max_tokens: 16384
|
87 |
+
dtype: bf16
|
apps/main/configs/llama_7B.yaml
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#python -m lingua.stool config=apps/main/configs/llama2_7B.yaml nodes=32 account=fair_amaia_cw_codegen qos=lowest
|
2 |
+
# dump_dir: !!!CHANGE_THIS!!!
|
3 |
+
name: "7b_baseline"
|
4 |
+
steps: 100_000
|
5 |
+
grad_acc_steps: 1
|
6 |
+
probe_freq: 100
|
7 |
+
|
8 |
+
seed: 777
|
9 |
+
optim:
|
10 |
+
lr: 1.0e-3
|
11 |
+
weight_decay: 0.1
|
12 |
+
warmup: 2000
|
13 |
+
lr_min_ratio: 0.000001
|
14 |
+
clip: 1.0
|
15 |
+
|
16 |
+
distributed:
|
17 |
+
fsdp_type: full_shard
|
18 |
+
compile: true
|
19 |
+
model_dtype: bf16
|
20 |
+
matmul_allow_tf32: false
|
21 |
+
selective_activation_checkpointing: false
|
22 |
+
tp_size: 1
|
23 |
+
|
24 |
+
model:
|
25 |
+
dim: 4096
|
26 |
+
n_layers: 32
|
27 |
+
n_heads: 32
|
28 |
+
rope_theta: 100_000
|
29 |
+
ffn_dim_multiplier: 1.0
|
30 |
+
multiple_of: 256
|
31 |
+
|
32 |
+
data:
|
33 |
+
root_dir: data/shuffled
|
34 |
+
sources:
|
35 |
+
dclm_baseline_1.0: 1.0
|
36 |
+
batch_size: 2
|
37 |
+
prefetch_size: 1024
|
38 |
+
seq_len: 4096
|
39 |
+
n_views: 2
|
40 |
+
load_async: true
|
41 |
+
tokenizer:
|
42 |
+
name: tiktoken
|
43 |
+
path: tokenizers/cl_toplang_128k.tiktoken
|
44 |
+
|
45 |
+
profiling:
|
46 |
+
run: true
|
47 |
+
mem_warmup: 0
|
48 |
+
mem_steps: 4
|
49 |
+
profile_warmup: 100
|
50 |
+
profile_steps: 4
|
51 |
+
|
52 |
+
checkpoint:
|
53 |
+
dump:
|
54 |
+
every: 10000
|
55 |
+
keep: -1
|
56 |
+
eval:
|
57 |
+
every: 1000
|
58 |
+
keep: 3
|
59 |
+
|
60 |
+
logging:
|
61 |
+
freq: 1
|
62 |
+
|
63 |
+
async_eval_gpus: 8
|
64 |
+
eval:
|
65 |
+
dataset_dir: datasets/eval
|
66 |
+
harness:
|
67 |
+
tasks:
|
68 |
+
- hellaswag
|
69 |
+
- task: boolq
|
70 |
+
dataset_kwargs:
|
71 |
+
trust_remote_code: true
|
72 |
+
- piqa
|
73 |
+
- task: social_iqa
|
74 |
+
dataset_kwargs:
|
75 |
+
trust_remote_code: true
|
76 |
+
- winogrande
|
77 |
+
- openbookqa
|
78 |
+
- arc_easy
|
79 |
+
- arc_challenge
|
80 |
+
- race
|
81 |
+
- commonsense_qa
|
82 |
+
# - coqa
|
83 |
+
- copa
|
84 |
+
- mmlu
|
85 |
+
- mmlu_pro
|
86 |
+
# - task: nq_open
|
87 |
+
# num_fewshot: 5
|
88 |
+
# - triviaqa
|
89 |
+
# - gsm8k
|
90 |
+
# - bbh
|
91 |
+
validation:
|
92 |
+
max_steps: 1000
|
93 |
+
generator:
|
94 |
+
max_tokens: 8192
|
95 |
+
dtype: bf16
|
apps/main/eval.py
ADDED
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
from collections import defaultdict
|
7 |
+
from dataclasses import asdict, dataclass, field
|
8 |
+
from datetime import datetime
|
9 |
+
from pathlib import Path
|
10 |
+
from typing import Any, List, Optional, Tuple, Union
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from lingua.args import dump_config
|
14 |
+
from lingua.data import init_choice_state, setup_sources
|
15 |
+
from lm_eval import simple_evaluate
|
16 |
+
from lm_eval.api.instance import Instance
|
17 |
+
from lm_eval.api.model import LM
|
18 |
+
from omegaconf import OmegaConf
|
19 |
+
|
20 |
+
from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints
|
21 |
+
from bytelatent.distributed import (
|
22 |
+
DistributedArgs,
|
23 |
+
dist_mean_dict,
|
24 |
+
get_global_rank,
|
25 |
+
get_world_size,
|
26 |
+
setup_torch_distributed,
|
27 |
+
)
|
28 |
+
from bytelatent.transformer import LMTransformer, LMTransformerArgs
|
29 |
+
|
30 |
+
from apps.main.generate import (
|
31 |
+
PackedCausalTransformerGenerator,
|
32 |
+
PackedCausalTransformerGeneratorArgs,
|
33 |
+
load_consolidated_model_and_tokenizer,
|
34 |
+
)
|
35 |
+
|
36 |
+
EVAL_FOLDER_NAME = "{:010d}"
|
37 |
+
|
38 |
+
logger = logging.getLogger()
|
39 |
+
|
40 |
+
|
41 |
+
@dataclass
|
42 |
+
class LMHarnessArgs:
|
43 |
+
tasks: Optional[List[Any]] = None
|
44 |
+
num_fewshot: Optional[int] = None
|
45 |
+
device: Optional[str] = None
|
46 |
+
use_cache: Optional[str] = None
|
47 |
+
cache_requests: bool = False
|
48 |
+
rewrite_requests_cache: bool = False
|
49 |
+
delete_requests_cache: bool = False
|
50 |
+
limit: Optional[Union[int, float]] = None
|
51 |
+
bootstrap_iters: int = 100000
|
52 |
+
check_integrity: bool = False
|
53 |
+
write_out: bool = False
|
54 |
+
log_samples: bool = True
|
55 |
+
system_instruction: Optional[str] = None
|
56 |
+
apply_chat_template: Union[bool, str] = False
|
57 |
+
fewshot_as_multiturn: bool = False
|
58 |
+
gen_kwargs: Optional[str] = None
|
59 |
+
verbosity: str = "INFO"
|
60 |
+
predict_only: bool = False
|
61 |
+
random_seed: int = 0
|
62 |
+
numpy_random_seed: int = 1234
|
63 |
+
torch_random_seed: int = 1234
|
64 |
+
fewshot_random_seed: int = 1234
|
65 |
+
|
66 |
+
|
67 |
+
@dataclass
|
68 |
+
class ValidationArgs:
|
69 |
+
max_steps: Optional[int] = (
|
70 |
+
None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu)
|
71 |
+
)
|
72 |
+
use_val_from_train_src: bool = True # Use the validation set from training sources
|
73 |
+
root_dir: str = ""
|
74 |
+
sources: List[str] = field(default_factory=list) # Other sources to eval on
|
75 |
+
|
76 |
+
|
77 |
+
@dataclass
|
78 |
+
class EvalArgs:
|
79 |
+
name: str = "evals"
|
80 |
+
dump_dir: Optional[str] = None
|
81 |
+
metric_log_dir: Optional[str] = None
|
82 |
+
ckpt_dir: str = ""
|
83 |
+
generator: PackedCausalTransformerGeneratorArgs = field(
|
84 |
+
default_factory=PackedCausalTransformerGeneratorArgs
|
85 |
+
)
|
86 |
+
harness: Optional[LMHarnessArgs] = field(default_factory=LMHarnessArgs)
|
87 |
+
validation: Optional[ValidationArgs] = field(default_factory=ValidationArgs)
|
88 |
+
|
89 |
+
wandb: Optional[Any] = None
|
90 |
+
|
91 |
+
global_step: Optional[int] = None # for in-training evaluation
|
92 |
+
|
93 |
+
|
94 |
+
def all_dicts_same(dict_list):
|
95 |
+
if not dict_list: # Check if the list is empty
|
96 |
+
return True
|
97 |
+
|
98 |
+
# Compare each dictionary to the first one
|
99 |
+
first_dict = dict_list[0]
|
100 |
+
return all(d == first_dict for d in dict_list)
|
101 |
+
|
102 |
+
|
103 |
+
class MockAccelerator:
|
104 |
+
def gather(self, tensor):
|
105 |
+
l = [torch.zeros_like(tensor) for _ in range(get_world_size())]
|
106 |
+
torch.distributed.all_gather(l, tensor)
|
107 |
+
return torch.stack(l)
|
108 |
+
|
109 |
+
def wait_for_everyone(self):
|
110 |
+
torch.distributed.barrier()
|
111 |
+
|
112 |
+
|
113 |
+
# Light wrapper around generator for lm-eval harness
|
114 |
+
class EvalHarnessLM(LM):
|
115 |
+
def __init__(self, generator):
|
116 |
+
super().__init__()
|
117 |
+
self.generator = generator
|
118 |
+
self.accelerator = MockAccelerator()
|
119 |
+
self._rank = get_global_rank()
|
120 |
+
self._world_size = get_world_size()
|
121 |
+
self.device = generator.device
|
122 |
+
|
123 |
+
def generate_until(self, requests: List[Instance]) -> List[str]:
|
124 |
+
prompts, gen_args = zip(*[req.args for req in requests])
|
125 |
+
assert all_dicts_same(gen_args), "Doesn't support different gen args for now"
|
126 |
+
gen_args = gen_args[0]
|
127 |
+
temperature = gen_args.get("temperature", 0.0)
|
128 |
+
top_p = gen_args.get("top_p", None)
|
129 |
+
top_k = gen_args.get("top_k", None)
|
130 |
+
until = gen_args.get("until", [])
|
131 |
+
|
132 |
+
self.generator.temperature = temperature
|
133 |
+
self.generator.top_p = top_p
|
134 |
+
self.generator.top_k = top_k
|
135 |
+
self.generator.until = until
|
136 |
+
generations, _, _ = self.generator.generate(prompts)
|
137 |
+
filtered_gen = []
|
138 |
+
for g in generations:
|
139 |
+
for e in until:
|
140 |
+
g = g.replace(e, "")
|
141 |
+
filtered_gen.append(g)
|
142 |
+
return filtered_gen
|
143 |
+
|
144 |
+
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
|
145 |
+
prompts, continuations = zip(*[req.args for req in requests])
|
146 |
+
inputs = [req.args[0] + req.args[1] for req in requests]
|
147 |
+
max_gen_len = self.generator.max_gen_len
|
148 |
+
# We temporarily lower max gen len
|
149 |
+
self.generator.max_gen_len = 1
|
150 |
+
_, lls, greedy = self.generator.generate(inputs)
|
151 |
+
results = []
|
152 |
+
for p, ll, gr in zip(prompts, lls, greedy):
|
153 |
+
p_len = len(
|
154 |
+
self.generator.tokenizer.encode(p, add_bos=False, add_eos=False)
|
155 |
+
)
|
156 |
+
results.append((ll[p_len:].sum().item(), gr[p_len:].all().item()))
|
157 |
+
|
158 |
+
self.generator.max_gen_len = max_gen_len
|
159 |
+
return results
|
160 |
+
|
161 |
+
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
|
162 |
+
prompts = [req.args[0] for req in requests]
|
163 |
+
max_gen_len = self.generator.max_gen_len
|
164 |
+
# We temporarily lower max gen len
|
165 |
+
self.generator.max_gen_len = 1
|
166 |
+
_, lls, _ = self.generator.generate(prompts)
|
167 |
+
results = []
|
168 |
+
for ll in lls:
|
169 |
+
results.append((ll.sum().item(),))
|
170 |
+
self.generator.max_gen_len = max_gen_len
|
171 |
+
|
172 |
+
return results
|
173 |
+
|
174 |
+
|
175 |
+
def eval_on_val(generator, val_args: ValidationArgs, train_cfg):
|
176 |
+
srcs = {}
|
177 |
+
for src in val_args.sources:
|
178 |
+
path = os.path.join(val_args.root_dir, src)
|
179 |
+
srcs[path] = 1.0
|
180 |
+
for src in train_cfg.data.sources:
|
181 |
+
path = os.path.join(train_cfg.data.root_dir, src)
|
182 |
+
srcs[path] = 1.0
|
183 |
+
|
184 |
+
multi_state = init_choice_state(
|
185 |
+
"", srcs, 0, get_global_rank(), get_world_size(), "*.val.jsonl"
|
186 |
+
)
|
187 |
+
path_to_iter = setup_sources(multi_state)
|
188 |
+
|
189 |
+
max_gen_len = generator.max_gen_len
|
190 |
+
# We temporarily lower max gen len
|
191 |
+
generator.max_gen_len = 1
|
192 |
+
|
193 |
+
all_val_metrics = {}
|
194 |
+
for src in path_to_iter:
|
195 |
+
jsonl_iterator = path_to_iter[src]
|
196 |
+
texts = []
|
197 |
+
logger.info(f"Running validation on {src}...")
|
198 |
+
for step, (content, state) in enumerate(jsonl_iterator):
|
199 |
+
if state["current_iter"] > 0 or (
|
200 |
+
val_args.max_steps is not None and step >= val_args.max_steps
|
201 |
+
):
|
202 |
+
break
|
203 |
+
content_key = "text" if ("text" in content) else "content"
|
204 |
+
texts.append(content[content_key])
|
205 |
+
|
206 |
+
_, loglikelihood, _ = generator.generate(texts)
|
207 |
+
|
208 |
+
metrics = defaultdict(list)
|
209 |
+
for i, ll in enumerate(loglikelihood):
|
210 |
+
tmp = ll.sum().item()
|
211 |
+
metrics["nll"].append(tmp)
|
212 |
+
metrics["nll_per_token"].append(tmp / len(ll))
|
213 |
+
metrics["nll_per_char"].append(tmp / len(texts[i]))
|
214 |
+
|
215 |
+
metrics["avg_seqlen"].append(len(ll))
|
216 |
+
|
217 |
+
for m in metrics:
|
218 |
+
metrics[m] = sum(metrics[m]) / len(metrics[m])
|
219 |
+
metrics.update(dist_mean_dict(metrics))
|
220 |
+
logger.info(f"Validation on {src} done. Metrics: {metrics}")
|
221 |
+
|
222 |
+
name = os.path.basename(src)
|
223 |
+
if name in all_val_metrics:
|
224 |
+
logger.warning(
|
225 |
+
f"Duplicate source name {name}, path {src} in validation sources, renaming to {name}_1"
|
226 |
+
)
|
227 |
+
name = f"{name}_1"
|
228 |
+
all_val_metrics[name] = metrics
|
229 |
+
|
230 |
+
generator.max_gen_len = max_gen_len
|
231 |
+
|
232 |
+
return all_val_metrics
|
233 |
+
|
234 |
+
|
235 |
+
def launch_eval(cfg: EvalArgs):
|
236 |
+
if not torch.distributed.is_initialized():
|
237 |
+
setup_torch_distributed(DistributedArgs())
|
238 |
+
if (
|
239 |
+
Path(cfg.ckpt_dir).exists()
|
240 |
+
and (Path(cfg.ckpt_dir) / "params.json").exists()
|
241 |
+
and next(Path(cfg.ckpt_dir).glob("*.pth"), None) is not None
|
242 |
+
):
|
243 |
+
consolidate_path = Path(cfg.ckpt_dir)
|
244 |
+
else:
|
245 |
+
consolidate_path = Path(cfg.ckpt_dir) / CONSOLIDATE_FOLDER
|
246 |
+
if not consolidate_path.exists() and get_global_rank() == 0:
|
247 |
+
consolidate_path = consolidate_checkpoints(cfg.ckpt_dir)
|
248 |
+
|
249 |
+
Path(cfg.dump_dir).mkdir(parents=True, exist_ok=True)
|
250 |
+
dump_config(cfg, Path(cfg.dump_dir) / "config.yaml", log_config=False)
|
251 |
+
|
252 |
+
consolidate_path = str(consolidate_path)
|
253 |
+
torch.distributed.barrier()
|
254 |
+
logger.info("Loading model")
|
255 |
+
model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
|
256 |
+
consolidate_path,
|
257 |
+
model_cls=LMTransformer,
|
258 |
+
model_args_cls=LMTransformerArgs,
|
259 |
+
)
|
260 |
+
logger.info("Model loaded")
|
261 |
+
model.eval()
|
262 |
+
generator = PackedCausalTransformerGenerator(cfg.generator, model, tokenizer)
|
263 |
+
|
264 |
+
wrap = EvalHarnessLM(generator)
|
265 |
+
results = simple_evaluate(wrap, **asdict(cfg.harness))
|
266 |
+
val_results = None
|
267 |
+
if cfg.validation:
|
268 |
+
val_results = eval_on_val(generator, cfg.validation, train_cfg)
|
269 |
+
if get_global_rank() == 0:
|
270 |
+
with open(Path(cfg.dump_dir) / "results.json", "w") as f:
|
271 |
+
f.write(json.dumps(results))
|
272 |
+
logger.info(f"All evaluation results: {results['results']}")
|
273 |
+
if val_results is not None:
|
274 |
+
with open(Path(cfg.dump_dir) / "validation.json", "w") as f:
|
275 |
+
f.write(json.dumps(val_results))
|
276 |
+
logger.info(f"All validation results: {val_results}")
|
277 |
+
if cfg.metric_log_dir and get_global_rank() == 0:
|
278 |
+
metric_log_path = Path(cfg.metric_log_dir) / "metrics.eval.jsonl"
|
279 |
+
|
280 |
+
logger.info(f"Writing metric logs to {metric_log_path}")
|
281 |
+
timestamp = {
|
282 |
+
"created_at": datetime.utcnow().isoformat(),
|
283 |
+
}
|
284 |
+
if cfg.global_step is not None:
|
285 |
+
timestamp["global_step"] = cfg.global_step
|
286 |
+
print(
|
287 |
+
json.dumps(timestamp | results["results"]),
|
288 |
+
file=open(metric_log_path, mode="a"),
|
289 |
+
flush=True,
|
290 |
+
)
|
291 |
+
|
292 |
+
val_log_path = Path(cfg.metric_log_dir) / "metrics.validation.jsonl"
|
293 |
+
if val_results is not None:
|
294 |
+
print(
|
295 |
+
json.dumps(timestamp | val_results),
|
296 |
+
file=open(val_log_path, mode="a"),
|
297 |
+
flush=True,
|
298 |
+
)
|
299 |
+
|
300 |
+
del generator
|
301 |
+
|
302 |
+
|
303 |
+
def main():
|
304 |
+
"""
|
305 |
+
The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments
|
306 |
+
This accepts arguments as a dot list
|
307 |
+
So if the dataclass looks like
|
308 |
+
|
309 |
+
@dataclass
|
310 |
+
class DummyArgs:
|
311 |
+
name: str
|
312 |
+
model: LMTransformerArgsgs
|
313 |
+
|
314 |
+
@dataclass
|
315 |
+
class LMTransformerArgsgs:
|
316 |
+
dim: int
|
317 |
+
|
318 |
+
Then you can pass model.dim=32 to change values in LMTransformerArgsgs
|
319 |
+
or just name=tictac for top level attributes.
|
320 |
+
|
321 |
+
The behavior here is as follows:
|
322 |
+
1. We instantiate EvalArgs with its default values
|
323 |
+
2. We override those default values with the ones in the provided config file
|
324 |
+
3. We override the result with the additional arguments provided through command line
|
325 |
+
|
326 |
+
For example, if the config is the following
|
327 |
+
|
328 |
+
model:
|
329 |
+
dim: 128
|
330 |
+
n_layers: 4
|
331 |
+
|
332 |
+
and you call eval.py with eval.py model.dim=64
|
333 |
+
|
334 |
+
Then the final TrainArgs will have
|
335 |
+
|
336 |
+
model:
|
337 |
+
dim: 64
|
338 |
+
n_layers: 4
|
339 |
+
|
340 |
+
Plus all the default values in EvalArgs dataclass.
|
341 |
+
"""
|
342 |
+
cli_args = OmegaConf.from_cli()
|
343 |
+
file_cfg = OmegaConf.load(cli_args.config)
|
344 |
+
# We remove 'config' attribute from config as the underlying DataClass does not have it
|
345 |
+
del cli_args.config
|
346 |
+
|
347 |
+
default_cfg = OmegaConf.structured(EvalArgs())
|
348 |
+
cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
|
349 |
+
cfg = OmegaConf.to_object(cfg)
|
350 |
+
launch_eval(cfg)
|
351 |
+
|
352 |
+
|
353 |
+
if __name__ == "__main__":
|
354 |
+
main()
|
apps/main/generate.py
ADDED
@@ -0,0 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
import time
|
4 |
+
from dataclasses import dataclass, field
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import List, Optional
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from lingua.args import dataclass_from_dict
|
10 |
+
from lingua.tokenizers.abstract_tokenizer import Tokenizer
|
11 |
+
from lingua.tokenizers.build_tokenizer import build_tokenizer
|
12 |
+
from omegaconf import OmegaConf
|
13 |
+
from torch import nn
|
14 |
+
from torch.nn import functional as F
|
15 |
+
from torch.nn.attention.flex_attention import create_block_mask
|
16 |
+
from tqdm import tqdm
|
17 |
+
|
18 |
+
from bytelatent.base_transformer import (
|
19 |
+
Attention,
|
20 |
+
causal_mask,
|
21 |
+
generate_doc_mask_mod,
|
22 |
+
lengths_to_local_ids,
|
23 |
+
lengths_to_start_ids,
|
24 |
+
)
|
25 |
+
from bytelatent.checkpoint import CONSOLIDATE_NAME
|
26 |
+
from bytelatent.transformer import LMTransformer, LMTransformerArgs
|
27 |
+
|
28 |
+
|
29 |
+
def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
|
30 |
+
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
31 |
+
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
32 |
+
mask = probs_sum - probs_sort > p
|
33 |
+
probs_sort[mask] = 0.0
|
34 |
+
next_token = torch.multinomial(probs_sort, num_samples=1)
|
35 |
+
next_token = torch.gather(probs_idx, -1, next_token)
|
36 |
+
return next_token
|
37 |
+
|
38 |
+
|
39 |
+
def sample_top_k(probs, k):
|
40 |
+
topk_value, _ = torch.topk(probs, k) # batch_sz x topk
|
41 |
+
min_value_top_k = topk_value[:, [-1]]
|
42 |
+
probs[probs < min_value_top_k] = 0.0
|
43 |
+
probs.div_(probs.sum(dim=-1, keepdim=True))
|
44 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
45 |
+
return next_token
|
46 |
+
|
47 |
+
|
48 |
+
def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None):
|
49 |
+
shape = logits.shape
|
50 |
+
logits = logits.flatten(end_dim=-2)
|
51 |
+
if temperature > 0.0:
|
52 |
+
probs = torch.softmax(logits / temperature, dim=-1)
|
53 |
+
|
54 |
+
if top_p is not None:
|
55 |
+
next_token = sample_top_p(probs, top_p)
|
56 |
+
elif top_k is not None:
|
57 |
+
next_token = sample_top_k(probs, top_k)
|
58 |
+
else:
|
59 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
60 |
+
else:
|
61 |
+
next_token = torch.argmax(logits, dim=-1)
|
62 |
+
return next_token.view(shape[:-1])
|
63 |
+
|
64 |
+
|
65 |
+
def pack_prompts(prompts: List[int]):
|
66 |
+
res = []
|
67 |
+
lengths = []
|
68 |
+
for i, p in enumerate(prompts):
|
69 |
+
p = torch.tensor(p, dtype=torch.long)
|
70 |
+
l = p.size(0)
|
71 |
+
res.append(p)
|
72 |
+
lengths.append(l)
|
73 |
+
lengths = torch.tensor(lengths, dtype=torch.long)
|
74 |
+
res = torch.cat(res)
|
75 |
+
return res, lengths
|
76 |
+
|
77 |
+
|
78 |
+
def batch_prompts(prompts, max_elements, lengths=None):
|
79 |
+
batches = []
|
80 |
+
current_batch = []
|
81 |
+
current_count = 0
|
82 |
+
|
83 |
+
for i in range(len(prompts)):
|
84 |
+
prt = prompts[i]
|
85 |
+
prompt_size = len(prt) if lengths is None else lengths[i]
|
86 |
+
if current_count + prompt_size <= max_elements:
|
87 |
+
current_batch.append(prt)
|
88 |
+
current_count += prompt_size
|
89 |
+
else:
|
90 |
+
if current_batch: # Add the current batch to batches
|
91 |
+
batches.append(current_batch)
|
92 |
+
# Start a new batch with the current prompt
|
93 |
+
current_batch = [prt]
|
94 |
+
current_count = prompt_size
|
95 |
+
|
96 |
+
# Add the last batch if it contains any prompts
|
97 |
+
if current_batch:
|
98 |
+
batches.append(current_batch)
|
99 |
+
|
100 |
+
return batches
|
101 |
+
|
102 |
+
|
103 |
+
class KVCache(nn.Module):
|
104 |
+
def __init__(self, bsz, seqlen, n_heads, head_dim, dtype, device):
|
105 |
+
super().__init__()
|
106 |
+
shape = (bsz, seqlen, n_heads, head_dim)
|
107 |
+
self.register_buffer("k_cache", torch.zeros(shape, dtype=dtype, device=device))
|
108 |
+
self.register_buffer("v_cache", torch.zeros(shape, dtype=dtype, device=device))
|
109 |
+
self.offset = 0
|
110 |
+
|
111 |
+
def reset(self):
|
112 |
+
self.k_cache.zero_()
|
113 |
+
self.v_cache.zero_()
|
114 |
+
self.offset = 0
|
115 |
+
|
116 |
+
def update(self, k_val, v_val, tok_idx):
|
117 |
+
# input_pos: [B], k_val: [B, S, H, D]
|
118 |
+
self.k_cache.index_copy_(1, self.offset + tok_idx, k_val)
|
119 |
+
self.v_cache.index_copy_(1, self.offset + tok_idx, v_val)
|
120 |
+
return self.k_cache, self.v_cache
|
121 |
+
|
122 |
+
|
123 |
+
@dataclass
|
124 |
+
class PackedCausalTransformerGeneratorArgs:
|
125 |
+
temperature: float = 0.0
|
126 |
+
top_p: Optional[float] = None
|
127 |
+
top_k: Optional[float] = None
|
128 |
+
max_gen_len: int = 512 # Maximum number of tokens to generate
|
129 |
+
max_tokens: int = 1024 # Maximum number of tokens that can go through the model
|
130 |
+
max_prompt_len: Optional[int] = None
|
131 |
+
until: List[str] = field(default_factory=list)
|
132 |
+
compile_prefilling: bool = False
|
133 |
+
reduce_generation_overhead: bool = False
|
134 |
+
show_progress: bool = False
|
135 |
+
dtype: Optional[str] = "bf16"
|
136 |
+
device: Optional[str] = "cuda"
|
137 |
+
|
138 |
+
|
139 |
+
class PackedCausalTransformerGenerator:
|
140 |
+
def __init__(
|
141 |
+
self,
|
142 |
+
cfg: PackedCausalTransformerGeneratorArgs,
|
143 |
+
model: nn.Module,
|
144 |
+
tokenizer: Tokenizer,
|
145 |
+
):
|
146 |
+
"""
|
147 |
+
This class wraps a causal transformer model with its corresponding tokenizer
|
148 |
+
and provides an efficient way to pack prompts together and do generation on
|
149 |
+
the packed sequence.
|
150 |
+
|
151 |
+
For example, if we had the prompts "Hello, I am a " and "Initiating calibration "
|
152 |
+
Then this class will concatenate those sequence (pack them together)
|
153 |
+
"Hello, I am a Initiating calibration"
|
154 |
+
And make the necessary attention masks such that a sequence only attends to itself
|
155 |
+
during prefilling and generation.
|
156 |
+
|
157 |
+
This class creates a fixed size cache of size max_tokens or sum of prompt sizes
|
158 |
+
+ the max number of generated tokens per sequence.
|
159 |
+
"""
|
160 |
+
self.model = model
|
161 |
+
self.tokenizer = tokenizer
|
162 |
+
self.temperature = cfg.temperature
|
163 |
+
self.top_p = cfg.top_p
|
164 |
+
self.top_k = cfg.top_k
|
165 |
+
|
166 |
+
self.max_gen_len = cfg.max_gen_len
|
167 |
+
self.max_tokens = cfg.max_tokens
|
168 |
+
self.max_prompt_len = cfg.max_prompt_len
|
169 |
+
self.until = cfg.until
|
170 |
+
self.max_until_size = max([len(e) for e in self.until]) if self.until else 1
|
171 |
+
self.device = cfg.device
|
172 |
+
|
173 |
+
# Compile if necessary
|
174 |
+
self.prefill = torch.compile(self.prefill, disable=not cfg.compile_prefilling)
|
175 |
+
self.generate_next_token = torch.compile(
|
176 |
+
self.generate_next_token,
|
177 |
+
mode="reduce-overhead",
|
178 |
+
disable=not cfg.reduce_generation_overhead,
|
179 |
+
)
|
180 |
+
|
181 |
+
self.show_progress = cfg.show_progress
|
182 |
+
self.dtype = dict(fp32=torch.float32, bf16=torch.bfloat16)[cfg.dtype]
|
183 |
+
|
184 |
+
self.prefill_doc_id, self.prefill_tok_id = None, None
|
185 |
+
self.padded_doc_id, self.padded_tok_id = None, None
|
186 |
+
self.current_doc_id, self.current_tok_id = None, None
|
187 |
+
self.padded_doc_start = None
|
188 |
+
self.prefill_mask = None
|
189 |
+
|
190 |
+
def clear_cache(self, offset):
|
191 |
+
for module in self.model.modules():
|
192 |
+
if isinstance(module, Attention):
|
193 |
+
if not hasattr(module, "kv_cache"):
|
194 |
+
module.kv_cache = KVCache(
|
195 |
+
1,
|
196 |
+
self.max_tokens,
|
197 |
+
module.n_kv_heads,
|
198 |
+
module.head_dim,
|
199 |
+
self.dtype,
|
200 |
+
self.device,
|
201 |
+
)
|
202 |
+
module.kv_cache.offset = offset
|
203 |
+
|
204 |
+
@torch.compiler.disable
|
205 |
+
def setup_prefilling(self, lengths: torch.Tensor):
|
206 |
+
# The KV cache is a fixed size tensor of size max_tokens that we need
|
207 |
+
# to update in order to do correct autoregressive generation.
|
208 |
+
|
209 |
+
# Here we will generate token by token but on multiple sequences
|
210 |
+
# at once. To do so, we need to have an attention mask that makes
|
211 |
+
# each sequence independent.
|
212 |
+
|
213 |
+
# Each sequence will write to its allocated space in the KV Cache.
|
214 |
+
# We allocate len(seq) + max_gen_len to each sequence in the cache.
|
215 |
+
|
216 |
+
# We will generate max_gen_len for each document
|
217 |
+
padded_lengths = lengths + self.max_gen_len
|
218 |
+
max_tokens = self.max_tokens or padded_lengths.sum().item()
|
219 |
+
# The last document might have more padding to fill up to max_tokens
|
220 |
+
padded_lengths[-1] += max_tokens - padded_lengths.sum()
|
221 |
+
|
222 |
+
# This is the start index in the cache for each document
|
223 |
+
self.padded_doc_start = lengths_to_start_ids(padded_lengths)
|
224 |
+
# For example with ab--123--cdef--
|
225 |
+
# this would be 0, 4, 9 if max_gen_len is 2
|
226 |
+
|
227 |
+
# We repeat interleave to align with tokens for prefilling
|
228 |
+
# Ex: ab--123--cdef--
|
229 |
+
# 000044444999999
|
230 |
+
prefill_offset = torch.repeat_interleave(self.padded_doc_start, lengths)
|
231 |
+
# This offset will make sure the tokens are written to the
|
232 |
+
# correct positions in the cache during prefilling
|
233 |
+
|
234 |
+
# We either init the cache or clear it by resetting the offset to prefill_offset
|
235 |
+
self.clear_cache(prefill_offset)
|
236 |
+
|
237 |
+
# The prefilling mask looks like the following for
|
238 |
+
# the two packed sequences ab and 123 : ab123
|
239 |
+
# Where spaces are empty cache positions
|
240 |
+
# keys
|
241 |
+
# ab---123---
|
242 |
+
# queries a 10000000000
|
243 |
+
# b 11000000000
|
244 |
+
# 1 00000100000
|
245 |
+
# 2 00000110000
|
246 |
+
# 3 00000111000
|
247 |
+
# We make sure to skip the empty cache positions
|
248 |
+
# and only attend to positions within the same sequence
|
249 |
+
doc_mask_mod = generate_doc_mask_mod(causal_mask, lengths, padded_lengths)
|
250 |
+
self.prefill_mask = create_block_mask(
|
251 |
+
doc_mask_mod, 1, None, lengths.sum(), max_tokens
|
252 |
+
)
|
253 |
+
|
254 |
+
# This creates the prefilling token ids which look like
|
255 |
+
# the following for the packed sequence abcdefg1234
|
256 |
+
# abcdefg1234
|
257 |
+
# 01234560123
|
258 |
+
# The token id gives us the position within each sequence
|
259 |
+
# This is used to compute ROPE and to update the cache
|
260 |
+
# At each forward pass the current tokens are written to
|
261 |
+
# offset + tok_id
|
262 |
+
self.prefill_doc_id, self.prefill_tok_id = lengths_to_local_ids(lengths)
|
263 |
+
|
264 |
+
# This creates the padded token and document ids
|
265 |
+
# which look like the following for the packed sequence ab123
|
266 |
+
# ab---123--- ab---123---
|
267 |
+
# padded_doc_id 00000111111 padded_tok_id 01234012345
|
268 |
+
# This will later be useful for the attention mask at generation
|
269 |
+
self.padded_doc_id, self.padded_tok_id = lengths_to_local_ids(padded_lengths)
|
270 |
+
|
271 |
+
@torch.compiler.disable
|
272 |
+
def setup_generation(self, lengths):
|
273 |
+
# KV Cache offset is set to the start of the padded documents
|
274 |
+
for module in self.model.modules():
|
275 |
+
if isinstance(module, Attention):
|
276 |
+
module.kv_cache.offset = self.padded_doc_start
|
277 |
+
# The token ids during generations correspond to the lengths of each doc
|
278 |
+
# current_tok_id will be incremented during generation
|
279 |
+
self.current_tok_id = lengths.clone()
|
280 |
+
# Since we're generating one token per document
|
281 |
+
# the document id is just an arange
|
282 |
+
self.current_doc_id = torch.arange(lengths.size(0), device=lengths.device)
|
283 |
+
|
284 |
+
# From here on some methods for generation
|
285 |
+
def prefill(self, tokens: torch.Tensor, lengths: torch.Tensor):
|
286 |
+
# Prefilling is done by taking multiple packed sequences and
|
287 |
+
# doing block diagonal attention on them so they remain independent
|
288 |
+
self.setup_prefilling(lengths=lengths)
|
289 |
+
prefill_out = self.model.forward(
|
290 |
+
tokens,
|
291 |
+
tok_idx=self.prefill_tok_id,
|
292 |
+
mask=self.prefill_mask,
|
293 |
+
attn_impl="flex_attention",
|
294 |
+
)
|
295 |
+
self.setup_generation(lengths=lengths)
|
296 |
+
return prefill_out
|
297 |
+
|
298 |
+
def generate_next_token(self, current_token):
|
299 |
+
# Since we're doing generation with multiple sequences at once
|
300 |
+
# we need to ignore tokens and cache entries from other sequences
|
301 |
+
# or in the future.
|
302 |
+
# Example mask :
|
303 |
+
# keys
|
304 |
+
# abc--1234--
|
305 |
+
# queries c 11100000000
|
306 |
+
# 4 00000111100
|
307 |
+
|
308 |
+
# mask shape : (n_seqs, cache_size)
|
309 |
+
doc_mask = self.current_doc_id.unsqueeze(1) == self.padded_doc_id.unsqueeze(0)
|
310 |
+
caus_mask = self.current_tok_id.unsqueeze(1) >= self.padded_tok_id.unsqueeze(0)
|
311 |
+
mask = doc_mask & caus_mask
|
312 |
+
out = self.model.forward(
|
313 |
+
current_token,
|
314 |
+
tok_idx=self.current_tok_id, # n_seqs
|
315 |
+
mask=mask,
|
316 |
+
attn_impl="sdpa",
|
317 |
+
)
|
318 |
+
self.current_tok_id += 1
|
319 |
+
return out
|
320 |
+
|
321 |
+
@torch.inference_mode()
|
322 |
+
def generate(self, prompts):
|
323 |
+
# Tokenize
|
324 |
+
prompts = [
|
325 |
+
self.tokenizer.encode(p, add_bos=True, add_eos=False) for p in prompts
|
326 |
+
]
|
327 |
+
# Truncate
|
328 |
+
max_seqlen = (
|
329 |
+
self.max_tokens
|
330 |
+
if not hasattr(self.model, "max_seqlen")
|
331 |
+
else self.model.max_seqlen
|
332 |
+
)
|
333 |
+
max_prompt_len = self.max_prompt_len or min(
|
334 |
+
max_seqlen - self.max_gen_len, self.max_tokens - self.max_gen_len
|
335 |
+
)
|
336 |
+
prompts = [p[-max_prompt_len:] for p in prompts]
|
337 |
+
# Account for the generation in lengths
|
338 |
+
padded_lengths = [len(p) + self.max_gen_len for p in prompts]
|
339 |
+
generation = []
|
340 |
+
loglikelihood = []
|
341 |
+
greedy = []
|
342 |
+
it = batch_prompts(prompts, self.max_tokens, lengths=padded_lengths)
|
343 |
+
if self.show_progress:
|
344 |
+
it = tqdm(it)
|
345 |
+
for batch in it:
|
346 |
+
n_seqs = len(batch)
|
347 |
+
generated_tokens = [[] for _ in range(n_seqs)]
|
348 |
+
is_done = [False for _ in range(n_seqs)]
|
349 |
+
packed_batch, lengths = pack_prompts(batch)
|
350 |
+
packed_batch, lengths = packed_batch.cuda(), lengths.cuda()
|
351 |
+
n_seqs = lengths.size(0)
|
352 |
+
|
353 |
+
# Prefilling cache
|
354 |
+
prompt_logits = self.prefill(packed_batch.unsqueeze(0), lengths)
|
355 |
+
# Selecting last token in each prompt
|
356 |
+
all_tokens = sample_tokens(
|
357 |
+
prompt_logits, self.temperature, self.top_p, self.top_k
|
358 |
+
)
|
359 |
+
start_token = all_tokens[:, lengths.cumsum(0) - 1]
|
360 |
+
|
361 |
+
for seq_id, tok in enumerate(start_token.squeeze(0).tolist()):
|
362 |
+
generated_tokens[seq_id].append(tok)
|
363 |
+
|
364 |
+
current_token = start_token
|
365 |
+
for i in range(1, self.max_gen_len):
|
366 |
+
|
367 |
+
next_logits = self.generate_next_token(current_token)
|
368 |
+
next_token = sample_tokens(
|
369 |
+
next_logits.clone(), self.temperature, self.top_p, self.top_k
|
370 |
+
)
|
371 |
+
|
372 |
+
for seq_id, tok in enumerate(next_token.squeeze(0).tolist()):
|
373 |
+
if not is_done[seq_id]:
|
374 |
+
generated_tokens[seq_id].append(tok)
|
375 |
+
current_end_str = self.tokenizer.decode(
|
376 |
+
generated_tokens[seq_id][-self.max_until_size :]
|
377 |
+
)
|
378 |
+
contains_end_string = any(
|
379 |
+
[e in current_end_str for e in self.until]
|
380 |
+
)
|
381 |
+
is_done[seq_id] = (
|
382 |
+
contains_end_string or tok == self.tokenizer.eos_id
|
383 |
+
)
|
384 |
+
if all(is_done):
|
385 |
+
break
|
386 |
+
|
387 |
+
current_token = next_token
|
388 |
+
|
389 |
+
generation.extend([self.tokenizer.decode(g) for g in generated_tokens])
|
390 |
+
|
391 |
+
for p, logit in zip(
|
392 |
+
batch, prompt_logits.squeeze(0).split(lengths.tolist())
|
393 |
+
):
|
394 |
+
x = logit[:-1]
|
395 |
+
y = torch.tensor(p[1:], device=x.device)
|
396 |
+
loglikelihood.append(-F.cross_entropy(x, y, reduction="none").cpu())
|
397 |
+
greedy.append((x.argmax(dim=-1) == y).cpu())
|
398 |
+
|
399 |
+
return generation, loglikelihood, greedy
|
400 |
+
|
401 |
+
|
402 |
+
def load_consolidated_model_and_tokenizer(
|
403 |
+
consolidated_path,
|
404 |
+
model_cls=LMTransformer,
|
405 |
+
model_args_cls=LMTransformerArgs,
|
406 |
+
):
|
407 |
+
ckpt_path = Path(consolidated_path)
|
408 |
+
config = ckpt_path / "params.json"
|
409 |
+
config = OmegaConf.load(config)
|
410 |
+
|
411 |
+
param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[
|
412 |
+
config.distributed.model_dtype
|
413 |
+
]
|
414 |
+
model_args = dataclass_from_dict(model_args_cls, config.model, strict=False)
|
415 |
+
tokenizer = build_tokenizer(config.data.tokenizer.name, config.data.tokenizer.path)
|
416 |
+
model = model_cls(model_args)
|
417 |
+
st_dict = torch.load(ckpt_path / CONSOLIDATE_NAME, weights_only=True)
|
418 |
+
model.load_state_dict(st_dict["model"])
|
419 |
+
model = model.cuda().eval()
|
420 |
+
for param in model.parameters():
|
421 |
+
param.data = param.data.to(dtype=param_dtype)
|
422 |
+
return model, tokenizer, config
|
423 |
+
|
424 |
+
|
425 |
+
def main():
|
426 |
+
# Load CLI arguments (overrides) and combine with a YAML config
|
427 |
+
cfg = OmegaConf.from_cli()
|
428 |
+
gen_cfg = dataclass_from_dict(
|
429 |
+
PackedCausalTransformerGeneratorArgs, cfg, strict=False
|
430 |
+
)
|
431 |
+
print(cfg)
|
432 |
+
|
433 |
+
model, tokenizer, _ = load_consolidated_model_and_tokenizer(cfg.ckpt)
|
434 |
+
|
435 |
+
generator = PackedCausalTransformerGenerator(gen_cfg, model, tokenizer)
|
436 |
+
|
437 |
+
# Allow multiple prompts
|
438 |
+
prompts = []
|
439 |
+
while True:
|
440 |
+
prompt = input("Enter a prompt (or press enter to finish): ")
|
441 |
+
if not prompt:
|
442 |
+
break
|
443 |
+
prompts.append(prompt)
|
444 |
+
|
445 |
+
# Start generation
|
446 |
+
start_time = time.time()
|
447 |
+
generation, loglikelihood, greedy = generator.generate(prompts)
|
448 |
+
end_time = time.time()
|
449 |
+
|
450 |
+
# Calculate tokens per second
|
451 |
+
total_tokens = sum(len(tokenizer.encode(gen, False, False)) for gen in generation)
|
452 |
+
tokens_per_second = total_tokens / (end_time - start_time)
|
453 |
+
|
454 |
+
# Display the results
|
455 |
+
for i, gen in enumerate(generation):
|
456 |
+
print(f"\nPrompt {i+1}: {prompts[i]}")
|
457 |
+
print(f"Generated Text: {gen}")
|
458 |
+
|
459 |
+
print(f"\nTokens per second: {tokens_per_second:.2f}")
|
460 |
+
|
461 |
+
|
462 |
+
if __name__ == "__main__":
|
463 |
+
main()
|
apps/main/lingua_train.py
ADDED
@@ -0,0 +1,654 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
3 |
+
|
4 |
+
import gc
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
from contextlib import ExitStack
|
9 |
+
from copy import deepcopy
|
10 |
+
from dataclasses import asdict, dataclass, field
|
11 |
+
from pathlib import Path
|
12 |
+
from timeit import default_timer as timer
|
13 |
+
from typing import Any, Dict, Optional
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.distributed
|
17 |
+
import wandb
|
18 |
+
import xformers.profiler
|
19 |
+
from lingua.args import dataclass_from_dict, dump_config, flatten_dict
|
20 |
+
from lingua.data import (
|
21 |
+
DataArgs,
|
22 |
+
PackTokensState,
|
23 |
+
build_dataloader_from_args,
|
24 |
+
init_dataloader_state_from_args,
|
25 |
+
)
|
26 |
+
from lingua.tokenizers.build_tokenizer import TokenizerArgs
|
27 |
+
from omegaconf import OmegaConf
|
28 |
+
from pydantic import BaseModel
|
29 |
+
from torch.distributed._tensor import DTensor
|
30 |
+
from torch.distributed.checkpoint.stateful import Stateful
|
31 |
+
from torch.optim import lr_scheduler
|
32 |
+
|
33 |
+
from bytelatent.checkpoint import (
|
34 |
+
CheckpointArgs,
|
35 |
+
CheckpointManager,
|
36 |
+
load_from_checkpoint,
|
37 |
+
)
|
38 |
+
from bytelatent.distributed import (
|
39 |
+
DistributedArgs,
|
40 |
+
EnvironmentArgs,
|
41 |
+
check_model_value_range,
|
42 |
+
clean_env,
|
43 |
+
dist_mean_dict,
|
44 |
+
get_device_mesh,
|
45 |
+
get_is_master,
|
46 |
+
get_world_size,
|
47 |
+
init_signal_handler,
|
48 |
+
parallelize_model,
|
49 |
+
requeue_slurm_job,
|
50 |
+
setup_env,
|
51 |
+
setup_torch_distributed,
|
52 |
+
)
|
53 |
+
from bytelatent.logger import init_logger
|
54 |
+
from bytelatent.metrics import (
|
55 |
+
GPUMemoryMonitor,
|
56 |
+
LoggingArgs,
|
57 |
+
MetricLogger,
|
58 |
+
get_num_params,
|
59 |
+
)
|
60 |
+
from bytelatent.optim import OptimArgs, build_optimizer
|
61 |
+
from bytelatent.probe import AutoProbeD
|
62 |
+
from bytelatent.profiling import ProfilerArgs, maybe_run_profiler
|
63 |
+
from bytelatent.stool import StoolArgs, launch_job
|
64 |
+
from bytelatent.transformer import (
|
65 |
+
LMTransformer,
|
66 |
+
LMTransformerArgs,
|
67 |
+
build_fsdp_grouping_plan,
|
68 |
+
get_no_recompute_ops,
|
69 |
+
get_num_flop_per_token,
|
70 |
+
tp_parallelize,
|
71 |
+
)
|
72 |
+
|
73 |
+
logger = logging.getLogger()
|
74 |
+
|
75 |
+
|
76 |
+
class TrainArgs(BaseModel):
|
77 |
+
name: str = "lingua"
|
78 |
+
dump_dir: str = ""
|
79 |
+
|
80 |
+
seed: int = 42
|
81 |
+
|
82 |
+
# Number of gradient accumulation steps
|
83 |
+
# Total batch size is batch_size*grad_acc_steps
|
84 |
+
grad_acc_steps: int = 1
|
85 |
+
|
86 |
+
gc_collect_freq: int = 1000
|
87 |
+
probe_freq: int | None = None
|
88 |
+
|
89 |
+
# Nb optimizer steps to take
|
90 |
+
steps: int = 1000
|
91 |
+
|
92 |
+
data: DataArgs
|
93 |
+
optim: OptimArgs
|
94 |
+
model: LMTransformerArgs
|
95 |
+
distributed: DistributedArgs
|
96 |
+
env: EnvironmentArgs
|
97 |
+
|
98 |
+
checkpoint: CheckpointArgs
|
99 |
+
profiling: ProfilerArgs
|
100 |
+
logging: LoggingArgs
|
101 |
+
|
102 |
+
# If set to None, eval is run locally otherwise it launches a new job with the given number of gpus
|
103 |
+
async_eval_gpus: int | None = None
|
104 |
+
eval: Any | None = None
|
105 |
+
|
106 |
+
|
107 |
+
@dataclass
|
108 |
+
class TrainState(Stateful):
|
109 |
+
step: int # Nb of steps taken by the optimizer
|
110 |
+
acc_step: int # Nb of accumulation steps done since last optimizer step
|
111 |
+
scheduler: lr_scheduler.LambdaLR
|
112 |
+
data_loader_state: PackTokensState
|
113 |
+
|
114 |
+
def state_dict(self) -> Dict[str, Any]:
|
115 |
+
return {
|
116 |
+
"step": self.step,
|
117 |
+
"acc_step": self.acc_step,
|
118 |
+
"data_loader_state": self.data_loader_state,
|
119 |
+
"scheduler": self.scheduler.state_dict(),
|
120 |
+
}
|
121 |
+
|
122 |
+
def load_state_dict(self, state_dict):
|
123 |
+
self.step = state_dict["step"]
|
124 |
+
self.acc_step = state_dict["acc_step"]
|
125 |
+
self.data_loader_state = PackTokensState(**state_dict["data_loader_state"])
|
126 |
+
self.scheduler.load_state_dict(state_dict["scheduler"])
|
127 |
+
|
128 |
+
|
129 |
+
def validate_train_args(args: TrainArgs, output_size: int):
|
130 |
+
if args.model.vocab_size < 0:
|
131 |
+
logger.info(f"Setting model output size to {args.model.vocab_size}")
|
132 |
+
args.model.vocab_size = output_size
|
133 |
+
assert (
|
134 |
+
args.model.vocab_size == output_size
|
135 |
+
), "Vocab size should be the same as output size"
|
136 |
+
|
137 |
+
assert args.dump_dir, "Dump dir not set"
|
138 |
+
|
139 |
+
if args.checkpoint.path is None:
|
140 |
+
logger.info(f"Setting checkpoint path to {args.checkpoint.path}")
|
141 |
+
args.checkpoint.path = str(Path(args.dump_dir) / "checkpoints")
|
142 |
+
|
143 |
+
for source in args.data.sources:
|
144 |
+
data_path = os.path.join(args.data.root_dir, source)
|
145 |
+
assert os.path.exists(data_path), f"{data_path} doesn't exist"
|
146 |
+
|
147 |
+
if (
|
148 |
+
args.distributed.dp_replicate
|
149 |
+
* args.distributed.dp_shard
|
150 |
+
* args.distributed.tp_size
|
151 |
+
!= get_world_size()
|
152 |
+
):
|
153 |
+
assert get_world_size() % args.distributed.dp_shard == 0
|
154 |
+
args.distributed.dp_replicate = get_world_size() // args.distributed.dp_shard
|
155 |
+
|
156 |
+
assert args.distributed.dp_replicate % args.distributed.tp_size == 0
|
157 |
+
args.distributed.dp_replicate = (
|
158 |
+
args.distributed.dp_replicate // args.distributed.tp_size
|
159 |
+
)
|
160 |
+
|
161 |
+
logger.warning(
|
162 |
+
f"Setting Data Parallel size to {args.distributed.dp_replicate * args.distributed.dp_shard}"
|
163 |
+
)
|
164 |
+
assert (
|
165 |
+
args.distributed.dp_replicate
|
166 |
+
* args.distributed.dp_shard
|
167 |
+
* args.distributed.tp_size
|
168 |
+
== get_world_size()
|
169 |
+
)
|
170 |
+
|
171 |
+
if args.distributed.fsdp_type == "no_shard":
|
172 |
+
assert (
|
173 |
+
args.distributed.dp_shard == 1
|
174 |
+
and args.distributed.dp_replicate == get_world_size()
|
175 |
+
)
|
176 |
+
|
177 |
+
args.model.max_seqlen = args.data.seq_len
|
178 |
+
|
179 |
+
if args.distributed.tp_size == 1:
|
180 |
+
logger.warning(
|
181 |
+
"Tensor parallelism has not been tested for a while, use at your own risk"
|
182 |
+
)
|
183 |
+
|
184 |
+
assert (
|
185 |
+
args.probe_freq != args.profiling.mem_steps
|
186 |
+
), "Don't profile during probe step"
|
187 |
+
assert (
|
188 |
+
args.probe_freq != args.profiling.profile_steps
|
189 |
+
), "Don't profile during probe step"
|
190 |
+
if args.logging.wandb is not None:
|
191 |
+
args.logging.wandb.name = args.name
|
192 |
+
|
193 |
+
if args.probe_freq is not None:
|
194 |
+
assert (
|
195 |
+
args.distributed.tp_size == 1
|
196 |
+
), "Probing not supported with tensor parallelism"
|
197 |
+
assert (
|
198 |
+
args.distributed.selective_activation_checkpointing is False
|
199 |
+
), "Probing not supported with selective activation checkpointing"
|
200 |
+
|
201 |
+
|
202 |
+
preemption_flag = dict(flag=False)
|
203 |
+
|
204 |
+
|
205 |
+
def set_preemption_flag(signum, frame):
|
206 |
+
logger.warning("Signal handler called with signal " + str(signum))
|
207 |
+
logger.warning("Preemption ! checkpointing asap and exiting.")
|
208 |
+
preemption_flag["flag"] = True
|
209 |
+
|
210 |
+
|
211 |
+
def every_n_steps(train_state, freq, acc_step=None, acc_freq=None):
|
212 |
+
test = train_state.step % freq == 0
|
213 |
+
if acc_step is not None:
|
214 |
+
test = test and (train_state.acc_step == acc_step)
|
215 |
+
elif acc_freq is not None:
|
216 |
+
test = test and ((train_state.acc_step % acc_freq) == 0)
|
217 |
+
return test
|
218 |
+
|
219 |
+
|
220 |
+
def train(args: TrainArgs):
|
221 |
+
with ExitStack() as context_stack:
|
222 |
+
tokenizer_args = TokenizerArgs(
|
223 |
+
name=args.data.name,
|
224 |
+
init_kwargs=args.data.tokenizer.init_kwargs,
|
225 |
+
)
|
226 |
+
tokenizer = tokenizer_args.build()
|
227 |
+
validate_train_args(
|
228 |
+
args,
|
229 |
+
tokenizer.n_words,
|
230 |
+
)
|
231 |
+
if get_is_master():
|
232 |
+
os.makedirs(args.dump_dir, exist_ok=True)
|
233 |
+
dump_config(args, Path(args.dump_dir) / "config.yaml")
|
234 |
+
init_logger(Path(args.dump_dir) / "train.log")
|
235 |
+
init_signal_handler(set_preemption_flag) # For handling preemption signals.
|
236 |
+
setup_env(args.env)
|
237 |
+
setup_torch_distributed(args.distributed)
|
238 |
+
world_mesh = get_device_mesh(args.distributed)
|
239 |
+
logger.info(f"Starting job: {args.name}")
|
240 |
+
|
241 |
+
# build dataloader
|
242 |
+
# need dp world size and rank
|
243 |
+
dp_mesh = world_mesh["dp_replicate"]
|
244 |
+
dp_degree = dp_mesh.size()
|
245 |
+
dp_rank = dp_mesh.get_local_rank()
|
246 |
+
if args.distributed.dp_shard > 1:
|
247 |
+
dp_rank = dp_rank * dp_degree + world_mesh["dp_shard"].get_local_rank()
|
248 |
+
dp_degree *= world_mesh["dp_shard"].size()
|
249 |
+
|
250 |
+
logger.info(f"Running on dp rank : {dp_rank}")
|
251 |
+
logger.info(f"Running on dp size : {dp_degree}")
|
252 |
+
|
253 |
+
torch.manual_seed(args.seed)
|
254 |
+
logger.info("Building model")
|
255 |
+
|
256 |
+
# Initializing Model in meta device allows us to initialize models much bigger than 1 gpu's memory
|
257 |
+
with torch.device("meta"):
|
258 |
+
model = LMTransformer(args.model)
|
259 |
+
logger.info("Model is built !")
|
260 |
+
|
261 |
+
model_param_count = get_num_params(model)
|
262 |
+
|
263 |
+
model = parallelize_model(
|
264 |
+
model,
|
265 |
+
world_mesh,
|
266 |
+
args.model,
|
267 |
+
args.distributed,
|
268 |
+
fsdp_grouping_plan=build_fsdp_grouping_plan(args.model),
|
269 |
+
tp_parallelize=tp_parallelize,
|
270 |
+
no_recompute_ops=get_no_recompute_ops(),
|
271 |
+
)
|
272 |
+
|
273 |
+
# Once we shard the model on different gpus we can actually initialize the model
|
274 |
+
# First we create empty tensors of the correct shapes
|
275 |
+
model = model.to_empty(device="cuda")
|
276 |
+
# Then we init the model. Please make sure this function initializes *ALL* parameters
|
277 |
+
# and buffers, otherwise you will have random values in the unitialized tensors
|
278 |
+
# which will silently fail (give nan gradients for example)
|
279 |
+
|
280 |
+
if args.checkpoint.init_ckpt_path:
|
281 |
+
logger.info(f"Loading initial model from {args.checkpoint.init_ckpt_path}")
|
282 |
+
load_from_checkpoint(
|
283 |
+
args.checkpoint.init_ckpt_path, model, model_key="model"
|
284 |
+
) # Put model_key="" if its directly the model checkpoint
|
285 |
+
model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded
|
286 |
+
else:
|
287 |
+
with torch.random.fork_rng(devices=[torch.cuda.current_device()]):
|
288 |
+
torch.manual_seed(args.model.seed)
|
289 |
+
model.init_weights()
|
290 |
+
check_model_value_range(model, range=10.0, std=1.0)
|
291 |
+
|
292 |
+
# log model size
|
293 |
+
|
294 |
+
logger.info(f"Model size: {model_param_count:,} total parameters")
|
295 |
+
|
296 |
+
gpu_memory_monitor = GPUMemoryMonitor("cuda")
|
297 |
+
logger.info(
|
298 |
+
f"GPU capacity: {gpu_memory_monitor.device_name} ({gpu_memory_monitor.device_index}) "
|
299 |
+
f"with {gpu_memory_monitor.device_capacity_gib:.2f}GiB memory"
|
300 |
+
)
|
301 |
+
logger.info(f"GPU memory usage: {gpu_memory_monitor}")
|
302 |
+
|
303 |
+
# build optimizer after apply parallelisms to the model
|
304 |
+
optimizer, scheduler = build_optimizer(model, args.optim, args.steps)
|
305 |
+
data_loader_state = init_dataloader_state_from_args(
|
306 |
+
args.data, dp_rank, dp_degree
|
307 |
+
)
|
308 |
+
|
309 |
+
train_state = TrainState(
|
310 |
+
step=0,
|
311 |
+
acc_step=0,
|
312 |
+
data_loader_state=data_loader_state,
|
313 |
+
scheduler=scheduler,
|
314 |
+
)
|
315 |
+
|
316 |
+
checkpoint = CheckpointManager.instantiate_and_make_dir(args.checkpoint)
|
317 |
+
checkpoint.load(model, optimizer, train_state, world_mesh)
|
318 |
+
# Either load from latest checkpoint or start from scratch
|
319 |
+
if args.probe_freq is not None:
|
320 |
+
if get_is_master():
|
321 |
+
os.makedirs(Path(args.dump_dir) / "probe", exist_ok=True)
|
322 |
+
torch.distributed.barrier()
|
323 |
+
probe = AutoProbeD(
|
324 |
+
model,
|
325 |
+
(
|
326 |
+
Path(args.dump_dir) / "probe" / f"probe.{dp_rank}.jsonl"
|
327 |
+
if (dp_rank % 128 == 0)
|
328 |
+
else None
|
329 |
+
),
|
330 |
+
)
|
331 |
+
probe_mod = model._orig_mod if args.distributed.compile else model
|
332 |
+
|
333 |
+
gc.disable()
|
334 |
+
|
335 |
+
# train loop
|
336 |
+
model.train()
|
337 |
+
metric_logger = context_stack.enter_context(
|
338 |
+
MetricLogger(Path(args.dump_dir) / "metrics.jsonl", args)
|
339 |
+
)
|
340 |
+
data_loader = context_stack.enter_context(
|
341 |
+
build_dataloader_from_args(
|
342 |
+
args.data,
|
343 |
+
state=train_state.data_loader_state,
|
344 |
+
)
|
345 |
+
)
|
346 |
+
torch_profiler = context_stack.enter_context(
|
347 |
+
maybe_run_profiler(args.dump_dir, model, args.profiling)
|
348 |
+
)
|
349 |
+
|
350 |
+
nwords_since_last_log = 0
|
351 |
+
time_last_log = timer()
|
352 |
+
gc.collect()
|
353 |
+
while train_state.step < args.steps:
|
354 |
+
# We constrain train_state.acc_step to be in range 0 to args.grad_acc_steps - 1
|
355 |
+
train_state.acc_step += 1
|
356 |
+
train_state.acc_step = train_state.acc_step % args.grad_acc_steps
|
357 |
+
|
358 |
+
# get batch
|
359 |
+
curr_lr = float(optimizer.param_groups[0]["lr"])
|
360 |
+
data_load_start = timer()
|
361 |
+
batch, train_state.data_loader_state = next(data_loader)
|
362 |
+
batch = torch.tensor(
|
363 |
+
batch,
|
364 |
+
dtype=torch.long,
|
365 |
+
)
|
366 |
+
|
367 |
+
if every_n_steps(train_state, args.gc_collect_freq, acc_step=0):
|
368 |
+
logger.info("garbage collection")
|
369 |
+
# we do garbage collection manually otherwise different processes
|
370 |
+
# run the GC at different times so they slow down the whole pipeline
|
371 |
+
gc.collect()
|
372 |
+
|
373 |
+
input_ids = batch[:, :, 0].cuda()
|
374 |
+
labels = batch[:, :, 1].cuda()
|
375 |
+
data_load_time = round(timer() - data_load_start, 4)
|
376 |
+
nwords_since_last_log += input_ids.numel()
|
377 |
+
|
378 |
+
bsz, seqlen = labels.shape
|
379 |
+
|
380 |
+
# forward
|
381 |
+
start_timer = torch.cuda.Event(enable_timing=True)
|
382 |
+
end_timer = torch.cuda.Event(enable_timing=True)
|
383 |
+
start_timer.record()
|
384 |
+
|
385 |
+
# This is an automatic probe that will compute statistics
|
386 |
+
# of all linears' inputs, weights and outputs
|
387 |
+
# along with attention logits and entropy
|
388 |
+
# both in forward and backward pass
|
389 |
+
if (args.probe_freq is not None) and every_n_steps(
|
390 |
+
train_state, args.probe_freq, acc_step=1 % args.grad_acc_steps
|
391 |
+
):
|
392 |
+
# Here we do a fake forward and backward pass on a smaller
|
393 |
+
# batch size to avoid OOM
|
394 |
+
# This assumes the model has no stateful layers (batch norm..)
|
395 |
+
assert (
|
396 |
+
next(probe_mod.parameters()).grad is None
|
397 |
+
), "Can't probe model if grads are not reset"
|
398 |
+
|
399 |
+
with probe:
|
400 |
+
probe.metadata = {
|
401 |
+
"it": train_state.step,
|
402 |
+
"global_step": train_state.step,
|
403 |
+
"loop": "lingua",
|
404 |
+
}
|
405 |
+
# Non compiled model uses roughly 2x memory in our exps
|
406 |
+
# So we divide bsz by 2 or seqlen by 2
|
407 |
+
probe_bsz = max(1, bsz // 2)
|
408 |
+
probe_seq = seqlen if (bsz // 2 >= 1) else (seqlen // 2)
|
409 |
+
probe_loss = probe_mod(
|
410 |
+
input_ids[:probe_bsz, :probe_seq],
|
411 |
+
labels[:probe_bsz, :probe_seq],
|
412 |
+
)
|
413 |
+
probe_loss.backward()
|
414 |
+
# We zero grads to cancel this fake step
|
415 |
+
optimizer.zero_grad()
|
416 |
+
|
417 |
+
assert (
|
418 |
+
next(probe_mod.parameters()).grad is None
|
419 |
+
), "Probe model shouldn't have grads at this point"
|
420 |
+
loss = model(input_ids, labels)
|
421 |
+
|
422 |
+
# We scale loss with grad_acc_steps so the gradient is the same
|
423 |
+
# regardless of grad_acc_steps
|
424 |
+
loss = loss / args.grad_acc_steps
|
425 |
+
# backward on scaled loss to create scaled gradients
|
426 |
+
loss.backward()
|
427 |
+
# For logging we undo that scaling
|
428 |
+
loss = loss.detach() * args.grad_acc_steps
|
429 |
+
|
430 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(
|
431 |
+
model.parameters(), max_norm=args.optim.clip, foreach=True
|
432 |
+
)
|
433 |
+
|
434 |
+
grad_norm = (
|
435 |
+
grad_norm.full_tensor() if isinstance(grad_norm, DTensor) else grad_norm
|
436 |
+
).item()
|
437 |
+
|
438 |
+
# optimizer step
|
439 |
+
if train_state.acc_step == 0:
|
440 |
+
optimizer.step()
|
441 |
+
scheduler.step()
|
442 |
+
optimizer.zero_grad()
|
443 |
+
train_state.step += 1
|
444 |
+
|
445 |
+
# updates the scale for next iteration
|
446 |
+
# training iteration complete
|
447 |
+
end_timer.record()
|
448 |
+
|
449 |
+
torch.cuda.synchronize()
|
450 |
+
|
451 |
+
curr_iter_time = round(start_timer.elapsed_time(end_timer) * 1e-3, 4)
|
452 |
+
|
453 |
+
# if profiler is active
|
454 |
+
if torch_profiler:
|
455 |
+
xformers.profiler.step()
|
456 |
+
|
457 |
+
# log metrics
|
458 |
+
if every_n_steps(
|
459 |
+
train_state,
|
460 |
+
args.logging.freq,
|
461 |
+
acc_step=None if args.logging.acc_freq else 0,
|
462 |
+
acc_freq=args.logging.acc_freq,
|
463 |
+
):
|
464 |
+
time_delta = timer() - time_last_log
|
465 |
+
wps = nwords_since_last_log / (time_delta * args.distributed.tp_size)
|
466 |
+
|
467 |
+
gpu_mem_stats = gpu_memory_monitor.get_peak_stats()
|
468 |
+
|
469 |
+
total_acc_steps = (
|
470 |
+
args.grad_acc_steps * train_state.step + train_state.acc_step
|
471 |
+
)
|
472 |
+
tokens_per_gpu = (
|
473 |
+
total_acc_steps * args.data.batch_size * args.data.seq_len
|
474 |
+
)
|
475 |
+
total_tokens = dp_degree * tokens_per_gpu
|
476 |
+
# This is an estimate and the correct values may change
|
477 |
+
# if you change the architecture
|
478 |
+
# Use xformer's analyze profile trace to get actual measurement
|
479 |
+
FLOPS = (
|
480 |
+
get_num_flop_per_token(
|
481 |
+
model_param_count - args.model.vocab_size * args.model.dim,
|
482 |
+
args.model.n_layers,
|
483 |
+
args.model.dim,
|
484 |
+
args.data.seq_len,
|
485 |
+
)
|
486 |
+
* wps
|
487 |
+
)
|
488 |
+
metrics = flatten_dict(
|
489 |
+
{
|
490 |
+
"global_step": train_state.step,
|
491 |
+
"acc_step": train_state.acc_step,
|
492 |
+
"speed": {
|
493 |
+
"wps": wps,
|
494 |
+
"FLOPS": FLOPS,
|
495 |
+
"curr_iter_time": curr_iter_time,
|
496 |
+
"data_load_time": data_load_time,
|
497 |
+
},
|
498 |
+
"optim": {
|
499 |
+
"grad_norm": grad_norm,
|
500 |
+
"lr": curr_lr,
|
501 |
+
"total_tokens": total_tokens,
|
502 |
+
},
|
503 |
+
"memory": gpu_mem_stats._asdict(),
|
504 |
+
},
|
505 |
+
sep="/",
|
506 |
+
)
|
507 |
+
|
508 |
+
to_sync = {}
|
509 |
+
to_sync["loss/out"] = loss.item()
|
510 |
+
metrics.update(dist_mean_dict(to_sync))
|
511 |
+
|
512 |
+
if get_is_master():
|
513 |
+
metric_logger.log(metrics)
|
514 |
+
|
515 |
+
gpu_memory_monitor.reset_peak_stats()
|
516 |
+
nwords_since_last_log = 0
|
517 |
+
time_last_log = timer()
|
518 |
+
logger.info(
|
519 |
+
f"step: {train_state.step}"
|
520 |
+
f" acc: {train_state.acc_step}"
|
521 |
+
f" loss: {round(loss.item(),4):>7}"
|
522 |
+
f" grad: {grad_norm:.2e}"
|
523 |
+
f" flops: {FLOPS:.2e}"
|
524 |
+
f" wps: {wps:.2e}"
|
525 |
+
f" iter: {curr_iter_time:>7}"
|
526 |
+
f" data: {data_load_time:>5}"
|
527 |
+
f" lr: {curr_lr:.2e}"
|
528 |
+
f" mem: {gpu_mem_stats.max_active_pct:.0f}%"
|
529 |
+
f" pow: {gpu_mem_stats.power_draw/1000} W"
|
530 |
+
)
|
531 |
+
|
532 |
+
saved = False
|
533 |
+
if every_n_steps(
|
534 |
+
train_state, args.checkpoint.dump.every, acc_step=0
|
535 |
+
) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):
|
536 |
+
saved = checkpoint.save(
|
537 |
+
model,
|
538 |
+
optimizer,
|
539 |
+
train_state,
|
540 |
+
args,
|
541 |
+
device_mesh=world_mesh,
|
542 |
+
)
|
543 |
+
|
544 |
+
if args.eval is not None and every_n_steps(
|
545 |
+
train_state, args.checkpoint.eval.every, acc_step=0
|
546 |
+
):
|
547 |
+
from apps.main.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval
|
548 |
+
|
549 |
+
eval_args = dataclass_from_dict(EvalArgs, args.eval)
|
550 |
+
|
551 |
+
eval_args.global_step = train_state.step
|
552 |
+
eval_args.ckpt_dir = str(checkpoint.existing_saves[-1])
|
553 |
+
eval_args.dump_dir = str(
|
554 |
+
os.path.join(
|
555 |
+
args.dump_dir,
|
556 |
+
"evals",
|
557 |
+
EVAL_FOLDER_NAME.format(train_state.step),
|
558 |
+
)
|
559 |
+
)
|
560 |
+
eval_args.metric_log_dir = args.dump_dir
|
561 |
+
if args.async_eval_gpus is None:
|
562 |
+
launch_eval(eval_args)
|
563 |
+
elif get_is_master():
|
564 |
+
if wandb.run is not None and args.logging.wandb is not None:
|
565 |
+
eval_args.wandb = deepcopy(args.logging.wandb)
|
566 |
+
assert args.async_eval_gpus > 0
|
567 |
+
logger.info(f"Launching evals on {args.async_eval_gpus} gpus")
|
568 |
+
with clean_env():
|
569 |
+
launch_job(
|
570 |
+
StoolArgs(
|
571 |
+
asdict(eval_args),
|
572 |
+
script="apps.main.eval",
|
573 |
+
copy_code=False,
|
574 |
+
nodes=args.async_eval_gpus // 8,
|
575 |
+
qos="lowest",
|
576 |
+
)
|
577 |
+
)
|
578 |
+
|
579 |
+
if preemption_flag["flag"]:
|
580 |
+
if not saved:
|
581 |
+
checkpoint.save(
|
582 |
+
model,
|
583 |
+
optimizer,
|
584 |
+
train_state,
|
585 |
+
args,
|
586 |
+
device_mesh=world_mesh,
|
587 |
+
)
|
588 |
+
requeue_slurm_job()
|
589 |
+
sys.exit(0)
|
590 |
+
|
591 |
+
if not saved:
|
592 |
+
checkpoint.save(
|
593 |
+
model,
|
594 |
+
optimizer,
|
595 |
+
train_state,
|
596 |
+
args,
|
597 |
+
device_mesh=world_mesh,
|
598 |
+
)
|
599 |
+
gc.collect()
|
600 |
+
|
601 |
+
|
602 |
+
def main():
|
603 |
+
"""
|
604 |
+
The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments
|
605 |
+
This accepts arguments as a dot list
|
606 |
+
So if the dataclass looks like
|
607 |
+
|
608 |
+
@dataclass
|
609 |
+
class DummyArgs:
|
610 |
+
name: str
|
611 |
+
model: LMTransformerArgsgs
|
612 |
+
|
613 |
+
@dataclass
|
614 |
+
class LMTransformerArgsgs:
|
615 |
+
dim: int
|
616 |
+
|
617 |
+
Then you can pass model.dim=32 to change values in LMTransformerArgsgs
|
618 |
+
or just name=tictac for top level attributes.
|
619 |
+
|
620 |
+
The behavior here is as follows:
|
621 |
+
1. We instantiate TrainArgs with its default values
|
622 |
+
2. We override those default values with the ones in the provided config file
|
623 |
+
3. We override the result with the additional arguments provided through command line
|
624 |
+
|
625 |
+
For example, if the config is the following
|
626 |
+
|
627 |
+
model:
|
628 |
+
dim: 128
|
629 |
+
n_layers: 4
|
630 |
+
|
631 |
+
and you call train.py with train.py model.dim=64
|
632 |
+
|
633 |
+
Then the final TrainArgs will have
|
634 |
+
|
635 |
+
model:
|
636 |
+
dim: 64
|
637 |
+
n_layers: 4
|
638 |
+
|
639 |
+
Plus all the default values in TrainArgs dataclass.
|
640 |
+
"""
|
641 |
+
cli_args = OmegaConf.from_cli()
|
642 |
+
file_cfg = OmegaConf.load(cli_args.config)
|
643 |
+
# We remove 'config' attribute from config as the underlying DataClass does not have it
|
644 |
+
del cli_args.config
|
645 |
+
|
646 |
+
default_cfg = OmegaConf.structured(TrainArgs())
|
647 |
+
cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
|
648 |
+
cfg = OmegaConf.to_object(cfg)
|
649 |
+
|
650 |
+
train(cfg)
|
651 |
+
|
652 |
+
|
653 |
+
if __name__ == "__main__":
|
654 |
+
main()
|
blt-figure.jpg
ADDED
![]() |
blt-figure.pdf
ADDED
Binary file (62.5 kB). View file
|
|
bytelatent/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
bytelatent/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
class ByteLatentError(Exception):
|
3 |
+
pass
|
bytelatent/args.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
from typing import Any
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import yaml
|
8 |
+
from pydantic import BaseModel, ConfigDict
|
9 |
+
|
10 |
+
from bytelatent.checkpoint import CheckpointArgs
|
11 |
+
from bytelatent.data.data_types import Batch
|
12 |
+
from bytelatent.data.iterators.abstract_iterator import StatefulIterator
|
13 |
+
from bytelatent.data.iterators.arrow_iterator import (
|
14 |
+
ArrowFileIterator,
|
15 |
+
find_and_sanitize_chunks,
|
16 |
+
)
|
17 |
+
from bytelatent.data.iterators.looping_iterator import LoopingIterator
|
18 |
+
from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator
|
19 |
+
from bytelatent.data.iterators.packing_iterator import PackingArgs, PackingIterator
|
20 |
+
from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator
|
21 |
+
from bytelatent.data.iterators.sampling_iterator import SamplingIterator
|
22 |
+
from bytelatent.data.iterators.sequence_iterator import (
|
23 |
+
SequenceIterator,
|
24 |
+
SequencePackingArgs,
|
25 |
+
)
|
26 |
+
from bytelatent.data.patcher import PatcherArgs
|
27 |
+
from bytelatent.distributed import DistributedArgs, EnvironmentArgs
|
28 |
+
from bytelatent.metrics import LoggingArgs
|
29 |
+
from bytelatent.model.blt import ByteLatentTransformerArgs
|
30 |
+
from bytelatent.optim import OptimArgs
|
31 |
+
from bytelatent.profiling import ProfilerArgs
|
32 |
+
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
|
33 |
+
|
34 |
+
logger = logging.getLogger()
|
35 |
+
|
36 |
+
|
37 |
+
def get_rng_state(seed: int, rank: int, world_size: int) -> dict[str, Any]:
|
38 |
+
return np.random.default_rng((seed, rank, world_size)).bit_generator.state
|
39 |
+
|
40 |
+
|
41 |
+
def distribute_data_to_rank(
|
42 |
+
*,
|
43 |
+
dataset_path: str,
|
44 |
+
preprocess_dir: str,
|
45 |
+
entropy_model_name: str | None,
|
46 |
+
arrow_batch_size: int,
|
47 |
+
rank: int,
|
48 |
+
world_size: int,
|
49 |
+
) -> ArrowFileIterator:
|
50 |
+
dataset_chunks = find_and_sanitize_chunks(dataset_path, world_size)
|
51 |
+
n_workers_per_chunk = world_size // len(dataset_chunks)
|
52 |
+
rank_to_arrow_iterator_params = []
|
53 |
+
for chunk_path in dataset_chunks:
|
54 |
+
for worker_id in range(n_workers_per_chunk):
|
55 |
+
rank_to_arrow_iterator_params.append(
|
56 |
+
ArrowFileIterator(
|
57 |
+
file_path=chunk_path,
|
58 |
+
worker_id=worker_id,
|
59 |
+
num_workers=n_workers_per_chunk,
|
60 |
+
preprocess_dir=preprocess_dir,
|
61 |
+
dataset_files=None,
|
62 |
+
entropy_model_name=entropy_model_name,
|
63 |
+
arrow_batch_size=arrow_batch_size,
|
64 |
+
)
|
65 |
+
)
|
66 |
+
return rank_to_arrow_iterator_params[rank]
|
67 |
+
|
68 |
+
|
69 |
+
class DataloaderArgs(BaseModel):
|
70 |
+
model_config = ConfigDict(extra="forbid")
|
71 |
+
root_dir: str | None = None
|
72 |
+
sources: dict[str, float] = {}
|
73 |
+
batch_size: int = 2
|
74 |
+
seq_len: int = 2048
|
75 |
+
seed: int = 42
|
76 |
+
add_bos: bool = True
|
77 |
+
add_eos: bool = True
|
78 |
+
load_async: bool = True
|
79 |
+
prefetch_size: int = 64
|
80 |
+
preprocess_dir: str | None = None
|
81 |
+
dataset_files: list[str] | None = None
|
82 |
+
entropy_model_name: str | None = "transformer_100m"
|
83 |
+
arrow_batch_size: int = 100
|
84 |
+
buffer_size: int = 64
|
85 |
+
|
86 |
+
pad_to_max_length: bool = True
|
87 |
+
max_encoder_seq_length: int = 12288
|
88 |
+
enable_byte_ngrams: bool = False
|
89 |
+
|
90 |
+
tokenizer_args: TokenizerArgs = TokenizerArgs()
|
91 |
+
patcher_args: PatcherArgs = PatcherArgs()
|
92 |
+
|
93 |
+
def _create_sequence_iterators(
|
94 |
+
self, rank: int, world_size: int
|
95 |
+
) -> dict[str, SequenceIterator]:
|
96 |
+
sequence_packing_args = SequencePackingArgs(
|
97 |
+
output_seq_len=self.seq_len,
|
98 |
+
buffer_size=self.buffer_size,
|
99 |
+
)
|
100 |
+
source_to_sequence_iterator: dict[str, SequenceIterator] = {}
|
101 |
+
for dataset_path in self.sources:
|
102 |
+
shuffle_rng_state = get_rng_state(self.seed + 1, rank, world_size)
|
103 |
+
arrow_iterator = distribute_data_to_rank(
|
104 |
+
dataset_path=os.path.join(self.root_dir, dataset_path),
|
105 |
+
preprocess_dir=self.preprocess_dir,
|
106 |
+
entropy_model_name=self.entropy_model_name,
|
107 |
+
arrow_batch_size=self.arrow_batch_size,
|
108 |
+
rank=rank,
|
109 |
+
world_size=world_size,
|
110 |
+
)
|
111 |
+
looping_iterator = LoopingIterator(arrow_iterator)
|
112 |
+
preprocess_iterator = PreprocessIterator(
|
113 |
+
looping_iterator,
|
114 |
+
patcher_args=self.patcher_args,
|
115 |
+
tokenizer_args=self.tokenizer_args,
|
116 |
+
)
|
117 |
+
sequence_iterator = SequenceIterator(
|
118 |
+
preprocess_iterator,
|
119 |
+
sequence_packing_args=sequence_packing_args,
|
120 |
+
rng_state=shuffle_rng_state,
|
121 |
+
)
|
122 |
+
|
123 |
+
source_to_sequence_iterator[dataset_path] = sequence_iterator
|
124 |
+
return source_to_sequence_iterator
|
125 |
+
|
126 |
+
def build_from_rank(
|
127 |
+
self, rank: int, world_size: int
|
128 |
+
) -> StatefulIterator[Batch, Any]:
|
129 |
+
source_to_sequence_iterators = self._create_sequence_iterators(rank, world_size)
|
130 |
+
weight_rng_state = get_rng_state(self.seed + 1, rank, world_size)
|
131 |
+
sampling_iterator = SamplingIterator(
|
132 |
+
rng_state=weight_rng_state,
|
133 |
+
source_to_weight=self.sources,
|
134 |
+
source_to_iterator=source_to_sequence_iterators,
|
135 |
+
)
|
136 |
+
tokenizer = self.tokenizer_args.build()
|
137 |
+
packing_args = PackingArgs(
|
138 |
+
batch_size=self.batch_size,
|
139 |
+
seq_len=self.seq_len,
|
140 |
+
pad_id=tokenizer.boe_id,
|
141 |
+
max_length=self.max_encoder_seq_length,
|
142 |
+
pad_to_max_length=self.pad_to_max_length,
|
143 |
+
enable_byte_ngrams=self.enable_byte_ngrams,
|
144 |
+
)
|
145 |
+
packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args)
|
146 |
+
mp_iterator = MultiprocessIterator(
|
147 |
+
packing_iterator, n_batches_to_prefetch=self.prefetch_size
|
148 |
+
)
|
149 |
+
|
150 |
+
return mp_iterator
|
151 |
+
|
152 |
+
|
153 |
+
class TrainArgs(BaseModel):
|
154 |
+
model_config = ConfigDict(extra="forbid")
|
155 |
+
name: str = "lingua"
|
156 |
+
dump_dir: str = ""
|
157 |
+
|
158 |
+
seed: int = 42
|
159 |
+
|
160 |
+
# Number of gradient accumulation steps
|
161 |
+
# Total batch size is batch_size*grad_acc_steps
|
162 |
+
grad_acc_steps: int = 1
|
163 |
+
|
164 |
+
gc_collect_freq: int = 1000
|
165 |
+
probe_freq: int | None = None
|
166 |
+
|
167 |
+
# Nb optimizer steps to take
|
168 |
+
steps: int = 1000
|
169 |
+
|
170 |
+
data: DataloaderArgs = DataloaderArgs()
|
171 |
+
optim: OptimArgs = OptimArgs()
|
172 |
+
model: ByteLatentTransformerArgs = ByteLatentTransformerArgs()
|
173 |
+
distributed: DistributedArgs = DistributedArgs()
|
174 |
+
env: EnvironmentArgs = EnvironmentArgs()
|
175 |
+
|
176 |
+
checkpoint: CheckpointArgs = CheckpointArgs()
|
177 |
+
profiling: ProfilerArgs = ProfilerArgs()
|
178 |
+
logging: LoggingArgs = LoggingArgs()
|
179 |
+
|
180 |
+
# If set to None, eval is run locally otherwise it launches a new job with the given number of gpus
|
181 |
+
async_eval_gpus: int | None = None
|
182 |
+
eval: Any | None = None
|
183 |
+
eval_on_gpus: int | None = None
|
184 |
+
|
185 |
+
def dump_to_yaml_file(
|
186 |
+
self, path: str, log_config: bool = True, sort_keys: bool = True
|
187 |
+
):
|
188 |
+
model_dict = self.model_dump(mode="json")
|
189 |
+
yaml_str = yaml.dump(
|
190 |
+
model_dict,
|
191 |
+
allow_unicode=True,
|
192 |
+
sort_keys=sort_keys,
|
193 |
+
default_flow_style=False,
|
194 |
+
)
|
195 |
+
with open(path, "w") as f:
|
196 |
+
if log_config:
|
197 |
+
logger.info("Using the following config for this run:")
|
198 |
+
logger.info(yaml_str)
|
199 |
+
f.write(yaml_str)
|
bytelatent/base_transformer.py
ADDED
@@ -0,0 +1,585 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
from enum import Enum
|
4 |
+
from typing import Optional, Tuple, Union
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from pydantic import BaseModel
|
8 |
+
from torch import nn
|
9 |
+
from torch.nn import functional as F
|
10 |
+
from torch.nn.attention.flex_attention import (
|
11 |
+
BlockMask,
|
12 |
+
_mask_mod_signature,
|
13 |
+
flex_attention,
|
14 |
+
)
|
15 |
+
from xformers.ops import AttentionBias, fmha
|
16 |
+
|
17 |
+
from bytelatent import probe
|
18 |
+
|
19 |
+
flex_attention_comp = torch.compile(flex_attention)
|
20 |
+
|
21 |
+
|
22 |
+
class InitStdFactor(Enum):
|
23 |
+
DISABLED = "disabled" # Init std is divided by 1.0
|
24 |
+
GLOBAL_DEPTH = "global_depth" # Init std is divided by sqrt(2*n_layers)
|
25 |
+
CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth)
|
26 |
+
DIM_RATIO = "dim_ratio" # Init std is divided by model_dim/4096
|
27 |
+
|
28 |
+
|
29 |
+
class BaseTransformerArgs(BaseModel):
|
30 |
+
dim: int = 512
|
31 |
+
n_layers: int = 8
|
32 |
+
head_dim: Optional[int] = None
|
33 |
+
n_heads: Optional[int] = None
|
34 |
+
n_kv_heads: Optional[int] = None
|
35 |
+
|
36 |
+
ffn_dim_multiplier: Optional[float] = None
|
37 |
+
|
38 |
+
multiple_of: int = 256
|
39 |
+
|
40 |
+
norm_eps: float = 1e-5
|
41 |
+
|
42 |
+
rope_theta: float = 10000.0
|
43 |
+
|
44 |
+
init_base_std: Optional[float] = None
|
45 |
+
init_std_factor: InitStdFactor = InitStdFactor.DISABLED
|
46 |
+
|
47 |
+
max_seqlen: int = 1024
|
48 |
+
|
49 |
+
|
50 |
+
def cross_entropy(pred, target, **kwargs):
|
51 |
+
return F.nll_loss(
|
52 |
+
F.log_softmax(pred.flatten(end_dim=-2).float(), -1),
|
53 |
+
target.flatten(end_dim=-1),
|
54 |
+
**kwargs,
|
55 |
+
)
|
56 |
+
|
57 |
+
|
58 |
+
def repeat_kv(x: torch.Tensor, n_rep: int, dim: int) -> torch.Tensor:
|
59 |
+
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
|
60 |
+
assert dim == 2, "Only dim=2 is supported. Check the implementation for other dims."
|
61 |
+
bs, slen, n_kv_heads, head_dim = x.shape
|
62 |
+
if n_rep == 1:
|
63 |
+
return x
|
64 |
+
return (
|
65 |
+
x[:, :, :, None, :]
|
66 |
+
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
|
67 |
+
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
|
68 |
+
)
|
69 |
+
|
70 |
+
|
71 |
+
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
|
72 |
+
"""
|
73 |
+
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
74 |
+
|
75 |
+
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
|
76 |
+
and the end index 'end'. The 'theta' parameter scales the frequencies.
|
77 |
+
The returned tensor contains complex values in complex64 data type.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
dim (int): Dimension of the frequency tensor.
|
81 |
+
end (int): End index for precomputing frequencies.
|
82 |
+
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
torch.Tensor: Precomputed frequency tensor with complex exponentials.
|
86 |
+
"""
|
87 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
88 |
+
t = torch.arange(end, device=freqs.device)
|
89 |
+
freqs = torch.outer(t, freqs).float()
|
90 |
+
|
91 |
+
cos, sin = freqs.cos(), freqs.sin()
|
92 |
+
|
93 |
+
return torch.stack((cos, -sin, sin, cos), dim=-1).view(*freqs.size(), 2, 2)
|
94 |
+
|
95 |
+
|
96 |
+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor, seq_dim: int):
|
97 |
+
"""
|
98 |
+
Reshape frequency tensor for broadcasting it with another tensor.
|
99 |
+
|
100 |
+
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
|
101 |
+
for the purpose of broadcasting the frequency tensor during element-wise operations.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
|
105 |
+
x (torch.Tensor): Target tensor for broadcasting compatibility.
|
106 |
+
seq_dim (int): Sequence dimension index.
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
torch.Tensor: Reshaped frequency tensor.
|
110 |
+
"""
|
111 |
+
ndim = x.ndim
|
112 |
+
assert 0 <= seq_dim < ndim
|
113 |
+
assert freqs_cis.shape == (
|
114 |
+
x.shape[seq_dim],
|
115 |
+
x.shape[-3],
|
116 |
+
2,
|
117 |
+
2,
|
118 |
+
), f"freqs_cis vs x: {(freqs_cis.shape, x.shape)}"
|
119 |
+
shape = [
|
120 |
+
d if i == seq_dim or i == ndim - 3 else 1 for i, d in enumerate(x.shape[:-2])
|
121 |
+
] + [2, 2]
|
122 |
+
return freqs_cis.view(*shape)
|
123 |
+
|
124 |
+
|
125 |
+
def apply_rotary_emb(
|
126 |
+
xq: torch.Tensor,
|
127 |
+
xk: torch.Tensor,
|
128 |
+
seq_dim: int,
|
129 |
+
freqs_cis: torch.Tensor,
|
130 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
131 |
+
xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2
|
132 |
+
xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2
|
133 |
+
freqs_cis = reshape_for_broadcast(
|
134 |
+
freqs_cis, xq_, seq_dim
|
135 |
+
).float() # S D/2 2 2 -> 1 S 1 D/2 2 2
|
136 |
+
xq_out = (xq_ * freqs_cis).sum(5).flatten(3)
|
137 |
+
xk_out = (xk_ * freqs_cis).sum(5).flatten(3)
|
138 |
+
return xq_out.type_as(xq), xk_out.type_as(xk)
|
139 |
+
|
140 |
+
|
141 |
+
def causal_mask(b, h, q_idx, kv_idx):
|
142 |
+
return q_idx >= kv_idx
|
143 |
+
|
144 |
+
|
145 |
+
def lengths_to_start_ids(lengths):
|
146 |
+
doc_start = lengths.cumsum(0)
|
147 |
+
doc_start = doc_start.roll(1)
|
148 |
+
doc_start[0] = 0
|
149 |
+
return doc_start
|
150 |
+
|
151 |
+
|
152 |
+
def lengths_to_local_ids(lengths):
|
153 |
+
assert lengths.ndim == 1
|
154 |
+
nb_seqs = lengths.size(0)
|
155 |
+
total_seqlen = lengths.sum()
|
156 |
+
# This gives the document id of each token
|
157 |
+
doc_id = torch.repeat_interleave(lengths)
|
158 |
+
# Compute document start for each document
|
159 |
+
doc_start = lengths_to_start_ids(lengths)
|
160 |
+
# Compute document start for each token
|
161 |
+
doc_start = doc_start[doc_id]
|
162 |
+
# Compute the position of each token within each document
|
163 |
+
tok_id = torch.arange(total_seqlen, device=lengths.device) - doc_start
|
164 |
+
|
165 |
+
return doc_id, tok_id
|
166 |
+
|
167 |
+
|
168 |
+
def generate_doc_mask_mod(
|
169 |
+
mask_mod: _mask_mod_signature,
|
170 |
+
lengths: torch.Tensor,
|
171 |
+
kv_lengths: Optional[torch.Tensor] = None,
|
172 |
+
) -> _mask_mod_signature:
|
173 |
+
"""Generates mask mods that apply to inputs to flex attention in the sequence stacked
|
174 |
+
format.
|
175 |
+
|
176 |
+
Args:
|
177 |
+
mask_mod: The mask mod to apply to the documents
|
178 |
+
lengths: Lengths of each document
|
179 |
+
|
180 |
+
Note:
|
181 |
+
What is the sequence stacked format? When assembling batches of inputs, we
|
182 |
+
take multiple sequences and stack them together to form 1 large sequence. We then
|
183 |
+
use masking to ensure that the attention scores are only applied to tokens within
|
184 |
+
the same document.
|
185 |
+
|
186 |
+
Example:
|
187 |
+
|
188 |
+
- Square mask
|
189 |
+
doc_mask lengths
|
190 |
+
a a b b b c c 2 3 2
|
191 |
+
a 1 0 0 0 0 0 0
|
192 |
+
a 1 1 0 0 0 0 0
|
193 |
+
b 0 0 1 0 0 0 0
|
194 |
+
b 0 0 1 1 0 0 0
|
195 |
+
b 0 0 1 1 1 0 0
|
196 |
+
c 0 0 0 0 0 1 0
|
197 |
+
c 0 0 0 0 0 1 1
|
198 |
+
|
199 |
+
"""
|
200 |
+
kv_lengths = kv_lengths if kv_lengths is not None else lengths
|
201 |
+
q_document_id, q_token_id = lengths_to_local_ids(lengths)
|
202 |
+
kv_document_id, kv_token_id = lengths_to_local_ids(kv_lengths)
|
203 |
+
q_max_idx = lengths.sum() - 1
|
204 |
+
kv_max_idx = kv_lengths.sum() - 1
|
205 |
+
|
206 |
+
def doc_mask_mod(b, h, q_idx, kv_idx):
|
207 |
+
q_idx_cap = torch.minimum(q_max_idx, q_idx)
|
208 |
+
kv_idx_cap = torch.minimum(kv_max_idx, kv_idx)
|
209 |
+
valid_idx = (q_idx <= q_max_idx) & (kv_idx <= kv_max_idx)
|
210 |
+
same_doc = q_document_id[q_idx_cap] == kv_document_id[kv_idx_cap]
|
211 |
+
q_logical = q_token_id[q_idx_cap]
|
212 |
+
kv_logical = kv_token_id[kv_idx_cap]
|
213 |
+
inner_mask = mask_mod(b, h, q_logical, kv_logical)
|
214 |
+
return same_doc & inner_mask & valid_idx
|
215 |
+
|
216 |
+
return doc_mask_mod
|
217 |
+
|
218 |
+
|
219 |
+
# Rotary embedding as in xformer, see if torchtrain implementation is not better. Also might be usefull to make it work with batch*seqlen collapsed.
|
220 |
+
class RotaryEmbedding(torch.nn.Module):
|
221 |
+
"""
|
222 |
+
RotaryEmbedding Module
|
223 |
+
"""
|
224 |
+
|
225 |
+
def __init__(self, theta: float, head_dim: int, max_seqlen: int = 1024):
|
226 |
+
super().__init__()
|
227 |
+
|
228 |
+
self.theta = theta
|
229 |
+
self.head_dim = head_dim
|
230 |
+
self.max_seqlen = max_seqlen
|
231 |
+
|
232 |
+
self.register_buffer(
|
233 |
+
"freqs_cis",
|
234 |
+
precompute_freqs_cis(dim=head_dim, end=max_seqlen, theta=theta),
|
235 |
+
persistent=False,
|
236 |
+
)
|
237 |
+
|
238 |
+
def reset_parameters(self):
|
239 |
+
self.freqs_cis[...] = precompute_freqs_cis(
|
240 |
+
dim=self.head_dim, end=self.max_seqlen, theta=self.theta
|
241 |
+
)
|
242 |
+
|
243 |
+
def forward(
|
244 |
+
self, seqlen: Optional[int] = None, tok_idx: Optional[torch.Tensor] = None
|
245 |
+
):
|
246 |
+
"""
|
247 |
+
Return freqs_cis corresponding to consecutive seqlen positions or the corresponding tok_idx positions
|
248 |
+
Args:
|
249 |
+
seqlen (int): Contiguous sequence length
|
250 |
+
tok_idx (torch.Tensor[int]): Position indices of each token this overrides seqlen
|
251 |
+
|
252 |
+
Returns:
|
253 |
+
Tuple(torch.Tensor, torch.Tensor): Embedded input tensor and freqs_cis
|
254 |
+
"""
|
255 |
+
test = (seqlen is not None) or (tok_idx is not None)
|
256 |
+
assert test, "Should provide atleast seqlen or tok_idx"
|
257 |
+
if tok_idx is not None:
|
258 |
+
return self.freqs_cis[tok_idx]
|
259 |
+
elif seqlen is not None:
|
260 |
+
return self.freqs_cis[0:seqlen]
|
261 |
+
|
262 |
+
|
263 |
+
class RMSNorm(nn.Module):
|
264 |
+
"""
|
265 |
+
Initialize the RMSNorm normalization layer.
|
266 |
+
|
267 |
+
Args:
|
268 |
+
dim (int): The dimension of the input tensor.
|
269 |
+
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
270 |
+
|
271 |
+
Attributes:
|
272 |
+
eps (float): A small value added to the denominator for numerical stability.
|
273 |
+
weight (nn.Parameter): Learnable scaling parameter.
|
274 |
+
|
275 |
+
"""
|
276 |
+
|
277 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
278 |
+
super().__init__()
|
279 |
+
self.eps = eps
|
280 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
281 |
+
|
282 |
+
def _norm(self, x: torch.Tensor):
|
283 |
+
return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps)
|
284 |
+
|
285 |
+
def forward(self, x: torch.Tensor):
|
286 |
+
x = probe.log_stats(x, "resid")
|
287 |
+
output = self._norm(x.float())
|
288 |
+
return (output * self.weight.float()).type_as(x)
|
289 |
+
|
290 |
+
def reset_parameters(self):
|
291 |
+
torch.nn.init.ones_(self.weight) # type: ignore
|
292 |
+
|
293 |
+
|
294 |
+
class Attention(nn.Module):
|
295 |
+
def __init__(
|
296 |
+
self,
|
297 |
+
dim: int,
|
298 |
+
head_dim: int,
|
299 |
+
n_heads: int,
|
300 |
+
n_kv_heads: int,
|
301 |
+
rope_theta: float,
|
302 |
+
):
|
303 |
+
super().__init__()
|
304 |
+
|
305 |
+
self.dim = dim
|
306 |
+
self.head_dim = head_dim
|
307 |
+
self.rope_theta = rope_theta
|
308 |
+
|
309 |
+
self.n_heads = n_heads
|
310 |
+
self.n_kv_heads = n_kv_heads
|
311 |
+
self.heads_per_group = self.n_heads // self.n_kv_heads
|
312 |
+
|
313 |
+
self.wq = nn.Linear(
|
314 |
+
dim,
|
315 |
+
n_heads * head_dim,
|
316 |
+
bias=False,
|
317 |
+
)
|
318 |
+
self.wk = nn.Linear(
|
319 |
+
dim,
|
320 |
+
n_kv_heads * head_dim,
|
321 |
+
bias=False,
|
322 |
+
)
|
323 |
+
self.wv = nn.Linear(
|
324 |
+
dim,
|
325 |
+
n_kv_heads * head_dim,
|
326 |
+
bias=False,
|
327 |
+
)
|
328 |
+
|
329 |
+
self.wo = nn.Linear(
|
330 |
+
n_heads * head_dim,
|
331 |
+
dim,
|
332 |
+
bias=False,
|
333 |
+
)
|
334 |
+
|
335 |
+
def forward(
|
336 |
+
self,
|
337 |
+
x: torch.Tensor,
|
338 |
+
freq_cis: torch.Tensor,
|
339 |
+
tok_idx: Optional[torch.Tensor] = None,
|
340 |
+
mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
|
341 |
+
attn_impl: str = "sdpa",
|
342 |
+
) -> torch.Tensor:
|
343 |
+
# B S D
|
344 |
+
bsz, seq_len, dim = x.shape
|
345 |
+
xq = self.wq(x.view_as(x))
|
346 |
+
xk = self.wk(x.view_as(x))
|
347 |
+
xv = self.wv(x.view_as(x))
|
348 |
+
|
349 |
+
output_shape = xq.shape
|
350 |
+
# B S D -> B S H D
|
351 |
+
xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim)
|
352 |
+
xk = xk.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
|
353 |
+
xv = xv.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
|
354 |
+
|
355 |
+
xq, xk = apply_rotary_emb(xq, xk, 1, freq_cis[0:seq_len])
|
356 |
+
|
357 |
+
# This condition helps us be easily compatible
|
358 |
+
# with inference by adding a pluggable KVCache
|
359 |
+
if hasattr(self, "kv_cache"):
|
360 |
+
xk, xv = self.kv_cache.update(xk, xv, tok_idx)
|
361 |
+
|
362 |
+
xk = repeat_kv(xk, self.heads_per_group, dim=2)
|
363 |
+
xv = repeat_kv(xv, self.heads_per_group, dim=2)
|
364 |
+
|
365 |
+
if attn_impl == "flex_attention":
|
366 |
+
assert mask is None or isinstance(mask, BlockMask)
|
367 |
+
xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv))
|
368 |
+
output = flex_attention_comp(xq, xk, xv, block_mask=mask)
|
369 |
+
output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
|
370 |
+
|
371 |
+
elif attn_impl == "fmha":
|
372 |
+
assert mask is None or isinstance(mask, AttentionBias)
|
373 |
+
output = fmha.memory_efficient_attention(xq, xk, xv, attn_bias=mask)
|
374 |
+
# This uses B S H D instead of B H S D of pytorch
|
375 |
+
|
376 |
+
elif attn_impl == "sdpa":
|
377 |
+
xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv))
|
378 |
+
assert mask is None or isinstance(mask, (str, torch.Tensor))
|
379 |
+
is_causal = (mask == "causal") if isinstance(mask, str) else False
|
380 |
+
mask = mask if isinstance(mask, torch.Tensor) else None
|
381 |
+
output = F.scaled_dot_product_attention(
|
382 |
+
xq,
|
383 |
+
xk,
|
384 |
+
xv,
|
385 |
+
is_causal=is_causal,
|
386 |
+
attn_mask=mask,
|
387 |
+
)
|
388 |
+
output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
|
389 |
+
else:
|
390 |
+
raise NotImplementedError(
|
391 |
+
f"Attention implementation {attn_impl} not supported"
|
392 |
+
)
|
393 |
+
|
394 |
+
output = self.wo(output.reshape(output_shape))
|
395 |
+
|
396 |
+
return output
|
397 |
+
|
398 |
+
def reset_parameters(self, init_std=None, factor=1.0):
|
399 |
+
init_std = init_std or (self.dim ** (-0.5))
|
400 |
+
|
401 |
+
for w in [self.wq, self.wk, self.wv]:
|
402 |
+
nn.init.trunc_normal_(
|
403 |
+
w.weight,
|
404 |
+
mean=0.0,
|
405 |
+
std=init_std,
|
406 |
+
a=-3 * init_std,
|
407 |
+
b=3 * init_std,
|
408 |
+
)
|
409 |
+
|
410 |
+
nn.init.trunc_normal_(
|
411 |
+
self.wo.weight,
|
412 |
+
mean=0.0,
|
413 |
+
std=init_std / factor,
|
414 |
+
a=-3 * init_std,
|
415 |
+
b=3 * init_std,
|
416 |
+
)
|
417 |
+
|
418 |
+
|
419 |
+
class FeedForward(nn.Module):
|
420 |
+
def __init__(
|
421 |
+
self,
|
422 |
+
dim: int,
|
423 |
+
hidden_dim: int,
|
424 |
+
multiple_of: int,
|
425 |
+
ffn_dim_multiplier: Optional[float],
|
426 |
+
mp_size: int = 1,
|
427 |
+
):
|
428 |
+
super().__init__()
|
429 |
+
|
430 |
+
hidden_dim = int(2 * hidden_dim / 3)
|
431 |
+
if ffn_dim_multiplier is not None:
|
432 |
+
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
433 |
+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
434 |
+
assert hidden_dim % mp_size == 0
|
435 |
+
|
436 |
+
self.dim = dim
|
437 |
+
self.hidden_dim = hidden_dim
|
438 |
+
|
439 |
+
self.w1 = nn.Linear(
|
440 |
+
dim,
|
441 |
+
hidden_dim,
|
442 |
+
bias=False,
|
443 |
+
)
|
444 |
+
self.w3 = nn.Linear(
|
445 |
+
dim,
|
446 |
+
hidden_dim,
|
447 |
+
bias=False,
|
448 |
+
)
|
449 |
+
self.w2 = nn.Linear(
|
450 |
+
hidden_dim,
|
451 |
+
dim,
|
452 |
+
bias=False,
|
453 |
+
)
|
454 |
+
|
455 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
456 |
+
# B S D
|
457 |
+
x1 = self.w1(x.view_as(x))
|
458 |
+
x3 = self.w3(x.view_as(x))
|
459 |
+
output = self.w2(F.silu(x1) * x3)
|
460 |
+
return output
|
461 |
+
|
462 |
+
def reset_parameters(self, init_std=None, factor=1.0):
|
463 |
+
in_init_std = init_std or (self.dim ** (-0.5))
|
464 |
+
out_init_std = init_std or (self.hidden_dim ** (-0.5))
|
465 |
+
in_init_std = in_init_std
|
466 |
+
out_init_std = out_init_std / factor
|
467 |
+
for w in [self.w1, self.w3]:
|
468 |
+
nn.init.trunc_normal_(
|
469 |
+
w.weight,
|
470 |
+
mean=0.0,
|
471 |
+
std=in_init_std,
|
472 |
+
a=-3 * in_init_std,
|
473 |
+
b=3 * in_init_std,
|
474 |
+
)
|
475 |
+
nn.init.trunc_normal_(
|
476 |
+
self.w2.weight,
|
477 |
+
mean=0.0,
|
478 |
+
std=out_init_std,
|
479 |
+
a=-3 * out_init_std,
|
480 |
+
b=3 * out_init_std,
|
481 |
+
)
|
482 |
+
|
483 |
+
|
484 |
+
class TransformerBlock(nn.Module):
|
485 |
+
def __init__(self, args: BaseTransformerArgs):
|
486 |
+
super().__init__()
|
487 |
+
|
488 |
+
assert (args.head_dim is not None) or (
|
489 |
+
args.n_heads is not None
|
490 |
+
), "Should specify at least head_dim or n_heads"
|
491 |
+
self.head_dim = args.head_dim or args.dim // args.n_heads
|
492 |
+
self.n_heads = args.n_heads or args.dim // args.head_dim
|
493 |
+
self.n_kv_heads = args.n_kv_heads or self.n_heads
|
494 |
+
|
495 |
+
assert args.n_heads % self.n_kv_heads == 0
|
496 |
+
assert args.dim % args.n_heads == 0
|
497 |
+
|
498 |
+
self.attention = Attention(
|
499 |
+
dim=args.dim,
|
500 |
+
head_dim=self.head_dim,
|
501 |
+
n_heads=self.n_heads,
|
502 |
+
n_kv_heads=self.n_kv_heads,
|
503 |
+
rope_theta=args.rope_theta,
|
504 |
+
)
|
505 |
+
self.feed_forward = FeedForward(
|
506 |
+
dim=args.dim,
|
507 |
+
hidden_dim=4 * args.dim,
|
508 |
+
multiple_of=args.multiple_of,
|
509 |
+
ffn_dim_multiplier=args.ffn_dim_multiplier,
|
510 |
+
)
|
511 |
+
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
512 |
+
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
513 |
+
|
514 |
+
def forward(
|
515 |
+
self,
|
516 |
+
x: torch.Tensor,
|
517 |
+
freq_cis: torch.Tensor,
|
518 |
+
tok_idx: Optional[torch.Tensor] = None,
|
519 |
+
mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
|
520 |
+
attn_impl: str = "sdpa",
|
521 |
+
) -> torch.Tensor:
|
522 |
+
h = x + self.attention(
|
523 |
+
self.attention_norm(x),
|
524 |
+
freq_cis,
|
525 |
+
tok_idx=tok_idx,
|
526 |
+
mask=mask,
|
527 |
+
attn_impl=attn_impl,
|
528 |
+
)
|
529 |
+
out = h + self.feed_forward(self.ffn_norm(h))
|
530 |
+
return out
|
531 |
+
|
532 |
+
def init_weights(self, init_std=None, factor=1.0):
|
533 |
+
self.attention.reset_parameters(init_std, factor)
|
534 |
+
self.attention_norm.reset_parameters()
|
535 |
+
|
536 |
+
self.feed_forward.reset_parameters(init_std, factor)
|
537 |
+
self.ffn_norm.reset_parameters()
|
538 |
+
|
539 |
+
|
540 |
+
class BaseTransformer(nn.Module):
|
541 |
+
def __init__(self, args: BaseTransformerArgs):
|
542 |
+
super().__init__()
|
543 |
+
self.dim = args.dim
|
544 |
+
self.init_base_std = args.init_base_std
|
545 |
+
self.init_std_factor = InitStdFactor(args.init_std_factor)
|
546 |
+
self.max_seqlen = args.max_seqlen
|
547 |
+
self.rope_embeddings = RotaryEmbedding(
|
548 |
+
theta=args.rope_theta,
|
549 |
+
head_dim=args.head_dim or args.dim // args.n_heads,
|
550 |
+
max_seqlen=args.max_seqlen,
|
551 |
+
)
|
552 |
+
|
553 |
+
self.layers = nn.ModuleList()
|
554 |
+
for _ in range(args.n_layers):
|
555 |
+
self.layers.append(TransformerBlock(args))
|
556 |
+
|
557 |
+
def forward(
|
558 |
+
self,
|
559 |
+
h,
|
560 |
+
tok_idx: Optional[torch.Tensor] = None,
|
561 |
+
mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
|
562 |
+
attn_impl: str = "sdpa",
|
563 |
+
):
|
564 |
+
|
565 |
+
freq_cis = self.rope_embeddings(seqlen=self.max_seqlen, tok_idx=tok_idx)
|
566 |
+
|
567 |
+
for i, layer in enumerate(self.layers):
|
568 |
+
h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)
|
569 |
+
return h
|
570 |
+
|
571 |
+
def reset_parameters(self):
|
572 |
+
# Either use fixed base std or sqrt model dim
|
573 |
+
self.rope_embeddings.reset_parameters()
|
574 |
+
|
575 |
+
def init_weights(self):
|
576 |
+
self.reset_parameters()
|
577 |
+
for depth, layer in enumerate(self.layers):
|
578 |
+
factor = {
|
579 |
+
InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
|
580 |
+
InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5,
|
581 |
+
InitStdFactor.DIM_RATIO: self.dim / 4096,
|
582 |
+
InitStdFactor.DISABLED: 1.0,
|
583 |
+
}[self.init_std_factor]
|
584 |
+
|
585 |
+
layer.init_weights(self.init_base_std, factor)
|
bytelatent/checkpoint.py
ADDED
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
import re
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import List, Optional, Tuple
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.distributed as dist
|
12 |
+
import torch.distributed.checkpoint as dcp
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.optim.optimizer
|
15 |
+
from pydantic import BaseModel, ConfigDict
|
16 |
+
from torch.distributed._tensor import DeviceMesh
|
17 |
+
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
|
18 |
+
from torch.distributed.checkpoint.state_dict import (
|
19 |
+
get_model_state_dict,
|
20 |
+
get_state_dict,
|
21 |
+
set_state_dict,
|
22 |
+
)
|
23 |
+
|
24 |
+
from bytelatent.distributed import get_is_master
|
25 |
+
|
26 |
+
logger = logging.getLogger("CHECKPOINT")
|
27 |
+
|
28 |
+
FOLDER_NAME = "{:010d}"
|
29 |
+
RE_FOLDER = r"\d{10}"
|
30 |
+
|
31 |
+
RE_CKPT = r"__\d_\d\.distcp"
|
32 |
+
|
33 |
+
CONSOLIDATE_FOLDER = "consolidated"
|
34 |
+
CONSOLIDATE_NAME = "consolidated.pth"
|
35 |
+
|
36 |
+
CONFIG_NAME = "params.json"
|
37 |
+
TRAIN_STATE_NAME = "train_state_{:05d}.json"
|
38 |
+
RE_DIGITS = re.compile(r"\d+")
|
39 |
+
|
40 |
+
|
41 |
+
class SaveEvery(BaseModel):
|
42 |
+
model_config = ConfigDict(extra="forbid")
|
43 |
+
every: int = 1000
|
44 |
+
keep: int = 0
|
45 |
+
|
46 |
+
|
47 |
+
class CheckpointArgs(BaseModel):
|
48 |
+
model_config = ConfigDict(extra="forbid")
|
49 |
+
dump: SaveEvery = SaveEvery()
|
50 |
+
eval: SaveEvery = SaveEvery()
|
51 |
+
path: str | None = None
|
52 |
+
init_ckpt_path: str | None = None
|
53 |
+
continue_training_from_init: bool = False
|
54 |
+
|
55 |
+
|
56 |
+
def _get_key_step(name: str):
|
57 |
+
return int(re.findall(RE_DIGITS, name)[-1])
|
58 |
+
|
59 |
+
|
60 |
+
def consolidate_checkpoints(ckpt_dir: str):
|
61 |
+
"""
|
62 |
+
Consolidates all FSDP checkpoints in a directory to a single file
|
63 |
+
Consolidate checkpoint is saved in a subdirectory of ckpt_dir
|
64 |
+
|
65 |
+
Parameters:
|
66 |
+
ckpt_dir: str - path to the directory containing the checkpoints
|
67 |
+
|
68 |
+
Returns the path to the consolidated checkpoint
|
69 |
+
"""
|
70 |
+
consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER
|
71 |
+
if not (consolidate_path / CONSOLIDATE_NAME).exists():
|
72 |
+
consolidate_path.mkdir(exist_ok=True)
|
73 |
+
logger.info(f"Consolidating to: {str(consolidate_path)}")
|
74 |
+
dcp_to_torch_save(ckpt_dir, str(consolidate_path / CONSOLIDATE_NAME))
|
75 |
+
(consolidate_path / CONFIG_NAME).write_text(
|
76 |
+
(Path(ckpt_dir) / CONFIG_NAME).read_text()
|
77 |
+
)
|
78 |
+
logger.info("Consolidated !")
|
79 |
+
return consolidate_path
|
80 |
+
|
81 |
+
|
82 |
+
def load_from_checkpoint(
|
83 |
+
ckpt_dir: str,
|
84 |
+
model: nn.Module,
|
85 |
+
optimizer: Optional[torch.optim.Optimizer] = None,
|
86 |
+
model_key: str = "model",
|
87 |
+
optim_key: str = "optim",
|
88 |
+
):
|
89 |
+
if not (Path(ckpt_dir) / ".metadata").exists():
|
90 |
+
raise ValueError(
|
91 |
+
f"Please convert the checkpoint distcp format using `torch.distributed.checkpoint.format_utils.torch_save_to_dcp` before loading it"
|
92 |
+
)
|
93 |
+
|
94 |
+
state_dict = {}
|
95 |
+
if optimizer is not None:
|
96 |
+
state_dict[model_key], state_dict[optim_key] = get_state_dict(model, optimizer)
|
97 |
+
else:
|
98 |
+
state_dict[model_key] = get_model_state_dict(model)
|
99 |
+
if model_key == "": # If only loading a model directly, the key should be empty
|
100 |
+
state_dict = state_dict.pop(model_key)
|
101 |
+
|
102 |
+
dcp.load(state_dict, checkpoint_id=ckpt_dir)
|
103 |
+
|
104 |
+
|
105 |
+
class CheckpointManager:
|
106 |
+
def __init__(self, args: CheckpointArgs):
|
107 |
+
self.path = args.path
|
108 |
+
self.dump_every = args.dump
|
109 |
+
self.eval_every = args.eval
|
110 |
+
self.init_ckpt_path = args.init_ckpt_path
|
111 |
+
self.continue_training_from_init = args.continue_training_from_init
|
112 |
+
|
113 |
+
assert os.path.exists(
|
114 |
+
self.path
|
115 |
+
), f"Path {self.path} does not exist and needs to be created before using CheckpointManager (use instantiate_and_make_dir)"
|
116 |
+
|
117 |
+
self.existing_saves = self.get_existing_saves()
|
118 |
+
|
119 |
+
def get_existing_saves(self) -> List[Path]:
|
120 |
+
folders = [
|
121 |
+
p
|
122 |
+
for p in Path(self.path).iterdir()
|
123 |
+
if p.is_dir() and re.match(RE_FOLDER, p.name)
|
124 |
+
]
|
125 |
+
folders.sort(key=lambda p: _get_key_step(p.name))
|
126 |
+
return folders
|
127 |
+
|
128 |
+
def clean_up(self):
|
129 |
+
logger.info("Cleaning up checkpoints...")
|
130 |
+
dump_folders = []
|
131 |
+
eval_folders = []
|
132 |
+
other_folders = []
|
133 |
+
for p in self.existing_saves:
|
134 |
+
is_dump = _get_key_step(p.name) % self.dump_every.every == 0
|
135 |
+
is_eval = _get_key_step(p.name) % self.eval_every.every == 0
|
136 |
+
if is_dump:
|
137 |
+
dump_folders.append(p)
|
138 |
+
if is_eval:
|
139 |
+
eval_folders.append(p)
|
140 |
+
if not (is_dump or is_eval):
|
141 |
+
other_folders.append(p)
|
142 |
+
|
143 |
+
logger.info(f"Dump folders: {dump_folders}")
|
144 |
+
logger.info(f"Eval folders: {eval_folders}")
|
145 |
+
logger.info(f"Other folders: {other_folders}")
|
146 |
+
|
147 |
+
if self.dump_every.keep > 0:
|
148 |
+
dump_folders = dump_folders[-self.dump_every.keep :]
|
149 |
+
if self.eval_every.keep > 0:
|
150 |
+
eval_folders = eval_folders[-self.eval_every.keep :]
|
151 |
+
|
152 |
+
folder_to_keep = set(other_folders + dump_folders + eval_folders)
|
153 |
+
folder_to_remove = set(self.existing_saves) - folder_to_keep
|
154 |
+
|
155 |
+
logger.info(f"Removing folders: {folder_to_remove}")
|
156 |
+
|
157 |
+
if dist.get_rank() == 0:
|
158 |
+
for folder in folder_to_remove:
|
159 |
+
for file in folder.iterdir():
|
160 |
+
if file.is_file():
|
161 |
+
file.unlink()
|
162 |
+
elif file.is_dir():
|
163 |
+
assert file.name in [CONSOLIDATE_FOLDER]
|
164 |
+
for f in file.iterdir():
|
165 |
+
f.unlink()
|
166 |
+
file.rmdir()
|
167 |
+
folder.rmdir()
|
168 |
+
|
169 |
+
dist.barrier()
|
170 |
+
|
171 |
+
self.existing_saves = list(folder_to_keep)
|
172 |
+
self.existing_saves.sort(key=lambda p: _get_key_step(p.name))
|
173 |
+
|
174 |
+
def get_last_step_path(self, dp_rank: int = 0) -> Optional[Path]:
|
175 |
+
path = None
|
176 |
+
for p in reversed(self.existing_saves):
|
177 |
+
if (p / TRAIN_STATE_NAME.format(dp_rank)).is_file():
|
178 |
+
path = p
|
179 |
+
break
|
180 |
+
return path
|
181 |
+
|
182 |
+
def _create_folder(self, base_path: Path, folder_name: str) -> Path:
|
183 |
+
folder = base_path / folder_name
|
184 |
+
if get_is_master():
|
185 |
+
folder.mkdir(parents=False, exist_ok=True)
|
186 |
+
if dist.is_initialized():
|
187 |
+
dist.barrier()
|
188 |
+
return folder
|
189 |
+
|
190 |
+
def _get_dp_tp_mesh(
|
191 |
+
self, device_mesh: Optional[DeviceMesh] = None
|
192 |
+
) -> Tuple[int, int]:
|
193 |
+
dp_rank = 0
|
194 |
+
tp_rank = 0
|
195 |
+
if device_mesh is not None:
|
196 |
+
if "dp_replicate" in device_mesh.mesh_dim_names:
|
197 |
+
dp_rank = device_mesh.get_local_rank("dp_replicate")
|
198 |
+
if "dp_shard" in device_mesh.mesh_dim_names:
|
199 |
+
dp_rank = dp_rank * device_mesh[
|
200 |
+
"dp_replicate"
|
201 |
+
].size() + device_mesh.get_local_rank("dp_shard")
|
202 |
+
if "tp" in device_mesh.mesh_dim_names:
|
203 |
+
tp_rank = device_mesh.get_local_rank("tp")
|
204 |
+
return dp_rank, tp_rank
|
205 |
+
|
206 |
+
@torch.no_grad()
|
207 |
+
def get_state_dict(
|
208 |
+
self,
|
209 |
+
model,
|
210 |
+
optimizer,
|
211 |
+
):
|
212 |
+
model_sd, optim_sd = get_state_dict(model, optimizer)
|
213 |
+
return {"model": model_sd, "optim": optim_sd}
|
214 |
+
|
215 |
+
def save(
|
216 |
+
self,
|
217 |
+
model,
|
218 |
+
optimizer,
|
219 |
+
train_state,
|
220 |
+
config,
|
221 |
+
device_mesh: Optional[DeviceMesh] = None,
|
222 |
+
) -> bool:
|
223 |
+
|
224 |
+
# When creating directory check if only rank0 or is there other solution
|
225 |
+
path = Path(self.path)
|
226 |
+
curr_save_dir = self._create_folder(path, FOLDER_NAME.format(train_state.step))
|
227 |
+
logger.info(f"Saving to: {str(curr_save_dir)}")
|
228 |
+
|
229 |
+
if dist.is_initialized():
|
230 |
+
dist.barrier()
|
231 |
+
|
232 |
+
logger.info("Saving...")
|
233 |
+
state_dict = self.get_state_dict(model, optimizer)
|
234 |
+
dcp.save(state_dict, checkpoint_id=curr_save_dir)
|
235 |
+
logger.info("State dict saved!")
|
236 |
+
|
237 |
+
if dist.is_initialized():
|
238 |
+
dist.barrier()
|
239 |
+
|
240 |
+
if get_is_master():
|
241 |
+
config.dump_to_yaml_file(curr_save_dir / CONFIG_NAME)
|
242 |
+
|
243 |
+
# Add json dump here
|
244 |
+
dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh)
|
245 |
+
if tp_rank == 0:
|
246 |
+
train_state_name = TRAIN_STATE_NAME.format(dp_rank)
|
247 |
+
logger.info(
|
248 |
+
f"Saving train state to: {str(curr_save_dir / train_state_name)}"
|
249 |
+
)
|
250 |
+
with open(curr_save_dir / train_state_name, "w") as f:
|
251 |
+
json.dump(train_state.state_dict(), f)
|
252 |
+
logger.info("Train state saved !")
|
253 |
+
|
254 |
+
self.existing_saves.append(curr_save_dir)
|
255 |
+
|
256 |
+
self.clean_up()
|
257 |
+
|
258 |
+
if dist.is_initialized():
|
259 |
+
dist.barrier()
|
260 |
+
return True
|
261 |
+
|
262 |
+
@torch.no_grad()
|
263 |
+
def load(
|
264 |
+
self,
|
265 |
+
model: nn.Module,
|
266 |
+
optimizer,
|
267 |
+
train_state,
|
268 |
+
device_mesh: DeviceMesh,
|
269 |
+
path: Optional[Path] = None,
|
270 |
+
):
|
271 |
+
dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh)
|
272 |
+
# Loading tries to load the provided path, if not available the last saved step and finally from the init path
|
273 |
+
path = path or self.get_last_step_path(dp_rank=dp_rank)
|
274 |
+
# If none of those are available don't do anything
|
275 |
+
if path is None:
|
276 |
+
# If no checkpoints exist do nothing
|
277 |
+
return
|
278 |
+
|
279 |
+
# Only load train state if it's provided, the files exist and we're not loading from init path
|
280 |
+
train_state_name = TRAIN_STATE_NAME.format(dp_rank)
|
281 |
+
logger.info("Reloading train state")
|
282 |
+
with open(path / train_state_name, "r") as f:
|
283 |
+
train_state_dict = json.load(f)
|
284 |
+
train_state.load_state_dict(train_state_dict)
|
285 |
+
logger.info("Train state reloaded")
|
286 |
+
|
287 |
+
logger.info(f"Loading from: {str(path)}")
|
288 |
+
state_dict = self.get_state_dict(
|
289 |
+
model=model,
|
290 |
+
optimizer=optimizer,
|
291 |
+
)
|
292 |
+
dcp.load(state_dict, checkpoint_id=path)
|
293 |
+
logger.info("State dict loaded.")
|
294 |
+
|
295 |
+
logger.info("Reloading model and optim")
|
296 |
+
|
297 |
+
set_state_dict(
|
298 |
+
model,
|
299 |
+
optimizer,
|
300 |
+
model_state_dict=state_dict["model"],
|
301 |
+
optim_state_dict=state_dict["optim"],
|
302 |
+
)
|
303 |
+
logger.info("Model and optim reloaded")
|
304 |
+
|
305 |
+
@classmethod
|
306 |
+
def instantiate_and_make_dir(cls, args: CheckpointArgs):
|
307 |
+
if get_is_master():
|
308 |
+
os.makedirs(args.path, exist_ok=True)
|
309 |
+
dist.barrier()
|
310 |
+
|
311 |
+
return cls(args)
|
bytelatent/configs/debug.yaml
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Template config, need to change dump_dir, data.root_dir and tokenizer.path
|
2 |
+
# Evals can be activated by uncommenting its config
|
3 |
+
# python -m launchers.stool config=apps/main/configs/debug.yaml nodes=8 account=fair_amaia_cw_codegen qos=lowest
|
4 |
+
|
5 |
+
dump_dir: /tmp/
|
6 |
+
name: "debug"
|
7 |
+
steps: 100_000
|
8 |
+
probe_freq: null
|
9 |
+
seed: 777
|
10 |
+
optim:
|
11 |
+
lr: 4e-04
|
12 |
+
warmup: 500
|
13 |
+
lr_min_ratio: 0.1
|
14 |
+
clip: 10.0
|
15 |
+
|
16 |
+
distributed:
|
17 |
+
fsdp_type: full_shard
|
18 |
+
compile: true
|
19 |
+
model_dtype: bf16
|
20 |
+
matmul_allow_tf32: false
|
21 |
+
selective_activation_checkpointing: false
|
22 |
+
tp_size: 1
|
23 |
+
|
24 |
+
model:
|
25 |
+
n_heads: 8
|
26 |
+
dim: 512
|
27 |
+
vocab_size: 260
|
28 |
+
dim_token: 256
|
29 |
+
patch_size: 6
|
30 |
+
tokenization_mode: "bytes"
|
31 |
+
patching_mode: "space"
|
32 |
+
tie_local_encoder_decoder_logits: false
|
33 |
+
data_loader_patching: true
|
34 |
+
max_encoder_seq_length: 12288
|
35 |
+
pad_to_max_length: true
|
36 |
+
patching_threshold: 3.1439168453216553
|
37 |
+
encoder_hash_byte_group_size: [4]
|
38 |
+
encoder_hash_byte_group_vocab: 50002
|
39 |
+
encoder_hash_byte_group_nb_functions: 3
|
40 |
+
encoder_enable_byte_ngrams: false
|
41 |
+
cross_attn_encoder: true # assuming cross_attention is true
|
42 |
+
cross_attn_decoder: true # assuming cross_attention is true
|
43 |
+
cross_attn_window_encoder: 512
|
44 |
+
cross_attn_window_decoder: 512
|
45 |
+
dim_local_encoder: 256
|
46 |
+
dim_local_decoder: 256
|
47 |
+
cross_attn_k: 8
|
48 |
+
cross_attn_nheads: 4
|
49 |
+
cross_attn_all_layers_decoder: true
|
50 |
+
cross_attn_all_layers_encoder: true
|
51 |
+
cross_attn_use_flex_attention: true
|
52 |
+
cross_attn_init_by_pooling: true
|
53 |
+
log_patch_lengths: true
|
54 |
+
non_linearity: "swiglu"
|
55 |
+
use_rope: true
|
56 |
+
recompute_fc1_out: false
|
57 |
+
recompute_fc3_out: false
|
58 |
+
recompute_attn: false
|
59 |
+
custom_bwd: false
|
60 |
+
layer_ckpt: "none"
|
61 |
+
efficient_attn: "sdpa"
|
62 |
+
patch_only_encoder: false
|
63 |
+
patch_only_decoder: false
|
64 |
+
use_local_encoder_transformer: true
|
65 |
+
init_use_gaussian: true
|
66 |
+
init_use_depth: "current"
|
67 |
+
attn_bias_type: "block_causal"
|
68 |
+
alpha_depth: "disabled"
|
69 |
+
max_length: 256
|
70 |
+
local_attention_window_len: 512
|
71 |
+
max_seqlen: 12288
|
72 |
+
downsampling_by_pooling: "max"
|
73 |
+
|
74 |
+
data:
|
75 |
+
root_dir: ???
|
76 |
+
sources:
|
77 |
+
dclm_baseline_1.0: 1.0
|
78 |
+
batch_size: 2
|
79 |
+
prefetch_size: 64
|
80 |
+
seq_len: 4096
|
81 |
+
load_async: true
|
82 |
+
preprocess_dir: ???
|
83 |
+
tokenizer_args:
|
84 |
+
name: blt
|
85 |
+
init_kwargs:
|
86 |
+
bpe_tokenizer_path: ???
|
87 |
+
|
88 |
+
profiling:
|
89 |
+
run: false
|
90 |
+
|
91 |
+
checkpoint:
|
92 |
+
dump:
|
93 |
+
every: 500
|
94 |
+
keep: 3
|
95 |
+
eval:
|
96 |
+
every: 1000
|
97 |
+
keep: -1
|
98 |
+
|
99 |
+
logging:
|
100 |
+
freq: 10
|
101 |
+
|
102 |
+
eval_on_gpus: 8
|
103 |
+
eval:
|
104 |
+
dataset_dir: /checkpoint/amaia/codegen/datasets/eval
|
105 |
+
tasks: boolq,hellaswag,nq,piqa,siqa,tqa,winogrande,obqa,arc_easy,arc_challenge,race.middle,race.high,gsm8k,math,bbh,copa,human_eval_plus,mbpp,mmlu
|
106 |
+
generator:
|
107 |
+
max_tokens: 65536
|
108 |
+
dtype: bf16
|
109 |
+
|
110 |
+
mp_size: 1
|
bytelatent/constants.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
BLT_DATA = Path(os.environ.get("BLT_DATA", "data"))
|
bytelatent/data/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
bytelatent/data/data_types.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
import json
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from typing import Any, Iterator
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from pydantic import BaseModel, ConfigDict
|
8 |
+
|
9 |
+
|
10 |
+
class BltExample(BaseModel):
|
11 |
+
model_config = ConfigDict(extra="forbid")
|
12 |
+
sample_id: str
|
13 |
+
text: str
|
14 |
+
tokens: list[int] | None
|
15 |
+
entropies: list[float] | None
|
16 |
+
patch_lengths: list[int] | None
|
17 |
+
mask: list[bool] | None
|
18 |
+
|
19 |
+
|
20 |
+
class MultiChoiceState(BaseModel):
|
21 |
+
model_config = ConfigDict(extra="forbid")
|
22 |
+
root_dir: str
|
23 |
+
sources: dict[str, float]
|
24 |
+
source_to_state: dict[str, Any]
|
25 |
+
rng_state: dict[str, Any]
|
26 |
+
|
27 |
+
|
28 |
+
class PrefetchState(BaseModel):
|
29 |
+
model_config = ConfigDict(extra="forbid")
|
30 |
+
seq_idx: int
|
31 |
+
rng_state: dict[str, Any]
|
32 |
+
prefetch_size: int
|
33 |
+
batch_size: int
|
34 |
+
|
35 |
+
|
36 |
+
class BltPackTokensState(BaseModel):
|
37 |
+
model_config = ConfigDict(extra="forbid")
|
38 |
+
start_token: int
|
39 |
+
output_seq_len: int
|
40 |
+
n_views: int = 2
|
41 |
+
|
42 |
+
|
43 |
+
class DataLoaderState(BaseModel):
|
44 |
+
model_config = ConfigDict(extra="forbid")
|
45 |
+
multi_choice_state: MultiChoiceState
|
46 |
+
pack_tokens_state: BltPackTokensState
|
47 |
+
prefetch_state: PrefetchState
|
48 |
+
|
49 |
+
|
50 |
+
BltIterator = Iterator[tuple[BltExample, DataLoaderState]]
|
51 |
+
|
52 |
+
|
53 |
+
class BltSequence(BaseModel):
|
54 |
+
tokens: list[int]
|
55 |
+
mask: list[bool]
|
56 |
+
patch_lengths: list[int]
|
57 |
+
|
58 |
+
|
59 |
+
@dataclass
|
60 |
+
class Batch:
|
61 |
+
x: np.ndarray
|
62 |
+
y: np.ndarray
|
63 |
+
mask: np.ndarray | None = None
|
64 |
+
patch_lengths: np.ndarray | None = None
|
65 |
+
ngram_ids: np.ndarray | None = None
|
66 |
+
is_final: bool = False
|
67 |
+
|
68 |
+
def to_python_dict(self) -> dict:
|
69 |
+
x = self.x.tolist()
|
70 |
+
y = self.y.tolist()
|
71 |
+
if self.mask is None:
|
72 |
+
mask = None
|
73 |
+
else:
|
74 |
+
mask = self.mask.tolist()
|
75 |
+
if self.patch_lengths is None:
|
76 |
+
patch_lengths = None
|
77 |
+
else:
|
78 |
+
patch_lengths = self.patch_lengths.tolist()
|
79 |
+
if self.ngram_ids is None:
|
80 |
+
ngram_ids = None
|
81 |
+
else:
|
82 |
+
ngram_ids = self.ngram_ids.tolist()
|
83 |
+
return {
|
84 |
+
"x": x,
|
85 |
+
"y": y,
|
86 |
+
"mask": mask,
|
87 |
+
"patch_lengths": patch_lengths,
|
88 |
+
"ngram_ids": ngram_ids,
|
89 |
+
"is_final": self.is_final,
|
90 |
+
}
|
91 |
+
|
92 |
+
@classmethod
|
93 |
+
def from_python_dict(cls, data: dict) -> "Batch":
|
94 |
+
x = np.array(data["x"])
|
95 |
+
y = np.array(data["y"])
|
96 |
+
if data["mask"] is None:
|
97 |
+
mask = None
|
98 |
+
else:
|
99 |
+
mask = np.array(data["mask"])
|
100 |
+
if data["patch_lengths"] is None:
|
101 |
+
patch_lengths = None
|
102 |
+
else:
|
103 |
+
patch_lengths = np.array(data["patch_lengths"])
|
104 |
+
if data["ngram_ids"] is None:
|
105 |
+
ngram_ids = None
|
106 |
+
else:
|
107 |
+
ngram_ids = np.array(data["ngram_ids"])
|
108 |
+
return Batch(
|
109 |
+
x=x,
|
110 |
+
y=y,
|
111 |
+
mask=mask,
|
112 |
+
patch_lengths=patch_lengths,
|
113 |
+
ngram_ids=ngram_ids,
|
114 |
+
is_final=data["is_final"],
|
115 |
+
)
|
bytelatent/data/iterators/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
bytelatent/data/iterators/abstract_iterator.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
import abc
|
3 |
+
from typing import Any, Generator, Generic, TypeVar
|
4 |
+
|
5 |
+
T = TypeVar("T")
|
6 |
+
C = TypeVar("C")
|
7 |
+
|
8 |
+
|
9 |
+
class StatefulIterator(Generic[T, C], abc.ABC):
|
10 |
+
|
11 |
+
@abc.abstractmethod
|
12 |
+
def get_state(self) -> C:
|
13 |
+
pass
|
14 |
+
|
15 |
+
@abc.abstractmethod
|
16 |
+
def create_iter(self) -> Generator[T, Any, None]:
|
17 |
+
pass
|
18 |
+
|
19 |
+
|
20 |
+
class IteratorState(Generic[C]):
|
21 |
+
@abc.abstractmethod
|
22 |
+
def build(self) -> StatefulIterator[T, C]:
|
23 |
+
pass
|
bytelatent/data/iterators/arrow_iterator.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
import re
|
3 |
+
from logging import getLogger
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Any, Generator
|
6 |
+
|
7 |
+
import pyarrow as pa
|
8 |
+
|
9 |
+
# pyarrow needs the initialization from this import
|
10 |
+
import pyarrow.dataset # pyright: ignore
|
11 |
+
from pydantic import BaseModel, ConfigDict
|
12 |
+
|
13 |
+
from bytelatent import ByteLatentError
|
14 |
+
from bytelatent.data.data_types import BltExample
|
15 |
+
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
|
16 |
+
|
17 |
+
logger = getLogger(__name__)
|
18 |
+
|
19 |
+
|
20 |
+
class ArrowFileIteratorState(BaseModel, IteratorState):
|
21 |
+
model_config = ConfigDict(extra="forbid")
|
22 |
+
file_path: str | None
|
23 |
+
row_num: int
|
24 |
+
num_workers: int
|
25 |
+
worker_id: int
|
26 |
+
preprocess_dir: str | None
|
27 |
+
dataset_files: list[str] | None
|
28 |
+
entropy_model_name: str | None
|
29 |
+
arrow_batch_size: int = 100
|
30 |
+
|
31 |
+
def build(self) -> "ArrowFileIterator":
|
32 |
+
arrow_file = ArrowFileIterator(
|
33 |
+
file_path=self.file_path,
|
34 |
+
worker_id=self.worker_id,
|
35 |
+
num_workers=self.num_workers,
|
36 |
+
preprocess_dir=self.preprocess_dir,
|
37 |
+
entropy_model_name=self.entropy_model_name,
|
38 |
+
arrow_batch_size=self.arrow_batch_size,
|
39 |
+
dataset_files=self.dataset_files,
|
40 |
+
)
|
41 |
+
if self.row_num != 0:
|
42 |
+
arrow_file._set_row_num(self.row_num)
|
43 |
+
return arrow_file
|
44 |
+
|
45 |
+
|
46 |
+
def shard_sort_key(file: str | Path):
|
47 |
+
match = re.search(r".+\.shard_([0-9]+)\.arrow", str(file))
|
48 |
+
shard_number = int(match.group(1))
|
49 |
+
return shard_number
|
50 |
+
|
51 |
+
|
52 |
+
class ArrowFileIterator(StatefulIterator):
|
53 |
+
def __init__(
|
54 |
+
self,
|
55 |
+
*,
|
56 |
+
file_path: str | None,
|
57 |
+
worker_id: int,
|
58 |
+
num_workers: int,
|
59 |
+
preprocess_dir: str | None,
|
60 |
+
entropy_model_name: str | None,
|
61 |
+
arrow_batch_size: int,
|
62 |
+
dataset_files: list[str] | None = None,
|
63 |
+
):
|
64 |
+
assert 0 <= worker_id < num_workers, (worker_id, num_workers)
|
65 |
+
if file_path is None and dataset_files is None:
|
66 |
+
raise ByteLatentError("file_path and dataset_files cannot both be None")
|
67 |
+
self.row_num = 0
|
68 |
+
self.iter_id = 0
|
69 |
+
self.batch_iterator = None
|
70 |
+
self.batch_to_consume = None
|
71 |
+
self.dataset = None
|
72 |
+
self.file_path = file_path
|
73 |
+
self.worker_id = worker_id
|
74 |
+
self.num_workers = num_workers
|
75 |
+
self.preprocess_dir = preprocess_dir
|
76 |
+
self.entropy_model_name = entropy_model_name
|
77 |
+
self.arrow_batch_size = arrow_batch_size
|
78 |
+
if dataset_files is None:
|
79 |
+
# Prepare arrow shards
|
80 |
+
jsonl_file = Path(file_path)
|
81 |
+
parts = re.match(r"(.+)\.chunk\.[0-9]+\.jsonl", jsonl_file.name)
|
82 |
+
assert parts is not None
|
83 |
+
dataset = parts.group(1)
|
84 |
+
data_dir = Path(preprocess_dir) / dataset / entropy_model_name
|
85 |
+
shard_files = list(data_dir.glob(f"{jsonl_file.name}.shard_*.arrow"))
|
86 |
+
for s in shard_files:
|
87 |
+
if not (data_dir / f"{s.name}.complete").exists():
|
88 |
+
raise ValueError(f"Missing .complete for input file: {s}")
|
89 |
+
|
90 |
+
shard_files = sorted(shard_files, key=shard_sort_key)
|
91 |
+
if len(shard_files) == 0:
|
92 |
+
raise ByteLatentError(
|
93 |
+
f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow"
|
94 |
+
)
|
95 |
+
self.dataset_files = [str(f) for f in shard_files]
|
96 |
+
else:
|
97 |
+
self.preprocess_dir = None
|
98 |
+
self.dataset_files = dataset_files
|
99 |
+
|
100 |
+
def get_state(self) -> ArrowFileIteratorState:
|
101 |
+
return ArrowFileIteratorState(
|
102 |
+
file_path=self.file_path,
|
103 |
+
row_num=self.row_num,
|
104 |
+
worker_id=self.worker_id,
|
105 |
+
num_workers=self.num_workers,
|
106 |
+
preprocess_dir=self.preprocess_dir,
|
107 |
+
entropy_model_name=self.entropy_model_name,
|
108 |
+
arrow_batch_size=self.arrow_batch_size,
|
109 |
+
dataset_files=self.dataset_files,
|
110 |
+
)
|
111 |
+
|
112 |
+
def create_iter(
|
113 |
+
self,
|
114 |
+
) -> Generator[BltExample, Any, None]:
|
115 |
+
if self.dataset is None:
|
116 |
+
self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow")
|
117 |
+
self.batch_iterator = self.dataset.to_batches(
|
118 |
+
batch_size=self.arrow_batch_size
|
119 |
+
)
|
120 |
+
self.iter_id += 1
|
121 |
+
if self.batch_to_consume is not None:
|
122 |
+
batch_columns: dict[str, list] = self.batch_to_consume
|
123 |
+
self.batch_to_consume = None
|
124 |
+
sample_ids = batch_columns["sample_id"]
|
125 |
+
texts = batch_columns["text"]
|
126 |
+
entropies = batch_columns["entropies"]
|
127 |
+
for i in range(len(sample_ids)):
|
128 |
+
out = BltExample(
|
129 |
+
sample_id=sample_ids[i],
|
130 |
+
entropies=entropies[i],
|
131 |
+
text=texts[i],
|
132 |
+
tokens=None,
|
133 |
+
mask=None,
|
134 |
+
patch_lengths=None,
|
135 |
+
)
|
136 |
+
self.row_num += 1
|
137 |
+
if (self.row_num - 1) % self.num_workers == self.worker_id:
|
138 |
+
yield out
|
139 |
+
|
140 |
+
for batch in self.batch_iterator:
|
141 |
+
batch_columns = batch.to_pydict()
|
142 |
+
sample_ids = batch_columns["sample_id"]
|
143 |
+
texts = batch_columns["text"]
|
144 |
+
entropies = batch_columns["entropies"]
|
145 |
+
for i in range(len(sample_ids)):
|
146 |
+
out = BltExample(
|
147 |
+
sample_id=sample_ids[i],
|
148 |
+
entropies=entropies[i],
|
149 |
+
text=texts[i],
|
150 |
+
tokens=None,
|
151 |
+
mask=None,
|
152 |
+
patch_lengths=None,
|
153 |
+
)
|
154 |
+
self.row_num += 1
|
155 |
+
if (self.row_num - 1) % self.num_workers == self.worker_id:
|
156 |
+
yield out
|
157 |
+
|
158 |
+
def _set_row_num(self, target_row_num: int):
|
159 |
+
logger.info(
|
160 |
+
f"Setting arrow position to {target_row_num} for {self.dataset_files}"
|
161 |
+
)
|
162 |
+
if target_row_num is None or target_row_num == 0:
|
163 |
+
self.row_num = 0
|
164 |
+
self.dataset = None
|
165 |
+
self.batch_iterator = None
|
166 |
+
self.batch_to_consume = None
|
167 |
+
else:
|
168 |
+
self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow")
|
169 |
+
self.batch_iterator = self.dataset.to_batches(
|
170 |
+
batch_size=self.arrow_batch_size
|
171 |
+
)
|
172 |
+
curr_remaining = target_row_num
|
173 |
+
for batch in self.batch_iterator:
|
174 |
+
if len(batch) > curr_remaining:
|
175 |
+
batch_columns: dict[str, list] = batch.to_pydict()
|
176 |
+
batch_columns["sample_id"] = batch_columns["sample_id"][
|
177 |
+
curr_remaining:
|
178 |
+
]
|
179 |
+
batch_columns["entropies"] = batch_columns["entropies"][
|
180 |
+
curr_remaining:
|
181 |
+
]
|
182 |
+
batch_columns["text"] = batch_columns["text"][curr_remaining:]
|
183 |
+
self.batch_to_consume = batch_columns
|
184 |
+
break
|
185 |
+
elif len(batch) == curr_remaining:
|
186 |
+
# We are exactly at the end of the batch,
|
187 |
+
# so the next batch is the right spot
|
188 |
+
break
|
189 |
+
else:
|
190 |
+
curr_remaining -= len(batch)
|
191 |
+
self.row_num = target_row_num
|
192 |
+
logger.info(
|
193 |
+
f"Finished setting arrow position to {target_row_num} for {self.dataset_files}"
|
194 |
+
)
|
195 |
+
|
196 |
+
|
197 |
+
TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl"
|
198 |
+
|
199 |
+
|
200 |
+
def find_and_sanitize_chunks(
|
201 |
+
dataset_path: str, world_size: int, file_pattern: str = TRAIN_DATA_FILE_PATTERN
|
202 |
+
):
|
203 |
+
dataset_chunks = [str(p) for p in Path(dataset_path).glob(file_pattern)]
|
204 |
+
n_chunks = len(dataset_chunks)
|
205 |
+
|
206 |
+
if n_chunks > world_size:
|
207 |
+
n_discard = n_chunks - world_size
|
208 |
+
dataset_chunks = dataset_chunks[:world_size]
|
209 |
+
else:
|
210 |
+
assert (
|
211 |
+
world_size % n_chunks == 0
|
212 |
+
), "World size should be a multiple of number of chunks"
|
213 |
+
|
214 |
+
assert n_chunks > 0, f"No valid chunks in {dataset_path}"
|
215 |
+
|
216 |
+
return dataset_chunks
|
bytelatent/data/iterators/looping_iterator.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
from pydantic import BaseModel
|
3 |
+
|
4 |
+
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
|
5 |
+
from bytelatent.data.iterators.arrow_iterator import (
|
6 |
+
ArrowFileIterator,
|
7 |
+
ArrowFileIteratorState,
|
8 |
+
)
|
9 |
+
|
10 |
+
|
11 |
+
class LoopingIteratorState(BaseModel, IteratorState):
|
12 |
+
file_iterator_state: ArrowFileIteratorState
|
13 |
+
epoch: int
|
14 |
+
|
15 |
+
def build(self) -> "LoopingIterator":
|
16 |
+
return LoopingIterator(
|
17 |
+
file_iterator=self.file_iterator_state.build(),
|
18 |
+
epoch=self.epoch,
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
class LoopingIterator(StatefulIterator):
|
23 |
+
def __init__(self, file_iterator: ArrowFileIterator, epoch: int = -1):
|
24 |
+
self.file_iterator = file_iterator
|
25 |
+
self.epoch = epoch
|
26 |
+
|
27 |
+
def get_state(self):
|
28 |
+
return LoopingIteratorState(
|
29 |
+
file_iterator_state=self.file_iterator.get_state(), epoch=self.epoch
|
30 |
+
)
|
31 |
+
|
32 |
+
def create_iter(self):
|
33 |
+
while True:
|
34 |
+
self.epoch += 1
|
35 |
+
iterator = self.file_iterator.create_iter()
|
36 |
+
yield from iterator
|
bytelatent/data/iterators/multiprocess_iterator.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
import multiprocessing as mp
|
5 |
+
from multiprocessing.synchronize import Event as EventClass
|
6 |
+
from queue import Empty, Full
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
from pydantic import BaseModel, ConfigDict
|
10 |
+
|
11 |
+
from bytelatent.data.data_types import Batch
|
12 |
+
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
|
13 |
+
from bytelatent.data.iterators.packing_iterator import PackingIteratorState
|
14 |
+
|
15 |
+
logger = logging.getLogger()
|
16 |
+
|
17 |
+
|
18 |
+
class MultiprocessIteratorState(BaseModel, IteratorState):
|
19 |
+
model_config = ConfigDict(extra="forbid")
|
20 |
+
base_iterator_state: PackingIteratorState
|
21 |
+
n_batches_to_prefetch: int
|
22 |
+
serialized_prefetch_buffer: str
|
23 |
+
|
24 |
+
def build(self):
|
25 |
+
base_iterator = self.base_iterator_state.build()
|
26 |
+
data = json.loads(self.serialized_prefetch_buffer)
|
27 |
+
prefetch_buffer = [Batch.from_python_dict(item) for item in data]
|
28 |
+
return MultiprocessIterator(
|
29 |
+
base_iterator,
|
30 |
+
n_batches_to_prefetch=self.n_batches_to_prefetch,
|
31 |
+
prefetch_buffer=prefetch_buffer,
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
def start_work_from_state(
|
36 |
+
batch_queue: mp.Queue,
|
37 |
+
state_queue: mp.Queue,
|
38 |
+
stop_event: EventClass,
|
39 |
+
state_dumped_event: EventClass,
|
40 |
+
state: IteratorState,
|
41 |
+
):
|
42 |
+
logging.info("Worker thread: Starting base_iterator work")
|
43 |
+
stateful_iterator = state.build()
|
44 |
+
iterator = stateful_iterator.create_iter()
|
45 |
+
for item in iterator:
|
46 |
+
while not stop_event.is_set():
|
47 |
+
try:
|
48 |
+
# Attempt to put on queue or timeout to try again (maybe main thread is busy)
|
49 |
+
batch_queue.put(item, timeout=0.1)
|
50 |
+
# On success, stop trying
|
51 |
+
break
|
52 |
+
except Full:
|
53 |
+
pass
|
54 |
+
if stop_event.is_set():
|
55 |
+
# Signal the end of output, this ensures that even if the queue takes a while to
|
56 |
+
# buffer, that the main thread receives everything (and tosses this fake batch)
|
57 |
+
logging.info(
|
58 |
+
"Worker thread: Stop event detected, outputting is_final=True batch"
|
59 |
+
)
|
60 |
+
batch_queue.put(
|
61 |
+
Batch(
|
62 |
+
x=np.zeros((1, 1)),
|
63 |
+
y=np.zeros((1, 1)),
|
64 |
+
is_final=True,
|
65 |
+
mask=None,
|
66 |
+
patch_lengths=None,
|
67 |
+
ngram_ids=None,
|
68 |
+
)
|
69 |
+
)
|
70 |
+
break
|
71 |
+
|
72 |
+
try:
|
73 |
+
logging.info("Worker thread: outputting state")
|
74 |
+
state_queue.put(iterator.get_state(), timeout=1)
|
75 |
+
logging.info("Worker thread: state dump complete")
|
76 |
+
state_dumped_event.set()
|
77 |
+
logging.info("Worker thread: set state_dump_event")
|
78 |
+
except Full:
|
79 |
+
raise ValueError(
|
80 |
+
"Attempted to dump state into the state queue, but it was full"
|
81 |
+
)
|
82 |
+
|
83 |
+
|
84 |
+
class MultiprocessIterator(StatefulIterator):
|
85 |
+
"""
|
86 |
+
Design sketch of the multiprocess iterator:
|
87 |
+
|
88 |
+
Given the base_iterator, the only thing we do with this is call get_state()
|
89 |
+
so that we can pass that through to the background worker process.
|
90 |
+
|
91 |
+
The background process will receive this, rebuild the iterator, then start yielding from it.
|
92 |
+
|
93 |
+
However, in order to implement MultiprocessIterator.get_state(), we need to be able to accurately get
|
94 |
+
(1) the state of the iterator in the worker process
|
95 |
+
(2) the currently buffered items in the Queue
|
96 |
+
|
97 |
+
To do this, we use:
|
98 |
+
- batch_queue: This is the prefetch buffer the worker yields to and the main loop yields from
|
99 |
+
- state_queue: This size 1 queue will be how the worker sends the iterator state once it has halted iterating.
|
100 |
+
It must hold the state in addition to the last batch, if the queue was full at the time the stop event is sent.
|
101 |
+
- stop_iterating_event: Once this is issued from the main loop, the worker will stop iterating and enter cleanup.
|
102 |
+
During cleanup, the iterator will send the state of the current iterator to the main loop,
|
103 |
+
in addition to possibly the last batch if the batch_queue was full at the time
|
104 |
+
- state_dumped_event: When the main loop issues the stop_iterating_event, it will wait until the state_dumped_event to attempt
|
105 |
+
to get state from the state_queue. It must do this since the worker may take some time to create and send the state.
|
106 |
+
Once received by the main loop, the main loop can safely store the Queue (plus maybe the last batch) as the prefetch buffer,
|
107 |
+
get the worker iterator's state, and terminate the background process + delete associated objects.
|
108 |
+
|
109 |
+
At this point, calling create_iter() again will bootstrap everything from the stored state and the old iterator will throw an error
|
110 |
+
since it will not iterate anymore (so the caller must call create_iter() again to get a python iterator).
|
111 |
+
|
112 |
+
"""
|
113 |
+
|
114 |
+
def __init__(
|
115 |
+
self,
|
116 |
+
base_iterator: StatefulIterator,
|
117 |
+
*,
|
118 |
+
n_batches_to_prefetch: int,
|
119 |
+
prefetch_buffer: list | None = None
|
120 |
+
):
|
121 |
+
self.base_iterator = base_iterator
|
122 |
+
self.n_batches_to_prefetch = n_batches_to_prefetch
|
123 |
+
if prefetch_buffer is None:
|
124 |
+
prefetch_buffer = []
|
125 |
+
self.prefetch_buffer = prefetch_buffer
|
126 |
+
self.batch_queue = None
|
127 |
+
self.state_queue = None
|
128 |
+
self.producer = None
|
129 |
+
self.stop_iterating_event = None
|
130 |
+
self.state_dumped_event = None
|
131 |
+
|
132 |
+
def get_state(self) -> MultiprocessIteratorState:
|
133 |
+
"""
|
134 |
+
This is slightly unusual in effectively destroying the current iterator, its necessary
|
135 |
+
to halt the background process and allow it to write the state to the main loop
|
136 |
+
in order to not lose data
|
137 |
+
"""
|
138 |
+
if self.producer is None:
|
139 |
+
serialized_prefetch_buffer = json.dumps(
|
140 |
+
[b.to_python_dict() for b in self.prefetch_buffer]
|
141 |
+
)
|
142 |
+
return MultiprocessIteratorState(
|
143 |
+
base_iterator_state=self.base_iterator.get_state(),
|
144 |
+
n_batches_to_prefetch=self.n_batches_to_prefetch,
|
145 |
+
serialized_prefetch_buffer=serialized_prefetch_buffer,
|
146 |
+
)
|
147 |
+
else:
|
148 |
+
logging.info("Main thread: Sending stop iteration event")
|
149 |
+
self.stop_iterating_event.set()
|
150 |
+
logging.info("Main thread: Waiting for state_dumped event")
|
151 |
+
self.state_dumped_event.wait()
|
152 |
+
self.prefetch_buffer = []
|
153 |
+
final_batch_received = False
|
154 |
+
while True:
|
155 |
+
try:
|
156 |
+
batch = self.batch_queue.get(timeout=1)
|
157 |
+
if batch.is_final:
|
158 |
+
final_batch_received = True
|
159 |
+
break
|
160 |
+
self.prefetch_buffer.append(batch)
|
161 |
+
except Empty:
|
162 |
+
logging.warning("Main thread: batch_queue is abnormally empty")
|
163 |
+
assert final_batch_received
|
164 |
+
|
165 |
+
try:
|
166 |
+
base_iterator_state = self.state_queue.get(timeout=1)
|
167 |
+
assert isinstance(base_iterator_state, IteratorState)
|
168 |
+
except Empty:
|
169 |
+
raise ValueError(
|
170 |
+
"Attempted to get the state, but it was unexpectantly missing"
|
171 |
+
)
|
172 |
+
|
173 |
+
self.base_iterator = base_iterator_state.build()
|
174 |
+
self.producer.close()
|
175 |
+
self.producer = None
|
176 |
+
self.batch_queue = None
|
177 |
+
self.state_queue = None
|
178 |
+
self.stop_iterating_event = None
|
179 |
+
self.state_dumped_event = None
|
180 |
+
|
181 |
+
return MultiprocessIteratorState(
|
182 |
+
base_iterator_state=self.base_iterator.get_state(),
|
183 |
+
n_batches_to_prefetch=self.n_batches_to_prefetch,
|
184 |
+
serialized_prefetch_buffer=json.dumps(
|
185 |
+
[b.to_python_dict() for b in self.prefetch_buffer]
|
186 |
+
),
|
187 |
+
)
|
188 |
+
|
189 |
+
def create_iter(self):
|
190 |
+
logging.info("Main thread: Creating MP iterator")
|
191 |
+
# First yield from the stored prefetch buffer.
|
192 |
+
if self.prefetch_buffer is not None:
|
193 |
+
while len(self.prefetch_buffer) > 0:
|
194 |
+
item = self.prefetch_buffer.pop(0)
|
195 |
+
yield item
|
196 |
+
self.prefetch_buffer = None
|
197 |
+
|
198 |
+
assert (
|
199 |
+
self.producer is None
|
200 |
+
), "Cannot create two parallel iterators at once, call get_state() then remake to have two."
|
201 |
+
|
202 |
+
# using mp context manager avoids excessive CPU loading
|
203 |
+
ctx = mp.get_context("forkserver")
|
204 |
+
self.batch_queue = ctx.Manager().Queue(maxsize=self.n_batches_to_prefetch)
|
205 |
+
|
206 |
+
# We should only ever one state, which is output at the detection of a stop event
|
207 |
+
self.state_queue = ctx.Manager().Queue(maxsize=1)
|
208 |
+
|
209 |
+
self.stop_iterating_event = ctx.Event()
|
210 |
+
self.state_dumped_event = ctx.Event()
|
211 |
+
|
212 |
+
self.producer = mp.Process(
|
213 |
+
name="blt_data_loader",
|
214 |
+
target=start_work_from_state,
|
215 |
+
args=(
|
216 |
+
self.batch_queue,
|
217 |
+
self.state_queue,
|
218 |
+
self.stop_iterating_event,
|
219 |
+
self.state_dumped_event,
|
220 |
+
self.base_iterator.get_state(),
|
221 |
+
),
|
222 |
+
)
|
223 |
+
logger.info("Async dataloader started")
|
224 |
+
self.producer.start()
|
225 |
+
|
226 |
+
while True:
|
227 |
+
if self.producer.exitcode is not None:
|
228 |
+
raise RuntimeError(
|
229 |
+
"Data loader quit unexpectedly, real error has been raised previously"
|
230 |
+
)
|
231 |
+
try:
|
232 |
+
batch = self.batch_queue.get(timeout=0.1)
|
233 |
+
assert isinstance(batch, Batch)
|
234 |
+
assert (
|
235 |
+
not batch.is_final
|
236 |
+
), "is_final should only be used during get_state() being called"
|
237 |
+
yield batch
|
238 |
+
except Empty:
|
239 |
+
pass
|
240 |
+
if self.producer is None:
|
241 |
+
raise ValueError(
|
242 |
+
"Attempted to call this iterator after calling get_state(). You must call create_iter() to make a new iterator instead."
|
243 |
+
)
|
bytelatent/data/iterators/packing_iterator.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
from typing import Any
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from pydantic import BaseModel, ConfigDict
|
6 |
+
|
7 |
+
from bytelatent.data.data_types import Batch, BltSequence
|
8 |
+
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
|
9 |
+
from bytelatent.data.iterators.sampling_iterator import SamplingIteratorState
|
10 |
+
|
11 |
+
|
12 |
+
class PackingArgs(BaseModel):
|
13 |
+
model_config = ConfigDict(extra="forbid")
|
14 |
+
batch_size: int
|
15 |
+
seq_len: int
|
16 |
+
pad_id: int
|
17 |
+
max_length: int | None
|
18 |
+
pad_to_max_length: bool
|
19 |
+
enable_byte_ngrams: bool
|
20 |
+
|
21 |
+
|
22 |
+
class PackingIteratorState(BaseModel, IteratorState):
|
23 |
+
model_config = ConfigDict(extra="forbid")
|
24 |
+
sequence_iterator_state: SamplingIteratorState
|
25 |
+
packing_args: PackingArgs
|
26 |
+
|
27 |
+
def build(self) -> "PackingIterator":
|
28 |
+
return PackingIterator(
|
29 |
+
sequence_iterator=self.sequence_iterator_state.build(),
|
30 |
+
packing_args=self.packing_args,
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
def _merge_patch_seq_masks(bs, slen: int, mask_seqs: list[list[bool]]):
|
35 |
+
assert len(mask_seqs) == bs
|
36 |
+
lens = [len(m) for m in mask_seqs]
|
37 |
+
if all(all(m) for m in mask_seqs) and all(lens[0] == l for l in lens):
|
38 |
+
return None
|
39 |
+
assert slen == max(lens) - 1
|
40 |
+
mask = np.zeros((bs, slen), dtype=bool)
|
41 |
+
for i, m in enumerate(mask_seqs):
|
42 |
+
if m is None:
|
43 |
+
print(
|
44 |
+
"Did not implement None mask, the mask should be True for all toks, so we need to pass that to this function."
|
45 |
+
)
|
46 |
+
raise NotImplementedError
|
47 |
+
mask[i][: len(mask_seqs[i]) - 1] = mask_seqs[i][1:]
|
48 |
+
return mask
|
49 |
+
|
50 |
+
|
51 |
+
def truncate_batch(
|
52 |
+
batch: Batch,
|
53 |
+
max_length: int,
|
54 |
+
pad_id: int,
|
55 |
+
pad_to_max_length: bool = False,
|
56 |
+
*,
|
57 |
+
enable_byte_ngrams: bool,
|
58 |
+
):
|
59 |
+
"""
|
60 |
+
Truncate the x to a given size, making sure we remove the corresponding patch sizes in patch_lenghts
|
61 |
+
and fixing the batch.mask.
|
62 |
+
|
63 |
+
batch.patch_lengths has unchanged shape
|
64 |
+
x,y, and mask may reduce in size
|
65 |
+
"""
|
66 |
+
if batch.patch_lengths is None:
|
67 |
+
return batch
|
68 |
+
|
69 |
+
seq_lengths = batch.patch_lengths.sum(axis=1)
|
70 |
+
max_length_adj = max_length + 1
|
71 |
+
if np.any(seq_lengths > max_length_adj):
|
72 |
+
for i in range(batch.x.shape[0]):
|
73 |
+
if seq_lengths[i] > max_length_adj:
|
74 |
+
# Find id of patch that tips over max_length + 1
|
75 |
+
count, j = 0, 0
|
76 |
+
while count + batch.patch_lengths[i, j] <= max_length_adj:
|
77 |
+
count += batch.patch_lengths[i, j]
|
78 |
+
j += 1
|
79 |
+
# Edit the batch
|
80 |
+
assert j < batch.patch_lengths.shape[1]
|
81 |
+
batch.x[i, max_length:] = pad_id
|
82 |
+
batch.y[i, max_length:] = pad_id
|
83 |
+
if batch.mask is not None:
|
84 |
+
batch.mask[i, max_length:] = False
|
85 |
+
batch.patch_lengths[i, j:] = 0
|
86 |
+
batch.patch_lengths[i, j] = max_length_adj - count
|
87 |
+
|
88 |
+
# Truncate if necessary.
|
89 |
+
if max_length < batch.x.shape[1]:
|
90 |
+
batch.x = batch.x[:, :max_length]
|
91 |
+
batch.y = batch.y[:, :max_length]
|
92 |
+
if batch.mask is not None:
|
93 |
+
batch.mask = batch.mask[:, :max_length]
|
94 |
+
|
95 |
+
# Right pad to max_length if necessary
|
96 |
+
elif pad_to_max_length:
|
97 |
+
if batch.x.shape[1] < max_length:
|
98 |
+
# NOTE: this has to be done on an actual patch.
|
99 |
+
non_zero_indices = (batch.patch_lengths != 0).sum(axis=1) - 1
|
100 |
+
non_zero_indices = np.maximum(0, non_zero_indices)
|
101 |
+
batch.patch_lengths[range(len(batch.patch_lengths)), non_zero_indices] += (
|
102 |
+
max_length - batch.x.shape[1]
|
103 |
+
)
|
104 |
+
# TODO: We could get rid of many of these complications by moving this funciton directly in the dataloader.
|
105 |
+
x = np.full((batch.x.shape[0], max_length), pad_id, dtype=batch.x.dtype)
|
106 |
+
x[:, : batch.x.shape[1]] = batch.x
|
107 |
+
batch.x = x
|
108 |
+
if batch.y.shape[1] < max_length:
|
109 |
+
y = np.full((batch.y.shape[0], max_length), pad_id, dtype=batch.y.dtype)
|
110 |
+
y[:, : batch.y.shape[1]] = batch.y
|
111 |
+
batch.y = y
|
112 |
+
if batch.mask is not None and batch.mask.shape[1] < max_length:
|
113 |
+
mask = np.full(
|
114 |
+
(batch.mask.shape[0], max_length), False, dtype=batch.mask.dtype
|
115 |
+
)
|
116 |
+
mask[:, : batch.mask.shape[1]] = batch.mask
|
117 |
+
batch.mask = mask
|
118 |
+
|
119 |
+
assert batch.x.shape[1] <= max_length
|
120 |
+
assert batch.y.shape[1] <= max_length
|
121 |
+
assert batch.mask is None or batch.mask.shape[1] <= max_length
|
122 |
+
assert np.all(max_length_adj - batch.patch_lengths.sum(axis=1) == 0)
|
123 |
+
if pad_to_max_length:
|
124 |
+
assert batch.x.shape[1] == max_length
|
125 |
+
assert batch.y.shape[1] == max_length
|
126 |
+
assert batch.mask is None or batch.mask.shape[1] == max_length
|
127 |
+
if enable_byte_ngrams:
|
128 |
+
raise NotImplementedError()
|
129 |
+
# (num_ngram, batch_size, seq_len)
|
130 |
+
ngram_ids = np.array(tokenizer.encode_token_ngrams(batch.x))
|
131 |
+
assert ngram_ids.shape[2] == batch.x.shape[1]
|
132 |
+
else:
|
133 |
+
ngram_ids = None
|
134 |
+
batch.ngram_ids = ngram_ids
|
135 |
+
|
136 |
+
|
137 |
+
class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
|
138 |
+
def __init__(
|
139 |
+
self,
|
140 |
+
sequence_iterator: StatefulIterator[BltSequence, Any],
|
141 |
+
*,
|
142 |
+
packing_args: PackingArgs,
|
143 |
+
):
|
144 |
+
self.sequence_iterator = sequence_iterator
|
145 |
+
self.packing_args = packing_args
|
146 |
+
|
147 |
+
def get_state(self):
|
148 |
+
return PackingIteratorState(
|
149 |
+
sequence_iterator_state=self.sequence_iterator.get_state(),
|
150 |
+
packing_args=self.packing_args,
|
151 |
+
)
|
152 |
+
|
153 |
+
def create_iter(self):
|
154 |
+
sequence_iter = self.sequence_iterator.create_iter()
|
155 |
+
batch_size = self.packing_args.batch_size
|
156 |
+
pad_id = self.packing_args.pad_id
|
157 |
+
seq_len = self.packing_args.seq_len
|
158 |
+
pad_to_max_length = self.packing_args.pad_to_max_length
|
159 |
+
enable_byte_ngrams = self.packing_args.enable_byte_ngrams
|
160 |
+
max_length = self.packing_args.max_length
|
161 |
+
while True:
|
162 |
+
tokens: list[list[int]] = []
|
163 |
+
masks: list[list[bool]] = []
|
164 |
+
patch_lengths: list[list[int]] = []
|
165 |
+
|
166 |
+
for _ in range(self.packing_args.batch_size):
|
167 |
+
sequence = next(sequence_iter)
|
168 |
+
_tokens = sequence.tokens
|
169 |
+
_mask = sequence.mask
|
170 |
+
_patch_lengths = sequence.patch_lengths
|
171 |
+
assert len(sequence.patch_lengths) == self.packing_args.seq_len
|
172 |
+
last_patch_length = 0
|
173 |
+
if _patch_lengths[0] > 1:
|
174 |
+
last_patch_length = _patch_lengths[-1]
|
175 |
+
_patch_lengths[0] -= 1
|
176 |
+
_patch_lengths = [1] + _patch_lengths[:-1]
|
177 |
+
tokens.append(_tokens[: len(_tokens) - last_patch_length])
|
178 |
+
masks.append(_mask[: len(_mask) - last_patch_length])
|
179 |
+
patch_lengths.append(_patch_lengths)
|
180 |
+
|
181 |
+
x_patch_lengths = np.array(patch_lengths)
|
182 |
+
# pad batch to same length
|
183 |
+
tok_seq_len = max([len(toks) for toks in tokens]) - 1
|
184 |
+
x = np.full((batch_size, tok_seq_len), fill_value=pad_id)
|
185 |
+
y = np.full((batch_size, tok_seq_len), fill_value=pad_id)
|
186 |
+
|
187 |
+
for i, tok_seq in enumerate(tokens):
|
188 |
+
x[i, : len(tok_seq) - 1] = tok_seq[:-1]
|
189 |
+
y[i, : len(tok_seq) - 1] = tok_seq[1:]
|
190 |
+
# Adjust patch lengths to match x
|
191 |
+
x_patch_lengths[i, -1] += tok_seq_len - (len(tok_seq) - 1)
|
192 |
+
|
193 |
+
assert x_patch_lengths.shape == (batch_size, seq_len)
|
194 |
+
|
195 |
+
if enable_byte_ngrams:
|
196 |
+
raise NotImplementedError()
|
197 |
+
else:
|
198 |
+
ngram_ids = None
|
199 |
+
|
200 |
+
batch = Batch(
|
201 |
+
x=x,
|
202 |
+
y=y,
|
203 |
+
patch_lengths=x_patch_lengths,
|
204 |
+
ngram_ids=ngram_ids,
|
205 |
+
mask=_merge_patch_seq_masks(batch_size, tok_seq_len, masks),
|
206 |
+
)
|
207 |
+
assert (
|
208 |
+
x_patch_lengths.sum() == x.size + batch_size
|
209 |
+
), f"{x_patch_lengths.sum()} != {x.size + batch_size}"
|
210 |
+
assert (
|
211 |
+
batch.mask is None or np.sum(x != pad_id) == batch.mask.sum()
|
212 |
+
), f"{np.sum(x != pad_id)} != {batch.mask.sum()}"
|
213 |
+
assert np.all(
|
214 |
+
x_patch_lengths[:, 0] == 1
|
215 |
+
), f"first patch should always be 1, {x_patch_lengths[:, 0]}"
|
216 |
+
# cuda_gb_allocated = (torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024)
|
217 |
+
# cuda_gb_reserved = torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024
|
218 |
+
# print(f"dataloader cuda_gb_allocated: {cuda_gb_allocated}, cuda_gb_reserved: {cuda_gb_reserved}")
|
219 |
+
truncate_batch(
|
220 |
+
batch,
|
221 |
+
max_length=max_length,
|
222 |
+
pad_id=pad_id,
|
223 |
+
pad_to_max_length=pad_to_max_length,
|
224 |
+
enable_byte_ngrams=enable_byte_ngrams,
|
225 |
+
)
|
226 |
+
yield batch
|
bytelatent/data/iterators/preprocess_iterator.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
from typing import Any, Generator
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from pydantic import BaseModel, ConfigDict
|
6 |
+
|
7 |
+
from bytelatent.data.data_types import BltExample
|
8 |
+
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
|
9 |
+
from bytelatent.data.iterators.arrow_iterator import (
|
10 |
+
ArrowFileIterator,
|
11 |
+
ArrowFileIteratorState,
|
12 |
+
)
|
13 |
+
from bytelatent.data.iterators.looping_iterator import LoopingIteratorState
|
14 |
+
from bytelatent.data.patcher import Patcher, PatcherArgs, PatchingModeEnum
|
15 |
+
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
|
16 |
+
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
|
17 |
+
|
18 |
+
|
19 |
+
class PreprocessIteratorState(BaseModel, IteratorState):
|
20 |
+
model_config = ConfigDict(extra="forbid")
|
21 |
+
arrow_file_iterator_state: ArrowFileIteratorState | LoopingIteratorState
|
22 |
+
add_tokens: bool
|
23 |
+
add_patches: bool
|
24 |
+
tokenizer_args: TokenizerArgs
|
25 |
+
patcher_args: PatcherArgs
|
26 |
+
|
27 |
+
def build(self):
|
28 |
+
arrow_iterator = self.arrow_file_iterator_state.build()
|
29 |
+
return PreprocessIterator(
|
30 |
+
arrow_iterator,
|
31 |
+
patcher_args=self.patcher_args,
|
32 |
+
tokenizer_args=self.tokenizer_args,
|
33 |
+
add_tokens=self.add_tokens,
|
34 |
+
add_patches=self.add_patches,
|
35 |
+
)
|
36 |
+
|
37 |
+
|
38 |
+
class PreprocessIterator(StatefulIterator):
|
39 |
+
"""
|
40 |
+
Take BltExamples with fields filled in only from ArrowFileIterator, and fill in fields that require
|
41 |
+
preprocessing like tokenization and patching
|
42 |
+
"""
|
43 |
+
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
arrow_iterator: ArrowFileIterator,
|
47 |
+
*,
|
48 |
+
patcher_args: PatcherArgs,
|
49 |
+
tokenizer_args: TokenizerArgs,
|
50 |
+
add_tokens: bool = True,
|
51 |
+
add_patches: bool = True,
|
52 |
+
):
|
53 |
+
self.arrow_iterator = arrow_iterator
|
54 |
+
self.tokenizer_args = tokenizer_args
|
55 |
+
self.patcher_args = patcher_args
|
56 |
+
self.add_tokens = add_tokens
|
57 |
+
self.add_patches = add_patches
|
58 |
+
self.tokenizer: BltTokenizer | None = None
|
59 |
+
self.patcher: Patcher | None = None
|
60 |
+
|
61 |
+
def get_state(self) -> PreprocessIteratorState:
|
62 |
+
"""
|
63 |
+
The only state to maintain here is from arrow, there
|
64 |
+
isn't any internal state on this iterator.
|
65 |
+
"""
|
66 |
+
return PreprocessIteratorState(
|
67 |
+
arrow_file_iterator_state=self.arrow_iterator.get_state(),
|
68 |
+
tokenizer_args=self.tokenizer_args,
|
69 |
+
patcher_args=self.patcher_args,
|
70 |
+
add_tokens=self.add_tokens,
|
71 |
+
add_patches=self.add_patches,
|
72 |
+
)
|
73 |
+
|
74 |
+
def create_iter(self) -> Generator[BltExample, Any, None]:
|
75 |
+
if self.tokenizer is None and self.add_tokens:
|
76 |
+
self.tokenizer = self.tokenizer_args.build()
|
77 |
+
if self.patcher is None and self.add_patches:
|
78 |
+
self.patcher = self.patcher_args.build()
|
79 |
+
|
80 |
+
example_iter = self.arrow_iterator.create_iter()
|
81 |
+
for example in example_iter:
|
82 |
+
if self.add_tokens:
|
83 |
+
tokens = self.tokenizer.encode(example.text)
|
84 |
+
else:
|
85 |
+
tokens = example.tokens
|
86 |
+
if (
|
87 |
+
self.patcher is not None
|
88 |
+
and self.patcher.patching_mode == PatchingModeEnum.entropy
|
89 |
+
):
|
90 |
+
assert (
|
91 |
+
example.entropies is not None
|
92 |
+
), "For patching, entropies cannot be None"
|
93 |
+
entropies = torch.tensor(example.entropies).unsqueeze(0)
|
94 |
+
else:
|
95 |
+
entropies = None
|
96 |
+
if self.patcher is None:
|
97 |
+
patch_lengths = None
|
98 |
+
else:
|
99 |
+
patch_lengths = self.patcher.patch(
|
100 |
+
torch.tensor(tokens).unsqueeze(0),
|
101 |
+
include_next_token=False,
|
102 |
+
entropies=entropies,
|
103 |
+
)[0][0].tolist()
|
104 |
+
yield BltExample(
|
105 |
+
sample_id=example.sample_id,
|
106 |
+
text=example.text,
|
107 |
+
tokens=tokens,
|
108 |
+
mask=[True] * len(tokens),
|
109 |
+
patch_lengths=patch_lengths,
|
110 |
+
entropies=example.entropies,
|
111 |
+
)
|
bytelatent/data/iterators/sampling_iterator.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
from typing import Any
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from pydantic import BaseModel, ConfigDict
|
6 |
+
|
7 |
+
from bytelatent.data.iterators.abstract_iterator import StatefulIterator
|
8 |
+
from bytelatent.data.iterators.sequence_iterator import SequenceIteratorState
|
9 |
+
|
10 |
+
|
11 |
+
class SamplingIteratorState(BaseModel):
|
12 |
+
model_config = ConfigDict(extra="forbid")
|
13 |
+
rng_state: dict[str, Any]
|
14 |
+
source_to_weight: dict[str, float]
|
15 |
+
source_to_iterator_state: dict[str, SequenceIteratorState]
|
16 |
+
|
17 |
+
def build(self) -> "SamplingIterator":
|
18 |
+
return SamplingIterator(
|
19 |
+
rng_state=self.rng_state,
|
20 |
+
source_to_weight=self.source_to_weight,
|
21 |
+
source_to_iterator={
|
22 |
+
source: state.build()
|
23 |
+
for source, state in self.source_to_iterator_state.items()
|
24 |
+
},
|
25 |
+
)
|
26 |
+
|
27 |
+
|
28 |
+
class SamplingIterator(StatefulIterator):
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
*,
|
32 |
+
rng_state: dict[str, Any],
|
33 |
+
source_to_weight: dict[str, float],
|
34 |
+
source_to_iterator: dict[str, StatefulIterator],
|
35 |
+
):
|
36 |
+
self.rng = np.random.default_rng()
|
37 |
+
self.rng.bit_generator.state = rng_state
|
38 |
+
self.source_to_weight = source_to_weight
|
39 |
+
self.source_to_iterator = source_to_iterator
|
40 |
+
|
41 |
+
def get_state(self) -> SamplingIteratorState:
|
42 |
+
return SamplingIteratorState(
|
43 |
+
rng_state=self.rng.bit_generator.state,
|
44 |
+
source_to_weight=self.source_to_weight,
|
45 |
+
source_to_iterator_state={
|
46 |
+
source: iterator.get_state()
|
47 |
+
for source, iterator in self.source_to_iterator.items()
|
48 |
+
},
|
49 |
+
)
|
50 |
+
|
51 |
+
def create_iter(self):
|
52 |
+
n_sources = len(self.source_to_weight)
|
53 |
+
possible_sources = []
|
54 |
+
weights = []
|
55 |
+
for source, w in self.source_to_weight.items():
|
56 |
+
possible_sources.append(source)
|
57 |
+
weights.append(w)
|
58 |
+
|
59 |
+
source_to_python_iter = {
|
60 |
+
source: self.source_to_iterator[source].create_iter()
|
61 |
+
for source in possible_sources
|
62 |
+
}
|
63 |
+
while True:
|
64 |
+
norm_weights = np.array(weights) / np.array(weights).sum()
|
65 |
+
source_choice = possible_sources[self.rng.choice(n_sources, p=norm_weights)]
|
66 |
+
yield next(source_to_python_iter[source_choice])
|
bytelatent/data/iterators/sequence_iterator.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
from logging import getLogger
|
3 |
+
from typing import Any
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from pydantic import BaseModel, ConfigDict
|
7 |
+
|
8 |
+
from bytelatent.data.data_types import BltSequence
|
9 |
+
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
|
10 |
+
from bytelatent.data.iterators.preprocess_iterator import (
|
11 |
+
PreprocessIterator,
|
12 |
+
PreprocessIteratorState,
|
13 |
+
)
|
14 |
+
|
15 |
+
logger = getLogger()
|
16 |
+
|
17 |
+
|
18 |
+
class SequencePackingArgs(BaseModel):
|
19 |
+
model_config = ConfigDict(extra="forbid")
|
20 |
+
output_seq_len: int
|
21 |
+
buffer_size: int
|
22 |
+
|
23 |
+
|
24 |
+
class SequenceIteratorState(BaseModel, IteratorState):
|
25 |
+
model_config = ConfigDict(extra="forbid")
|
26 |
+
sequence_packing_args: SequencePackingArgs
|
27 |
+
preprocess_iterator_state: PreprocessIteratorState
|
28 |
+
rng_state: dict[str, Any]
|
29 |
+
|
30 |
+
def build(self):
|
31 |
+
preprocess_iterator = self.preprocess_iterator_state.build()
|
32 |
+
return SequenceIterator(
|
33 |
+
preprocess_iterator,
|
34 |
+
sequence_packing_args=self.sequence_packing_args,
|
35 |
+
rng_state=self.rng_state,
|
36 |
+
)
|
37 |
+
|
38 |
+
|
39 |
+
class SequenceIterator(StatefulIterator):
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
preprocess_iterator: PreprocessIterator,
|
43 |
+
*,
|
44 |
+
rng_state: dict[str, Any],
|
45 |
+
sequence_packing_args: SequencePackingArgs,
|
46 |
+
):
|
47 |
+
self.preprocess_iterator = preprocess_iterator
|
48 |
+
self.sequence_packing_args = sequence_packing_args
|
49 |
+
self.output_seq_len = sequence_packing_args.output_seq_len
|
50 |
+
self.buffer_size = sequence_packing_args.buffer_size
|
51 |
+
self.rng = np.random.default_rng()
|
52 |
+
self.rng.bit_generator.state = rng_state
|
53 |
+
|
54 |
+
def get_state(self):
|
55 |
+
# TODO: need to also perist the current shuffle buffer
|
56 |
+
return SequenceIteratorState(
|
57 |
+
sequence_packing_args=self.sequence_packing_args,
|
58 |
+
preprocess_iterator_state=self.preprocess_iterator.get_state(),
|
59 |
+
rng_state=self.rng.bit_generator.state,
|
60 |
+
)
|
61 |
+
|
62 |
+
def create_iter(self):
|
63 |
+
example_iter = self.preprocess_iterator.create_iter()
|
64 |
+
n_buffer_patches = self.buffer_size * self.output_seq_len
|
65 |
+
|
66 |
+
patch_lengths: list[int] = []
|
67 |
+
tokens: list[int] = []
|
68 |
+
mask: list[bool] = []
|
69 |
+
first = True
|
70 |
+
for example in example_iter:
|
71 |
+
assert example.tokens is not None
|
72 |
+
assert example.mask is not None
|
73 |
+
assert example.patch_lengths is not None
|
74 |
+
assert len(example.tokens) != 0
|
75 |
+
assert len(example.mask) != 0
|
76 |
+
assert len(example.tokens) == len(example.mask)
|
77 |
+
assert len(example.tokens) == sum(example.patch_lengths)
|
78 |
+
|
79 |
+
tokens.extend(example.tokens)
|
80 |
+
mask.extend(example.mask)
|
81 |
+
patch_lengths.extend(example.patch_lengths)
|
82 |
+
|
83 |
+
while len(patch_lengths) >= n_buffer_patches:
|
84 |
+
if first:
|
85 |
+
first = False
|
86 |
+
logger.info("First buffer complete")
|
87 |
+
|
88 |
+
x_patches = np.array(patch_lengths[:n_buffer_patches]).reshape(
|
89 |
+
self.buffer_size, self.output_seq_len
|
90 |
+
)
|
91 |
+
seq_tokens = []
|
92 |
+
seq_mask = []
|
93 |
+
start_id = 0
|
94 |
+
# We fix the number of patches and therefore global steps per batch
|
95 |
+
# so we have a variable number of tokens we need to account for
|
96 |
+
for num_tokens in x_patches.sum(axis=-1):
|
97 |
+
seq_tokens.append(tokens[start_id : start_id + num_tokens])
|
98 |
+
seq_mask.append(mask[start_id : start_id + num_tokens])
|
99 |
+
start_id += num_tokens
|
100 |
+
|
101 |
+
assert start_id == x_patches.sum()
|
102 |
+
|
103 |
+
# Remove what we just added from the buffer
|
104 |
+
patch_lengths = patch_lengths[n_buffer_patches:]
|
105 |
+
tokens = tokens[x_patches.sum() :]
|
106 |
+
mask = mask[x_patches.sum() :]
|
107 |
+
|
108 |
+
seq_patch_lengths: list[list[int]] = x_patches.tolist()
|
109 |
+
assert len(seq_patch_lengths) == self.buffer_size
|
110 |
+
for idx in self.rng.permutation(len(seq_patch_lengths)):
|
111 |
+
assert len(seq_patch_lengths[idx]) == self.output_seq_len
|
112 |
+
assert (
|
113 |
+
sum(seq_patch_lengths[idx])
|
114 |
+
== len(seq_tokens[idx])
|
115 |
+
== len(seq_mask[idx])
|
116 |
+
), f"{sum(seq_patch_lengths[idx])}, {len(seq_tokens[idx])} {len(seq_mask[idx])}, idx={idx}"
|
117 |
+
assert seq_patch_lengths[idx][0] > 0, f"{seq_patch_lengths[idx]}"
|
118 |
+
yield BltSequence(
|
119 |
+
tokens=seq_tokens[idx],
|
120 |
+
mask=seq_mask[idx],
|
121 |
+
patch_lengths=seq_patch_lengths[idx],
|
122 |
+
)
|
bytelatent/data/iterators/test_arrow_iterator.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
import numpy as np
|
3 |
+
import pyarrow as pa
|
4 |
+
|
5 |
+
# pyarrow needs the initialization from this import
|
6 |
+
import pyarrow.dataset # pyright: ignore
|
7 |
+
|
8 |
+
from bytelatent.constants import BLT_DATA
|
9 |
+
from bytelatent.data.iterators.arrow_iterator import ArrowFileIteratorState
|
10 |
+
|
11 |
+
ENTROPY_MODEL = "transformer_100m"
|
12 |
+
ARROW_TEST_DATA_1 = str(BLT_DATA / "stackexchange.chunk.00.jsonl.shard_00.arrow")
|
13 |
+
ARROW_TEST_DATA_2 = str(BLT_DATA / "stackexchange.chunk.00.jsonl.shard_01.arrow")
|
14 |
+
|
15 |
+
|
16 |
+
def test_basic_arrow_file():
|
17 |
+
dataset = pa.dataset.dataset(ARROW_TEST_DATA_1, format="arrow")
|
18 |
+
n_head = 1000
|
19 |
+
head_df = dataset.head(n_head).to_pandas()
|
20 |
+
|
21 |
+
initial_state = ArrowFileIteratorState(
|
22 |
+
file_path=None,
|
23 |
+
num_workers=1,
|
24 |
+
worker_id=0,
|
25 |
+
preprocess_dir=None,
|
26 |
+
entropy_model_name=ENTROPY_MODEL,
|
27 |
+
dataset_files=[ARROW_TEST_DATA_1],
|
28 |
+
row_num=0,
|
29 |
+
arrow_batch_size=100,
|
30 |
+
)
|
31 |
+
arrow_file = initial_state.build()
|
32 |
+
start_state = arrow_file.get_state()
|
33 |
+
assert start_state.row_num == initial_state.row_num
|
34 |
+
|
35 |
+
sample_id = None
|
36 |
+
for example in arrow_file.create_iter():
|
37 |
+
sample_id = example.sample_id
|
38 |
+
assert head_df.iloc[0]["sample_id"] == sample_id
|
39 |
+
break
|
40 |
+
|
41 |
+
assert arrow_file.get_state().row_num == 1
|
42 |
+
arrow_file = initial_state.build()
|
43 |
+
for example in arrow_file.create_iter():
|
44 |
+
assert example.sample_id == sample_id
|
45 |
+
assert head_df.iloc[0]["sample_id"] == sample_id
|
46 |
+
break
|
47 |
+
|
48 |
+
# Test resume far enough in to be past the batch size of 100
|
49 |
+
resumed_state = ArrowFileIteratorState(
|
50 |
+
file_path=None,
|
51 |
+
num_workers=1,
|
52 |
+
worker_id=0,
|
53 |
+
preprocess_dir=None,
|
54 |
+
entropy_model_name=ENTROPY_MODEL,
|
55 |
+
dataset_files=[ARROW_TEST_DATA_1],
|
56 |
+
row_num=251,
|
57 |
+
arrow_batch_size=100,
|
58 |
+
)
|
59 |
+
arrow_file = resumed_state.build()
|
60 |
+
for example in arrow_file.create_iter():
|
61 |
+
assert example.sample_id == head_df.iloc[251]["sample_id"]
|
62 |
+
assert arrow_file.get_state().row_num == 252
|
63 |
+
break
|
64 |
+
|
65 |
+
world_rank = 1
|
66 |
+
world_size = 4
|
67 |
+
# Test World Size and Rank
|
68 |
+
rank_state = ArrowFileIteratorState(
|
69 |
+
file_path=None,
|
70 |
+
num_workers=world_size,
|
71 |
+
worker_id=world_rank,
|
72 |
+
preprocess_dir=None,
|
73 |
+
entropy_model_name=ENTROPY_MODEL,
|
74 |
+
dataset_files=[ARROW_TEST_DATA_1],
|
75 |
+
row_num=0,
|
76 |
+
arrow_batch_size=100,
|
77 |
+
)
|
78 |
+
arrow_file = rank_state.build()
|
79 |
+
expected_ids = []
|
80 |
+
for i in range(n_head):
|
81 |
+
if i % world_size == world_rank:
|
82 |
+
expected_ids.append(head_df.iloc[i]["sample_id"])
|
83 |
+
print(len(expected_ids))
|
84 |
+
i = 0
|
85 |
+
for example in arrow_file.create_iter():
|
86 |
+
assert example.sample_id == expected_ids[i]
|
87 |
+
i += 1
|
88 |
+
if i >= len(expected_ids):
|
89 |
+
break
|
bytelatent/data/iterators/test_iters.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
import pandas as pd
|
3 |
+
from pydantic import BaseModel
|
4 |
+
|
5 |
+
from bytelatent.constants import BLT_DATA
|
6 |
+
from bytelatent.data.data_types import BltExample
|
7 |
+
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
|
8 |
+
from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator
|
9 |
+
from bytelatent.data.patcher import PatcherArgs, PatchingModeEnum
|
10 |
+
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
|
11 |
+
|
12 |
+
|
13 |
+
class BltTestIteratorState(BaseModel, IteratorState):
|
14 |
+
position: int
|
15 |
+
total: int
|
16 |
+
|
17 |
+
def build(self):
|
18 |
+
blt_iter = BltTestIteratorState(total=self.total)
|
19 |
+
blt_iter.position = self.position
|
20 |
+
return blt_iter
|
21 |
+
|
22 |
+
|
23 |
+
class BltTestIterator(StatefulIterator):
|
24 |
+
def __init__(self, total: int):
|
25 |
+
self.position = 0
|
26 |
+
self.total = total
|
27 |
+
|
28 |
+
def get_state(self):
|
29 |
+
return BltTestIteratorState(position=self.position, total=self.total)
|
30 |
+
|
31 |
+
def create_iter(self):
|
32 |
+
for i in range(self.total):
|
33 |
+
self.position += 1
|
34 |
+
yield BltExample(
|
35 |
+
sample_id=f"test_{i}",
|
36 |
+
text=f"This is some test {i} text.",
|
37 |
+
tokens=None,
|
38 |
+
mask=None,
|
39 |
+
entropies=None,
|
40 |
+
patch_lengths=None,
|
41 |
+
)
|
42 |
+
|
43 |
+
|
44 |
+
class BltTestWithEntropiesIteratorState(BaseModel, IteratorState):
|
45 |
+
position: int
|
46 |
+
total: int
|
47 |
+
|
48 |
+
def build(self):
|
49 |
+
blt_iter = BltTestWithEntropiesIteratorState(total=self.total)
|
50 |
+
blt_iter.position = self.position
|
51 |
+
return blt_iter
|
52 |
+
|
53 |
+
|
54 |
+
class BltTestWithEntropiesIterator(StatefulIterator):
|
55 |
+
def __init__(self, total: int):
|
56 |
+
self.position = 0
|
57 |
+
self.total = total
|
58 |
+
|
59 |
+
def get_state(self):
|
60 |
+
return BltTestIteratorState(position=self.position, total=self.total)
|
61 |
+
|
62 |
+
def create_iter(self):
|
63 |
+
text = "Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin."
|
64 |
+
df = pd.read_json("fixtures/tokens_with_entropies.json")
|
65 |
+
tokens = df["token_ids"].tolist()
|
66 |
+
entropies = df["entropies"].tolist()
|
67 |
+
# BOS and EOS
|
68 |
+
assert len(tokens) == len(text) + 2
|
69 |
+
for i in range(self.total):
|
70 |
+
self.position += 1
|
71 |
+
yield BltExample(
|
72 |
+
sample_id=f"test_{i}",
|
73 |
+
text=text,
|
74 |
+
tokens=tokens,
|
75 |
+
mask=[True] * len(tokens),
|
76 |
+
entropies=entropies,
|
77 |
+
patch_lengths=None,
|
78 |
+
)
|
79 |
+
|
80 |
+
|
81 |
+
def test_preprocess_iter():
|
82 |
+
total = 3
|
83 |
+
tokenizer_args = TokenizerArgs(
|
84 |
+
name="blt",
|
85 |
+
init_kwargs={
|
86 |
+
"bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model"
|
87 |
+
},
|
88 |
+
)
|
89 |
+
for mode in [
|
90 |
+
PatchingModeEnum.bpe,
|
91 |
+
PatchingModeEnum.space,
|
92 |
+
]:
|
93 |
+
data_it = BltTestIterator(total)
|
94 |
+
patcher_args = PatcherArgs(patching_mode=mode)
|
95 |
+
example_it = PreprocessIterator(
|
96 |
+
data_it, tokenizer_args=tokenizer_args, patcher_args=patcher_args
|
97 |
+
)
|
98 |
+
count = 0
|
99 |
+
for example in example_it.create_iter():
|
100 |
+
assert isinstance(example.tokens, list)
|
101 |
+
assert isinstance(example.tokens[0], int)
|
102 |
+
# BOS and EOS
|
103 |
+
assert len(example.tokens) == len(example.text) + 2
|
104 |
+
assert example.mask is not None
|
105 |
+
assert len(example.tokens) == len(example.mask)
|
106 |
+
count += 1
|
107 |
+
|
108 |
+
assert count == total
|
109 |
+
|
110 |
+
|
111 |
+
def test_non_entropy_patch_iter():
|
112 |
+
total = 3
|
113 |
+
tokenizer_args = TokenizerArgs(
|
114 |
+
name="blt",
|
115 |
+
init_kwargs={
|
116 |
+
"bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model"
|
117 |
+
},
|
118 |
+
)
|
119 |
+
for mode in [
|
120 |
+
PatchingModeEnum.bpe,
|
121 |
+
PatchingModeEnum.space,
|
122 |
+
]:
|
123 |
+
patcher_args = PatcherArgs(patching_mode=mode)
|
124 |
+
data_it = BltTestIterator(total)
|
125 |
+
example_it = PreprocessIterator(
|
126 |
+
data_it, tokenizer_args=tokenizer_args, patcher_args=patcher_args
|
127 |
+
)
|
128 |
+
|
129 |
+
count = 0
|
130 |
+
for example in example_it.create_iter():
|
131 |
+
assert isinstance(example.patch_lengths, list)
|
132 |
+
assert isinstance(example.patch_lengths[0], int)
|
133 |
+
assert len(example.tokens) == sum(example.patch_lengths)
|
134 |
+
count += 1
|
135 |
+
|
136 |
+
assert count == total
|
137 |
+
|
138 |
+
|
139 |
+
def test_entropy_patch_iter():
|
140 |
+
total = 2
|
141 |
+
patcher_args = PatcherArgs(
|
142 |
+
patching_mode=PatchingModeEnum.entropy, threshold=1.335442066192627
|
143 |
+
)
|
144 |
+
tokenizer_args = TokenizerArgs(
|
145 |
+
name="blt",
|
146 |
+
init_kwargs={
|
147 |
+
"bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model"
|
148 |
+
},
|
149 |
+
)
|
150 |
+
data_it = BltTestWithEntropiesIterator(total)
|
151 |
+
example_it = PreprocessIterator(
|
152 |
+
data_it, tokenizer_args=tokenizer_args, patcher_args=patcher_args
|
153 |
+
)
|
154 |
+
|
155 |
+
count = 0
|
156 |
+
for example in example_it.create_iter():
|
157 |
+
assert isinstance(example.patch_lengths, list)
|
158 |
+
assert isinstance(example.patch_lengths[0], int)
|
159 |
+
assert len(example.tokens) == sum(example.patch_lengths)
|
160 |
+
count += 1
|
161 |
+
|
162 |
+
assert count == total
|
bytelatent/data/ngram_processor.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
import pickle
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from bytelatent import ByteLatentError
|
8 |
+
|
9 |
+
LOOKUP_OFFSET = 4
|
10 |
+
|
11 |
+
|
12 |
+
def apply_lookup_table_wrapper(ngram_to_idx: dict[tuple, int], lookup_offset=1):
|
13 |
+
"""
|
14 |
+
Wrapper function for applying the lookup table to each n-gram.
|
15 |
+
|
16 |
+
:param ngram: Array of numbers representing an n-gram.
|
17 |
+
:param lookup_table: Dictionary where keys are tuples (n-grams) and values are the desired outputs.
|
18 |
+
:param lookup_offset: Offset to add to the lookup result.
|
19 |
+
:return: The value associated with the n-gram tuple in the dictionary, or None if not found.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def apply_lookup_table(ngram):
|
23 |
+
"""
|
24 |
+
Function to apply to each n-gram: converts it to a tuple and looks it up in a dictionary.
|
25 |
+
|
26 |
+
:param ngram: Array of numbers representing an n-gram.
|
27 |
+
:return: The value associated with the n-gram tuple in the dictionary, or None if not found.
|
28 |
+
"""
|
29 |
+
# Convert the n-gram to a tuple
|
30 |
+
ngram_tuple = tuple(ngram)
|
31 |
+
|
32 |
+
if ngram_tuple not in ngram_to_idx:
|
33 |
+
return 0
|
34 |
+
else:
|
35 |
+
return ngram_to_idx[ngram_tuple] + lookup_offset
|
36 |
+
|
37 |
+
return apply_lookup_table
|
38 |
+
|
39 |
+
|
40 |
+
def get_byte_ngrams_ids(
|
41 |
+
byte_array: np.ndarray, n: int, ngram_to_idx: dict[tuple, int], pad_value=0
|
42 |
+
):
|
43 |
+
"""
|
44 |
+
Generate n-grams from a 2D numpy array.
|
45 |
+
|
46 |
+
:param n: The length of each n-gram.
|
47 |
+
:param pad_value: The value used for padding of the byte values to maintain the same dimensions for the n-grams.
|
48 |
+
:return: A 2D numpy array where each element is the ID of an n-gram offset by LOOKUP_OFFSET.
|
49 |
+
"""
|
50 |
+
num_rows, num_cols = byte_array.shape
|
51 |
+
|
52 |
+
# Create an array to hold the padded version of the original array
|
53 |
+
padded_array = np.pad(
|
54 |
+
byte_array, ((0, 0), (n - 1, 0)), mode="constant", constant_values=pad_value
|
55 |
+
)
|
56 |
+
|
57 |
+
# Use stride tricks to avoid explicit looping
|
58 |
+
strided = np.lib.stride_tricks.as_strided
|
59 |
+
shape = (num_rows, num_cols, n)
|
60 |
+
strides = padded_array.strides[:2] + (padded_array.strides[1],)
|
61 |
+
ngrams = strided(padded_array, shape=shape, strides=strides)
|
62 |
+
|
63 |
+
ngram_ids = np.apply_along_axis(
|
64 |
+
apply_lookup_table_wrapper(ngram_to_idx, lookup_offset=LOOKUP_OFFSET), 2, ngrams
|
65 |
+
)
|
66 |
+
assert ngram_ids.shape == byte_array.shape
|
67 |
+
return ngram_ids
|
68 |
+
|
69 |
+
|
70 |
+
def reload_tables(
|
71 |
+
ngram_table_dir: str, ngram_to_size: dict[int, int], offset: int = LOOKUP_OFFSET
|
72 |
+
) -> tuple[dict[int, list], dict[tuple, int], dict[int, int]]:
|
73 |
+
"""
|
74 |
+
Reload lookup tables from a directory. Reload only the ngrams in the dictionary and per ngram,
|
75 |
+
only load up to the max specified size. Return the actual number of ngrams taken per ngram size.
|
76 |
+
"""
|
77 |
+
idx_to_ngram_tables = {}
|
78 |
+
ngram_to_idx_tables = {}
|
79 |
+
vocab_sizes = {}
|
80 |
+
for ngram, size in ngram_to_size.items():
|
81 |
+
with open(Path(ngram_table_dir) / f"ngram-{ngram}.pickle", "rb") as f:
|
82 |
+
# These are already sorted by count
|
83 |
+
# Value: tuple of: count, ngram, dataset
|
84 |
+
ngram_data: list[tuple[tuple, tuple[int, int, str]]] = pickle.load(f)[
|
85 |
+
"counts"
|
86 |
+
]
|
87 |
+
table = [ngram for ngram, _ in ngram_data][:size]
|
88 |
+
if len(table) != size:
|
89 |
+
raise ValueError(
|
90 |
+
f"Ngram table for {ngram}-gram is not large enough to get {size} ngrams, max size is {len(ngram_data)}"
|
91 |
+
)
|
92 |
+
ngram_to_idx = {ngram: idx for idx, ngram in enumerate(table)}
|
93 |
+
actual_size = len(table)
|
94 |
+
idx_to_ngram_tables[ngram] = table
|
95 |
+
ngram_to_idx_tables[ngram] = ngram_to_idx
|
96 |
+
vocab_sizes[ngram] = actual_size + offset
|
97 |
+
return ngram_to_idx_tables, ngram_to_idx_tables, vocab_sizes
|
98 |
+
|
99 |
+
|
100 |
+
def parse_ngram_to_size(ngram_to_size_str: str | None) -> dict[int, int]:
|
101 |
+
if ngram_to_size_str is None:
|
102 |
+
return None
|
103 |
+
ngram_to_size = {}
|
104 |
+
for entry in ngram_to_size_str.split(","):
|
105 |
+
ngram, size = entry.split(":")
|
106 |
+
ngram = int(ngram)
|
107 |
+
size = int(size)
|
108 |
+
ngram_to_size[ngram] = size
|
109 |
+
return ngram_to_size
|
110 |
+
|
111 |
+
|
112 |
+
class NgramProcessor:
|
113 |
+
def __init__(
|
114 |
+
self,
|
115 |
+
ngram_table_dir: str | None = None,
|
116 |
+
ngram_to_size: dict[int, int] | None = None,
|
117 |
+
):
|
118 |
+
if ngram_table_dir is None or ngram_to_size is None:
|
119 |
+
raise ByteLatentError(
|
120 |
+
"ngram_table_dir and ngram_to_size cannot be none if enable_byte_ngrams is True"
|
121 |
+
)
|
122 |
+
(
|
123 |
+
self.ngram_to_idx_tables,
|
124 |
+
self.idx_to_ngram_tables,
|
125 |
+
self.ngram_vocab_sizes,
|
126 |
+
) = reload_tables(ngram_table_dir, ngram_to_size)
|
127 |
+
# Lowest to highest ngram
|
128 |
+
self.ngram_sizes = sorted(list(self.ngram_to_idx_tables.keys()))
|
129 |
+
# Although the model might not use all the ngrams, we need the tokenizer
|
130 |
+
# to produce ngram_ids such that index zero is the 2-gram, later on in
|
131 |
+
# src.model.megabyte.Megabyte.forward
|
132 |
+
assert self.ngram_sizes[0] == 2
|
133 |
+
|
134 |
+
def encode_single_ngram_table(self, data: np.ndarray, n: int):
|
135 |
+
"""
|
136 |
+
Return the n-grams of the input data for a given n
|
137 |
+
numpy array with ids of shape data.shape
|
138 |
+
"""
|
139 |
+
return get_byte_ngrams_ids(data, n, self.ngram_to_idx_tables[n], pad_value=0)
|
140 |
+
|
141 |
+
def encode_token_ngrams(self, data: np.ndarray):
|
142 |
+
"""
|
143 |
+
Return the n-grams of the input data.
|
144 |
+
output shape: [ids with data.shape for n in self.ngram_sizes]
|
145 |
+
"""
|
146 |
+
return [self.encode_single_ngram_table(data, n) for n in self.ngram_sizes]
|
bytelatent/data/patcher.py
ADDED
@@ -0,0 +1,609 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
import math
|
3 |
+
import time
|
4 |
+
from collections import defaultdict
|
5 |
+
from enum import Enum
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from pydantic import BaseModel
|
9 |
+
from torch.nn import functional as F
|
10 |
+
|
11 |
+
from bytelatent.distributed import get_local_rank
|
12 |
+
from bytelatent.entropy_model import load_entropy_model
|
13 |
+
|
14 |
+
# from src.slurm import get_local_rank
|
15 |
+
from bytelatent.tokenizers.blt_tokenizer import BPE_ID, OFFSET
|
16 |
+
from bytelatent.tokenizers.constants import BPE_ID, OFFSET
|
17 |
+
|
18 |
+
|
19 |
+
class PatchingModeEnum(str, Enum):
|
20 |
+
entropy = "entropy"
|
21 |
+
bpe = "bpe"
|
22 |
+
bpe_patcher = "bpe_patcher"
|
23 |
+
space = "space"
|
24 |
+
|
25 |
+
|
26 |
+
class PatcherArgs(BaseModel):
|
27 |
+
patching_mode: PatchingModeEnum = PatchingModeEnum.entropy
|
28 |
+
patching_device: str = "cuda"
|
29 |
+
entropy_model_checkpoint_dir: str | None = None
|
30 |
+
realtime_patching: bool = False
|
31 |
+
threshold: float = 1.335442066192627
|
32 |
+
threshold_add: float | None = None
|
33 |
+
max_patch_length: int | None = None
|
34 |
+
patch_size: float = 4.5
|
35 |
+
patching_batch_size: int = 1
|
36 |
+
data_loader_patching: bool = False
|
37 |
+
device: str = "cuda"
|
38 |
+
monotonicity: bool = False
|
39 |
+
log_time: bool = False
|
40 |
+
|
41 |
+
def build(self) -> "Patcher":
|
42 |
+
return Patcher(self)
|
43 |
+
|
44 |
+
|
45 |
+
def entropy(scores):
|
46 |
+
"""
|
47 |
+
scores: [bs, seq_len, vocab]
|
48 |
+
returns [bs, seq_len]
|
49 |
+
|
50 |
+
Computes the entropy for each token in the batch.
|
51 |
+
Note: uses natural log.
|
52 |
+
"""
|
53 |
+
log_probs = F.log_softmax(scores, dim=-1)
|
54 |
+
probs = torch.exp(log_probs)
|
55 |
+
p_log_p = log_probs * probs
|
56 |
+
entropy = -p_log_p.sum(dim=-1)
|
57 |
+
return entropy
|
58 |
+
|
59 |
+
|
60 |
+
def calculate_entropies(
|
61 |
+
tokens: torch.tensor, entropy_model, patching_batch_size, device: str | None = None
|
62 |
+
):
|
63 |
+
"""
|
64 |
+
tokens: 2D tensor of shape [batch_size, seq_len]
|
65 |
+
Return 2D tensor of shape [batch_size, seq_len] with entropies for each token.
|
66 |
+
|
67 |
+
Splits the tokens into chunks of size max_length and calculates entropies for each chunk.
|
68 |
+
Entropy model can be executed on cpu or gpu, specify either 'cuda' or 'cpu' in the device argument.
|
69 |
+
"""
|
70 |
+
with torch.no_grad():
|
71 |
+
entropies = []
|
72 |
+
max_length = getattr(entropy_model, "max_length", 8192)
|
73 |
+
batch_numel = max_length * patching_batch_size
|
74 |
+
splits = torch.split(tokens.flatten(), batch_numel)
|
75 |
+
for split in splits:
|
76 |
+
pad_size = (max_length - (split.numel() % max_length)) % max_length
|
77 |
+
pad = torch.zeros(
|
78 |
+
pad_size, dtype=split.dtype, device=split.device, requires_grad=False
|
79 |
+
)
|
80 |
+
split = torch.cat((split, pad), dim=0)
|
81 |
+
split = split.reshape(-1, max_length)
|
82 |
+
if device is not None:
|
83 |
+
split = split.to(device)
|
84 |
+
assert torch.all(split >= 0) and torch.all(split < 260)
|
85 |
+
pred, _ = entropy_model(split)
|
86 |
+
pred = pred.reshape(-1, pred.shape[-1])[
|
87 |
+
: split.numel() - pad_size, :
|
88 |
+
] # [batch_size * seq_len, vocab]
|
89 |
+
pred_entropies = entropy(pred)
|
90 |
+
entropies.append(pred_entropies)
|
91 |
+
|
92 |
+
entropies = torch.cat(entropies, dim=0)
|
93 |
+
entropies = entropies.reshape(tokens.shape)
|
94 |
+
return entropies
|
95 |
+
|
96 |
+
|
97 |
+
def patch_start_mask_from_entropy_with_monotonicity(entropies, t):
|
98 |
+
"""
|
99 |
+
entropies: [bs, seq_len] torch tensor of entropies
|
100 |
+
t: threshold
|
101 |
+
returns [bs, seq_len] mask where True indicates the start of a patch
|
102 |
+
"""
|
103 |
+
bs, seq_len = entropies.shape
|
104 |
+
mask = torch.zeros_like(entropies, dtype=torch.bool)
|
105 |
+
mask[:, 0] = True
|
106 |
+
|
107 |
+
# Calculate differences between consecutive elements along the sequence length
|
108 |
+
differences = entropies[:, 1:] - entropies[:, :-1]
|
109 |
+
|
110 |
+
# Calculate conditions for all elements except the first one in each sequence
|
111 |
+
condition = differences > t
|
112 |
+
|
113 |
+
# Update the mask based on the condition
|
114 |
+
mask[:, 1:] = condition
|
115 |
+
|
116 |
+
return mask
|
117 |
+
|
118 |
+
|
119 |
+
def patch_start_mask_global_and_monotonicity(entropies, t, t_add=0):
|
120 |
+
"""
|
121 |
+
entropies: [bs, seq_len] torch tensor of entropies
|
122 |
+
t: threshold
|
123 |
+
returns [bs, seq_len] mask where True indicates the start of a patch
|
124 |
+
"""
|
125 |
+
bs, seq_len = entropies.shape
|
126 |
+
mask = torch.zeros_like(entropies, dtype=torch.bool)
|
127 |
+
mask[:, 0] = True
|
128 |
+
|
129 |
+
# Calculate differences between consecutive elements along the sequence length
|
130 |
+
differences = entropies[:, 1:] - entropies[:, :-1]
|
131 |
+
|
132 |
+
# Calculate conditions for all elements except the first one in each sequence
|
133 |
+
condition = (differences > t_add) & (entropies[:, 1:] > t) & (~mask[:, :-1])
|
134 |
+
|
135 |
+
# Update the mask based on the condition
|
136 |
+
mask[:, 1:] = condition
|
137 |
+
|
138 |
+
return mask
|
139 |
+
|
140 |
+
|
141 |
+
def patch_start_ids_from_patch_start_mask(patch_start_mask):
|
142 |
+
bs, trunc_seq_len = patch_start_mask.shape
|
143 |
+
max_patches = patch_start_mask.sum(dim=1).max()
|
144 |
+
if max_patches == 0:
|
145 |
+
patch_start_ids = torch.full(
|
146 |
+
(bs, trunc_seq_len),
|
147 |
+
trunc_seq_len,
|
148 |
+
dtype=torch.long,
|
149 |
+
device=patch_start_mask.device,
|
150 |
+
)
|
151 |
+
else:
|
152 |
+
patch_ids = (
|
153 |
+
torch.arange(trunc_seq_len, device=patch_start_mask.device)
|
154 |
+
.unsqueeze(0)
|
155 |
+
.repeat(bs, 1)
|
156 |
+
)
|
157 |
+
extra_patch_ids = torch.full(
|
158 |
+
(bs, trunc_seq_len),
|
159 |
+
trunc_seq_len,
|
160 |
+
dtype=torch.long,
|
161 |
+
device=patch_start_mask.device,
|
162 |
+
)
|
163 |
+
all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1)
|
164 |
+
patch_start_mask_padded = torch.cat(
|
165 |
+
(patch_start_mask, ~patch_start_mask), dim=1
|
166 |
+
)
|
167 |
+
patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape(
|
168 |
+
bs, trunc_seq_len
|
169 |
+
)[:, :max_patches]
|
170 |
+
return patch_start_ids
|
171 |
+
|
172 |
+
|
173 |
+
def check_non_zero_after_zero(tensor):
|
174 |
+
zero_mask = tensor == 0
|
175 |
+
shifted_mask = torch.cat(
|
176 |
+
[
|
177 |
+
torch.zeros(tensor.shape[0], 1, dtype=torch.bool, device=tensor.device),
|
178 |
+
zero_mask[:, :-1],
|
179 |
+
],
|
180 |
+
dim=1,
|
181 |
+
)
|
182 |
+
non_zero_after_zero = (tensor != 0) & shifted_mask
|
183 |
+
return non_zero_after_zero.any()
|
184 |
+
|
185 |
+
|
186 |
+
def patch_lengths_from_start_ids(patch_start_ids, seq_len):
|
187 |
+
"""
|
188 |
+
Calculate patch lengths from start ids.
|
189 |
+
start ids: ex: [0, 1, 7, 7, 7, 7, 7], it has the start ids of the patches (here 0, 1), and then
|
190 |
+
the rest are filled to the seq len.
|
191 |
+
seq_len: ex: 7 length of the sequence
|
192 |
+
|
193 |
+
returns the patch lengths:
|
194 |
+
[1, 6] for the above example.
|
195 |
+
"""
|
196 |
+
last_ids = torch.full_like(patch_start_ids[:, :1], seq_len - 1)
|
197 |
+
patch_end_ids = torch.cat((patch_start_ids[:, 1:] - 1, last_ids), dim=1)
|
198 |
+
patch_lengths = patch_end_ids - patch_start_ids + 1
|
199 |
+
assert torch.all(patch_lengths >= 0), f"{patch_lengths}"
|
200 |
+
assert not check_non_zero_after_zero(patch_lengths), f"{patch_lengths}"
|
201 |
+
return patch_lengths
|
202 |
+
|
203 |
+
|
204 |
+
def find_space_patch_start_ids(tokens):
|
205 |
+
bs, seq_len = tokens.shape
|
206 |
+
tokens_no_offset = tokens - OFFSET
|
207 |
+
patch_end_mask = (
|
208 |
+
(tokens_no_offset < ord("0"))
|
209 |
+
| ((ord("9") < tokens_no_offset) & (tokens_no_offset < ord("A")))
|
210 |
+
| ((ord("Z") < tokens_no_offset) & (tokens_no_offset < ord("a")))
|
211 |
+
| ((ord("z") < tokens_no_offset) & (tokens_no_offset < 0b1000_0000))
|
212 |
+
| (0b1100_0000 <= tokens_no_offset)
|
213 |
+
)
|
214 |
+
patch_end_mask[:, 1:] &= patch_end_mask[:, :-1].bitwise_not()
|
215 |
+
patch_end_mask |= tokens < OFFSET
|
216 |
+
|
217 |
+
patch_start_mask = torch.cat(
|
218 |
+
[
|
219 |
+
torch.tensor([1, 1], device=tokens.device, dtype=torch.bool)
|
220 |
+
.unsqueeze(0)
|
221 |
+
.repeat(bs, 1),
|
222 |
+
patch_end_mask[:, 1:],
|
223 |
+
],
|
224 |
+
dim=1,
|
225 |
+
)
|
226 |
+
max_patches = patch_start_mask.sum(dim=1).max()
|
227 |
+
|
228 |
+
patch_ids = (
|
229 |
+
torch.arange(seq_len + 1, device=tokens.device).unsqueeze(0).repeat(bs, 1)
|
230 |
+
)
|
231 |
+
extra_patch_ids = torch.full(
|
232 |
+
(bs, seq_len + 1), seq_len + 1, dtype=torch.long, device=tokens.device
|
233 |
+
)
|
234 |
+
all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1)
|
235 |
+
patch_start_mask_padded = torch.cat((patch_start_mask, ~patch_start_mask), dim=1)
|
236 |
+
|
237 |
+
patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape(bs, -1)[
|
238 |
+
:, :max_patches
|
239 |
+
]
|
240 |
+
return patch_start_ids
|
241 |
+
|
242 |
+
|
243 |
+
def to_device(entropy_model, device=None):
|
244 |
+
if device == "cuda":
|
245 |
+
rank = get_local_rank()
|
246 |
+
device = f"cuda:{rank}"
|
247 |
+
entropy_model = entropy_model.to(device)
|
248 |
+
return entropy_model, device
|
249 |
+
|
250 |
+
|
251 |
+
def model_pred_to_bpe_patching_pred(pred):
|
252 |
+
_, indices = torch.max(pred, dim=1)
|
253 |
+
return indices == BPE_ID
|
254 |
+
|
255 |
+
|
256 |
+
def apply_bpe_patcher(tokens, bpe_patcher, patching_batch_size, device=None):
|
257 |
+
assert tokens.device == torch.device(
|
258 |
+
"cpu"
|
259 |
+
), f"{tokens.device} != cpu expects tokens to be on cpu"
|
260 |
+
with torch.no_grad():
|
261 |
+
bpe_patcher_device, device = to_device(
|
262 |
+
bpe_patcher, device
|
263 |
+
) # Get entropy model to right rank device.
|
264 |
+
bpe_patching_mask = []
|
265 |
+
max_length = getattr(bpe_patcher, "max_length", 8192)
|
266 |
+
batch_numel = max_length * patching_batch_size
|
267 |
+
splits = torch.split(tokens.flatten(), batch_numel)
|
268 |
+
for split in splits:
|
269 |
+
pad_size = (max_length - (split.numel() % max_length)) % max_length
|
270 |
+
pad = torch.zeros(
|
271 |
+
pad_size, dtype=split.dtype, device=split.device, requires_grad=False
|
272 |
+
)
|
273 |
+
split = torch.cat((split, pad), dim=0)
|
274 |
+
split = split.reshape(-1, max_length).to(device)
|
275 |
+
assert torch.all(split >= 0) and torch.all(split < 260)
|
276 |
+
pred = bpe_patcher_device(split)
|
277 |
+
pred_cpu = pred[0].cpu()
|
278 |
+
pred_cpu = pred_cpu.reshape(-1, pred_cpu.shape[-1])[
|
279 |
+
: split.numel() - pad_size, :
|
280 |
+
] # [batch_size * seq_len, vocab]
|
281 |
+
bpe_patching_pred = model_pred_to_bpe_patching_pred(pred_cpu)
|
282 |
+
bpe_patching_mask.append(bpe_patching_pred)
|
283 |
+
bpe_patching_mask = torch.cat(bpe_patching_mask, dim=0)
|
284 |
+
bpe_patching_mask = bpe_patching_mask.reshape(tokens.shape)
|
285 |
+
return bpe_patching_mask
|
286 |
+
|
287 |
+
|
288 |
+
def find_bpe_patcher_patch_start_ids(
|
289 |
+
tokens, bpe_patcher, patching_batch_size, device=None, include_next_token=True
|
290 |
+
):
|
291 |
+
bs, seq_len = tokens.shape
|
292 |
+
|
293 |
+
first_ids = (
|
294 |
+
torch.tensor([0, 1], dtype=torch.long, device=tokens.device)
|
295 |
+
.unsqueeze(0)
|
296 |
+
.repeat(bs, 1)
|
297 |
+
)
|
298 |
+
preds_truncation_len = first_ids.shape[1]
|
299 |
+
token_input = tokens[:, 1:] if include_next_token else tokens[:, 1:-1]
|
300 |
+
if token_input.shape[1] >= 1:
|
301 |
+
patch_start_mask = apply_bpe_patcher(
|
302 |
+
token_input, bpe_patcher, patching_batch_size, device
|
303 |
+
)
|
304 |
+
assert (
|
305 |
+
patch_start_mask.shape[1]
|
306 |
+
== tokens.shape[1] + include_next_token - preds_truncation_len
|
307 |
+
), f"{patch_start_mask.shape[1]} != {tokens.shape[1] + include_next_token - preds_truncation_len}"
|
308 |
+
patch_start_ids = patch_start_ids_from_patch_start_mask(patch_start_mask)
|
309 |
+
patch_start_ids = torch.cat(
|
310 |
+
(first_ids, patch_start_ids + preds_truncation_len), dim=1
|
311 |
+
)
|
312 |
+
else:
|
313 |
+
patch_start_ids = first_ids
|
314 |
+
return patch_start_ids
|
315 |
+
|
316 |
+
|
317 |
+
def find_entropy_patch_start_ids(
|
318 |
+
entropies,
|
319 |
+
patch_size=None,
|
320 |
+
threshold=None,
|
321 |
+
threshold_add=None,
|
322 |
+
monotonicity=False,
|
323 |
+
include_next_token=True,
|
324 |
+
):
|
325 |
+
"""
|
326 |
+
Use entropies to find the start ids of each patch.
|
327 |
+
Use patch_size or threshold to figure out the total number of patches to allocate.
|
328 |
+
|
329 |
+
When threshold is not None the number of patches is not constant between
|
330 |
+
different sequences, but patches can be identified incrementally rather than
|
331 |
+
decided globally using the entire sequence.
|
332 |
+
"""
|
333 |
+
bs, seq_len = entropies.shape[:2]
|
334 |
+
|
335 |
+
first_ids = (
|
336 |
+
torch.tensor([0, 1], dtype=torch.long, device=entropies.device)
|
337 |
+
.unsqueeze(0)
|
338 |
+
.repeat(bs, 1)
|
339 |
+
)
|
340 |
+
preds_truncation_len = first_ids.shape[
|
341 |
+
1
|
342 |
+
] # remove the first preds because they will be start of patches.
|
343 |
+
entropies = entropies[:, 1:]
|
344 |
+
if threshold is None:
|
345 |
+
num_patches = seq_len // patch_size
|
346 |
+
patch_start_ids = entropies.topk(num_patches - 2, dim=1).indices
|
347 |
+
patch_start_ids = patch_start_ids.sort(dim=1).values
|
348 |
+
else:
|
349 |
+
# Assumes that there is at least one token going over the threshold
|
350 |
+
if monotonicity:
|
351 |
+
patch_start_mask = patch_start_mask_from_entropy_with_monotonicity(
|
352 |
+
entropies, threshold
|
353 |
+
)
|
354 |
+
elif threshold_add is not None and threshold is not None:
|
355 |
+
patch_start_mask = patch_start_mask_global_and_monotonicity(
|
356 |
+
entropies, threshold, threshold_add
|
357 |
+
)
|
358 |
+
else:
|
359 |
+
patch_start_mask = entropies > threshold
|
360 |
+
if not include_next_token:
|
361 |
+
patch_start_mask = patch_start_mask[:, :-1]
|
362 |
+
# patch_start_mask[1:] |= tokens[:-1] < OFFSET
|
363 |
+
patch_start_ids = patch_start_ids_from_patch_start_mask(patch_start_mask)
|
364 |
+
|
365 |
+
patch_start_ids = torch.cat(
|
366 |
+
(first_ids, patch_start_ids + preds_truncation_len), dim=1
|
367 |
+
)
|
368 |
+
return patch_start_ids
|
369 |
+
|
370 |
+
|
371 |
+
def rightpad(seq, pad_id, max_len):
|
372 |
+
return seq + [pad_id] * (max_len - len(seq))
|
373 |
+
|
374 |
+
|
375 |
+
def find_bpe_delim_patch_start_ids(tokens, delim):
|
376 |
+
ids = (tokens[:, :-1] == delim).nonzero(as_tuple=False)
|
377 |
+
out = [[0, 1] for _ in range(tokens.shape[0])]
|
378 |
+
for x, y in ids:
|
379 |
+
# start is at delim + 1, delim should be the last element in the patch.
|
380 |
+
out[x.item()].append(y.item() + 1)
|
381 |
+
max_len = max([len(elt) for elt in out])
|
382 |
+
out = [rightpad(elt, tokens.shape[1], max_len) for elt in out]
|
383 |
+
patch_start_ids = torch.tensor(out, dtype=tokens.dtype, device=tokens.device)
|
384 |
+
return patch_start_ids
|
385 |
+
|
386 |
+
|
387 |
+
def find_lookup_table_start_mask(
|
388 |
+
tokens: torch.Tensor, lookup_table: torch.Tensor, include_next_token=True
|
389 |
+
):
|
390 |
+
window_size = lookup_table.ndim
|
391 |
+
# Unfold the tensor to get sliding windows
|
392 |
+
unfolded = tokens.unfold(1, window_size, 1)
|
393 |
+
# Gather indices for each dimension
|
394 |
+
indices = [unfolded[..., i] for i in range(window_size)]
|
395 |
+
# Access the lookup table using the gathered indices
|
396 |
+
result = lookup_table[indices]
|
397 |
+
return result
|
398 |
+
|
399 |
+
|
400 |
+
def find_lookup_table_patch_start_ids(
|
401 |
+
tokens: torch.Tensor, lookup_table: torch.Tensor, include_next_token=True
|
402 |
+
):
|
403 |
+
bs, seq_len = tokens.shape
|
404 |
+
|
405 |
+
first_ids = (
|
406 |
+
torch.tensor([0, 1], dtype=torch.long, device=tokens.device)
|
407 |
+
.unsqueeze(0)
|
408 |
+
.repeat(bs, 1)
|
409 |
+
)
|
410 |
+
preds_truncation_len = first_ids.shape[1]
|
411 |
+
window_size = lookup_table.ndim
|
412 |
+
assert window_size == 2, f"{window_size} != 2"
|
413 |
+
# output dimensions: token_input shape - window_size + 1 --> we want first ids + this = tokens shape + 1 if next token otherwise just token shape
|
414 |
+
token_input = (
|
415 |
+
tokens if include_next_token else tokens[:, : -preds_truncation_len + 1]
|
416 |
+
)
|
417 |
+
if token_input.shape[1] >= window_size:
|
418 |
+
patch_start_mask = find_lookup_table_start_mask(
|
419 |
+
token_input, lookup_table, include_next_token
|
420 |
+
)
|
421 |
+
assert (
|
422 |
+
patch_start_mask.shape[1]
|
423 |
+
== tokens.shape[1] + include_next_token - preds_truncation_len
|
424 |
+
), f"{patch_start_mask.shape[1]} != {tokens.shape[1] + include_next_token - preds_truncation_len}"
|
425 |
+
patch_start_ids = patch_start_ids_from_patch_start_mask(patch_start_mask)
|
426 |
+
patch_start_ids = torch.cat(
|
427 |
+
(first_ids, patch_start_ids + preds_truncation_len), dim=1
|
428 |
+
)
|
429 |
+
else:
|
430 |
+
patch_start_ids = first_ids
|
431 |
+
return patch_start_ids
|
432 |
+
|
433 |
+
|
434 |
+
def split_large_numbers(lst, m):
|
435 |
+
new_lst = []
|
436 |
+
for i in lst:
|
437 |
+
if i > m:
|
438 |
+
while i > m:
|
439 |
+
new_lst.append(m)
|
440 |
+
i -= m
|
441 |
+
new_lst.append(i)
|
442 |
+
else:
|
443 |
+
new_lst.append(i)
|
444 |
+
assert sum(new_lst) == sum(lst), f"{sum(new_lst)} != {sum(lst)}"
|
445 |
+
return new_lst
|
446 |
+
|
447 |
+
|
448 |
+
class Patcher:
|
449 |
+
def __init__(self, patcher_args: PatcherArgs):
|
450 |
+
self.patcher_args = patcher_args
|
451 |
+
self.patching_mode = patcher_args.patching_mode
|
452 |
+
self.realtime_patching = patcher_args.realtime_patching
|
453 |
+
if self.realtime_patching:
|
454 |
+
assert (
|
455 |
+
patcher_args.entropy_model_checkpoint_dir is not None
|
456 |
+
), "Cannot require realtime patching without an entropy model checkpoint"
|
457 |
+
entropy_model = load_entropy_model(
|
458 |
+
patcher_args.entropy_model_checkpoint_dir
|
459 |
+
)
|
460 |
+
entropy_model, _ = to_device(entropy_model, patcher_args.patching_device)
|
461 |
+
self.entropy_model = entropy_model
|
462 |
+
else:
|
463 |
+
self.entropy_model = None
|
464 |
+
self.threshold = patcher_args.threshold
|
465 |
+
self.threshold_add = patcher_args.threshold_add
|
466 |
+
self.max_patch_length = patcher_args.max_patch_length
|
467 |
+
self.patch_size = patcher_args.patch_size
|
468 |
+
self.patching_batch_size = patcher_args.patching_batch_size
|
469 |
+
self.data_loader_patching = patcher_args.data_loader_patching
|
470 |
+
self.device = patcher_args.device
|
471 |
+
self.monotonicity = patcher_args.monotonicity
|
472 |
+
self.log_time = patcher_args.log_time
|
473 |
+
if self.log_time:
|
474 |
+
self.log = defaultdict(float)
|
475 |
+
|
476 |
+
def patch(
|
477 |
+
self,
|
478 |
+
tokens: torch.Tensor,
|
479 |
+
include_next_token: bool = False,
|
480 |
+
preds: torch.Tensor | None = None,
|
481 |
+
entropies: torch.Tensor | None = None,
|
482 |
+
threshold: float = None,
|
483 |
+
) -> torch.Tensor:
|
484 |
+
"""
|
485 |
+
tokens: 2D tensor of shape [batch_size, seq_len] that needs to be patched
|
486 |
+
Returns patch lengths and optionally scores associated with the tokens (i.e. entropies, logprobs etc.)
|
487 |
+
-> output tensor: [batch_size, max_num_patches]
|
488 |
+
each tensor is processed independently and gets right padded with zeros.
|
489 |
+
|
490 |
+
Patching with the following modes:
|
491 |
+
1. patching_mode = None: static patch size
|
492 |
+
2. patching_mode = "entropy":
|
493 |
+
calculate entropy of each token, allocate patches so that the total
|
494 |
+
number of patches is the same as static patching but choose to begin
|
495 |
+
patches on tokens where the model is most uncertain (highest entropy).
|
496 |
+
|
497 |
+
When threshold is provided, it uses the threshold to decide when to
|
498 |
+
start a new patch.
|
499 |
+
3. patching_mode = "space":
|
500 |
+
use space like tokens to define the patches.
|
501 |
+
4. patching_mode = "bpe":
|
502 |
+
use bpe delim tokens to define the patches.
|
503 |
+
|
504 |
+
To correctly patch the last token, it may be necessary to include the next token in the patch
|
505 |
+
lengths calculations. This is controlled by the include_next_token argument.
|
506 |
+
"""
|
507 |
+
bs, seq_len = tokens.shape
|
508 |
+
seq_len_next_tok = seq_len + 1 if include_next_token else seq_len
|
509 |
+
scores = None
|
510 |
+
# STATIC
|
511 |
+
if self.patching_mode is None:
|
512 |
+
patch_lengths = torch.zeros(
|
513 |
+
(bs, math.ceil(seq_len_next_tok / self.patch_size)),
|
514 |
+
dtype=tokens.dtype,
|
515 |
+
device=tokens.device,
|
516 |
+
).fill_(self.patch_size)
|
517 |
+
if seq_len_next_tok % self.patch_size != 0:
|
518 |
+
patch_lengths[:, -1] = seq_len_next_tok % self.patch_size
|
519 |
+
# ENTROPY
|
520 |
+
elif self.patching_mode == PatchingModeEnum.entropy:
|
521 |
+
if self.log_time:
|
522 |
+
s = time.time()
|
523 |
+
if entropies is not None:
|
524 |
+
scores = torch.tensor(entropies, dtype=torch.float32)
|
525 |
+
elif preds is not None:
|
526 |
+
scores = entropy(preds)
|
527 |
+
else:
|
528 |
+
start_entropies = time.time()
|
529 |
+
scores = calculate_entropies(
|
530 |
+
tokens,
|
531 |
+
self.entropy_model,
|
532 |
+
self.patching_batch_size,
|
533 |
+
self.device,
|
534 |
+
)
|
535 |
+
if self.log_time:
|
536 |
+
self.log["calculate_entropies"] += time.time() - s
|
537 |
+
s = time.time()
|
538 |
+
patch_start_ids = find_entropy_patch_start_ids(
|
539 |
+
scores,
|
540 |
+
self.patch_size,
|
541 |
+
include_next_token=include_next_token,
|
542 |
+
threshold=threshold if threshold is not None else self.threshold,
|
543 |
+
threshold_add=self.threshold_add,
|
544 |
+
monotonicity=self.monotonicity,
|
545 |
+
)
|
546 |
+
if self.log_time:
|
547 |
+
self.log["find_entropy_patch_start_ids"] += time.time() - s
|
548 |
+
s = time.time()
|
549 |
+
patch_lengths = patch_lengths_from_start_ids(
|
550 |
+
patch_start_ids, seq_len_next_tok
|
551 |
+
)
|
552 |
+
if self.log_time:
|
553 |
+
self.log["patch_lengths_from_start_ids"] += time.time() - s
|
554 |
+
s = time.time()
|
555 |
+
# BPE
|
556 |
+
elif self.patching_mode == PatchingModeEnum.bpe:
|
557 |
+
patch_start_ids = find_bpe_delim_patch_start_ids(tokens, delim=BPE_ID)
|
558 |
+
patch_lengths = patch_lengths_from_start_ids(
|
559 |
+
patch_start_ids, seq_len_next_tok
|
560 |
+
)
|
561 |
+
elif self.patching_mode == PatchingModeEnum.bpe_patcher:
|
562 |
+
patch_start_ids = find_bpe_patcher_patch_start_ids(
|
563 |
+
tokens,
|
564 |
+
self.entropy_model,
|
565 |
+
self.patching_batch_size,
|
566 |
+
self.device,
|
567 |
+
include_next_token,
|
568 |
+
)
|
569 |
+
patch_lengths = patch_lengths_from_start_ids(
|
570 |
+
patch_start_ids, seq_len_next_tok
|
571 |
+
)
|
572 |
+
# SPACE
|
573 |
+
elif self.patching_mode == PatchingModeEnum.space:
|
574 |
+
patch_start_ids = find_space_patch_start_ids(tokens)
|
575 |
+
patch_lengths = patch_lengths_from_start_ids(
|
576 |
+
patch_start_ids, seq_len_next_tok
|
577 |
+
)
|
578 |
+
else:
|
579 |
+
raise NotImplementedError(f"self.patching_mode {self.patching_mode}")
|
580 |
+
|
581 |
+
# Apply any processing to patch lengths
|
582 |
+
if self.max_patch_length is not None:
|
583 |
+
# TODO: avoid going back to a list here.
|
584 |
+
patch_lengths = [
|
585 |
+
split_large_numbers(pl, self.max_patch_length)
|
586 |
+
for pl in patch_lengths.tolist()
|
587 |
+
]
|
588 |
+
max_len = max([len(pl) for pl in patch_lengths])
|
589 |
+
patch_lengths = [rightpad(pl, 0, max_len=max_len) for pl in patch_lengths]
|
590 |
+
patch_lengths = torch.tensor(
|
591 |
+
patch_lengths, dtype=tokens.dtype, device=tokens.device
|
592 |
+
)
|
593 |
+
assert not check_non_zero_after_zero(patch_lengths)
|
594 |
+
# Find the last non-zero column index using argmax on a reversed version of the tensor
|
595 |
+
last_non_zero_col_reversed = (
|
596 |
+
(patch_lengths != 0).flip(dims=[1]).int().argmax(dim=1).min()
|
597 |
+
)
|
598 |
+
# Slice the tensor up to the last non-zero column
|
599 |
+
patch_lengths = patch_lengths[
|
600 |
+
:, : patch_lengths.shape[1] - last_non_zero_col_reversed
|
601 |
+
]
|
602 |
+
assert (
|
603 |
+
torch.sum(patch_lengths)
|
604 |
+
== tokens.numel() + include_next_token * tokens.shape[0]
|
605 |
+
), f"{torch.sum(patch_lengths)} != {tokens.numel() + include_next_token * tokens.shape[0]}"
|
606 |
+
if self.log_time:
|
607 |
+
self.log["postprocessing_patch_lengths"] += time.time() - s
|
608 |
+
self.log["tokens"] += patch_lengths.sum().item()
|
609 |
+
return patch_lengths, scores
|
bytelatent/distributed.py
ADDED
@@ -0,0 +1,478 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
import atexit
|
4 |
+
import contextlib
|
5 |
+
import logging
|
6 |
+
import multiprocessing as mp
|
7 |
+
import os
|
8 |
+
import random
|
9 |
+
import shutil
|
10 |
+
import signal
|
11 |
+
import socket
|
12 |
+
import subprocess
|
13 |
+
import sys
|
14 |
+
import tempfile
|
15 |
+
from dataclasses import asdict, dataclass
|
16 |
+
from functools import lru_cache, partial, reduce
|
17 |
+
from itertools import chain
|
18 |
+
from typing import List, Optional, Tuple, Union
|
19 |
+
|
20 |
+
import torch
|
21 |
+
|
22 |
+
# for no recompute ops
|
23 |
+
import xformers.ops
|
24 |
+
from pydantic import BaseModel, ConfigDict
|
25 |
+
from torch import distributed as dist
|
26 |
+
from torch.distributed import ReduceOp
|
27 |
+
from torch.distributed._composable.fsdp import MixedPrecisionPolicy, fully_shard
|
28 |
+
from torch.distributed._tensor import DTensor
|
29 |
+
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
30 |
+
checkpoint_wrapper,
|
31 |
+
)
|
32 |
+
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
33 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
34 |
+
from torch.utils.checkpoint import (
|
35 |
+
CheckpointPolicy,
|
36 |
+
create_selective_checkpoint_contexts,
|
37 |
+
)
|
38 |
+
|
39 |
+
from bytelatent.float8 import convert_linears_to_fp8
|
40 |
+
|
41 |
+
logger = logging.getLogger()
|
42 |
+
|
43 |
+
# for selective AC
|
44 |
+
default_no_recompute_ops = {
|
45 |
+
torch.ops.aten.mm.default,
|
46 |
+
torch.ops.aten._scaled_mm.default,
|
47 |
+
torch.ops.aten._scaled_dot_product_efficient_attention.default,
|
48 |
+
torch.ops.aten._scaled_dot_product_flash_attention.default,
|
49 |
+
torch.ops.c10d_functional.reduce_scatter_tensor.default,
|
50 |
+
torch.ops.xformers_flash.flash_fwd.default,
|
51 |
+
torch.ops.xformers.efficient_attention_forward_cutlass.default,
|
52 |
+
}
|
53 |
+
|
54 |
+
|
55 |
+
class DistributedArgs(BaseModel):
|
56 |
+
model_config = ConfigDict(extra="forbid")
|
57 |
+
dp_shard: int = (
|
58 |
+
1 # In how many shard to split the model weight. Typically number gpu in a node.
|
59 |
+
)
|
60 |
+
dp_replicate: int = (
|
61 |
+
1 # How many times to replicate the model weight. Typically number of nodes.
|
62 |
+
)
|
63 |
+
tp_size: int = 1
|
64 |
+
selective_activation_checkpointing: bool = False
|
65 |
+
compile: bool = False
|
66 |
+
fsdp_type: str = "no_shard"
|
67 |
+
model_dtype: str = "bf16"
|
68 |
+
float8_recipe: str | None = None
|
69 |
+
float8_filter: str = r"layers\.[0-9]+\."
|
70 |
+
|
71 |
+
matmul_allow_tf32: bool = False
|
72 |
+
allow_bf16_reduced_precision_reduction: bool = True
|
73 |
+
detect_anomaly: bool = False
|
74 |
+
|
75 |
+
compile_cache_size_limit: int = 8
|
76 |
+
|
77 |
+
spawn_method: str = "forkserver"
|
78 |
+
|
79 |
+
|
80 |
+
class EnvironmentArgs(BaseModel):
|
81 |
+
model_config = ConfigDict(extra="forbid")
|
82 |
+
# Use GNU openMP (GOMP) instead of Intel OpenMP [Intel Math Kernel Library (MKL)]
|
83 |
+
MKL_SERVICE_FORCE_INTEL: str = "GNU"
|
84 |
+
OMP_NUM_THREADS: str = "1"
|
85 |
+
MKL_NUM_THREADS: str = "1"
|
86 |
+
# faster intra-node collectives, seems to be a cluster specific flag
|
87 |
+
ENABLE_INTRA_NODE_COMM: str = "1"
|
88 |
+
# avoids OOMs with long context
|
89 |
+
TORCH_NCCL_AVOID_RECORD_STREAMS: str = "1"
|
90 |
+
# increasing NCCL timeout time before having some NCCL error 22 should give a 16s timeout
|
91 |
+
NCCL_IB_TIMEOUT: str = "22"
|
92 |
+
NCCL_DEBUG: str = "INFO"
|
93 |
+
TORCH_NCCL_ASYNC_ERROR_HANDLING: str = "1"
|
94 |
+
|
95 |
+
|
96 |
+
def get_device_mesh(distributed_args: DistributedArgs):
|
97 |
+
tp_size = distributed_args.tp_size
|
98 |
+
dp_replicate = distributed_args.dp_replicate
|
99 |
+
dp_shard = distributed_args.dp_shard
|
100 |
+
|
101 |
+
assert (
|
102 |
+
dp_replicate * dp_shard * tp_size == get_world_size()
|
103 |
+
), f"dp_replicate * dp_shard * tp_size ({dp_replicate} * {dp_shard} * {tp_size}) != world_size ({get_world_size()})"
|
104 |
+
|
105 |
+
dims = []
|
106 |
+
names = []
|
107 |
+
if dp_replicate >= 1:
|
108 |
+
dims.append(dp_replicate)
|
109 |
+
names.append("dp_replicate")
|
110 |
+
if dp_shard > 1 or distributed_args.fsdp_type == "no_shard":
|
111 |
+
dims.append(dp_shard)
|
112 |
+
names.append("dp_shard")
|
113 |
+
if tp_size > 1:
|
114 |
+
dims.append(tp_size)
|
115 |
+
names.append("tp")
|
116 |
+
dims = tuple(dims)
|
117 |
+
names = tuple(names)
|
118 |
+
|
119 |
+
return init_device_mesh("cuda", mesh_shape=dims, mesh_dim_names=names)
|
120 |
+
|
121 |
+
|
122 |
+
def dist_max(x: Union[int, float], mesh: DeviceMesh = None):
|
123 |
+
tensor = torch.tensor(x).cuda()
|
124 |
+
dist.all_reduce(tensor, op=ReduceOp.MAX, group=mesh.get_group() if mesh else None)
|
125 |
+
return tensor
|
126 |
+
|
127 |
+
|
128 |
+
def dist_mean(x: Union[int, float], mesh: DeviceMesh = None):
|
129 |
+
tensor = torch.tensor(x).cuda()
|
130 |
+
dist.all_reduce(tensor, op=ReduceOp.AVG, group=mesh.get_group() if mesh else None)
|
131 |
+
return tensor
|
132 |
+
|
133 |
+
|
134 |
+
def dist_mean_dict(x):
|
135 |
+
r = dict()
|
136 |
+
for k in x:
|
137 |
+
r[k] = dist_mean(x[k])
|
138 |
+
r[k] = r[k].item() if (r[k].dim() == 0) else r[k].tolist()
|
139 |
+
return r
|
140 |
+
|
141 |
+
|
142 |
+
@lru_cache()
|
143 |
+
def get_is_torch_run() -> bool:
|
144 |
+
return os.environ.get("LOCAL_RANK") is not None
|
145 |
+
|
146 |
+
|
147 |
+
@lru_cache()
|
148 |
+
def get_is_slurm_job() -> bool:
|
149 |
+
return "SLURM_JOB_ID" in os.environ and not get_is_torch_run()
|
150 |
+
|
151 |
+
|
152 |
+
@lru_cache()
|
153 |
+
def get_global_rank() -> int:
|
154 |
+
if get_is_torch_run():
|
155 |
+
return int(os.environ["RANK"])
|
156 |
+
elif get_is_slurm_job():
|
157 |
+
return int(os.environ["SLURM_PROCID"])
|
158 |
+
else:
|
159 |
+
return 0
|
160 |
+
|
161 |
+
|
162 |
+
@lru_cache()
|
163 |
+
def get_local_rank() -> int:
|
164 |
+
if get_is_torch_run():
|
165 |
+
return int(os.environ["LOCAL_RANK"])
|
166 |
+
elif get_is_slurm_job():
|
167 |
+
return int(os.environ["SLURM_LOCALID"])
|
168 |
+
else:
|
169 |
+
return 0
|
170 |
+
|
171 |
+
|
172 |
+
@lru_cache()
|
173 |
+
def get_world_size() -> int:
|
174 |
+
if get_is_torch_run():
|
175 |
+
return int(os.environ["WORLD_SIZE"])
|
176 |
+
elif get_is_slurm_job():
|
177 |
+
return int(os.environ["SLURM_NTASKS"])
|
178 |
+
else:
|
179 |
+
return 1
|
180 |
+
|
181 |
+
|
182 |
+
@lru_cache()
|
183 |
+
def get_is_master() -> bool:
|
184 |
+
return get_global_rank() == 0
|
185 |
+
|
186 |
+
|
187 |
+
@lru_cache()
|
188 |
+
def get_master_port(job_id: int) -> int:
|
189 |
+
if get_is_torch_run():
|
190 |
+
return int(os.environ["MASTER_PORT"])
|
191 |
+
else:
|
192 |
+
MIN_MASTER_PORT, MAX_MASTER_PORT = (20000, 60000)
|
193 |
+
rng = random.Random(job_id)
|
194 |
+
return rng.randint(MIN_MASTER_PORT, MAX_MASTER_PORT)
|
195 |
+
|
196 |
+
|
197 |
+
@lru_cache()
|
198 |
+
def get_master_addr() -> str:
|
199 |
+
if get_is_torch_run():
|
200 |
+
return os.environ["MASTER_ADDR"]
|
201 |
+
elif get_is_slurm_job():
|
202 |
+
hostnames = subprocess.check_output(
|
203 |
+
["scontrol", "show", "hostnames", os.environ["SLURM_JOB_NODELIST"]]
|
204 |
+
)
|
205 |
+
return hostnames.split()[0].decode("utf-8")
|
206 |
+
else:
|
207 |
+
return "127.0.0.1"
|
208 |
+
|
209 |
+
|
210 |
+
def setup_env(env_args: EnvironmentArgs):
|
211 |
+
env_vars = env_args.model_dump()
|
212 |
+
|
213 |
+
# When using Triton, it attempts to locate prebuilt kernels in a cache
|
214 |
+
# located at ~/.triton/cache, but when that's backed by NFS this can fail
|
215 |
+
# with a "OSError: [Errno 116] Stale file handle" error. If we were to set
|
216 |
+
# it to a local directory it would belong to the first user who created it
|
217 |
+
# and it would fail for the job of any other successive user assigned to
|
218 |
+
# that machine. To avoid all this mess we use a temporary per-process cache.
|
219 |
+
triton_cache_dir = tempfile.mkdtemp()
|
220 |
+
atexit.register(shutil.rmtree, triton_cache_dir, ignore_errors=True)
|
221 |
+
env_vars["TRITON_CACHE_DIR"] = triton_cache_dir
|
222 |
+
|
223 |
+
# We change the tmp dir to /scratch in case it's slurm job
|
224 |
+
# This avoids filling up the host's usually limited tmpfs
|
225 |
+
# A full tmpfs leads to very slow creation of processes and weird bugs
|
226 |
+
if get_is_slurm_job():
|
227 |
+
new_tmp = f"/scratch/slurm_tmpdir/{os.environ['SLURM_JOB_ID']}"
|
228 |
+
if os.path.exists(new_tmp):
|
229 |
+
env_vars["TMP_DIR"] = new_tmp
|
230 |
+
|
231 |
+
for name, value in env_vars.items():
|
232 |
+
if os.environ.get(name) != str(value):
|
233 |
+
os.environ[name] = str(value)
|
234 |
+
logger.warning(f"WARNING: Setting {name} to {value}")
|
235 |
+
|
236 |
+
|
237 |
+
def setup_torch_distributed(dist_args):
|
238 |
+
"""
|
239 |
+
Handle single and multi-GPU / multi-node / SLURM jobs.
|
240 |
+
Initialize the following variables:
|
241 |
+
- global_rank
|
242 |
+
- world_size
|
243 |
+
"""
|
244 |
+
mp.set_start_method(dist_args.spawn_method)
|
245 |
+
with mp.Manager():
|
246 |
+
pass
|
247 |
+
|
248 |
+
local_rank = get_local_rank()
|
249 |
+
|
250 |
+
os.environ["RANK"] = str(get_global_rank())
|
251 |
+
os.environ["WORLD_SIZE"] = str(get_world_size())
|
252 |
+
os.environ["MASTER_ADDR"] = get_master_addr()
|
253 |
+
os.environ["MASTER_PORT"] = str(
|
254 |
+
get_master_port(job_id=int(os.environ.get("SLURM_JOB_ID", -1)))
|
255 |
+
)
|
256 |
+
|
257 |
+
if get_is_torch_run():
|
258 |
+
logger.info(f"Run launched with torchrun, local rank: {local_rank}")
|
259 |
+
elif get_is_slurm_job():
|
260 |
+
logger.info(f"Run launched with slurm, local rank: {local_rank}")
|
261 |
+
else:
|
262 |
+
logger.info("Single GPU job")
|
263 |
+
|
264 |
+
logger.info(f"ENV: {os.environ}")
|
265 |
+
|
266 |
+
# set GPU device
|
267 |
+
assert 0 <= local_rank < 8
|
268 |
+
if dist_args.matmul_allow_tf32:
|
269 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
270 |
+
logger.warning(
|
271 |
+
f"WARNING: Setting torch.backends.matmul.allow_tf32 to True. This is faster but less accurate."
|
272 |
+
)
|
273 |
+
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = (
|
274 |
+
dist_args.allow_bf16_reduced_precision_reduction
|
275 |
+
)
|
276 |
+
if torch.cuda.device_count() > 1:
|
277 |
+
torch.cuda.set_device(local_rank)
|
278 |
+
torch.distributed.init_process_group(init_method="env://", backend="nccl")
|
279 |
+
torch.autograd.set_detect_anomaly(dist_args.detect_anomaly)
|
280 |
+
|
281 |
+
|
282 |
+
def get_module(module, access_string):
|
283 |
+
names = access_string.split(sep=".")
|
284 |
+
return reduce(getattr, names, module)
|
285 |
+
|
286 |
+
|
287 |
+
def set_module(module, access_string, value):
|
288 |
+
names = access_string.split(sep=".")
|
289 |
+
parent = reduce(getattr, names[:-1], module)
|
290 |
+
setattr(parent, names[-1], value)
|
291 |
+
|
292 |
+
|
293 |
+
def default_fsdp_grouping_plan(n_layers: int) -> List[Tuple[str, bool]]:
|
294 |
+
return [(f"layers.{i}", i < n_layers - 1) for i in range(n_layers)]
|
295 |
+
|
296 |
+
|
297 |
+
def get_default_policy(no_recompute_ops=None):
|
298 |
+
no_recompute_ops = no_recompute_ops or default_no_recompute_ops
|
299 |
+
|
300 |
+
def default_policy(ctx, func, *args, **kwargs):
|
301 |
+
return (
|
302 |
+
CheckpointPolicy.MUST_SAVE
|
303 |
+
if func in no_recompute_ops
|
304 |
+
else CheckpointPolicy.PREFER_RECOMPUTE
|
305 |
+
)
|
306 |
+
|
307 |
+
return default_policy
|
308 |
+
|
309 |
+
|
310 |
+
@torch.no_grad()
|
311 |
+
def check_model_value_range(
|
312 |
+
model: torch.nn.Module, range: float = 1e3, std: float = 1e3
|
313 |
+
):
|
314 |
+
for name, param in chain(model.named_parameters(), model.named_buffers()):
|
315 |
+
if isinstance(param, DTensor):
|
316 |
+
param = param.to_local()
|
317 |
+
|
318 |
+
if param.numel() == 0:
|
319 |
+
logger.warning(
|
320 |
+
f"Model parameter {name} is empty, probably because of FSDP sharding"
|
321 |
+
)
|
322 |
+
continue
|
323 |
+
|
324 |
+
if torch.isnan(param).any() or torch.isinf(param).any():
|
325 |
+
logger.warning(f"Model parameter {name} contains NaN or Inf")
|
326 |
+
|
327 |
+
param_range = param.max() - param.min()
|
328 |
+
param_std = param.std()
|
329 |
+
if param_range > range:
|
330 |
+
logger.warning(
|
331 |
+
f"Model parameter {name} has a suspiciously large range ({param_range}): please check initialization and init_weights is defined and called"
|
332 |
+
)
|
333 |
+
if param_std > std:
|
334 |
+
logger.warning(
|
335 |
+
f"Model parameter {name} has a suspiciously large standard deviation ({param_std}): please check initialization and init_weights is defined and called"
|
336 |
+
)
|
337 |
+
if (param == 0).all():
|
338 |
+
logger.warning(
|
339 |
+
f"Model parameter {name} is all zeros: it might be because of a missing initialization"
|
340 |
+
)
|
341 |
+
|
342 |
+
|
343 |
+
def init_signal_handler(callable):
|
344 |
+
"""
|
345 |
+
Handle signals sent by SLURM for time limit / pre-emption.
|
346 |
+
"""
|
347 |
+
signal.signal(signal.SIGUSR2, callable)
|
348 |
+
logger.warning("Signal handler installed.")
|
349 |
+
|
350 |
+
|
351 |
+
def requeue_slurm_job():
|
352 |
+
prod_id = int(os.environ["SLURM_PROCID"])
|
353 |
+
logger.warning("Host: %s - Global rank: %i" % (socket.gethostname(), prod_id))
|
354 |
+
if prod_id == 0 and os.environ.get("LAUNCH_WITH", "") != "DORA":
|
355 |
+
logger.warning("Requeuing job " + os.environ["SLURM_JOB_ID"])
|
356 |
+
os.system("scontrol requeue " + os.environ["SLURM_JOB_ID"])
|
357 |
+
else:
|
358 |
+
logger.warning("Not the master process, no need to requeue.")
|
359 |
+
sys.exit(0)
|
360 |
+
|
361 |
+
|
362 |
+
@contextlib.contextmanager
|
363 |
+
def clean_env():
|
364 |
+
distrib_names = (
|
365 |
+
"MASTER_ADDR",
|
366 |
+
"MASTER_PORT",
|
367 |
+
"RANK",
|
368 |
+
"WORLD_SIZE",
|
369 |
+
"LOCAL_RANK",
|
370 |
+
"LOCAL_WORLD_SIZE",
|
371 |
+
"TORCHELASTIC_RUN_ID",
|
372 |
+
"DORA_FORCE_DISTRIB",
|
373 |
+
)
|
374 |
+
cluster_env = {
|
375 |
+
x: os.environ.pop(x)
|
376 |
+
for x in os.environ
|
377 |
+
if x.startswith(
|
378 |
+
("SLURM_", "SLURMD_", "SRUN_", "SBATCH_", "SUBMITIT_", "WANDB_")
|
379 |
+
)
|
380 |
+
or x in distrib_names
|
381 |
+
}
|
382 |
+
try:
|
383 |
+
yield
|
384 |
+
finally:
|
385 |
+
os.environ.update(cluster_env)
|
386 |
+
|
387 |
+
|
388 |
+
def parallelize_model(
|
389 |
+
model,
|
390 |
+
device_mesh,
|
391 |
+
model_args,
|
392 |
+
distributed_args: DistributedArgs,
|
393 |
+
fsdp_grouping_plan: Optional[List[Tuple[str, bool]]] = None,
|
394 |
+
tp_parallelize=None,
|
395 |
+
no_recompute_ops=None,
|
396 |
+
):
|
397 |
+
if distributed_args.tp_size > 1:
|
398 |
+
assert (
|
399 |
+
distributed_args.fsdp_type == "full_shard"
|
400 |
+
), "Only full shard is supported for TP parallelism"
|
401 |
+
assert tp_parallelize is not None, "TP plan is required for TP parallelism"
|
402 |
+
assert (
|
403 |
+
distributed_args.compile == False
|
404 |
+
), "Compile is not supported for TP parallelism"
|
405 |
+
|
406 |
+
tp_parallelize(model, device_mesh["tp"], model_args, distributed_args)
|
407 |
+
|
408 |
+
if distributed_args.float8_recipe is not None:
|
409 |
+
if distributed_args.tp_size > 1:
|
410 |
+
raise RuntimeError("float8 is incompatible with tensor-parallelism for now")
|
411 |
+
model = convert_linears_to_fp8(
|
412 |
+
model, distributed_args.float8_recipe, distributed_args.float8_filter
|
413 |
+
)
|
414 |
+
|
415 |
+
param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[
|
416 |
+
distributed_args.model_dtype
|
417 |
+
]
|
418 |
+
if (
|
419 |
+
distributed_args.fsdp_type == "full_shard"
|
420 |
+
or distributed_args.fsdp_type == "no_shard"
|
421 |
+
):
|
422 |
+
if distributed_args.fsdp_type == "no_shard":
|
423 |
+
assert (
|
424 |
+
distributed_args.dp_shard == 1
|
425 |
+
), "dp_shard must be 1 for no_shard fsdp_type"
|
426 |
+
assert (
|
427 |
+
device_mesh["dp_shard"].size() == 1
|
428 |
+
), "dp_shard must be 1 for no_shard fsdp_type"
|
429 |
+
|
430 |
+
fsdp_config = dict(
|
431 |
+
mp_policy=(
|
432 |
+
MixedPrecisionPolicy(
|
433 |
+
param_dtype=param_dtype,
|
434 |
+
reduce_dtype=torch.float32,
|
435 |
+
)
|
436 |
+
),
|
437 |
+
mesh=(
|
438 |
+
device_mesh["dp_replicate", "dp_shard"]
|
439 |
+
if distributed_args.dp_shard > 1
|
440 |
+
or distributed_args.fsdp_type == "no_shard"
|
441 |
+
else device_mesh["dp_replicate"]
|
442 |
+
),
|
443 |
+
)
|
444 |
+
|
445 |
+
if fsdp_grouping_plan is None:
|
446 |
+
# Assume that the model has list of layers and group around it
|
447 |
+
fsdp_grouping_plan = default_fsdp_grouping_plan(len(model.layers))
|
448 |
+
|
449 |
+
for path, reshard_after_forward in fsdp_grouping_plan:
|
450 |
+
module = get_module(model, path)
|
451 |
+
set_module(
|
452 |
+
model,
|
453 |
+
path,
|
454 |
+
fully_shard(
|
455 |
+
module, **fsdp_config, reshard_after_forward=reshard_after_forward
|
456 |
+
),
|
457 |
+
)
|
458 |
+
|
459 |
+
model = fully_shard(model, **fsdp_config, reshard_after_forward=True)
|
460 |
+
else:
|
461 |
+
raise ValueError(f"Invalid fsdp_type: {distributed_args.fsdp_type}")
|
462 |
+
|
463 |
+
if distributed_args.selective_activation_checkpointing:
|
464 |
+
model = checkpoint_wrapper(
|
465 |
+
model,
|
466 |
+
context_fn=partial(
|
467 |
+
create_selective_checkpoint_contexts,
|
468 |
+
get_default_policy(no_recompute_ops),
|
469 |
+
),
|
470 |
+
)
|
471 |
+
|
472 |
+
if distributed_args.compile:
|
473 |
+
torch._dynamo.config.cache_size_limit = (
|
474 |
+
distributed_args.compile_cache_size_limit
|
475 |
+
)
|
476 |
+
model = torch.compile(model)
|
477 |
+
|
478 |
+
return model
|
bytelatent/entropy_model.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from bytelatent.transformer import LMTransformer, LMTransformerArgs
|
9 |
+
|
10 |
+
|
11 |
+
def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cpu"):
|
12 |
+
with open(os.path.join(entropy_model_checkpoint_dir, "params.json")) as fr:
|
13 |
+
reloaded = json.loads(fr.read())
|
14 |
+
|
15 |
+
torch.set_default_dtype(torch.bfloat16)
|
16 |
+
model_params = reloaded["model"]
|
17 |
+
entropy_model = LMTransformer(
|
18 |
+
LMTransformerArgs(
|
19 |
+
dim=model_params["dim"],
|
20 |
+
n_layers=model_params["n_layers"],
|
21 |
+
n_heads=model_params["n_heads"],
|
22 |
+
max_seqlen=model_params["max_length"],
|
23 |
+
ffn_dim_multiplier=model_params["ffn_dim_multiplier"],
|
24 |
+
vocab_size=model_params["vocab_size"],
|
25 |
+
)
|
26 |
+
)
|
27 |
+
|
28 |
+
entropy_model.load_state_dict(
|
29 |
+
torch.load(state_dict_path, map_location=device), strict=False
|
30 |
+
)
|
31 |
+
entropy_model.to(device)
|
32 |
+
entropy_model = entropy_model.eval()
|
33 |
+
# no grads for the model:
|
34 |
+
for param in entropy_model.parameters():
|
35 |
+
param.requires_grad = False
|
36 |
+
return entropy_model
|
bytelatent/float8.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
import re
|
4 |
+
import warnings
|
5 |
+
from typing import Callable
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
# avoid division by zero when calculating scale
|
10 |
+
EPS = 1e-12
|
11 |
+
|
12 |
+
|
13 |
+
def scale(t, amax_t, dtype_t):
|
14 |
+
min_v, max_v = torch.finfo(dtype_t).min, torch.finfo(dtype_t).max
|
15 |
+
scale_t = torch.clamp(amax_t.float(), min=EPS) / max_v
|
16 |
+
t_fp8 = (t / scale_t).clamp(min=min_v, max=max_v).to(dtype_t)
|
17 |
+
return t_fp8, scale_t
|
18 |
+
|
19 |
+
|
20 |
+
def matmul(
|
21 |
+
first, amax_first, dtype_first, second_t, amax_second_t, dtype_second_t, bias
|
22 |
+
):
|
23 |
+
first_fp8, scale_first = scale(first, amax_first, dtype_first)
|
24 |
+
second_t_fp8, scale_second_t = scale(second_t, amax_second_t, dtype_second_t)
|
25 |
+
output = torch._scaled_mm(
|
26 |
+
first_fp8,
|
27 |
+
second_t_fp8.t(),
|
28 |
+
scale_a=scale_first,
|
29 |
+
scale_b=scale_second_t.t(),
|
30 |
+
bias=bias,
|
31 |
+
out_dtype=torch.bfloat16,
|
32 |
+
use_fast_accum=True,
|
33 |
+
)
|
34 |
+
return output
|
35 |
+
|
36 |
+
|
37 |
+
@torch._dynamo.allow_in_graph
|
38 |
+
class Fp8LinearFn(torch.autograd.Function):
|
39 |
+
@staticmethod
|
40 |
+
def forward(ctx, a, b_t, bias):
|
41 |
+
amax_a = a.abs().amax(dim=-1, keepdim=True)
|
42 |
+
amax_b_t = b_t.abs().amax(dim=-1, keepdim=True)
|
43 |
+
out = matmul(
|
44 |
+
a, amax_a, torch.float8_e4m3fn, b_t, amax_b_t, torch.float8_e4m3fn, bias
|
45 |
+
)
|
46 |
+
|
47 |
+
ctx.a_requires_grad = a.requires_grad
|
48 |
+
ctx.b_requires_grad = b_t.requires_grad
|
49 |
+
ctx.bias_requires_grad = bias.requires_grad if bias is not None else False
|
50 |
+
|
51 |
+
ctx.save_for_backward(a, b_t, amax_b_t.max())
|
52 |
+
|
53 |
+
return out
|
54 |
+
|
55 |
+
@staticmethod
|
56 |
+
def backward(ctx, grad_out):
|
57 |
+
a, b_t, amax_b = ctx.saved_tensors
|
58 |
+
|
59 |
+
if ctx.a_requires_grad:
|
60 |
+
b = b_t.t().contiguous()
|
61 |
+
amax_grad_out = grad_out.abs().amax(dim=-1, keepdim=True)
|
62 |
+
amax_b = amax_b.repeat(b.shape[0], 1)
|
63 |
+
grad_a = matmul(
|
64 |
+
grad_out,
|
65 |
+
amax_grad_out,
|
66 |
+
torch.float8_e4m3fn,
|
67 |
+
b,
|
68 |
+
amax_b,
|
69 |
+
torch.float8_e4m3fn,
|
70 |
+
None,
|
71 |
+
)
|
72 |
+
else:
|
73 |
+
grad_a = None
|
74 |
+
if ctx.b_requires_grad:
|
75 |
+
grad_b = grad_out.t() @ a
|
76 |
+
else:
|
77 |
+
grad_b = None
|
78 |
+
if ctx.bias_requires_grad:
|
79 |
+
grad_bias = grad_out.sum(dim=0)
|
80 |
+
else:
|
81 |
+
grad_bias = None
|
82 |
+
|
83 |
+
return grad_a, grad_b, grad_bias
|
84 |
+
|
85 |
+
|
86 |
+
class Fp8Linear(torch.nn.Linear):
|
87 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
88 |
+
out = Fp8LinearFn.apply(input.flatten(end_dim=-2), self.weight, self.bias)
|
89 |
+
out = out.unflatten(0, input.shape[:-1])
|
90 |
+
return out
|
91 |
+
|
92 |
+
|
93 |
+
def named_replace(
|
94 |
+
fn: Callable[[torch.nn.Module, str], torch.nn.Module],
|
95 |
+
module: torch.nn.Module,
|
96 |
+
name="",
|
97 |
+
) -> torch.nn.Module:
|
98 |
+
for child_name, child_module in list(module.named_children()):
|
99 |
+
full_name = f"{name}.{child_name}" if name else child_name
|
100 |
+
new_child_module = named_replace(fn, child_module, full_name)
|
101 |
+
setattr(module, child_name, new_child_module)
|
102 |
+
module = fn(module, name)
|
103 |
+
return module
|
104 |
+
|
105 |
+
|
106 |
+
def convert_linears_to_fp8(
|
107 |
+
root_module: torch.nn.Module, recipe: str, filter: str
|
108 |
+
) -> torch.nn.Module:
|
109 |
+
if recipe not in ["rowwise"]:
|
110 |
+
raise RuntimeError(f"Unknown float8 recipe {recipe!r}")
|
111 |
+
|
112 |
+
if recipe == "rowwise" and torch.__version__ < "2.5":
|
113 |
+
# We need https://github.com/pytorch/pytorch/pull/134781.
|
114 |
+
warnings.warn("Float8 row-wise scaling is slow in PyTorch prior to v2.5.0")
|
115 |
+
|
116 |
+
# Multi-kernel makes Inductor auto-tune between a regular "streaming"-based
|
117 |
+
# reduction kernel and a "persistent" reduction kernel. Since fp8 has some
|
118 |
+
# multi-pass steps (e.g., first get amax, then scale), persistent kernels
|
119 |
+
# should perform better.
|
120 |
+
torch._inductor.config.triton.multi_kernel = 1
|
121 |
+
|
122 |
+
filter_re = re.compile(filter)
|
123 |
+
|
124 |
+
def replace(module: torch.nn.Module, name: str) -> torch.nn.Module:
|
125 |
+
if not isinstance(module, torch.nn.Linear) or not filter_re.search(name):
|
126 |
+
return module
|
127 |
+
if type(module) == torch.nn.Linear:
|
128 |
+
if recipe == "rowwise":
|
129 |
+
new_module = Fp8Linear(
|
130 |
+
in_features=module.in_features,
|
131 |
+
out_features=module.out_features,
|
132 |
+
bias=module.bias is not None,
|
133 |
+
dtype=module.weight.dtype,
|
134 |
+
device=module.weight.device,
|
135 |
+
)
|
136 |
+
new_module.weight = module.weight
|
137 |
+
new_module.bias = module.bias
|
138 |
+
else:
|
139 |
+
assert False, recipe
|
140 |
+
else:
|
141 |
+
assert False, str(type(module))
|
142 |
+
return new_module
|
143 |
+
|
144 |
+
out = named_replace(replace, root_module)
|
145 |
+
|
146 |
+
# Force re-compile everything
|
147 |
+
torch._dynamo.reset_code_caches()
|
148 |
+
from torch._inductor.cudagraph_trees import reset_cudagraph_trees
|
149 |
+
|
150 |
+
reset_cudagraph_trees()
|
151 |
+
|
152 |
+
return out
|
bytelatent/logger.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
import logging
|
4 |
+
import math
|
5 |
+
import sys
|
6 |
+
import time
|
7 |
+
from datetime import timedelta
|
8 |
+
|
9 |
+
from bytelatent.distributed import get_global_rank, get_is_slurm_job
|
10 |
+
|
11 |
+
|
12 |
+
class LogFormatter(logging.Formatter):
|
13 |
+
"""
|
14 |
+
Custom logger for distributed jobs, displaying rank
|
15 |
+
and preserving indent from the custom prefix format.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self):
|
19 |
+
self.start_time = time.time()
|
20 |
+
self.rank = get_global_rank()
|
21 |
+
self.show_rank = not get_is_slurm_job() # srun has --label
|
22 |
+
|
23 |
+
def formatTime(self, record):
|
24 |
+
subsecond, seconds = math.modf(record.created)
|
25 |
+
curr_date = (
|
26 |
+
time.strftime("%y-%m-%d %H:%M:%S", time.localtime(seconds))
|
27 |
+
+ f".{int(subsecond * 1_000_000):06d}"
|
28 |
+
)
|
29 |
+
delta = timedelta(seconds=round(record.created - self.start_time))
|
30 |
+
return f"{curr_date} - {delta}"
|
31 |
+
|
32 |
+
def formatPrefix(self, record):
|
33 |
+
fmt_time = self.formatTime(record)
|
34 |
+
if self.show_rank:
|
35 |
+
return f"{self.rank}: {record.levelname:<7} {fmt_time} - "
|
36 |
+
else:
|
37 |
+
return f"{record.levelname:<7} {fmt_time} - "
|
38 |
+
|
39 |
+
def formatMessage(self, record, indent: str):
|
40 |
+
content = record.getMessage()
|
41 |
+
content = content.replace("\n", "\n" + indent)
|
42 |
+
# Exception handling as in the default formatter, albeit with indenting
|
43 |
+
# according to our custom prefix
|
44 |
+
if record.exc_info:
|
45 |
+
# Cache the traceback text to avoid converting it multiple times
|
46 |
+
# (it's constant anyway)
|
47 |
+
if not record.exc_text:
|
48 |
+
record.exc_text = self.formatException(record.exc_info)
|
49 |
+
if record.exc_text:
|
50 |
+
if content[-1:] != "\n":
|
51 |
+
content = content + "\n" + indent
|
52 |
+
content = content + indent.join(
|
53 |
+
[l + "\n" for l in record.exc_text.splitlines()]
|
54 |
+
)
|
55 |
+
if content[-1:] == "\n":
|
56 |
+
content = content[:-1]
|
57 |
+
if record.stack_info:
|
58 |
+
if content[-1:] != "\n":
|
59 |
+
content = content + "\n" + indent
|
60 |
+
stack_text = self.formatStack(record.stack_info)
|
61 |
+
content = content + indent.join([l + "\n" for l in stack_text.splitlines()])
|
62 |
+
if content[-1:] == "\n":
|
63 |
+
content = content[:-1]
|
64 |
+
|
65 |
+
return content
|
66 |
+
|
67 |
+
def format(self, record):
|
68 |
+
prefix = self.formatPrefix(record)
|
69 |
+
indent = " " * len(prefix)
|
70 |
+
content = self.formatMessage(record, indent)
|
71 |
+
return prefix + content
|
72 |
+
|
73 |
+
|
74 |
+
def set_root_log_level(log_level: str):
|
75 |
+
logger = logging.getLogger()
|
76 |
+
level: int | str = log_level.upper()
|
77 |
+
try:
|
78 |
+
level = int(log_level)
|
79 |
+
except ValueError:
|
80 |
+
pass
|
81 |
+
try:
|
82 |
+
logger.setLevel(level) # type: ignore
|
83 |
+
except Exception:
|
84 |
+
logger.warning(
|
85 |
+
f"Failed to set logging level to {log_level}, using default 'NOTSET'"
|
86 |
+
)
|
87 |
+
logger.setLevel(logging.NOTSET)
|
88 |
+
|
89 |
+
|
90 |
+
def init_logger(
|
91 |
+
log_file: str | None = None,
|
92 |
+
*,
|
93 |
+
name: str | None = None,
|
94 |
+
level: str = "NOTSET",
|
95 |
+
):
|
96 |
+
"""
|
97 |
+
Setup logging.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
log_file: A file name to save file logs to.
|
101 |
+
name: The name of the logger to configure, by default the root logger.
|
102 |
+
level: The logging level to use.
|
103 |
+
"""
|
104 |
+
set_root_log_level(level)
|
105 |
+
logger = logging.getLogger(name)
|
106 |
+
|
107 |
+
# stdout: everything
|
108 |
+
stdout_handler = logging.StreamHandler(sys.stdout)
|
109 |
+
stdout_handler.setLevel(logging.NOTSET)
|
110 |
+
stdout_handler.setFormatter(LogFormatter())
|
111 |
+
|
112 |
+
# stderr: warnings / errors and above
|
113 |
+
stderr_handler = logging.StreamHandler(sys.stderr)
|
114 |
+
stderr_handler.setLevel(logging.WARNING)
|
115 |
+
stderr_handler.setFormatter(LogFormatter())
|
116 |
+
|
117 |
+
# set stream handlers
|
118 |
+
logger.handlers.clear()
|
119 |
+
logger.handlers.append(stdout_handler)
|
120 |
+
logger.handlers.append(stderr_handler)
|
121 |
+
|
122 |
+
if log_file is not None and get_global_rank() == 0:
|
123 |
+
# build file handler
|
124 |
+
file_handler = logging.FileHandler(log_file, "a")
|
125 |
+
file_handler.setLevel(logging.NOTSET)
|
126 |
+
file_handler.setFormatter(LogFormatter())
|
127 |
+
# update logger
|
128 |
+
logger = logging.getLogger()
|
129 |
+
logger.addHandler(file_handler)
|
bytelatent/metrics.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
3 |
+
|
4 |
+
import json
|
5 |
+
import logging
|
6 |
+
from collections import namedtuple
|
7 |
+
from dataclasses import asdict
|
8 |
+
from datetime import datetime, timezone
|
9 |
+
from pathlib import Path
|
10 |
+
from typing import Any, Union
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import wandb
|
15 |
+
from pydantic import BaseModel, ConfigDict
|
16 |
+
|
17 |
+
from bytelatent.distributed import get_is_master
|
18 |
+
|
19 |
+
logger = logging.getLogger()
|
20 |
+
|
21 |
+
|
22 |
+
class WandbArgs(BaseModel):
|
23 |
+
model_config = ConfigDict(extra="forbid")
|
24 |
+
job_type: str | None = None
|
25 |
+
dir: str | None = None
|
26 |
+
project: str | None = None
|
27 |
+
entity: str | None = None
|
28 |
+
tags: list | None = None
|
29 |
+
group: str | None = None
|
30 |
+
name: str | None = None
|
31 |
+
notes: str | None = None
|
32 |
+
config_exclude_keys: list[str] | None = None
|
33 |
+
config_include_keys: list[str] | None = None
|
34 |
+
anonymous: str | None = None
|
35 |
+
mode: str | None = None
|
36 |
+
allow_val_change: bool | None = None
|
37 |
+
resume: Union[bool, str] | None = None
|
38 |
+
force: bool | None = None
|
39 |
+
tensorboard: bool | None = None
|
40 |
+
sync_tensorboard: bool | None = None
|
41 |
+
monitor_gym: bool | None = None
|
42 |
+
save_code: bool | None = None
|
43 |
+
id: str | None = None
|
44 |
+
fork_from: str | None = None
|
45 |
+
resume_from: str | None = None
|
46 |
+
|
47 |
+
|
48 |
+
class LoggingArgs(BaseModel):
|
49 |
+
model_config = ConfigDict(extra="forbid")
|
50 |
+
freq: int = 10 # Log every freq optimizer steps
|
51 |
+
acc_freq: int | None = None # Log every acc_freq gradient accumulation steps
|
52 |
+
|
53 |
+
wandb: WandbArgs | None = None
|
54 |
+
|
55 |
+
|
56 |
+
class MetricLogger:
|
57 |
+
def __init__(self, outdir: Path, args: Any | None = None):
|
58 |
+
self.outdir = outdir
|
59 |
+
self.jsonl_writer = None
|
60 |
+
self.args = args
|
61 |
+
|
62 |
+
def open(self):
|
63 |
+
if self.jsonl_writer is None:
|
64 |
+
self.jsonl_writer = open(self.outdir, "a")
|
65 |
+
if (
|
66 |
+
self.args is not None
|
67 |
+
and self.args.logging.wandb is not None
|
68 |
+
and get_is_master()
|
69 |
+
):
|
70 |
+
run = wandb.init(
|
71 |
+
config=asdict(self.args),
|
72 |
+
**asdict(self.args.logging.wandb),
|
73 |
+
)
|
74 |
+
|
75 |
+
def log(self, metrics: dict[str, Any]):
|
76 |
+
if (
|
77 |
+
self.args is not None
|
78 |
+
and self.args.logging.wandb is not None
|
79 |
+
and (wandb.run is not None)
|
80 |
+
):
|
81 |
+
wandb.log(metrics, step=metrics["global_step"])
|
82 |
+
|
83 |
+
metrics.update({"created_at": datetime.now(timezone.utc).isoformat()})
|
84 |
+
print(json.dumps(metrics), file=self.jsonl_writer, flush=True)
|
85 |
+
|
86 |
+
def close(self):
|
87 |
+
if self.jsonl_writer is not None:
|
88 |
+
self.jsonl_writer.close()
|
89 |
+
self.jsonl_writer = None
|
90 |
+
|
91 |
+
def __enter__(self):
|
92 |
+
self.open()
|
93 |
+
return self
|
94 |
+
|
95 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
96 |
+
self.close()
|
97 |
+
|
98 |
+
def __del__(self):
|
99 |
+
self.close()
|
100 |
+
|
101 |
+
|
102 |
+
GPUMemStats = namedtuple(
|
103 |
+
"GPUMemStats",
|
104 |
+
[
|
105 |
+
"max_active_gib",
|
106 |
+
"max_active_pct",
|
107 |
+
"max_reserved_gib",
|
108 |
+
"max_reserved_pct",
|
109 |
+
"num_alloc_retries",
|
110 |
+
"num_ooms",
|
111 |
+
"power_draw",
|
112 |
+
],
|
113 |
+
)
|
114 |
+
|
115 |
+
|
116 |
+
class GPUMemoryMonitor:
|
117 |
+
"""
|
118 |
+
Class to monitor GPU memory usage
|
119 |
+
"""
|
120 |
+
|
121 |
+
def __init__(self, device: str = "cuda:0"):
|
122 |
+
self.device = torch.device(device) # device object
|
123 |
+
self.device_name = torch.cuda.get_device_name(self.device)
|
124 |
+
self.device_index = torch.cuda.current_device()
|
125 |
+
self.device_capacity = torch.cuda.get_device_properties(
|
126 |
+
self.device
|
127 |
+
).total_memory
|
128 |
+
self.device_capacity_gib = self._to_gib(self.device_capacity)
|
129 |
+
|
130 |
+
# reset stats, clear cache
|
131 |
+
torch.cuda.reset_peak_memory_stats()
|
132 |
+
torch.cuda.empty_cache()
|
133 |
+
|
134 |
+
def _to_gib(self, memory_in_bytes):
|
135 |
+
# NOTE: GiB (gibibyte) is 1024, vs GB is 1000
|
136 |
+
_gib_in_bytes = 1024 * 1024 * 1024
|
137 |
+
memory_in_gib = memory_in_bytes / _gib_in_bytes
|
138 |
+
return memory_in_gib
|
139 |
+
|
140 |
+
def _to_pct(self, memory):
|
141 |
+
return 100 * memory / self.device_capacity
|
142 |
+
|
143 |
+
def get_peak_stats(self):
|
144 |
+
cuda_info = torch.cuda.memory_stats(self.device)
|
145 |
+
|
146 |
+
max_active = cuda_info["active_bytes.all.peak"]
|
147 |
+
max_active_gib = self._to_gib(max_active)
|
148 |
+
max_active_pct = self._to_pct(max_active)
|
149 |
+
|
150 |
+
max_reserved = cuda_info["reserved_bytes.all.peak"]
|
151 |
+
max_reserved_gib = self._to_gib(max_reserved)
|
152 |
+
max_reserved_pct = self._to_pct(max_reserved)
|
153 |
+
|
154 |
+
num_retries = cuda_info["num_alloc_retries"]
|
155 |
+
num_ooms = cuda_info["num_ooms"]
|
156 |
+
power_draw = torch.cuda.power_draw()
|
157 |
+
|
158 |
+
if num_retries > 0:
|
159 |
+
logger.warning(f"{num_retries} CUDA memory allocation retries.")
|
160 |
+
if num_ooms > 0:
|
161 |
+
logger.warning(f"{num_ooms} CUDA OOM errors thrown.")
|
162 |
+
|
163 |
+
return GPUMemStats(
|
164 |
+
max_active_gib,
|
165 |
+
max_active_pct,
|
166 |
+
max_reserved_gib,
|
167 |
+
max_reserved_pct,
|
168 |
+
num_retries,
|
169 |
+
num_ooms,
|
170 |
+
power_draw,
|
171 |
+
)
|
172 |
+
|
173 |
+
def reset_peak_stats(self):
|
174 |
+
torch.cuda.reset_peak_memory_stats()
|
175 |
+
torch.cuda.reset_accumulated_memory_stats()
|
176 |
+
|
177 |
+
def __str__(self):
|
178 |
+
mem_stats = self.get_peak_stats()
|
179 |
+
display_str = f"{self.device_name} ({self.device_index}): {self.device_capacity_gib} GiB capacity, "
|
180 |
+
display_str += (
|
181 |
+
f"{mem_stats.max_reserved_gib} GiB peak, {mem_stats.max_reserved_pct}% peak"
|
182 |
+
)
|
183 |
+
return f"{display_str}"
|
184 |
+
|
185 |
+
|
186 |
+
def upload_train_to_wandb(
|
187 |
+
ckpt_dir, project="lingua", entity="codegen-team", train=True, eval=True
|
188 |
+
):
|
189 |
+
import json
|
190 |
+
from pathlib import Path
|
191 |
+
|
192 |
+
import wandb
|
193 |
+
from omegaconf import OmegaConf
|
194 |
+
|
195 |
+
cfg = OmegaConf.load(Path(ckpt_dir) / "config.yaml")
|
196 |
+
cfg = OmegaConf.to_container(cfg)
|
197 |
+
|
198 |
+
if train:
|
199 |
+
wandb.init(config=cfg, name=cfg["name"], project=project, entity=entity)
|
200 |
+
|
201 |
+
with open(Path(ckpt_dir) / "metrics.jsonl") as f:
|
202 |
+
for l in f:
|
203 |
+
m = json.loads(l)
|
204 |
+
wandb.log(m, step=m["global_step"])
|
205 |
+
|
206 |
+
wandb.finish()
|
207 |
+
|
208 |
+
if eval:
|
209 |
+
wandb.init(config=cfg, name=cfg["name"], project=project, entity=entity)
|
210 |
+
|
211 |
+
with open(Path(ckpt_dir) / "metrics.eval.jsonl") as f:
|
212 |
+
for l in f:
|
213 |
+
m = json.loads(l)
|
214 |
+
wandb.log(
|
215 |
+
{
|
216 |
+
f"evals/{name.replace('/','.')}": value
|
217 |
+
for name, value in m.items()
|
218 |
+
if "/" in name
|
219 |
+
},
|
220 |
+
step=m["global_step"],
|
221 |
+
)
|
222 |
+
|
223 |
+
wandb.finish()
|
224 |
+
|
225 |
+
|
226 |
+
def get_num_params(model: nn.Module) -> int:
|
227 |
+
"""
|
228 |
+
Get the total model params
|
229 |
+
Args : only_trainable: whether to only count trainable params
|
230 |
+
"""
|
231 |
+
numel = {n: p.numel() for n, p in model.named_parameters()}
|
232 |
+
return sum(numel.values())
|
bytelatent/model/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
bytelatent/model/blt.py
ADDED
@@ -0,0 +1,1064 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
from enum import Enum, auto
|
4 |
+
from typing import Any, Optional
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from pydantic import ConfigDict, model_validator
|
8 |
+
from torch import nn
|
9 |
+
from torch.nn.attention.flex_attention import create_block_mask
|
10 |
+
from typing_extensions import Self
|
11 |
+
|
12 |
+
from bytelatent.base_transformer import (
|
13 |
+
BaseTransformerArgs,
|
14 |
+
InitStdFactor,
|
15 |
+
TransformerBlock,
|
16 |
+
)
|
17 |
+
from bytelatent.data.patcher import Patcher, PatcherArgs
|
18 |
+
from bytelatent.model.local_models import LocalDecoder, LocalEncoder
|
19 |
+
from bytelatent.model.transformer import GlobalTransformer
|
20 |
+
from bytelatent.model.utils import downsample
|
21 |
+
from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID
|
22 |
+
|
23 |
+
|
24 |
+
def attention_flops_per_token(n_layers, seq_len, dim, causal):
|
25 |
+
# Formula from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py#L27-L30
|
26 |
+
return 3.5 * (4 * n_layers * seq_len * dim // (2 if causal else 1))
|
27 |
+
|
28 |
+
|
29 |
+
def get_num_flop_per_token(
|
30 |
+
num_non_embed_params: int, n_layers: int, dim: int, seq_len: int
|
31 |
+
) -> int:
|
32 |
+
return 6 * num_non_embed_params + attention_flops_per_token(
|
33 |
+
n_layers, seq_len, dim, True
|
34 |
+
)
|
35 |
+
|
36 |
+
|
37 |
+
def causal_mask(b, h, q_idx, kv_idx):
|
38 |
+
return q_idx >= kv_idx
|
39 |
+
|
40 |
+
|
41 |
+
def setattrs(_self, **kwargs):
|
42 |
+
for k, v in kwargs.items():
|
43 |
+
setattr(_self, k, v)
|
44 |
+
|
45 |
+
|
46 |
+
def get_encoder_dim_token_emb(args):
|
47 |
+
if args.dim_token is not None:
|
48 |
+
dim_token_emb = args.dim_token
|
49 |
+
elif args.use_local_encoder_transformer:
|
50 |
+
dim_token_emb = args.dim_local_encoder
|
51 |
+
else:
|
52 |
+
dim_token_emb = args.dim_global // args.patch_size
|
53 |
+
return dim_token_emb
|
54 |
+
|
55 |
+
|
56 |
+
def get_encoder_dim_patch_emb(args):
|
57 |
+
dim_patch_emb = None
|
58 |
+
if args.cross_attn_encoder:
|
59 |
+
if args.cross_attn_init_by_pooling:
|
60 |
+
dim_patch_emb = args.dim_local_encoder
|
61 |
+
else:
|
62 |
+
dim_patch_emb = args.dim_global
|
63 |
+
return dim_patch_emb
|
64 |
+
|
65 |
+
|
66 |
+
def get_global_dim_patch_emb(args):
|
67 |
+
dim_token_emb = get_encoder_dim_token_emb(args)
|
68 |
+
if args.cross_attn_encoder:
|
69 |
+
dim_patch_emb = dim_token_emb * args.cross_attn_k
|
70 |
+
elif (
|
71 |
+
args.downsampling_by_pooling is None
|
72 |
+
or not args.downsampling_by_pooling
|
73 |
+
or len(args.downsampling_by_pooling) == 0
|
74 |
+
):
|
75 |
+
dim_patch_emb = dim_token_emb * args.patch_size
|
76 |
+
else:
|
77 |
+
dim_patch_emb = dim_token_emb * sum(
|
78 |
+
[
|
79 |
+
pooling in args.downsampling_by_pooling
|
80 |
+
for pooling in ["avg", "min", "max"]
|
81 |
+
]
|
82 |
+
)
|
83 |
+
return dim_patch_emb
|
84 |
+
|
85 |
+
|
86 |
+
def get_decoder_dim_token_emb(args):
|
87 |
+
if args.share_encoder_decoder_emb:
|
88 |
+
dim_token_emb = get_encoder_dim_token_emb(args)
|
89 |
+
elif args.dim_token is not None:
|
90 |
+
dim_token_emb = args.dim_token
|
91 |
+
else:
|
92 |
+
dim_token_emb = args.dim_local_decoder
|
93 |
+
return dim_token_emb
|
94 |
+
|
95 |
+
|
96 |
+
def parse_ngram_to_size(ngram_to_size_str: str | None) -> dict[int, int]:
|
97 |
+
if ngram_to_size_str is None:
|
98 |
+
return None
|
99 |
+
ngram_to_size = {}
|
100 |
+
for entry in ngram_to_size_str.split(","):
|
101 |
+
ngram, size = entry.split(":")
|
102 |
+
ngram = int(ngram)
|
103 |
+
size = int(size)
|
104 |
+
ngram_to_size[ngram] = size
|
105 |
+
return ngram_to_size
|
106 |
+
|
107 |
+
|
108 |
+
def fill_tokens(tokens, patch_size, fill_id):
|
109 |
+
batch_size, seq_len = tokens.shape
|
110 |
+
if seq_len % patch_size == 0:
|
111 |
+
return tokens
|
112 |
+
else:
|
113 |
+
remaining = patch_size - seq_len % patch_size
|
114 |
+
final_padding = tokens.new(batch_size, remaining).fill_(fill_id)
|
115 |
+
return torch.cat((tokens, final_padding), dim=1)
|
116 |
+
|
117 |
+
|
118 |
+
def decoder_patch_ids_from_lengths(patch_lengths, nb_boe, seq_len):
|
119 |
+
first_patch_length = patch_lengths[0, 0]
|
120 |
+
assert torch.all(
|
121 |
+
first_patch_length == patch_lengths[:, 0]
|
122 |
+
), "first patch should always be the same size (1 for dynamic, patch_size for static)."
|
123 |
+
assert (
|
124 |
+
first_patch_length - nb_boe == 1
|
125 |
+
), f"First patch (patch length: {first_patch_length}) should have one non-boe token (boe toks: {nb_boe})"
|
126 |
+
# Remove first patch from patch_ids for local decoder inputs and shift the last patch.
|
127 |
+
# decoder_patch_lengths = patch_lengths[:, 1:].clone()
|
128 |
+
# decoder_patch_lengths = add_to_last_nonzero_patch(decoder_patch_lengths, 1)
|
129 |
+
decoder_patch_lengths = patch_lengths[:, 1:]
|
130 |
+
assert (
|
131 |
+
decoder_patch_lengths.sum() + (nb_boe + 1) * patch_lengths.shape[0]
|
132 |
+
== patch_lengths.sum()
|
133 |
+
), f"{decoder_patch_lengths.sum() + (nb_boe + 1) * patch_lengths.shape[0]} != {patch_lengths.sum()}"
|
134 |
+
assert torch.all(decoder_patch_lengths >= 0), f"{decoder_patch_lengths}"
|
135 |
+
decoder_patch_ids = patch_ids_from_lengths(
|
136 |
+
patch_lengths=decoder_patch_lengths, seq_len=seq_len
|
137 |
+
)
|
138 |
+
return decoder_patch_ids
|
139 |
+
|
140 |
+
|
141 |
+
primes = [
|
142 |
+
1000000007,
|
143 |
+
5915587277,
|
144 |
+
1500450271,
|
145 |
+
3267000013,
|
146 |
+
5754853343,
|
147 |
+
4093082899,
|
148 |
+
9576890767,
|
149 |
+
3628273133,
|
150 |
+
2860486313,
|
151 |
+
5463458053,
|
152 |
+
3367900313,
|
153 |
+
]
|
154 |
+
|
155 |
+
|
156 |
+
def rolling_polynomial_hash(t, hash_func_nb: int = 0):
|
157 |
+
prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=t.device)
|
158 |
+
prime_powers = torch.stack([prime**i for i in range(t.shape[-1])])
|
159 |
+
return torch.sum(t * prime_powers, dim=-1)
|
160 |
+
|
161 |
+
|
162 |
+
def get_rolling_polynomial_hash_fn(hash_func_nb: int = 0, group_size: int = 2):
|
163 |
+
prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64)
|
164 |
+
prime_powers = torch.stack([prime**i for i in range(group_size)])
|
165 |
+
|
166 |
+
def rolling_polynomial_hash_fn(t):
|
167 |
+
return torch.sum(t * prime_powers, dim=-1)
|
168 |
+
|
169 |
+
return rolling_polynomial_hash_fn
|
170 |
+
|
171 |
+
|
172 |
+
def byte_group_hash_function(
|
173 |
+
x: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000
|
174 |
+
):
|
175 |
+
"""
|
176 |
+
Returns a hash of the input x and maps it to a value in the range [0, max_hash].
|
177 |
+
|
178 |
+
expects: x of shape (batch_size, seq_len) with values as ids in the token vocab.
|
179 |
+
returns a tensor of shape (batch_size, seq_len) with values in the range [0, max_hash].
|
180 |
+
|
181 |
+
Note: max hash can make a big difference on the number of collisions.
|
182 |
+
"""
|
183 |
+
with torch.no_grad():
|
184 |
+
bs, seq_len = x.shape
|
185 |
+
# x_numpy = x.numpy()
|
186 |
+
# hash_values = torch.zeros(bs, seq_len, dtype=torch.int64, requires_grad=False)
|
187 |
+
# for i in range(bs):
|
188 |
+
# for j in range(seq_len):
|
189 |
+
# start = max(j, j-group_size+1)
|
190 |
+
# end = j+1
|
191 |
+
# hash_values[i, j] = hash_array(x_numpy[i, start:end], max_hash)
|
192 |
+
|
193 |
+
prefix = torch.zeros(bs, group_size - 1, dtype=torch.int64, device=x.device)
|
194 |
+
x = torch.cat([prefix, x], dim=1)
|
195 |
+
windows = x.unfold(1, group_size, 1)
|
196 |
+
# hashes = get_rolling_polynomial_hash_fn(hash_func_nb, group_size)(windows)
|
197 |
+
hashes = rolling_polynomial_hash(windows, hash_func_nb)
|
198 |
+
hash_values_range = hashes % max_hash
|
199 |
+
hash_values_range.requires_grad = False
|
200 |
+
return hash_values_range
|
201 |
+
|
202 |
+
|
203 |
+
def create_patch_mask_from_ids(
|
204 |
+
patch_ids, num_patches, window=None, patches_as_queries=False
|
205 |
+
):
|
206 |
+
"""
|
207 |
+
Creates a tensor of shape [bs, seq_len, num_patches] where each element at position (i, j, k)
|
208 |
+
is True if the patch id at position (i, j) is less than or equal to k.
|
209 |
+
Args:
|
210 |
+
patch_ids (torch.Tensor): Tensor of shape [bs, seq_len] containing patch ids.
|
211 |
+
num_patches (int): Total number of patches.
|
212 |
+
window (int): If not None, only considers patches within a window of size window.
|
213 |
+
patches_as_queries (bool): If True, the patches are used as queries
|
214 |
+
Returns:
|
215 |
+
torch.Tensor: Tensor of shape [bs, q_len, kv_len] with the desired mask.
|
216 |
+
"""
|
217 |
+
bs, seq_len = patch_ids.shape
|
218 |
+
if not patches_as_queries:
|
219 |
+
q_ids = patch_ids.unsqueeze(-1).expand(bs, seq_len, num_patches)
|
220 |
+
kv_ids = (
|
221 |
+
torch.arange(num_patches, device=patch_ids.device)
|
222 |
+
.unsqueeze(0)
|
223 |
+
.unsqueeze(0)
|
224 |
+
.expand(bs, seq_len, num_patches)
|
225 |
+
)
|
226 |
+
else:
|
227 |
+
kv_ids = patch_ids.unsqueeze(1).expand(bs, num_patches, seq_len)
|
228 |
+
q_ids = (
|
229 |
+
torch.arange(num_patches, device=patch_ids.device)
|
230 |
+
.unsqueeze(0)
|
231 |
+
.unsqueeze(-1)
|
232 |
+
.expand(bs, num_patches, seq_len)
|
233 |
+
)
|
234 |
+
if window is None:
|
235 |
+
mask = q_ids == kv_ids
|
236 |
+
else:
|
237 |
+
mask = (kv_ids <= q_ids) & (q_ids < kv_ids + window)
|
238 |
+
return mask
|
239 |
+
|
240 |
+
|
241 |
+
def cross_attn_mask(
|
242 |
+
patch_ids,
|
243 |
+
patch_lengths,
|
244 |
+
N,
|
245 |
+
patches_as_queries=False,
|
246 |
+
cross_attn_k=1,
|
247 |
+
window=None,
|
248 |
+
block_mask=True,
|
249 |
+
):
|
250 |
+
bs = patch_ids.shape[0]
|
251 |
+
with torch.no_grad():
|
252 |
+
# Create the patch mask
|
253 |
+
cross_mask = create_patch_mask_from_ids(
|
254 |
+
patch_ids,
|
255 |
+
patch_lengths.shape[1],
|
256 |
+
window=window,
|
257 |
+
patches_as_queries=patches_as_queries,
|
258 |
+
).repeat_interleave(cross_attn_k, dim=1 if patches_as_queries else -1)
|
259 |
+
q_len = patch_lengths.shape[1] * cross_attn_k if patches_as_queries else N
|
260 |
+
kv_len = N if patches_as_queries else patch_lengths.shape[1] * cross_attn_k
|
261 |
+
assert cross_mask.shape == (
|
262 |
+
bs,
|
263 |
+
q_len,
|
264 |
+
kv_len,
|
265 |
+
), f"{cross_mask.shape} != {(bs, q_len, kv_len)}"
|
266 |
+
if block_mask:
|
267 |
+
|
268 |
+
def patch_mask(b, h, q_idx, kv_idx):
|
269 |
+
return cross_mask[b, q_idx, kv_idx]
|
270 |
+
|
271 |
+
block_mask = create_block_mask(
|
272 |
+
patch_mask,
|
273 |
+
B=bs,
|
274 |
+
H=None,
|
275 |
+
Q_LEN=q_len,
|
276 |
+
KV_LEN=kv_len,
|
277 |
+
_compile=True,
|
278 |
+
)
|
279 |
+
return block_mask
|
280 |
+
else:
|
281 |
+
return torch.where(
|
282 |
+
cross_mask, torch.tensor(0.0), torch.tensor(float("-inf"))
|
283 |
+
).unsqueeze(
|
284 |
+
1
|
285 |
+
) # [bs, 1, q_len, kv_len]
|
286 |
+
|
287 |
+
|
288 |
+
def get_blt_input(
|
289 |
+
tokens: torch.Tensor,
|
290 |
+
enforce_patch_size_multiple: bool,
|
291 |
+
nb_boe: torch.Tensor,
|
292 |
+
patch_size: int,
|
293 |
+
boe_id: int,
|
294 |
+
):
|
295 |
+
"""
|
296 |
+
This function returns X_et, X_gt and X_dt, the encoder, global, and decoder
|
297 |
+
tokens respectively.
|
298 |
+
|
299 |
+
Consider the input and target sequences:
|
300 |
+
X=[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12,13]
|
301 |
+
Y=[4,5,6,7,eos,bos,8,9,10,eos,bos,11,12,13,14]
|
302 |
+
with patch_size=4
|
303 |
+
|
304 |
+
Note 1: that there will be no special tokens introduced at the patch level.
|
305 |
+
Note 2: X_e needs to be trimmed to be passed to Global
|
306 |
+
|
307 |
+
Current without boe:
|
308 |
+
X_et = [[boe,boe,boe,boe] [3,4,5,6], [7,eos,bos,8], [9,10,eos,bos] [11,12,13, pad]]
|
309 |
+
X_g = [[boe,boe,boe,boe] [3,4,5,6], [7,eos,bos,8], [9,10,eos,bos] [11,12,13, pad]] # remove last glob patch
|
310 |
+
X_dt = [[3,4,5,6] [7,eos,bos,8], [9,10,eos,bos], [11,12,13]]
|
311 |
+
Y = [[4,5,6,7] [eos,bos,8,9], [10,eos,bos,11], [12,13,14]]
|
312 |
+
|
313 |
+
--> lag fix:
|
314 |
+
X_et = [[boe,boe,boe,3] [4,5,6,7], [eos,bos,8,9], [10,eos,bos,11] [12,13,pad,pad]]
|
315 |
+
X_g = [[boe,boe,boe,3] [4,5,6,7], [eos,bos,8,9], [10,eos,bos,11]]
|
316 |
+
X_dt = [[3,4,5,6] [7,eos,bos,8], [9,10,eos,bos], [11,12,13]]
|
317 |
+
Y = [[4,5,6,7] [eos,bos,8,9], [10,eos,bos,11], [12,13,14]]
|
318 |
+
|
319 |
+
Dynamic (current):
|
320 |
+
X = [3,4,5,6,7,eos,bos,8,9,10,eos,bos]
|
321 |
+
Y = [4,5,6,7,eos,bos,8,9,10,eos,bos,11]
|
322 |
+
|
323 |
+
entropy patching:
|
324 |
+
input: 7, bos, 9, 10
|
325 |
+
pred (high entropy): eos, 8, 10, eos
|
326 |
+
|
327 |
+
X_et = [[boe,3,4,5,6,7,eos,bos,8,9,10,eos,bos]
|
328 |
+
X_g = [[boe], [3,4,5,6], [7,eos],[bos,8],[9], [10,eos]]
|
329 |
+
X_dt = [[3,4,5,6], [7,eos], [bos,8],[9], [10,eos],[bos]]
|
330 |
+
Y = [4,5,6,7,eos,bos,8,9,10,eos,bos,11]
|
331 |
+
|
332 |
+
--> lag fix no boe (force single byte first patch):
|
333 |
+
X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12]
|
334 |
+
X_g = [[3], [4,5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # remove last global patch
|
335 |
+
X_dt = [[3,4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11,12]]
|
336 |
+
Y = [4,5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12,13]
|
337 |
+
|
338 |
+
input: 4, 7, bos, 9, 10
|
339 |
+
pred (high entropy): 5, eos, 8, 10, eos
|
340 |
+
|
341 |
+
X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12]
|
342 |
+
X_g = [[3], [4] , [5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # remove last global patch
|
343 |
+
X_dt = [[3] [4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11,12]]
|
344 |
+
Y = [4,] [5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12,13]
|
345 |
+
|
346 |
+
Handle the last byte properly.
|
347 |
+
patch_lengths = [1, 1, 3, 2, 2 1 2 2 1]
|
348 |
+
X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12]
|
349 |
+
X_g = [[3], [4] , [5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # do not remove last global patch
|
350 |
+
X_dt = [[3] [4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11] [12]]
|
351 |
+
Y = [4,] [5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12, 13]]
|
352 |
+
|
353 |
+
|
354 |
+
bpe delim
|
355 |
+
X_et = [[3,4,5,6,7,<d>,eos,bos,<d>,8,9,<d>,10,<d>,eos,bos,11,12]
|
356 |
+
X_g = [[3], [4,5,6,7,<d>], [eos,bos,<d>], ..
|
357 |
+
X_dt = [[3,4,5,6,7], [<d>,eos,bos], [<d>,bos,8], ..
|
358 |
+
Y = [4,5,6,7,<d>, eos,bos,<d> 8,9,<d>, ..
|
359 |
+
|
360 |
+
|
361 |
+
Note 1: that there will be no special tokens introduced at the patch level.
|
362 |
+
Note 2: X_e needs to be trimmed to be passed to Global
|
363 |
+
"""
|
364 |
+
batch_size, seq_len = tokens.shape
|
365 |
+
local_encoder_tokens = tokens
|
366 |
+
local_decoder_tokens = tokens
|
367 |
+
|
368 |
+
if nb_boe > 0:
|
369 |
+
padded_patch = tokens.new(batch_size, nb_boe).fill_(boe_id)
|
370 |
+
local_encoder_tokens = torch.cat((padded_patch, local_encoder_tokens), dim=1)
|
371 |
+
# global_tokens = tokens.new(batch_size, ((seq_len-1) // patch_size)+1).fill_(boe_id)
|
372 |
+
|
373 |
+
# create global tokens, contains boe tokens and eos
|
374 |
+
# padded_local_encoder_tokens = fill_tokens(local_encoder_tokens, patch_size, boe_id)
|
375 |
+
# patches = padded_local_encoder_tokens.view(batch_size, -1, patch_size)
|
376 |
+
# global_tokens = (patches.eq(eos_id).any(dim=2).int() * eos_id)[:, 1:]
|
377 |
+
# global_tokens += global_tokens.eq(0).int() * boe_id
|
378 |
+
# TODO: fix this when we want to use block causal in the global.
|
379 |
+
|
380 |
+
if enforce_patch_size_multiple and local_encoder_tokens.shape[-1] % patch_size != 0:
|
381 |
+
local_encoder_tokens = fill_tokens(local_encoder_tokens, patch_size, boe_id)
|
382 |
+
|
383 |
+
return local_encoder_tokens, None, local_decoder_tokens
|
384 |
+
|
385 |
+
|
386 |
+
def patch_ids_from_lengths(patch_lengths, seq_len):
|
387 |
+
bs, num_patches = patch_lengths.shape
|
388 |
+
# Create a tensor of cumulative sums of the patch lengths
|
389 |
+
cum_d = torch.cat(
|
390 |
+
[
|
391 |
+
torch.zeros(bs, 1, dtype=patch_lengths.dtype, device=patch_lengths.device),
|
392 |
+
patch_lengths.cumsum(dim=-1),
|
393 |
+
],
|
394 |
+
dim=-1,
|
395 |
+
)
|
396 |
+
patch_ids = (cum_d.unsqueeze(-1) <= torch.arange(seq_len, device=cum_d.device)).sum(
|
397 |
+
dim=-2
|
398 |
+
) - 1
|
399 |
+
assert not (
|
400 |
+
torch.max(patch_ids) > patch_lengths.shape[-1] or torch.min(patch_ids) < 0
|
401 |
+
), f"{torch.max(patch_ids)} > {patch_lengths.shape[-1]} or {torch.min(patch_ids)} < 0"
|
402 |
+
return patch_ids
|
403 |
+
|
404 |
+
|
405 |
+
class ByteLatentTransformerArgs(BaseTransformerArgs):
|
406 |
+
model_config = ConfigDict(extra="forbid")
|
407 |
+
# Basic model configuration
|
408 |
+
seed: int = 42
|
409 |
+
vocab_size: int = -1
|
410 |
+
dim: int = 512
|
411 |
+
n_layers: int = 8
|
412 |
+
n_heads: int = 8
|
413 |
+
# TODO: What is the purpose of this parameter?
|
414 |
+
weight_tying: bool = False
|
415 |
+
sliding_window: Optional[int] = None
|
416 |
+
|
417 |
+
# Architecture and dimensions
|
418 |
+
dim_token: int = 256
|
419 |
+
dim_global: int = 512
|
420 |
+
dim_local_decoder: int = 512
|
421 |
+
dim_local_encoder: int = 512
|
422 |
+
n_layers_global: int = 8
|
423 |
+
n_layers_local_decoder: int = 8
|
424 |
+
n_layers_local_encoder: int = 8
|
425 |
+
|
426 |
+
# Tokenization and patching
|
427 |
+
tokenization_mode: str = "bpe"
|
428 |
+
patch_size: float | None = None
|
429 |
+
patching_mode: str | None = None
|
430 |
+
patching_threshold: float | None = None
|
431 |
+
patching_threshold_add: float | None = None
|
432 |
+
monotonicity: bool = False
|
433 |
+
patching_batch_size: int = 1
|
434 |
+
patching_device: str = "cuda"
|
435 |
+
data_loader_patching: bool = False
|
436 |
+
max_patch_length: int | None = None
|
437 |
+
|
438 |
+
# Encoder/Decoder configuration
|
439 |
+
tie_local_encoder_decoder_logits: bool = False
|
440 |
+
use_local_encoder_transformer: bool = False
|
441 |
+
encoder_lm_loss: bool = False
|
442 |
+
max_encoder_seq_length: int | None = None
|
443 |
+
pad_to_max_length: bool = False
|
444 |
+
encoder_enable_byte_ngrams: bool = False
|
445 |
+
encoder_enable_byte_group_hash: bool = False
|
446 |
+
ngram_vocab_sizes: int | None = None
|
447 |
+
|
448 |
+
# Cross attention configurations
|
449 |
+
cross_attn_encoder: bool = False
|
450 |
+
cross_attn_decoder: bool = False
|
451 |
+
cross_attn_window_encoder: int | None = None
|
452 |
+
cross_attn_window_decoder: int | None = None
|
453 |
+
cross_attn_k: int | None = None
|
454 |
+
cross_attn_nheads: int | None = None
|
455 |
+
cross_attn_all_layers_decoder: bool = False
|
456 |
+
cross_attn_all_layers_encoder: bool = False
|
457 |
+
cross_attn_use_flex_attention: bool = True
|
458 |
+
cross_attn_init_by_pooling: bool = False
|
459 |
+
|
460 |
+
# Encoder hash configurations
|
461 |
+
encoder_hash_byte_group_size: Any | None = None
|
462 |
+
encoder_hash_byte_group_vocab: int = 30000
|
463 |
+
encoder_hash_byte_group_nb_functions: int = 3
|
464 |
+
|
465 |
+
# Model behavior and optimization
|
466 |
+
log_patch_lengths: bool = False
|
467 |
+
non_linearity: str = "swiglu"
|
468 |
+
use_rope: bool = True
|
469 |
+
recompute_fc1_out: bool = False
|
470 |
+
recompute_fc3_out: bool = False
|
471 |
+
recompute_attn: bool = True
|
472 |
+
custom_bwd: bool = False
|
473 |
+
layer_ckpt: str = "all"
|
474 |
+
efficient_attn: str | None = None
|
475 |
+
|
476 |
+
# Architecture options
|
477 |
+
patch_only_encoder: bool = False
|
478 |
+
patch_only_decoder: bool = False
|
479 |
+
|
480 |
+
# Initialization and attention
|
481 |
+
init_use_gaussian: bool = True
|
482 |
+
init_use_depth: str = "current"
|
483 |
+
attn_bias_type: str = "causal"
|
484 |
+
alpha_depth: str = "disabled"
|
485 |
+
max_length: int = 2048
|
486 |
+
|
487 |
+
# Norm configuration
|
488 |
+
norm_eps: float = 1e-5
|
489 |
+
norm_affine: bool = True
|
490 |
+
pre_norm: bool = True
|
491 |
+
norm_type: str = "rmsnorm"
|
492 |
+
|
493 |
+
# Additional configurations
|
494 |
+
multiple_of: int = 256
|
495 |
+
ffn_dim_multiplier: float = 1.0
|
496 |
+
dropout: float = 0
|
497 |
+
output_size: int = -1
|
498 |
+
|
499 |
+
# Additional parameters from ModelArgs
|
500 |
+
architecture: str = "vanilla"
|
501 |
+
share_encoder_decoder_emb: bool = True
|
502 |
+
global_local_decoder_residual_layer: str | None = None
|
503 |
+
|
504 |
+
tokenize_with_bpe_delimiter: bool = False
|
505 |
+
patching_thresholds_str: str | None = None
|
506 |
+
tie_local_encoder_decoder: bool = False
|
507 |
+
encoder_preds_low_entropy_toks: float | None = None
|
508 |
+
encoder_preds_random_toks: float | None = None
|
509 |
+
dim_token_emb: int | None = None
|
510 |
+
dim_patch_emb: int | None = None
|
511 |
+
|
512 |
+
encoder_ngram_table_dir: str | None = None
|
513 |
+
encoder_ngram_to_size_str: str | None = None
|
514 |
+
|
515 |
+
# Model architecture params
|
516 |
+
entropy_model_checkpoint_dir: str | None = None
|
517 |
+
entropy_model_is_ngram_model: bool = False
|
518 |
+
downsampling_by_pooling: str | None = None
|
519 |
+
n_heads_global: int = 8
|
520 |
+
n_heads_local_decoder: int = 8
|
521 |
+
n_heads_local_encoder: int = 8
|
522 |
+
n_kv_heads: int | None = None
|
523 |
+
n_kv_heads_global: int | None = None
|
524 |
+
conv_kernel_size: int | None = None
|
525 |
+
local_attention_window_len: int | None = None
|
526 |
+
|
527 |
+
# Performance optimization
|
528 |
+
sequence_parallel: bool = False
|
529 |
+
loss_parallel: bool = False
|
530 |
+
fuse_sequence_parallel: bool = False
|
531 |
+
use_fsdp: bool = True
|
532 |
+
attn_to_keep: str = "all"
|
533 |
+
|
534 |
+
# RoPE parameters
|
535 |
+
rope_theta: float = 10000.0
|
536 |
+
rope_use_fp32_in_outer_product: bool = False
|
537 |
+
|
538 |
+
# Parameter mixing
|
539 |
+
pm_size: int = 0
|
540 |
+
|
541 |
+
# Logging
|
542 |
+
full_logging_n_layers: int = 4
|
543 |
+
|
544 |
+
# Special token config
|
545 |
+
eos_id: int | None = None
|
546 |
+
|
547 |
+
@model_validator(mode="after")
|
548 |
+
def check_hash_byte_sizes(self) -> Self:
|
549 |
+
if (
|
550 |
+
self.encoder_hash_byte_group_size is not None
|
551 |
+
and type(self.encoder_hash_byte_group_size) == str
|
552 |
+
):
|
553 |
+
self.encoder_hash_byte_group_size = [
|
554 |
+
int(x)
|
555 |
+
for x in self.encoder_hash_byte_group_size.split(",")
|
556 |
+
if len(x) > 0
|
557 |
+
]
|
558 |
+
return self
|
559 |
+
|
560 |
+
|
561 |
+
class LocalEncoderArgs(ByteLatentTransformerArgs):
|
562 |
+
# Local encoder specific dimensions
|
563 |
+
n_heads_local_encoder: int = 8
|
564 |
+
dim_token_emb: int | None = None
|
565 |
+
dim_patch_emb: int | None = None
|
566 |
+
|
567 |
+
def __post_init__(self):
|
568 |
+
# Override base args with local encoder specific values
|
569 |
+
self.dim = self.dim_local_encoder
|
570 |
+
self.n_layers = self.n_layers_local_encoder
|
571 |
+
self.n_heads = self.n_heads_local_encoder
|
572 |
+
self.cross_attn_decoder = False
|
573 |
+
self.cross_attn_k = self.cross_attn_k if self.cross_attn_encoder else None
|
574 |
+
self.attn_bias_type = "local_block_causal"
|
575 |
+
|
576 |
+
|
577 |
+
class GlobalTransformerArgs(ByteLatentTransformerArgs):
|
578 |
+
# Global encoder specific dimensions
|
579 |
+
dim_token_emb: int | None = None
|
580 |
+
dim_patch_emb: int | None = None
|
581 |
+
|
582 |
+
def __post_init__(self):
|
583 |
+
# Override base args with global encoder specific values
|
584 |
+
self.dim = self.dim_global
|
585 |
+
self.n_layers = self.n_layers_global
|
586 |
+
self.n_heads = self.n_heads_global
|
587 |
+
self.n_kv_heads = self.n_kv_heads_global
|
588 |
+
self.local_attention_window_len = None
|
589 |
+
self.cross_attn_encoder = False
|
590 |
+
self.cross_attn_decoder = False
|
591 |
+
|
592 |
+
|
593 |
+
class LocalDecoderArgs(ByteLatentTransformerArgs):
|
594 |
+
# Local decoder specific dimensions
|
595 |
+
dim_token_emb: int | None = None
|
596 |
+
dim_patch_emb: int | None = None
|
597 |
+
|
598 |
+
def __post_init__(self):
|
599 |
+
# Override base args with local decoder specific values
|
600 |
+
self.dim = self.dim_local_decoder
|
601 |
+
self.n_layers = self.n_layers_local_decoder
|
602 |
+
self.n_heads = self.n_heads_local_decoder
|
603 |
+
self.cross_attn_encoder = False
|
604 |
+
self.cross_attn_init_by_pooling = False
|
605 |
+
self.attn_bias_type = "local_block_causal"
|
606 |
+
|
607 |
+
|
608 |
+
def create_global_transformer(args: ByteLatentTransformerArgs) -> GlobalTransformer:
|
609 |
+
global_args = args.model_copy(
|
610 |
+
deep=True,
|
611 |
+
update=dict(
|
612 |
+
dim=args.dim_global,
|
613 |
+
n_layers=args.n_layers_global,
|
614 |
+
n_heads=args.n_heads_global,
|
615 |
+
n_kv_heads=args.n_kv_heads_global,
|
616 |
+
local_attention_window_len=None,
|
617 |
+
dim_token_emb=get_global_dim_patch_emb(args),
|
618 |
+
dim_patch_emb=None,
|
619 |
+
cross_attn_encoder=False,
|
620 |
+
cross_attn_decoder=False,
|
621 |
+
),
|
622 |
+
)
|
623 |
+
|
624 |
+
return GlobalTransformer(global_args)
|
625 |
+
|
626 |
+
|
627 |
+
def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder:
|
628 |
+
# First deep copy the original args
|
629 |
+
# Replace with local encoder specific values
|
630 |
+
local_encoder_args = args.model_copy(
|
631 |
+
deep=True,
|
632 |
+
update=dict(
|
633 |
+
dim=args.dim_local_encoder,
|
634 |
+
n_layers=args.n_layers_local_encoder,
|
635 |
+
n_heads=args.n_heads_local_encoder,
|
636 |
+
dim_token_emb=get_encoder_dim_token_emb(args),
|
637 |
+
dim_patch_emb=get_encoder_dim_patch_emb(args),
|
638 |
+
cross_attn_decoder=False,
|
639 |
+
cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None,
|
640 |
+
attn_bias_type="local_block_causal",
|
641 |
+
),
|
642 |
+
)
|
643 |
+
|
644 |
+
return LocalEncoder(local_encoder_args)
|
645 |
+
|
646 |
+
|
647 |
+
def create_local_decoder(args: ByteLatentTransformerArgs) -> LocalDecoder:
|
648 |
+
# First deep copy the original args
|
649 |
+
local_decoder_args = args.model_copy(
|
650 |
+
deep=True,
|
651 |
+
update=dict(
|
652 |
+
dim=args.dim_local_decoder,
|
653 |
+
n_layers=args.n_layers_local_decoder,
|
654 |
+
n_heads=args.n_heads_local_decoder,
|
655 |
+
cross_attn_encoder=False,
|
656 |
+
cross_attn_init_by_pooling=False, # states are already defined
|
657 |
+
dim_token_emb=get_decoder_dim_token_emb(args),
|
658 |
+
dim_patch_emb=args.dim_global,
|
659 |
+
cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None,
|
660 |
+
),
|
661 |
+
)
|
662 |
+
|
663 |
+
return LocalDecoder(local_decoder_args)
|
664 |
+
|
665 |
+
|
666 |
+
class EmbeddingType(Enum):
|
667 |
+
HASH_TOK = auto()
|
668 |
+
NGRAM = auto()
|
669 |
+
|
670 |
+
|
671 |
+
def init_embeddings(
|
672 |
+
args,
|
673 |
+
embedding_type: EmbeddingType,
|
674 |
+
local_encoder_dim: int,
|
675 |
+
encoder_hash_byte_group_size: list = None,
|
676 |
+
):
|
677 |
+
if (
|
678 |
+
embedding_type == EmbeddingType.HASH_TOK
|
679 |
+
and args.encoder_hash_byte_group_size is None
|
680 |
+
):
|
681 |
+
return None
|
682 |
+
if embedding_type == EmbeddingType.NGRAM and args.encoder_ngram_to_size_str is None:
|
683 |
+
return None
|
684 |
+
|
685 |
+
embeddings = []
|
686 |
+
|
687 |
+
if embedding_type == EmbeddingType.HASH_TOK:
|
688 |
+
emb_dim = local_encoder_dim
|
689 |
+
encoder_hash_byte_group_vocab = args.encoder_hash_byte_group_vocab
|
690 |
+
for _ in range(args.encoder_hash_byte_group_nb_functions):
|
691 |
+
for _ in encoder_hash_byte_group_size:
|
692 |
+
embeddings.append(
|
693 |
+
nn.Embedding(
|
694 |
+
encoder_hash_byte_group_vocab,
|
695 |
+
emb_dim,
|
696 |
+
)
|
697 |
+
)
|
698 |
+
|
699 |
+
elif embedding_type == EmbeddingType.NGRAM:
|
700 |
+
encoder_ngram_to_size = parse_ngram_to_size(args.encoder_ngram_to_size_str)
|
701 |
+
emb_dim = local_encoder_dim
|
702 |
+
OFFSET = 4 # This should be passed as parameter if it's variable
|
703 |
+
for ngram_vocab_size in encoder_ngram_to_size.values():
|
704 |
+
embeddings.append(nn.Embedding(ngram_vocab_size + OFFSET, emb_dim))
|
705 |
+
|
706 |
+
return nn.ModuleList(embeddings)
|
707 |
+
|
708 |
+
|
709 |
+
def compute_hash_embeddings(
|
710 |
+
local_encoder_tokens: torch.Tensor,
|
711 |
+
local_encoder,
|
712 |
+
encoder_hash_tok_embedding: nn.ModuleList,
|
713 |
+
encoder_hash_byte_group_nb_functions: int,
|
714 |
+
encoder_hash_byte_group_size: list,
|
715 |
+
encoder_hash_byte_group_vocab: int,
|
716 |
+
) -> torch.Tensor:
|
717 |
+
"""
|
718 |
+
Compute embeddings using hash token embeddings.
|
719 |
+
|
720 |
+
Args:
|
721 |
+
local_encoder_tokens: Input tokens tensor
|
722 |
+
local_encoder: Encoder object with tok_embeddings method
|
723 |
+
encoder_hash_tok_embedding: ModuleList of hash token embeddings
|
724 |
+
encoder_hash_byte_group_nb_functions: Number of hash functions
|
725 |
+
encoder_hash_byte_group_size: List of byte group sizes
|
726 |
+
encoder_hash_byte_group_vocab: Vocabulary size for hash embeddings
|
727 |
+
|
728 |
+
Returns:
|
729 |
+
torch.Tensor: Combined embeddings
|
730 |
+
"""
|
731 |
+
if encoder_hash_tok_embedding is None:
|
732 |
+
return None
|
733 |
+
|
734 |
+
local_encoder_embeds = local_encoder.tok_embeddings(local_encoder_tokens)
|
735 |
+
|
736 |
+
i = 0
|
737 |
+
for func_nb in range(encoder_hash_byte_group_nb_functions):
|
738 |
+
for byte_group_size in encoder_hash_byte_group_size:
|
739 |
+
hash_ids = byte_group_hash_function(
|
740 |
+
local_encoder_tokens,
|
741 |
+
byte_group_size,
|
742 |
+
hash_func_nb=func_nb,
|
743 |
+
max_hash=encoder_hash_byte_group_vocab,
|
744 |
+
)
|
745 |
+
hash_tok_embedding = encoder_hash_tok_embedding[i]
|
746 |
+
local_encoder_embeds = local_encoder_embeds + hash_tok_embedding(hash_ids)
|
747 |
+
i += 1
|
748 |
+
|
749 |
+
assert i == len(encoder_hash_tok_embedding)
|
750 |
+
return local_encoder_embeds
|
751 |
+
|
752 |
+
|
753 |
+
class ByteLatentTransformer(nn.Module):
|
754 |
+
"""
|
755 |
+
The ByteLatentTransformer (BLT) is a byte-level language model architecture that processes byte sequences
|
756 |
+
by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers,
|
757 |
+
and local decoders to efficiently encode and decode byte sequences, leveraging patch-based processing for
|
758 |
+
improved performance and inference efficiency.
|
759 |
+
"""
|
760 |
+
|
761 |
+
def __init__(self, args: ByteLatentTransformerArgs):
|
762 |
+
super().__init__()
|
763 |
+
|
764 |
+
# General configuration
|
765 |
+
self.weight_tying = args.weight_tying
|
766 |
+
self.sliding_window = args.sliding_window
|
767 |
+
self.patch_size = args.patch_size
|
768 |
+
self.patching_mode = args.patching_mode
|
769 |
+
self.boe_id, self.bos_id, self.pad_id, self.eos_id = (
|
770 |
+
BOE_ID,
|
771 |
+
BOS_ID,
|
772 |
+
PAD_ID,
|
773 |
+
EOS_ID,
|
774 |
+
)
|
775 |
+
self.downsampling_by_pooling = args.downsampling_by_pooling
|
776 |
+
self.patching_threshold = args.patching_threshold
|
777 |
+
self.dim = args.dim
|
778 |
+
self.init_base_std = args.init_base_std
|
779 |
+
self.init_std_factor = InitStdFactor(args.init_std_factor)
|
780 |
+
self.max_seqlen = args.max_seqlen
|
781 |
+
|
782 |
+
# Cross attention configuration
|
783 |
+
self.cross_attn_encoder = args.cross_attn_encoder
|
784 |
+
self.cross_attn_decoder = args.cross_attn_decoder
|
785 |
+
self.cross_attn_k = args.cross_attn_k
|
786 |
+
self.cross_attn_window_encoder = args.cross_attn_window_encoder
|
787 |
+
self.cross_attn_window_decoder = args.cross_attn_window_decoder
|
788 |
+
self.cross_attn_use_flex_attention = args.cross_attn_use_flex_attention
|
789 |
+
|
790 |
+
# Encoder hash configuration
|
791 |
+
self.encoder_hash_byte_group_size = args.encoder_hash_byte_group_size
|
792 |
+
self.encoder_hash_byte_group_vocab = args.encoder_hash_byte_group_vocab
|
793 |
+
self.encoder_hash_byte_group_nb_functions = (
|
794 |
+
args.encoder_hash_byte_group_nb_functions
|
795 |
+
)
|
796 |
+
|
797 |
+
# ByteLatent modules
|
798 |
+
self.local_encoder = create_local_encoder(args)
|
799 |
+
self.global_transformer = create_global_transformer(args)
|
800 |
+
self.local_decoder = create_local_decoder(args)
|
801 |
+
self.encoder_hash_tok_embedding = init_embeddings(
|
802 |
+
args,
|
803 |
+
EmbeddingType.HASH_TOK,
|
804 |
+
local_encoder_dim=self.local_encoder.dim,
|
805 |
+
encoder_hash_byte_group_size=self.encoder_hash_byte_group_size,
|
806 |
+
)
|
807 |
+
self.encoder_ngram_embedding = init_embeddings(
|
808 |
+
args,
|
809 |
+
EmbeddingType.NGRAM,
|
810 |
+
local_encoder_dim=self.local_encoder.dim,
|
811 |
+
encoder_hash_byte_group_size=None,
|
812 |
+
)
|
813 |
+
self.tok_embeddings = torch.nn.Embedding(args.vocab_size, args.dim)
|
814 |
+
|
815 |
+
# Transformer layers
|
816 |
+
self.layers = nn.ModuleList(
|
817 |
+
[TransformerBlock(args) for _ in range(args.n_layers)]
|
818 |
+
)
|
819 |
+
|
820 |
+
# Encoder ngram embedding tables
|
821 |
+
self.encoder_ngram_embedding = None
|
822 |
+
if args.encoder_enable_byte_ngrams:
|
823 |
+
self.encoder_ngram_embedding = nn.ModuleList()
|
824 |
+
assert args.ngram_vocab_sizes is not None
|
825 |
+
self.encoder_ngram_to_size = parse_ngram_to_size(
|
826 |
+
args.encoder_ngram_to_size_str
|
827 |
+
)
|
828 |
+
ngram_emb_dim = self.local_encoder.dim
|
829 |
+
for ngram_vocab_size in self.encoder_ngram_to_size.values():
|
830 |
+
self.encoder_ngram_embedding.append(
|
831 |
+
nn.Embedding(ngram_vocab_size + OFFSET, ngram_emb_dim)
|
832 |
+
)
|
833 |
+
|
834 |
+
# Output layer
|
835 |
+
assert args.vocab_size > 0, "vocab_size must be greater than 0"
|
836 |
+
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
|
837 |
+
if args.weight_tying:
|
838 |
+
self.output.weight = self.tok_embeddings.weight
|
839 |
+
|
840 |
+
# Patcher module
|
841 |
+
if not args.data_loader_patching:
|
842 |
+
self.patcher = Patcher(
|
843 |
+
PatcherArgs(
|
844 |
+
patch_size=args.patch_size,
|
845 |
+
patching_mode=args.patching_mode,
|
846 |
+
patching_threshold=args.patching_threshold,
|
847 |
+
patching_threshold_add=args.patching_threshold_add,
|
848 |
+
monotonicity=args.monotonicity,
|
849 |
+
max_patch_length=args.max_patch_length,
|
850 |
+
)
|
851 |
+
)
|
852 |
+
|
853 |
+
def forward(
|
854 |
+
self,
|
855 |
+
tokens: torch.Tensor,
|
856 |
+
patch_lengths: Optional[torch.Tensor] = None,
|
857 |
+
ngram_ids: Optional[torch.Tensor] = None,
|
858 |
+
):
|
859 |
+
# Ensure ngram_ids is either a tensor or None
|
860 |
+
assert (
|
861 |
+
isinstance(ngram_ids, torch.Tensor) or ngram_ids is None
|
862 |
+
), f"ngram_ids must be a tensor or None, but was: {type(ngram_ids)}"
|
863 |
+
|
864 |
+
bs, N = tokens.shape # Batch size and sequence length
|
865 |
+
|
866 |
+
# Get megabyte inputs
|
867 |
+
nb_boe = int(0 if self.patching_mode != "" else self.patch_size - 1)
|
868 |
+
local_encoder_tokens, _, local_decoder_tokens = get_blt_input(
|
869 |
+
tokens=tokens,
|
870 |
+
enforce_patch_size_multiple=False,
|
871 |
+
nb_boe=nb_boe,
|
872 |
+
patch_size=self.patch_size,
|
873 |
+
boe_id=self.boe_id,
|
874 |
+
)
|
875 |
+
|
876 |
+
# Patching
|
877 |
+
if patch_lengths is None:
|
878 |
+
assert (
|
879 |
+
getattr(self, "patcher", None) is not None
|
880 |
+
), "Patcher not defined and no patch_lengths passed."
|
881 |
+
patch_lengths, tok_scores = self.patcher.patch(
|
882 |
+
local_encoder_tokens,
|
883 |
+
include_next_token=True,
|
884 |
+
threshold=self.patcher.threshold,
|
885 |
+
)
|
886 |
+
else:
|
887 |
+
if nb_boe > 0:
|
888 |
+
patch_lengths[:, 0] += nb_boe
|
889 |
+
|
890 |
+
assert torch.min(patch_lengths) >= 0
|
891 |
+
|
892 |
+
# Generate patch IDs from patch_lengths
|
893 |
+
patch_ids = patch_ids_from_lengths(
|
894 |
+
patch_lengths, local_encoder_tokens.shape[-1]
|
895 |
+
)
|
896 |
+
assert torch.max(patch_ids) + 1 <= torch.max(
|
897 |
+
(patch_lengths != 0).sum(dim=-1)
|
898 |
+
), f"{torch.max(patch_ids) + 1} > {torch.max((patch_lengths != 0).sum(dim=-1))}"
|
899 |
+
|
900 |
+
cross_attn_mask_enc = None
|
901 |
+
# Cross-attention encoder
|
902 |
+
if self.cross_attn_encoder:
|
903 |
+
cross_attn_mask_enc = cross_attn_mask(
|
904 |
+
patch_ids,
|
905 |
+
patch_lengths,
|
906 |
+
N,
|
907 |
+
patches_as_queries=True,
|
908 |
+
cross_attn_k=self.cross_attn_k,
|
909 |
+
window=self.cross_attn_window_encoder,
|
910 |
+
block_mask=self.cross_attn_use_flex_attention,
|
911 |
+
)
|
912 |
+
|
913 |
+
# Hashing and embedding
|
914 |
+
local_encoder_embeds = compute_hash_embeddings(
|
915 |
+
local_encoder_tokens=local_encoder_tokens,
|
916 |
+
local_encoder=self.local_encoder,
|
917 |
+
encoder_hash_tok_embedding=self.encoder_hash_tok_embedding,
|
918 |
+
encoder_hash_byte_group_nb_functions=self.encoder_hash_byte_group_nb_functions,
|
919 |
+
encoder_hash_byte_group_size=self.encoder_hash_byte_group_size,
|
920 |
+
encoder_hash_byte_group_vocab=self.encoder_hash_byte_group_vocab,
|
921 |
+
)
|
922 |
+
|
923 |
+
# N-gram table embeddings
|
924 |
+
if self.encoder_ngram_embedding is not None:
|
925 |
+
assert ngram_ids is not None, "ngram_ids must be provided"
|
926 |
+
if local_encoder_embeds is None:
|
927 |
+
local_encoder_embeds = self.local_encoder.tok_embeddings(
|
928 |
+
local_encoder_tokens
|
929 |
+
)
|
930 |
+
assert len(ngram_ids) == len(
|
931 |
+
self.encoder_ngram_embedding
|
932 |
+
), f"ngram_ids.shape[0]={ngram_ids.shape[0]} versus len(encoder_ngram_embedding)={len(self.encoder_ngram_embedding)}, ngram_ids.shape={ngram_ids.shape}"
|
933 |
+
for i in range(ngram_ids.shape[0]):
|
934 |
+
ngram_embedding = self.encoder_ngram_embedding[i]
|
935 |
+
ngram_embeds = ngram_embedding(ngram_ids[i])
|
936 |
+
assert (
|
937 |
+
local_encoder_embeds.shape == ngram_embeds.shape
|
938 |
+
), f"Shape mismatch: {local_encoder_embeds.shape} vs {ngram_embeds.shape}, ngram_ids.shape={ngram_ids.shape}"
|
939 |
+
local_encoder_embeds = local_encoder_embeds + ngram_embeds
|
940 |
+
|
941 |
+
# Local encoder
|
942 |
+
h_cross = None
|
943 |
+
(h_encoder, h_cross), cache_encoder = self.local_encoder(
|
944 |
+
tokens=local_encoder_tokens,
|
945 |
+
embeds=local_encoder_embeds,
|
946 |
+
patch_embeds=h_cross if self.cross_attn_encoder else None,
|
947 |
+
cross_mask=cross_attn_mask_enc,
|
948 |
+
num_patches=patch_lengths.shape[1],
|
949 |
+
patch_ids=patch_ids,
|
950 |
+
)
|
951 |
+
|
952 |
+
# Downsampling
|
953 |
+
if not self.cross_attn_encoder:
|
954 |
+
assert (
|
955 |
+
patch_ids.shape[1] == h_encoder.shape[1]
|
956 |
+
), f"{patch_ids.shape[1]} != {h_encoder.shape[1]}"
|
957 |
+
h = downsample(
|
958 |
+
h_encoder,
|
959 |
+
patch_lengths.shape[1],
|
960 |
+
patch_lengths,
|
961 |
+
patch_ids,
|
962 |
+
downsampling_by_pooling=self.downsampling_by_pooling,
|
963 |
+
patch_size=self.patch_size,
|
964 |
+
)
|
965 |
+
else:
|
966 |
+
# Reshape h_cross
|
967 |
+
h = h_cross.view(bs, patch_lengths.shape[1], -1)
|
968 |
+
|
969 |
+
# Global transformer
|
970 |
+
global_tokens = tokens.new(h.shape[0], h.shape[1]).fill_(self.boe_id)
|
971 |
+
rows, cols = torch.where(local_encoder_tokens == self.eos_id)
|
972 |
+
eos_patch_ids = patch_ids[rows, cols]
|
973 |
+
global_tokens[rows, eos_patch_ids] = self.eos_id
|
974 |
+
|
975 |
+
h, _ = self.global_transformer(
|
976 |
+
embeds=h,
|
977 |
+
tokens=global_tokens,
|
978 |
+
)
|
979 |
+
|
980 |
+
# Unpatching
|
981 |
+
dec_embeds = h_encoder[:, nb_boe : nb_boe + N, :]
|
982 |
+
|
983 |
+
# Generate decoder patch IDs
|
984 |
+
decoder_patch_ids = decoder_patch_ids_from_lengths(
|
985 |
+
patch_lengths, nb_boe, local_decoder_tokens.shape[-1]
|
986 |
+
)
|
987 |
+
assert (
|
988 |
+
torch.max(decoder_patch_ids) + 1 <= h.shape[1]
|
989 |
+
), f"{torch.max(decoder_patch_ids) + 1} > {h.shape[1]}"
|
990 |
+
assert (
|
991 |
+
decoder_patch_ids.shape[1] == dec_embeds.shape[1]
|
992 |
+
), f"{decoder_patch_ids.shape[1]} != {dec_embeds.shape[1]}"
|
993 |
+
|
994 |
+
# Cross-attention decoder
|
995 |
+
if not self.cross_attn_decoder:
|
996 |
+
h = torch.gather(
|
997 |
+
h, 1, decoder_patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1])
|
998 |
+
)
|
999 |
+
cross_attn_mask_dec = None
|
1000 |
+
assert local_decoder_tokens.shape == h.shape[:-1]
|
1001 |
+
else:
|
1002 |
+
cross_attn_mask_dec = cross_attn_mask(
|
1003 |
+
decoder_patch_ids,
|
1004 |
+
patch_lengths,
|
1005 |
+
N,
|
1006 |
+
patches_as_queries=False,
|
1007 |
+
cross_attn_k=self.cross_attn_k,
|
1008 |
+
window=self.cross_attn_window_decoder,
|
1009 |
+
block_mask=self.cross_attn_use_flex_attention,
|
1010 |
+
)
|
1011 |
+
|
1012 |
+
# Local decoder
|
1013 |
+
output, _ = self.local_decoder(
|
1014 |
+
embeds=dec_embeds,
|
1015 |
+
patch_embeds=h,
|
1016 |
+
tokens=local_decoder_tokens,
|
1017 |
+
cross_mask=cross_attn_mask_dec,
|
1018 |
+
)
|
1019 |
+
return output
|
1020 |
+
|
1021 |
+
def reset_parameters(self, init_std=None):
|
1022 |
+
# Either use fixed base std or sqrt model dim
|
1023 |
+
init_std = init_std or (self.dim ** (-0.5))
|
1024 |
+
nn.init.trunc_normal_(
|
1025 |
+
self.tok_embeddings.weight,
|
1026 |
+
mean=0.0,
|
1027 |
+
std=init_std,
|
1028 |
+
a=-3 * init_std,
|
1029 |
+
b=3 * init_std,
|
1030 |
+
)
|
1031 |
+
if not self.weight_tying:
|
1032 |
+
nn.init.trunc_normal_(
|
1033 |
+
self.output.weight,
|
1034 |
+
mean=0.0,
|
1035 |
+
std=init_std,
|
1036 |
+
a=-3 * init_std,
|
1037 |
+
b=3 * init_std,
|
1038 |
+
)
|
1039 |
+
|
1040 |
+
def init_weights(self):
|
1041 |
+
self.reset_parameters()
|
1042 |
+
self.init_base_std = self.init_base_std or (self.dim ** (-0.5))
|
1043 |
+
for depth, layer in enumerate(self.layers):
|
1044 |
+
factor = {
|
1045 |
+
InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
|
1046 |
+
InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5,
|
1047 |
+
InitStdFactor.DIM_RATIO: self.dim / 4096,
|
1048 |
+
InitStdFactor.DISABLED: 1.0,
|
1049 |
+
}[self.init_std_factor]
|
1050 |
+
|
1051 |
+
layer.init_weights(self.init_base_std, factor)
|
1052 |
+
|
1053 |
+
self.local_decoder.init_weights(self.init_base_std)
|
1054 |
+
self.global_transformer.init_weights(self.init_base_std)
|
1055 |
+
self.local_encoder.init_weights(self.init_base_std)
|
1056 |
+
|
1057 |
+
for emb in self.encoder_hash_tok_embedding:
|
1058 |
+
nn.init.trunc_normal_(
|
1059 |
+
emb.weight,
|
1060 |
+
mean=0.0,
|
1061 |
+
std=self.init_base_std,
|
1062 |
+
a=-3 * self.init_base_std,
|
1063 |
+
b=3 * self.init_base_std,
|
1064 |
+
)
|
bytelatent/model/local_models.py
ADDED
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
import logging
|
4 |
+
from typing import List, Optional, Tuple, Union
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn
|
8 |
+
import torch.nn as nn
|
9 |
+
from torch.nn import functional as F
|
10 |
+
from torch.nn.attention.flex_attention import BlockMask
|
11 |
+
from xformers.ops import AttentionBias
|
12 |
+
|
13 |
+
from bytelatent.base_transformer import (
|
14 |
+
InitStdFactor,
|
15 |
+
RMSNorm,
|
16 |
+
RotaryEmbedding,
|
17 |
+
TransformerBlock,
|
18 |
+
)
|
19 |
+
from bytelatent.model.transformer import CrossAttention
|
20 |
+
from bytelatent.model.utils import create_causal_mask, downsample
|
21 |
+
from bytelatent.tokenizers.blt_tokenizer import BOE_ID
|
22 |
+
|
23 |
+
logger = logging.getLogger()
|
24 |
+
|
25 |
+
|
26 |
+
class LocalModelBase(nn.Module):
|
27 |
+
def __init__(self, args):
|
28 |
+
super().__init__()
|
29 |
+
|
30 |
+
self.dim = args.dim
|
31 |
+
self.dropout = args.dropout
|
32 |
+
self.vocab_size = args.vocab_size + args.pm_size
|
33 |
+
self.patch_size = args.patch_size
|
34 |
+
|
35 |
+
self.efficient_attn = args.efficient_attn
|
36 |
+
self.sliding_window = args.sliding_window
|
37 |
+
self.use_rope = args.use_rope
|
38 |
+
self.init_std_factor = args.init_std_factor
|
39 |
+
self.cross_attn_encoder = getattr(args, "cross_attn_encoder", None)
|
40 |
+
self.cross_attn_decoder = getattr(args, "cross_attn_decoder", None)
|
41 |
+
self.cross_attn_k = getattr(args, "cross_attn_k", None)
|
42 |
+
|
43 |
+
self.boe_id = BOE_ID
|
44 |
+
|
45 |
+
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
|
46 |
+
self.layers = nn.ModuleList(
|
47 |
+
[TransformerBlock(args) for _ in range(args.n_layers)]
|
48 |
+
)
|
49 |
+
|
50 |
+
self.tok_embeddings = nn.Embedding(self.vocab_size, args.dim)
|
51 |
+
if not self.use_rope:
|
52 |
+
self.pos_embeddings = nn.Embedding(args.max_length, args.dim)
|
53 |
+
else:
|
54 |
+
self.rope = RotaryEmbedding(
|
55 |
+
theta=args.rope_theta,
|
56 |
+
head_dim=args.head_dim or args.dim // args.n_heads,
|
57 |
+
max_seqlen=getattr(args, "max_encoder_seq_length", args.max_length),
|
58 |
+
)
|
59 |
+
self.pos_embeddings = None
|
60 |
+
|
61 |
+
self.token_embedding_projection = (
|
62 |
+
nn.Linear(args.dim_token_emb, args.dim, bias=False)
|
63 |
+
if hasattr(args, "dim_token_emb") and args.dim_token_emb != self.dim
|
64 |
+
else None
|
65 |
+
)
|
66 |
+
|
67 |
+
self.patch_embedding_projection = self._create_patch_projection(args)
|
68 |
+
|
69 |
+
def _should_create_patch_projection(self, args):
|
70 |
+
dimension_mismatch = (
|
71 |
+
getattr(args, "dim_patch_emb") and args.dim_patch_emb != self.dim
|
72 |
+
)
|
73 |
+
|
74 |
+
# Check cross attention conditions
|
75 |
+
cross_attn_conditions = (
|
76 |
+
hasattr(args, "cross_attn_encoder")
|
77 |
+
and args.cross_attn_encoder
|
78 |
+
and getattr(args, "cross_attn_init_by_pooling")
|
79 |
+
) or (
|
80 |
+
hasattr(args, "cross_attn_decoder")
|
81 |
+
and args.cross_attn_decoder
|
82 |
+
and getattr(args, "cross_attn_init_by_pooling")
|
83 |
+
)
|
84 |
+
|
85 |
+
return dimension_mismatch or cross_attn_conditions
|
86 |
+
|
87 |
+
def _create_patch_projection(self, args):
|
88 |
+
if not self._should_create_patch_projection(args):
|
89 |
+
return None
|
90 |
+
|
91 |
+
output_dim = args.dim_token_emb * (self.cross_attn_k or 1)
|
92 |
+
|
93 |
+
return nn.Linear(
|
94 |
+
in_features=args.dim_patch_emb,
|
95 |
+
out_features=output_dim,
|
96 |
+
bias=False,
|
97 |
+
)
|
98 |
+
|
99 |
+
def apply_embedding(self, tokens, embeds):
|
100 |
+
if embeds is not None:
|
101 |
+
return embeds
|
102 |
+
else:
|
103 |
+
return self.tok_embeddings(tokens)
|
104 |
+
|
105 |
+
def init_weights(self, init_std=None):
|
106 |
+
self.rope.reset_parameters()
|
107 |
+
|
108 |
+
init_std = init_std or (self.dim ** (-0.5))
|
109 |
+
nn.init.trunc_normal_(
|
110 |
+
self.tok_embeddings.weight,
|
111 |
+
mean=0.0,
|
112 |
+
std=init_std,
|
113 |
+
a=-3 * init_std,
|
114 |
+
b=3 * init_std,
|
115 |
+
)
|
116 |
+
if self.pos_embeddings is not None:
|
117 |
+
nn.init.trunc_normal_(
|
118 |
+
self.pos_embeddings.weight,
|
119 |
+
mean=0.0,
|
120 |
+
std=init_std,
|
121 |
+
a=-3 * init_std,
|
122 |
+
b=3 * init_std,
|
123 |
+
)
|
124 |
+
|
125 |
+
for depth, layer in enumerate(self.layers):
|
126 |
+
factor = {
|
127 |
+
InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
|
128 |
+
InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5,
|
129 |
+
InitStdFactor.DIM_RATIO: self.dim / 4096,
|
130 |
+
InitStdFactor.DISABLED: 1.0,
|
131 |
+
}[self.init_std_factor]
|
132 |
+
|
133 |
+
layer.init_weights(init_std, factor)
|
134 |
+
|
135 |
+
if self.token_embedding_projection is not None:
|
136 |
+
nn.init.trunc_normal_(
|
137 |
+
self.token_embedding_projection.weight,
|
138 |
+
mean=0.0,
|
139 |
+
std=init_std,
|
140 |
+
a=-3 * init_std,
|
141 |
+
b=3 * init_std,
|
142 |
+
)
|
143 |
+
|
144 |
+
if self.patch_embedding_projection is not None:
|
145 |
+
nn.init.trunc_normal_(
|
146 |
+
self.patch_embedding_projection.weight,
|
147 |
+
mean=0.0,
|
148 |
+
std=init_std,
|
149 |
+
a=-3 * init_std,
|
150 |
+
b=3 * init_std,
|
151 |
+
)
|
152 |
+
|
153 |
+
if hasattr(self, "output"):
|
154 |
+
nn.init.trunc_normal_(
|
155 |
+
self.output.weight,
|
156 |
+
mean=0.0,
|
157 |
+
std=init_std,
|
158 |
+
a=-3 * init_std,
|
159 |
+
b=3 * init_std,
|
160 |
+
)
|
161 |
+
|
162 |
+
if self.cross_attn_layers is not None:
|
163 |
+
for depth, layer in enumerate(self.cross_attn_layers):
|
164 |
+
factor = {
|
165 |
+
InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
|
166 |
+
InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5,
|
167 |
+
InitStdFactor.DIM_RATIO: self.dim / 4096,
|
168 |
+
InitStdFactor.DISABLED: 1.0,
|
169 |
+
}[self.init_std_factor]
|
170 |
+
|
171 |
+
layer.init_weights(init_std, factor)
|
172 |
+
|
173 |
+
|
174 |
+
class LocalEncoder(LocalModelBase):
|
175 |
+
def __init__(self, args):
|
176 |
+
super().__init__(args)
|
177 |
+
self.output_proj = (
|
178 |
+
args.patching_mode in ["entropy", "probmax"]
|
179 |
+
) and args.entropy_model_checkpoint_dir is None
|
180 |
+
|
181 |
+
self.apply_transformer = args.use_local_encoder_transformer
|
182 |
+
self.downsampling_by_pooling = args.downsampling_by_pooling
|
183 |
+
self.patch_only = args.patch_only_encoder
|
184 |
+
self.expects_hash_embeddings = args.encoder_hash_byte_group_size is not None
|
185 |
+
self.cross_attn_encoder = args.cross_attn_encoder
|
186 |
+
self.cross_attn_all_layers_encoder = args.cross_attn_all_layers_encoder
|
187 |
+
self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling
|
188 |
+
self.cross_attn_nheads = args.cross_attn_nheads
|
189 |
+
|
190 |
+
if self.cross_attn_encoder:
|
191 |
+
self.cross_attn_layers = torch.nn.ModuleList()
|
192 |
+
layers_to_add = args.n_layers if self.cross_attn_all_layers_encoder else 1
|
193 |
+
for _ in range(layers_to_add):
|
194 |
+
self.cross_attn_layers.append(
|
195 |
+
CrossAttention(
|
196 |
+
dim=self.dim,
|
197 |
+
head_dim=self.dim // self.cross_attn_nheads,
|
198 |
+
n_heads=self.cross_attn_nheads,
|
199 |
+
n_kv_heads=self.cross_attn_nheads,
|
200 |
+
norm_eps=args.norm_eps,
|
201 |
+
)
|
202 |
+
)
|
203 |
+
|
204 |
+
def apply_embedding(self, tokens, embeds):
|
205 |
+
if embeds is not None:
|
206 |
+
assert (
|
207 |
+
self.expects_hash_embeddings
|
208 |
+
), "Not expecting embeddings to be passed."
|
209 |
+
return embeds
|
210 |
+
else:
|
211 |
+
return self.tok_embeddings(tokens)
|
212 |
+
|
213 |
+
def forward(
|
214 |
+
self,
|
215 |
+
tokens: torch.Tensor,
|
216 |
+
embeds: Optional[torch.Tensor] = None,
|
217 |
+
patch_embeds: Optional[torch.Tensor] = None,
|
218 |
+
mask: Optional[Union["BlockMask", "AttentionBias", torch.Tensor, str]] = None,
|
219 |
+
cross_mask: Optional[torch.Tensor] = None,
|
220 |
+
num_patches: Optional[int] = None,
|
221 |
+
patch_ids: Optional[torch.Tensor] = None,
|
222 |
+
cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
|
223 |
+
):
|
224 |
+
""" """
|
225 |
+
bs, seqlen = tokens.shape
|
226 |
+
if mask is None:
|
227 |
+
mask = create_causal_mask(seqlen, self.efficient_attn, self.sliding_window)
|
228 |
+
|
229 |
+
h = self.apply_embedding(tokens, embeds)
|
230 |
+
freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None
|
231 |
+
|
232 |
+
h = F.dropout(h, p=self.dropout, training=self.training)
|
233 |
+
|
234 |
+
for i, layer in enumerate(self.layers):
|
235 |
+
h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.efficient_attn)
|
236 |
+
# check if cross attention should be applied to either all layer or only the last layer
|
237 |
+
if self.cross_attn_encoder and (
|
238 |
+
i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder
|
239 |
+
):
|
240 |
+
patch_embeds = self.apply_cross_attention(
|
241 |
+
h, patch_embeds, i, bs, num_patches, patch_ids, cross_mask
|
242 |
+
)
|
243 |
+
|
244 |
+
h_residual = patch_embeds if self.cross_attn_encoder else None
|
245 |
+
return (h, h_residual), cache
|
246 |
+
|
247 |
+
def apply_cross_attention(
|
248 |
+
self, h, patch_embeds, layer_idx, bs, num_patches, patch_ids, cross_mask
|
249 |
+
):
|
250 |
+
# apply pooling and project
|
251 |
+
if self.cross_attn_init_by_pooling and patch_embeds is None:
|
252 |
+
patch_embeds = downsample(
|
253 |
+
h,
|
254 |
+
num_patches,
|
255 |
+
patch_ids=patch_ids,
|
256 |
+
downsampling_by_pooling=self.downsampling_by_pooling,
|
257 |
+
patch_size=self.patch_size,
|
258 |
+
)
|
259 |
+
if self.patch_embedding_projection is not None:
|
260 |
+
patch_embeds = self.patch_embedding_projection(patch_embeds)
|
261 |
+
patch_embeds = patch_embeds.reshape(
|
262 |
+
bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim
|
263 |
+
)
|
264 |
+
|
265 |
+
layer_idx = layer_idx if self.cross_attn_all_layers_encoder else 0
|
266 |
+
patch_embeds_cross = self.cross_attn_layers[layer_idx](
|
267 |
+
x=patch_embeds,
|
268 |
+
kv=h,
|
269 |
+
mask=cross_mask,
|
270 |
+
)
|
271 |
+
patch_embeds += patch_embeds_cross
|
272 |
+
return patch_embeds
|
273 |
+
|
274 |
+
|
275 |
+
class LocalDecoder(LocalModelBase):
|
276 |
+
def __init__(self, args):
|
277 |
+
super().__init__(args)
|
278 |
+
|
279 |
+
# Model configuration flags
|
280 |
+
self.patch_only = args.patch_only_decoder
|
281 |
+
self.expects_embeddings = args.share_encoder_decoder_emb
|
282 |
+
self.cross_attn_decoder = args.cross_attn_decoder
|
283 |
+
self.cross_attn_all_layers_decoder = args.cross_attn_all_layers_decoder
|
284 |
+
self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling
|
285 |
+
self.cross_attn_nheads = args.cross_attn_nheads
|
286 |
+
|
287 |
+
if self.cross_attn_decoder:
|
288 |
+
self.cross_attn_layers = torch.nn.ModuleList()
|
289 |
+
layers_to_add = args.n_layers if self.cross_attn_all_layers_decoder else 1
|
290 |
+
for _ in range(layers_to_add):
|
291 |
+
self.cross_attn_layers.append(
|
292 |
+
CrossAttention(
|
293 |
+
dim=self.dim,
|
294 |
+
head_dim=self.dim // self.cross_attn_nheads,
|
295 |
+
n_heads=self.cross_attn_nheads,
|
296 |
+
n_kv_heads=self.cross_attn_nheads,
|
297 |
+
norm_eps=args.norm_eps,
|
298 |
+
)
|
299 |
+
)
|
300 |
+
|
301 |
+
self.output = nn.Linear(
|
302 |
+
self.dim,
|
303 |
+
args.vocab_size,
|
304 |
+
bias=False,
|
305 |
+
)
|
306 |
+
|
307 |
+
def forward(
|
308 |
+
self,
|
309 |
+
tokens: torch.Tensor,
|
310 |
+
embeds: Optional[torch.Tensor],
|
311 |
+
patch_embeds: Optional[torch.Tensor] = None,
|
312 |
+
mask: Optional[Union["BlockMask", "AttentionBias", torch.Tensor, str]] = None,
|
313 |
+
cross_mask: Optional[torch.Tensor] = None,
|
314 |
+
cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
|
315 |
+
):
|
316 |
+
bs, seqlen = tokens.shape
|
317 |
+
assert embeds is not None, "Embeddings must be provided"
|
318 |
+
|
319 |
+
if mask is None:
|
320 |
+
mask = create_causal_mask(seqlen, self.efficient_attn, self.sliding_window)
|
321 |
+
|
322 |
+
h = embeds
|
323 |
+
|
324 |
+
if self.patch_embedding_projection is not None:
|
325 |
+
assert patch_embeds is not None, "Patch embeddings must be passed."
|
326 |
+
patch_embeds = self.patch_embedding_projection(patch_embeds)
|
327 |
+
if self.cross_attn_k is not None:
|
328 |
+
patch_embeds = patch_embeds.reshape(
|
329 |
+
bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim
|
330 |
+
)
|
331 |
+
|
332 |
+
if patch_embeds is not None and not self.cross_attn_decoder:
|
333 |
+
h = h + patch_embeds
|
334 |
+
|
335 |
+
freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None
|
336 |
+
|
337 |
+
h = F.dropout(h, p=self.dropout, training=self.training)
|
338 |
+
for i, layer in enumerate(self.layers):
|
339 |
+
if self.cross_attn_decoder and (
|
340 |
+
i == 0 or self.cross_attn_all_layers_decoder
|
341 |
+
):
|
342 |
+
# Use cross attention to extract info from patch_embeds into h
|
343 |
+
h_cross = self.cross_attn_layers[i](
|
344 |
+
x=h,
|
345 |
+
kv=patch_embeds,
|
346 |
+
mask=cross_mask,
|
347 |
+
)
|
348 |
+
h = h + h_cross
|
349 |
+
|
350 |
+
h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.efficient_attn)
|
351 |
+
|
352 |
+
h_preds = self.norm(h)
|
353 |
+
h_preds = F.dropout(h_preds, p=self.dropout, training=self.training)
|
354 |
+
h_preds = self.output(h_preds)
|
355 |
+
h_preds = h_preds.float()
|
356 |
+
return h_preds, cache
|
bytelatent/model/transformer.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
import logging
|
3 |
+
from typing import List, Optional, Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn
|
7 |
+
import torch.nn as nn
|
8 |
+
from torch.nn import functional as F
|
9 |
+
from torch.nn.attention.flex_attention import BlockMask
|
10 |
+
from xformers.ops import AttentionBias
|
11 |
+
|
12 |
+
from bytelatent.base_transformer import (
|
13 |
+
BaseTransformer,
|
14 |
+
RMSNorm,
|
15 |
+
flex_attention_comp,
|
16 |
+
repeat_kv,
|
17 |
+
)
|
18 |
+
from bytelatent.model.utils import create_causal_mask
|
19 |
+
|
20 |
+
logger = logging.getLogger()
|
21 |
+
|
22 |
+
|
23 |
+
class CrossAttention(nn.Module):
|
24 |
+
"""
|
25 |
+
CrossAttention block to attend to the encoder states from the decoder.
|
26 |
+
Rope is not supported.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
dim: int,
|
32 |
+
head_dim: int,
|
33 |
+
n_heads: int,
|
34 |
+
n_kv_heads: int,
|
35 |
+
norm_eps: float,
|
36 |
+
):
|
37 |
+
super().__init__()
|
38 |
+
|
39 |
+
self.dim = dim
|
40 |
+
self.head_dim = head_dim
|
41 |
+
|
42 |
+
self.n_heads = n_heads
|
43 |
+
self.n_kv_heads = n_kv_heads
|
44 |
+
self.heads_per_group = self.n_heads // self.n_kv_heads
|
45 |
+
|
46 |
+
self.cross_attn_norm_q = RMSNorm(dim, eps=norm_eps)
|
47 |
+
self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps)
|
48 |
+
|
49 |
+
self.wq = nn.Linear(
|
50 |
+
dim,
|
51 |
+
n_heads * head_dim,
|
52 |
+
bias=False,
|
53 |
+
)
|
54 |
+
self.wk = nn.Linear(
|
55 |
+
dim,
|
56 |
+
n_kv_heads * head_dim,
|
57 |
+
bias=False,
|
58 |
+
)
|
59 |
+
self.wv = nn.Linear(
|
60 |
+
dim,
|
61 |
+
n_kv_heads * head_dim,
|
62 |
+
bias=False,
|
63 |
+
)
|
64 |
+
|
65 |
+
self.wo = nn.Linear(
|
66 |
+
n_heads * head_dim,
|
67 |
+
dim,
|
68 |
+
bias=False,
|
69 |
+
)
|
70 |
+
|
71 |
+
def forward(
|
72 |
+
self,
|
73 |
+
x: torch.Tensor,
|
74 |
+
kv: torch.Tensor,
|
75 |
+
mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
|
76 |
+
) -> torch.Tensor:
|
77 |
+
# B S D
|
78 |
+
bsz, seq_len, _ = x.shape
|
79 |
+
_, slen_kv, _ = kv.shape
|
80 |
+
x = self.cross_attn_norm_q(x)
|
81 |
+
kv = self.cross_attn_norm_kv(kv)
|
82 |
+
|
83 |
+
xq = self.wq(x)
|
84 |
+
xk = self.wk(kv)
|
85 |
+
xv = self.wv(kv)
|
86 |
+
|
87 |
+
output_shape = xq.shape
|
88 |
+
# B S D -> B S H D
|
89 |
+
xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim)
|
90 |
+
xk = xk.view(bsz, slen_kv, self.n_kv_heads, self.head_dim)
|
91 |
+
xv = xv.view(bsz, slen_kv, self.n_kv_heads, self.head_dim)
|
92 |
+
|
93 |
+
xk = repeat_kv(xk, self.heads_per_group, dim=2)
|
94 |
+
xv = repeat_kv(xv, self.heads_per_group, dim=2)
|
95 |
+
|
96 |
+
assert mask is None or isinstance(mask, BlockMask)
|
97 |
+
xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv))
|
98 |
+
output = flex_attention_comp(xq, xk, xv, block_mask=mask)
|
99 |
+
output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
|
100 |
+
|
101 |
+
output = self.wo(output.reshape(output_shape))
|
102 |
+
|
103 |
+
return x + output
|
104 |
+
|
105 |
+
def init_weights(self, base_std: float, factor: float = 1.0):
|
106 |
+
std = base_std * factor
|
107 |
+
|
108 |
+
nn.init.trunc_normal_(
|
109 |
+
self.wq.weight,
|
110 |
+
mean=0.0,
|
111 |
+
std=std,
|
112 |
+
a=-3 * std,
|
113 |
+
b=3 * std,
|
114 |
+
)
|
115 |
+
|
116 |
+
nn.init.trunc_normal_(
|
117 |
+
self.wk.weight,
|
118 |
+
mean=0.0,
|
119 |
+
std=std,
|
120 |
+
a=-3 * std,
|
121 |
+
b=3 * std,
|
122 |
+
)
|
123 |
+
|
124 |
+
nn.init.trunc_normal_(
|
125 |
+
self.wv.weight,
|
126 |
+
mean=0.0,
|
127 |
+
std=std,
|
128 |
+
a=-3 * std,
|
129 |
+
b=3 * std,
|
130 |
+
)
|
131 |
+
|
132 |
+
output_std = std / (2**0.5)
|
133 |
+
nn.init.trunc_normal_(
|
134 |
+
self.wo.weight,
|
135 |
+
mean=0.0,
|
136 |
+
std=output_std,
|
137 |
+
a=-3 * output_std,
|
138 |
+
b=3 * output_std,
|
139 |
+
)
|
140 |
+
self.cross_attn_norm_q.reset_parameters()
|
141 |
+
self.cross_attn_norm_kv.reset_parameters()
|
142 |
+
|
143 |
+
|
144 |
+
class GlobalTransformer(BaseTransformer):
|
145 |
+
def __init__(self, args):
|
146 |
+
super().__init__(args)
|
147 |
+
self.dropout = args.dropout
|
148 |
+
self.sliding_window = args.sliding_window
|
149 |
+
self.efficient_attn = args.efficient_attn
|
150 |
+
|
151 |
+
self.token_embedding_projection = None
|
152 |
+
if args.dim_token_emb is not None and args.dim_token_emb != self.dim:
|
153 |
+
self.token_embedding_projection = nn.Linear(
|
154 |
+
args.dim_token_emb,
|
155 |
+
args.dim,
|
156 |
+
bias=False,
|
157 |
+
)
|
158 |
+
|
159 |
+
def forward(
|
160 |
+
self,
|
161 |
+
tokens: torch.Tensor,
|
162 |
+
tok_idx: Optional[torch.Tensor] = None,
|
163 |
+
embeds: Optional[torch.Tensor] = None,
|
164 |
+
mask: Optional[Union[BlockMask, AttentionBias, torch.Tensor, str]] = None,
|
165 |
+
cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
|
166 |
+
):
|
167 |
+
"""
|
168 |
+
Similar to BaseTransformer.forward, but with an additional embeds argument
|
169 |
+
and projection to the token space.
|
170 |
+
"""
|
171 |
+
bs, seqlen = tokens.shape
|
172 |
+
attn_impl = self.efficient_attn
|
173 |
+
|
174 |
+
h = embeds
|
175 |
+
|
176 |
+
mask = (
|
177 |
+
mask
|
178 |
+
if mask is not None
|
179 |
+
else create_causal_mask(seqlen, attn_impl, self.sliding_window)
|
180 |
+
)
|
181 |
+
|
182 |
+
if self.token_embedding_projection is not None and h.shape[-1] != self.dim:
|
183 |
+
h = self.token_embedding_projection(h)
|
184 |
+
|
185 |
+
h = F.dropout(h, p=self.dropout, training=self.training)
|
186 |
+
|
187 |
+
h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)
|
188 |
+
return h, cache
|
189 |
+
|
190 |
+
def init_weights(self, init_base_std: float):
|
191 |
+
super().init_weights()
|
192 |
+
if self.token_embedding_projection is not None:
|
193 |
+
nn.init.trunc_normal_(
|
194 |
+
self.token_embedding_projection.weight,
|
195 |
+
mean=0.0,
|
196 |
+
std=init_base_std,
|
197 |
+
a=-3 * init_base_std,
|
198 |
+
b=3 * init_base_std,
|
199 |
+
)
|
bytelatent/model/utils.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
import torch
|
3 |
+
from torch.nn.attention.flex_attention import create_block_mask
|
4 |
+
from xformers.ops import fmha
|
5 |
+
|
6 |
+
|
7 |
+
def patch_reduce(h, max_num_patches, reduction, patch_ids):
|
8 |
+
"""
|
9 |
+
Reduce variable length patches to single embedding per patch
|
10 |
+
Note: this works with variable number of patches for different sequences in the batch
|
11 |
+
It handles variable length patches by assuming that patch_lengths will be 0 for any
|
12 |
+
extra patches on the *right*. Since there can be a variable number of patches
|
13 |
+
this function also return the number of patches for each sequence in the batch.
|
14 |
+
Any embeddings on the right that are not allocated to a patch
|
15 |
+
(i.e. if the sum(patch_lengths[i]) < seq_len for any i)
|
16 |
+
will be sent to a dummy patch, which is trimmed before returning.
|
17 |
+
"""
|
18 |
+
bs, seq_len, emb_dim = h.shape
|
19 |
+
|
20 |
+
patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1])
|
21 |
+
|
22 |
+
reduced_embs = torch.zeros(
|
23 |
+
(bs, max_num_patches, emb_dim), dtype=h.dtype, device=h.device
|
24 |
+
)
|
25 |
+
reduced_embs = reduced_embs.scatter_reduce(
|
26 |
+
src=h,
|
27 |
+
dim=1,
|
28 |
+
index=patch_ids,
|
29 |
+
reduce=reduction,
|
30 |
+
include_self=False,
|
31 |
+
)
|
32 |
+
reduced_embs = reduced_embs[:, :max_num_patches, :]
|
33 |
+
|
34 |
+
return reduced_embs
|
35 |
+
|
36 |
+
|
37 |
+
def concat_downsample(h, patch_lengths, patch_size):
|
38 |
+
# The assumption in this function is that seq_len = patch_size * num_patches.
|
39 |
+
bs, seq_len, emb_dim = h.shape
|
40 |
+
patch_end_ids = torch.cumsum(patch_lengths, dim=1)
|
41 |
+
patch_ids = patch_end_ids.unsqueeze(-1) - torch.arange(patch_size, 0, -1).to(
|
42 |
+
patch_end_ids.device
|
43 |
+
)
|
44 |
+
# Is clamp ok here?
|
45 |
+
patch_ids = patch_ids.clamp(min=0).unsqueeze(-1).expand(-1, -1, -1, h.shape[-1])
|
46 |
+
patch_ids = patch_ids.view(bs, -1, emb_dim)
|
47 |
+
# after gather h.shape = [batch_size, seq_len, dim]
|
48 |
+
h = torch.gather(h, 1, patch_ids)
|
49 |
+
h = h.reshape(bs, patch_lengths.shape[1], patch_size * h.size(-1))
|
50 |
+
return h
|
51 |
+
|
52 |
+
|
53 |
+
def pooling_downsample(h, max_num_patches, pooling_mode, patch_ids):
|
54 |
+
cat = []
|
55 |
+
if "avg" in pooling_mode or "mean" in pooling_mode:
|
56 |
+
cat.append(patch_reduce(h, max_num_patches, "mean", patch_ids))
|
57 |
+
if "min" in pooling_mode:
|
58 |
+
cat.append(patch_reduce(h, max_num_patches, "amin", patch_ids))
|
59 |
+
if "max" in pooling_mode:
|
60 |
+
cat.append(patch_reduce(h, max_num_patches, "amax", patch_ids))
|
61 |
+
assert len(cat) > 0
|
62 |
+
h = torch.cat(cat, dim=-1)
|
63 |
+
return h
|
64 |
+
|
65 |
+
|
66 |
+
def downsample(
|
67 |
+
h,
|
68 |
+
num_patches,
|
69 |
+
patch_lengths=None,
|
70 |
+
patch_ids=None,
|
71 |
+
downsampling_by_pooling=None,
|
72 |
+
patch_size=4,
|
73 |
+
):
|
74 |
+
"""
|
75 |
+
Downsampling:
|
76 |
+
a. concatenating embeddings in the patch
|
77 |
+
Note: with dynamic patching, patch the last patch_size tokens.
|
78 |
+
b. pooling embeddings in the patch
|
79 |
+
"""
|
80 |
+
# input: h.shape = [batch_size, seq_len, dim]
|
81 |
+
# input: pool h.shape = [batch_size, seq_len / patch_size, dim]
|
82 |
+
# if we don't use the cros_attn, we pool so that we convert bytes rep to patch rep
|
83 |
+
if downsampling_by_pooling is not None and len(downsampling_by_pooling) > 0:
|
84 |
+
# By pooling
|
85 |
+
max_num_patches = num_patches
|
86 |
+
assert patch_ids is not None
|
87 |
+
h = pooling_downsample(h, max_num_patches, downsampling_by_pooling, patch_ids)
|
88 |
+
else:
|
89 |
+
# TODO: remove this condition
|
90 |
+
# By concatenating (fixed lengths patching)
|
91 |
+
assert patch_lengths is not None
|
92 |
+
h = concat_downsample(h, patch_lengths, patch_size)
|
93 |
+
return h
|
94 |
+
|
95 |
+
|
96 |
+
def causal_mask(b, h, q_idx, kv_idx):
|
97 |
+
return q_idx >= kv_idx
|
98 |
+
|
99 |
+
|
100 |
+
def create_causal_mask(seqlen, attn_impl, sliding_window):
|
101 |
+
if sliding_window is not None and attn_impl == "xformers":
|
102 |
+
return fmha.attn_bias.LocalAttentionFromBottomRightMask(
|
103 |
+
window_left=sliding_window - 1, window_right=0
|
104 |
+
)
|
105 |
+
elif attn_impl == "xformers":
|
106 |
+
return fmha.attn_bias.LowerTriangularMask()
|
107 |
+
elif attn_impl == "sdpa":
|
108 |
+
return "causal"
|
109 |
+
elif attn_impl == "flex_attention":
|
110 |
+
return create_block_mask(causal_mask, None, None, seqlen, seqlen)
|
111 |
+
elif attn_impl == "fmha":
|
112 |
+
return None
|
113 |
+
else:
|
114 |
+
raise NotImplementedError(
|
115 |
+
f"Attention {attn_impl} with {sliding_window} sliding window not implemented"
|
116 |
+
)
|