par-meta commited on
Commit
bcc039b
·
0 Parent(s):

Initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .github/workflows/black.yml +12 -0
  2. .github/workflows/isort.yml +10 -0
  3. .gitignore +168 -0
  4. .prettierrc +8 -0
  5. CODE_OF_CONDUCT.md +80 -0
  6. CONTRIBUTING.md +36 -0
  7. LICENSE +28 -0
  8. README.md +117 -0
  9. apps/__init__.py +0 -0
  10. apps/main/__init__.py +0 -0
  11. apps/main/configs/eval.yaml +35 -0
  12. apps/main/configs/llama_1B.yaml +87 -0
  13. apps/main/configs/llama_7B.yaml +95 -0
  14. apps/main/eval.py +354 -0
  15. apps/main/generate.py +463 -0
  16. apps/main/lingua_train.py +654 -0
  17. blt-figure.jpg +0 -0
  18. blt-figure.pdf +0 -0
  19. bytelatent/.DS_Store +0 -0
  20. bytelatent/__init__.py +3 -0
  21. bytelatent/args.py +199 -0
  22. bytelatent/base_transformer.py +585 -0
  23. bytelatent/checkpoint.py +311 -0
  24. bytelatent/configs/debug.yaml +110 -0
  25. bytelatent/constants.py +5 -0
  26. bytelatent/data/__init__.py +1 -0
  27. bytelatent/data/data_types.py +115 -0
  28. bytelatent/data/iterators/__init__.py +1 -0
  29. bytelatent/data/iterators/abstract_iterator.py +23 -0
  30. bytelatent/data/iterators/arrow_iterator.py +216 -0
  31. bytelatent/data/iterators/looping_iterator.py +36 -0
  32. bytelatent/data/iterators/multiprocess_iterator.py +243 -0
  33. bytelatent/data/iterators/packing_iterator.py +226 -0
  34. bytelatent/data/iterators/preprocess_iterator.py +111 -0
  35. bytelatent/data/iterators/sampling_iterator.py +66 -0
  36. bytelatent/data/iterators/sequence_iterator.py +122 -0
  37. bytelatent/data/iterators/test_arrow_iterator.py +89 -0
  38. bytelatent/data/iterators/test_iters.py +162 -0
  39. bytelatent/data/ngram_processor.py +146 -0
  40. bytelatent/data/patcher.py +609 -0
  41. bytelatent/distributed.py +478 -0
  42. bytelatent/entropy_model.py +36 -0
  43. bytelatent/float8.py +152 -0
  44. bytelatent/logger.py +129 -0
  45. bytelatent/metrics.py +232 -0
  46. bytelatent/model/__init__.py +1 -0
  47. bytelatent/model/blt.py +1064 -0
  48. bytelatent/model/local_models.py +356 -0
  49. bytelatent/model/transformer.py +199 -0
  50. bytelatent/model/utils.py +116 -0
.github/workflows/black.yml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Lint with Black
2
+
3
+ on: [push, pull_request]
4
+
5
+ jobs:
6
+ lint:
7
+ runs-on: ubuntu-latest
8
+ steps:
9
+ - uses: actions/checkout@v4
10
+ - uses: psf/black@stable
11
+ with:
12
+ version: "24.8.0"
.github/workflows/isort.yml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Lint with isort
2
+
3
+ on: [push, pull_request]
4
+
5
+ jobs:
6
+ lint:
7
+ runs-on: ubuntu-latest
8
+ steps:
9
+ - uses: actions/checkout@v4
10
+ - uses: isort/isort-action@master
.gitignore ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
+ .pdm.toml
111
+ .pdm-python
112
+ .pdm-build/
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ #.idea/
163
+ *.out
164
+
165
+ figures/
166
+ .vscode/
167
+ .DS_Store
168
+
.prettierrc ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "overrides": [
3
+ {
4
+ "files": "*.yaml",
5
+ "options": { "tabWidth": 2 }
6
+ }
7
+ ]
8
+ }
CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ In the interest of fostering an open and welcoming environment, we as
6
+ contributors and maintainers pledge to make participation in our project and
7
+ our community a harassment-free experience for everyone, regardless of age, body
8
+ size, disability, ethnicity, sex characteristics, gender identity and expression,
9
+ level of experience, education, socio-economic status, nationality, personal
10
+ appearance, race, religion, or sexual identity and orientation.
11
+
12
+ ## Our Standards
13
+
14
+ Examples of behavior that contributes to creating a positive environment
15
+ include:
16
+
17
+ * Using welcoming and inclusive language
18
+ * Being respectful of differing viewpoints and experiences
19
+ * Gracefully accepting constructive criticism
20
+ * Focusing on what is best for the community
21
+ * Showing empathy towards other community members
22
+
23
+ Examples of unacceptable behavior by participants include:
24
+
25
+ * The use of sexualized language or imagery and unwelcome sexual attention or
26
+ advances
27
+ * Trolling, insulting/derogatory comments, and personal or political attacks
28
+ * Public or private harassment
29
+ * Publishing others' private information, such as a physical or electronic
30
+ address, without explicit permission
31
+ * Other conduct which could reasonably be considered inappropriate in a
32
+ professional setting
33
+
34
+ ## Our Responsibilities
35
+
36
+ Project maintainers are responsible for clarifying the standards of acceptable
37
+ behavior and are expected to take appropriate and fair corrective action in
38
+ response to any instances of unacceptable behavior.
39
+
40
+ Project maintainers have the right and responsibility to remove, edit, or
41
+ reject comments, commits, code, wiki edits, issues, and other contributions
42
+ that are not aligned to this Code of Conduct, or to ban temporarily or
43
+ permanently any contributor for other behaviors that they deem inappropriate,
44
+ threatening, offensive, or harmful.
45
+
46
+ ## Scope
47
+
48
+ This Code of Conduct applies within all project spaces, and it also applies when
49
+ an individual is representing the project or its community in public spaces.
50
+ Examples of representing a project or community include using an official
51
+ project e-mail address, posting via an official social media account, or acting
52
+ as an appointed representative at an online or offline event. Representation of
53
+ a project may be further defined and clarified by project maintainers.
54
+
55
+ This Code of Conduct also applies outside the project spaces when there is a
56
+ reasonable belief that an individual's behavior may have a negative impact on
57
+ the project or its community.
58
+
59
+ ## Enforcement
60
+
61
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
62
+ reported by contacting the project team at <opensource-conduct@meta.com>. All
63
+ complaints will be reviewed and investigated and will result in a response that
64
+ is deemed necessary and appropriate to the circumstances. The project team is
65
+ obligated to maintain confidentiality with regard to the reporter of an incident.
66
+ Further details of specific enforcement policies may be posted separately.
67
+
68
+ Project maintainers who do not follow or enforce the Code of Conduct in good
69
+ faith may face temporary or permanent repercussions as determined by other
70
+ members of the project's leadership.
71
+
72
+ ## Attribution
73
+
74
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
75
+ available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
76
+
77
+ [homepage]: https://www.contributor-covenant.org
78
+
79
+ For answers to common questions about this code of conduct, see
80
+ https://www.contributor-covenant.org/faq
CONTRIBUTING.md ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to
2
+
3
+ We want to make contributing to this project as easy and transparent as
4
+ possible.
5
+
6
+ ## Pull Requests
7
+
8
+ We actively welcome your pull requests.
9
+
10
+ 1. Fork the repo and create your branch from `main`.
11
+ 2. If you've added code that should be tested, add tests.
12
+ 3. If you've changed APIs, update the documentation.
13
+ 4. Ensure the test suite passes.
14
+ 5. Make sure your code lints.
15
+ 6. If you haven't already, complete the Contributor License Agreement ("CLA").
16
+
17
+ ## Contributor License Agreement ("CLA")
18
+
19
+ In order to accept your pull request, we need you to submit a CLA. You only need
20
+ to do this once to work on any of Meta's open source projects.
21
+
22
+ Complete your CLA here: <https://code.facebook.com/cla>
23
+
24
+ ## Issues
25
+
26
+ We use GitHub issues to track public bugs. Please ensure your description is
27
+ clear and has sufficient instructions to be able to reproduce the issue.
28
+
29
+ Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe
30
+ disclosure of security bugs. In those cases, please go through the process
31
+ outlined on that page and do not file a public issue.
32
+
33
+ ## License
34
+
35
+ By contributing to BLT, you agree that your contributions will be licensed
36
+ under the LICENSE file in the root directory of this source tree.
LICENSE ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright 2024 Meta
4
+
5
+ Redistribution and use in source and binary forms, with or without modification,
6
+ are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice,this list
9
+ of conditions and the following disclaimer.
10
+
11
+ 2. Redistributions in binary form must reproduce the above copyright notice, this
12
+ list of conditions and the following disclaimer in the documentation
13
+ and/or other materials provided with the distribution.
14
+
15
+ 3. Neither the name of the copyright holder nor the names of its contributors may
16
+ be used to endorse or promote products derived from this software without specific
17
+ prior written permission.
18
+
19
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY
20
+ EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
21
+ OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT
22
+ SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
23
+ INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
24
+ TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
25
+ BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
26
+ CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
27
+ ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
28
+ DAMAGE.
README.md ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte Latent Transformer
2
+
3
+ This repository contains code for our paper: "Byte Latent Transformer: Patches Scale Better Than Tokens"
4
+
5
+ - [Paper Link](https://dl.fbaipublicfiles.com/blt/BLT__Patches_Scale_Better_Than_Tokens.pdf)
6
+
7
+ ## Abstract
8
+
9
+ We introduce the Byte Latent Transformer architecture (BLTs), a new byte-level LLM architecture that
10
+ for the first time, matches tokenization-based LLM performance at scale, with significant improvements
11
+ in inference efficiency and robustness. BLT encodes bytes into dynamically sized patches, which serve
12
+ as the primary units of computation. Patches are segmented dynamically based on the entropy of the
13
+ next byte, allocating more compute and model capacity where there is more data complexity. The BLT
14
+ architecture includes new attention mechanisms to maximize the information flow between byte and
15
+ patch hidden representations and a new type of byte-sequence memory. We present the first scaling
16
+ study of byte-level models up to 8B parameters and 8T training bytes, showing for the first time
17
+ that we can train a model end-to-end at scale from bytes with no tokenization or other preprocessing.
18
+ Scaling trends reveal training and inference efficiency benefits from dynamically selecting very long
19
+ patches on average, along with qualitative improvements with reasoning and long tail generalization
20
+ from modeling byte-sequences.
21
+
22
+ ![BLT Architecture Diagram](blt-figure.jpg)
23
+
24
+ ## Development Status
25
+
26
+ We are actively updating the blt code to make it easier to reproduce our results.
27
+ Please file an issue and/or be patient while we make more of our code public!
28
+
29
+ ## Quick start
30
+
31
+ The following commands launch a SLURM job that creates an environment for Meta Lingua.
32
+ The env creation should take around 5 minutes without counting downloads.
33
+
34
+ ```bash
35
+ git clone https://github.com/facebookresearch/blt
36
+ cd blt
37
+
38
+ bash setup/create_env.sh
39
+ # or if you have access to a SLURM cluster
40
+ sbatch setup/create_env.sh
41
+ ```
42
+
43
+ Once that is done your can activate the environment
44
+
45
+ ```bash
46
+ conda activate blt_<date>
47
+ ```
48
+
49
+ use the provided script to download and prepare data from huggingface (among `fineweb_edu`, `fineweb_edu_10bt`, or `dclm_baseline_1.0`).
50
+ This command will download the `fineweb_edu` and prepare it for training in the `./data` directory, specifying the amount of memory `terashuf` (the tool used to shuffle samples) will be allocated. By default, the number of chunks (`nchunks`) is 32. If you are running on fewer than 32 GPUs, it is recommended to set `nchunks` to 1 or to match `nchunks` with the number of GPUs (`nchunks` = NGPUs). See [here](https://github.com/facebookresearch/lingua/issues/55#issuecomment-2483643076) for more details.
51
+
52
+ ```bash
53
+ python setup/download_prepare_hf_data.py fineweb_edu <MEMORY> --data_dir ./data --seed 42 --nchunks <NCHUNKS>
54
+ ```
55
+
56
+ to download tokenizer (here llama3), use the folowing script:
57
+
58
+ ```bash
59
+ python setup/download_tokenizer.py llama3 <SAVE_PATH> --api_key <HUGGINGFACE_TOKEN>
60
+ ```
61
+
62
+ Now launch a debug job to check if everything works. **The provided configurations are templates, you need to adapt them for them to work (change `dump_dir`, `data.root_dir`, `data.tokenizer.path`, etc ...)**
63
+
64
+ ```bash
65
+ # stool stands for SLURM tool !
66
+ python -m bytelatent.stool script=bytelatent.train config=apps/bytelatent/configs/debug.yaml nodes=1 partition=<partition>
67
+ # if you want to launch locally you can use torchrun
68
+ torchrun --nproc-per-node 8 -m bytelatent.train config=apps/bytelatent/configs/debug.yaml
69
+ # or you can also launch on 1 GPU
70
+ python -m bytelatent.train config=apps/bytelatent/configs/debug.yaml
71
+ ```
72
+
73
+ When using `stool`, if a job crashes, it can be relaunched using sbatch:
74
+
75
+ ```bash
76
+ sbatch path/to/dump_dir/submit.slurm
77
+ ```
78
+
79
+ ## Linting
80
+
81
+ To lint, run the following command
82
+
83
+ ```
84
+ bash dev/lint.sh
85
+ ```
86
+
87
+ ## Citation
88
+
89
+ The BLT is partially based on Meta Lingua, so consider citing it in addition to our BLT paper if you re-use our work.
90
+
91
+ BLT Paper Citation (will be updated to arXiv soon)
92
+
93
+ ```
94
+ @article{meta_blt,
95
+ author = {Artidoro Pagnoni, Ram Pasunuru, Pedro Rodriguez, John Nguyen, Benjamin Muller, Margaret Li, Chunting Zhou, Lili Yu, Jason Weston, Luke Zettlemoyer, Gargi Ghosh, Mike Lewis, Ari Holtzman†, Srinivasan Iyer},
96
+ title = {Byte Latent Transformer: Patches Scale Better Than Tokens},
97
+ url = {https://github.com/facebookresearch/blt},
98
+ year = {2024}
99
+ }
100
+ ```
101
+
102
+ Lingua Code
103
+
104
+ ```
105
+ @misc{meta_lingua,
106
+ author = {Mathurin Videau, Badr Youbi Idrissi, Daniel Haziza, Luca Wehrstedt, Jade Copet, Olivier Teytaud, David Lopez-Paz},
107
+ title = {{Meta Lingua}: A minimal {PyTorch LLM} training library},
108
+ url = {https://github.com/facebookresearch/lingua},
109
+ year = {2024}
110
+ }
111
+ ```
112
+
113
+ ## License
114
+
115
+ The BLT code is partially based on Meta Lingia.
116
+
117
+ Meta Lingua is licensed under BSD-3-Clause license. Refer to the LICENSE file in the top level directory.
apps/__init__.py ADDED
File without changes
apps/main/__init__.py ADDED
File without changes
apps/main/configs/eval.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: "debug_evals"
2
+ # ckpt_dir: !!CHANGETHIS!!
3
+ # dump_dir: !!CHANGETHIS!!
4
+ generator:
5
+ max_tokens: 8192
6
+ dtype: bf16
7
+ temperature: 1.0
8
+ top_p: 0.95
9
+ harness:
10
+ tasks:
11
+ - hellaswag
12
+ - task: boolq
13
+ dataset_kwargs:
14
+ trust_remote_code: true
15
+ - task: nq_open
16
+ num_fewshot: 5
17
+ - piqa
18
+ - task: social_iqa
19
+ dataset_kwargs:
20
+ trust_remote_code: true
21
+ - triviaqa
22
+ - winogrande
23
+ - openbookqa
24
+ - arc_easy
25
+ - arc_challenge
26
+ - race
27
+ - commonsense_qa
28
+ # - coqa
29
+ - copa
30
+ - gsm8k
31
+ - bbh
32
+ - mmlu
33
+ - mmlu_pro
34
+ validation:
35
+ max_steps: 1000
apps/main/configs/llama_1B.yaml ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dump_dir: !!!CHANGE_THIS!!!
2
+ name: large_lm
3
+ steps: 60_000
4
+ probe_freq: null
5
+ seed: 777
6
+
7
+ optim:
8
+ lr: 3e-3
9
+ weight_decay: 0.033
10
+ warmup: 5000
11
+ lr_min_ratio: 0.000001
12
+ clip: 1.0
13
+
14
+ distributed:
15
+ fsdp_type: full_shard
16
+ compile: true
17
+ model_dtype: bf16
18
+ matmul_allow_tf32: false
19
+ selective_activation_checkpointing: false
20
+ tp_size: 1
21
+
22
+ model:
23
+ dim: 2048
24
+ n_layers: 25
25
+ n_heads: 16
26
+
27
+ data:
28
+ root_dir: data/shuffled
29
+ sources:
30
+ dclm_baseline_1.0: 100.0
31
+ batch_size: 4
32
+ prefetch_size: 1024
33
+ seq_len: 4096
34
+ n_views: 2
35
+ load_async: true
36
+ add_bos: true
37
+ add_eos: true
38
+ tokenizer:
39
+ name: tiktoken
40
+ path: tokenizers/cl_toplang_128k.tiktoken
41
+
42
+ profiling:
43
+ run: true
44
+ mem_warmup: 0
45
+ mem_steps: 4
46
+ profile_warmup: 100
47
+ profile_steps: 4
48
+
49
+ checkpoint:
50
+ dump:
51
+ every: 2500
52
+ keep: 3
53
+ eval:
54
+ every: 5000
55
+ keep: -1
56
+
57
+ logging:
58
+ freq: 1
59
+
60
+ async_eval_gpus: 8
61
+ eval:
62
+ harness:
63
+ tasks:
64
+ - hellaswag
65
+ - task: boolq
66
+ dataset_kwargs:
67
+ trust_remote_code: true
68
+ - piqa
69
+ - task: social_iqa
70
+ dataset_kwargs:
71
+ trust_remote_code: true
72
+ - winogrande
73
+ - openbookqa
74
+ - arc_easy
75
+ - arc_challenge
76
+ - race
77
+ - commonsense_qa
78
+ - copa
79
+ # - coqa
80
+ # - task: nq_open
81
+ # num_fewshot: 5
82
+ # - triviaqa
83
+ validation:
84
+ max_steps: 1000
85
+ generator:
86
+ max_tokens: 16384
87
+ dtype: bf16
apps/main/configs/llama_7B.yaml ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #python -m lingua.stool config=apps/main/configs/llama2_7B.yaml nodes=32 account=fair_amaia_cw_codegen qos=lowest
2
+ # dump_dir: !!!CHANGE_THIS!!!
3
+ name: "7b_baseline"
4
+ steps: 100_000
5
+ grad_acc_steps: 1
6
+ probe_freq: 100
7
+
8
+ seed: 777
9
+ optim:
10
+ lr: 1.0e-3
11
+ weight_decay: 0.1
12
+ warmup: 2000
13
+ lr_min_ratio: 0.000001
14
+ clip: 1.0
15
+
16
+ distributed:
17
+ fsdp_type: full_shard
18
+ compile: true
19
+ model_dtype: bf16
20
+ matmul_allow_tf32: false
21
+ selective_activation_checkpointing: false
22
+ tp_size: 1
23
+
24
+ model:
25
+ dim: 4096
26
+ n_layers: 32
27
+ n_heads: 32
28
+ rope_theta: 100_000
29
+ ffn_dim_multiplier: 1.0
30
+ multiple_of: 256
31
+
32
+ data:
33
+ root_dir: data/shuffled
34
+ sources:
35
+ dclm_baseline_1.0: 1.0
36
+ batch_size: 2
37
+ prefetch_size: 1024
38
+ seq_len: 4096
39
+ n_views: 2
40
+ load_async: true
41
+ tokenizer:
42
+ name: tiktoken
43
+ path: tokenizers/cl_toplang_128k.tiktoken
44
+
45
+ profiling:
46
+ run: true
47
+ mem_warmup: 0
48
+ mem_steps: 4
49
+ profile_warmup: 100
50
+ profile_steps: 4
51
+
52
+ checkpoint:
53
+ dump:
54
+ every: 10000
55
+ keep: -1
56
+ eval:
57
+ every: 1000
58
+ keep: 3
59
+
60
+ logging:
61
+ freq: 1
62
+
63
+ async_eval_gpus: 8
64
+ eval:
65
+ dataset_dir: datasets/eval
66
+ harness:
67
+ tasks:
68
+ - hellaswag
69
+ - task: boolq
70
+ dataset_kwargs:
71
+ trust_remote_code: true
72
+ - piqa
73
+ - task: social_iqa
74
+ dataset_kwargs:
75
+ trust_remote_code: true
76
+ - winogrande
77
+ - openbookqa
78
+ - arc_easy
79
+ - arc_challenge
80
+ - race
81
+ - commonsense_qa
82
+ # - coqa
83
+ - copa
84
+ - mmlu
85
+ - mmlu_pro
86
+ # - task: nq_open
87
+ # num_fewshot: 5
88
+ # - triviaqa
89
+ # - gsm8k
90
+ # - bbh
91
+ validation:
92
+ max_steps: 1000
93
+ generator:
94
+ max_tokens: 8192
95
+ dtype: bf16
apps/main/eval.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import json
4
+ import logging
5
+ import os
6
+ from collections import defaultdict
7
+ from dataclasses import asdict, dataclass, field
8
+ from datetime import datetime
9
+ from pathlib import Path
10
+ from typing import Any, List, Optional, Tuple, Union
11
+
12
+ import torch
13
+ from lingua.args import dump_config
14
+ from lingua.data import init_choice_state, setup_sources
15
+ from lm_eval import simple_evaluate
16
+ from lm_eval.api.instance import Instance
17
+ from lm_eval.api.model import LM
18
+ from omegaconf import OmegaConf
19
+
20
+ from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints
21
+ from bytelatent.distributed import (
22
+ DistributedArgs,
23
+ dist_mean_dict,
24
+ get_global_rank,
25
+ get_world_size,
26
+ setup_torch_distributed,
27
+ )
28
+ from bytelatent.transformer import LMTransformer, LMTransformerArgs
29
+
30
+ from apps.main.generate import (
31
+ PackedCausalTransformerGenerator,
32
+ PackedCausalTransformerGeneratorArgs,
33
+ load_consolidated_model_and_tokenizer,
34
+ )
35
+
36
+ EVAL_FOLDER_NAME = "{:010d}"
37
+
38
+ logger = logging.getLogger()
39
+
40
+
41
+ @dataclass
42
+ class LMHarnessArgs:
43
+ tasks: Optional[List[Any]] = None
44
+ num_fewshot: Optional[int] = None
45
+ device: Optional[str] = None
46
+ use_cache: Optional[str] = None
47
+ cache_requests: bool = False
48
+ rewrite_requests_cache: bool = False
49
+ delete_requests_cache: bool = False
50
+ limit: Optional[Union[int, float]] = None
51
+ bootstrap_iters: int = 100000
52
+ check_integrity: bool = False
53
+ write_out: bool = False
54
+ log_samples: bool = True
55
+ system_instruction: Optional[str] = None
56
+ apply_chat_template: Union[bool, str] = False
57
+ fewshot_as_multiturn: bool = False
58
+ gen_kwargs: Optional[str] = None
59
+ verbosity: str = "INFO"
60
+ predict_only: bool = False
61
+ random_seed: int = 0
62
+ numpy_random_seed: int = 1234
63
+ torch_random_seed: int = 1234
64
+ fewshot_random_seed: int = 1234
65
+
66
+
67
+ @dataclass
68
+ class ValidationArgs:
69
+ max_steps: Optional[int] = (
70
+ None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu)
71
+ )
72
+ use_val_from_train_src: bool = True # Use the validation set from training sources
73
+ root_dir: str = ""
74
+ sources: List[str] = field(default_factory=list) # Other sources to eval on
75
+
76
+
77
+ @dataclass
78
+ class EvalArgs:
79
+ name: str = "evals"
80
+ dump_dir: Optional[str] = None
81
+ metric_log_dir: Optional[str] = None
82
+ ckpt_dir: str = ""
83
+ generator: PackedCausalTransformerGeneratorArgs = field(
84
+ default_factory=PackedCausalTransformerGeneratorArgs
85
+ )
86
+ harness: Optional[LMHarnessArgs] = field(default_factory=LMHarnessArgs)
87
+ validation: Optional[ValidationArgs] = field(default_factory=ValidationArgs)
88
+
89
+ wandb: Optional[Any] = None
90
+
91
+ global_step: Optional[int] = None # for in-training evaluation
92
+
93
+
94
+ def all_dicts_same(dict_list):
95
+ if not dict_list: # Check if the list is empty
96
+ return True
97
+
98
+ # Compare each dictionary to the first one
99
+ first_dict = dict_list[0]
100
+ return all(d == first_dict for d in dict_list)
101
+
102
+
103
+ class MockAccelerator:
104
+ def gather(self, tensor):
105
+ l = [torch.zeros_like(tensor) for _ in range(get_world_size())]
106
+ torch.distributed.all_gather(l, tensor)
107
+ return torch.stack(l)
108
+
109
+ def wait_for_everyone(self):
110
+ torch.distributed.barrier()
111
+
112
+
113
+ # Light wrapper around generator for lm-eval harness
114
+ class EvalHarnessLM(LM):
115
+ def __init__(self, generator):
116
+ super().__init__()
117
+ self.generator = generator
118
+ self.accelerator = MockAccelerator()
119
+ self._rank = get_global_rank()
120
+ self._world_size = get_world_size()
121
+ self.device = generator.device
122
+
123
+ def generate_until(self, requests: List[Instance]) -> List[str]:
124
+ prompts, gen_args = zip(*[req.args for req in requests])
125
+ assert all_dicts_same(gen_args), "Doesn't support different gen args for now"
126
+ gen_args = gen_args[0]
127
+ temperature = gen_args.get("temperature", 0.0)
128
+ top_p = gen_args.get("top_p", None)
129
+ top_k = gen_args.get("top_k", None)
130
+ until = gen_args.get("until", [])
131
+
132
+ self.generator.temperature = temperature
133
+ self.generator.top_p = top_p
134
+ self.generator.top_k = top_k
135
+ self.generator.until = until
136
+ generations, _, _ = self.generator.generate(prompts)
137
+ filtered_gen = []
138
+ for g in generations:
139
+ for e in until:
140
+ g = g.replace(e, "")
141
+ filtered_gen.append(g)
142
+ return filtered_gen
143
+
144
+ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
145
+ prompts, continuations = zip(*[req.args for req in requests])
146
+ inputs = [req.args[0] + req.args[1] for req in requests]
147
+ max_gen_len = self.generator.max_gen_len
148
+ # We temporarily lower max gen len
149
+ self.generator.max_gen_len = 1
150
+ _, lls, greedy = self.generator.generate(inputs)
151
+ results = []
152
+ for p, ll, gr in zip(prompts, lls, greedy):
153
+ p_len = len(
154
+ self.generator.tokenizer.encode(p, add_bos=False, add_eos=False)
155
+ )
156
+ results.append((ll[p_len:].sum().item(), gr[p_len:].all().item()))
157
+
158
+ self.generator.max_gen_len = max_gen_len
159
+ return results
160
+
161
+ def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
162
+ prompts = [req.args[0] for req in requests]
163
+ max_gen_len = self.generator.max_gen_len
164
+ # We temporarily lower max gen len
165
+ self.generator.max_gen_len = 1
166
+ _, lls, _ = self.generator.generate(prompts)
167
+ results = []
168
+ for ll in lls:
169
+ results.append((ll.sum().item(),))
170
+ self.generator.max_gen_len = max_gen_len
171
+
172
+ return results
173
+
174
+
175
+ def eval_on_val(generator, val_args: ValidationArgs, train_cfg):
176
+ srcs = {}
177
+ for src in val_args.sources:
178
+ path = os.path.join(val_args.root_dir, src)
179
+ srcs[path] = 1.0
180
+ for src in train_cfg.data.sources:
181
+ path = os.path.join(train_cfg.data.root_dir, src)
182
+ srcs[path] = 1.0
183
+
184
+ multi_state = init_choice_state(
185
+ "", srcs, 0, get_global_rank(), get_world_size(), "*.val.jsonl"
186
+ )
187
+ path_to_iter = setup_sources(multi_state)
188
+
189
+ max_gen_len = generator.max_gen_len
190
+ # We temporarily lower max gen len
191
+ generator.max_gen_len = 1
192
+
193
+ all_val_metrics = {}
194
+ for src in path_to_iter:
195
+ jsonl_iterator = path_to_iter[src]
196
+ texts = []
197
+ logger.info(f"Running validation on {src}...")
198
+ for step, (content, state) in enumerate(jsonl_iterator):
199
+ if state["current_iter"] > 0 or (
200
+ val_args.max_steps is not None and step >= val_args.max_steps
201
+ ):
202
+ break
203
+ content_key = "text" if ("text" in content) else "content"
204
+ texts.append(content[content_key])
205
+
206
+ _, loglikelihood, _ = generator.generate(texts)
207
+
208
+ metrics = defaultdict(list)
209
+ for i, ll in enumerate(loglikelihood):
210
+ tmp = ll.sum().item()
211
+ metrics["nll"].append(tmp)
212
+ metrics["nll_per_token"].append(tmp / len(ll))
213
+ metrics["nll_per_char"].append(tmp / len(texts[i]))
214
+
215
+ metrics["avg_seqlen"].append(len(ll))
216
+
217
+ for m in metrics:
218
+ metrics[m] = sum(metrics[m]) / len(metrics[m])
219
+ metrics.update(dist_mean_dict(metrics))
220
+ logger.info(f"Validation on {src} done. Metrics: {metrics}")
221
+
222
+ name = os.path.basename(src)
223
+ if name in all_val_metrics:
224
+ logger.warning(
225
+ f"Duplicate source name {name}, path {src} in validation sources, renaming to {name}_1"
226
+ )
227
+ name = f"{name}_1"
228
+ all_val_metrics[name] = metrics
229
+
230
+ generator.max_gen_len = max_gen_len
231
+
232
+ return all_val_metrics
233
+
234
+
235
+ def launch_eval(cfg: EvalArgs):
236
+ if not torch.distributed.is_initialized():
237
+ setup_torch_distributed(DistributedArgs())
238
+ if (
239
+ Path(cfg.ckpt_dir).exists()
240
+ and (Path(cfg.ckpt_dir) / "params.json").exists()
241
+ and next(Path(cfg.ckpt_dir).glob("*.pth"), None) is not None
242
+ ):
243
+ consolidate_path = Path(cfg.ckpt_dir)
244
+ else:
245
+ consolidate_path = Path(cfg.ckpt_dir) / CONSOLIDATE_FOLDER
246
+ if not consolidate_path.exists() and get_global_rank() == 0:
247
+ consolidate_path = consolidate_checkpoints(cfg.ckpt_dir)
248
+
249
+ Path(cfg.dump_dir).mkdir(parents=True, exist_ok=True)
250
+ dump_config(cfg, Path(cfg.dump_dir) / "config.yaml", log_config=False)
251
+
252
+ consolidate_path = str(consolidate_path)
253
+ torch.distributed.barrier()
254
+ logger.info("Loading model")
255
+ model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
256
+ consolidate_path,
257
+ model_cls=LMTransformer,
258
+ model_args_cls=LMTransformerArgs,
259
+ )
260
+ logger.info("Model loaded")
261
+ model.eval()
262
+ generator = PackedCausalTransformerGenerator(cfg.generator, model, tokenizer)
263
+
264
+ wrap = EvalHarnessLM(generator)
265
+ results = simple_evaluate(wrap, **asdict(cfg.harness))
266
+ val_results = None
267
+ if cfg.validation:
268
+ val_results = eval_on_val(generator, cfg.validation, train_cfg)
269
+ if get_global_rank() == 0:
270
+ with open(Path(cfg.dump_dir) / "results.json", "w") as f:
271
+ f.write(json.dumps(results))
272
+ logger.info(f"All evaluation results: {results['results']}")
273
+ if val_results is not None:
274
+ with open(Path(cfg.dump_dir) / "validation.json", "w") as f:
275
+ f.write(json.dumps(val_results))
276
+ logger.info(f"All validation results: {val_results}")
277
+ if cfg.metric_log_dir and get_global_rank() == 0:
278
+ metric_log_path = Path(cfg.metric_log_dir) / "metrics.eval.jsonl"
279
+
280
+ logger.info(f"Writing metric logs to {metric_log_path}")
281
+ timestamp = {
282
+ "created_at": datetime.utcnow().isoformat(),
283
+ }
284
+ if cfg.global_step is not None:
285
+ timestamp["global_step"] = cfg.global_step
286
+ print(
287
+ json.dumps(timestamp | results["results"]),
288
+ file=open(metric_log_path, mode="a"),
289
+ flush=True,
290
+ )
291
+
292
+ val_log_path = Path(cfg.metric_log_dir) / "metrics.validation.jsonl"
293
+ if val_results is not None:
294
+ print(
295
+ json.dumps(timestamp | val_results),
296
+ file=open(val_log_path, mode="a"),
297
+ flush=True,
298
+ )
299
+
300
+ del generator
301
+
302
+
303
+ def main():
304
+ """
305
+ The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments
306
+ This accepts arguments as a dot list
307
+ So if the dataclass looks like
308
+
309
+ @dataclass
310
+ class DummyArgs:
311
+ name: str
312
+ model: LMTransformerArgsgs
313
+
314
+ @dataclass
315
+ class LMTransformerArgsgs:
316
+ dim: int
317
+
318
+ Then you can pass model.dim=32 to change values in LMTransformerArgsgs
319
+ or just name=tictac for top level attributes.
320
+
321
+ The behavior here is as follows:
322
+ 1. We instantiate EvalArgs with its default values
323
+ 2. We override those default values with the ones in the provided config file
324
+ 3. We override the result with the additional arguments provided through command line
325
+
326
+ For example, if the config is the following
327
+
328
+ model:
329
+ dim: 128
330
+ n_layers: 4
331
+
332
+ and you call eval.py with eval.py model.dim=64
333
+
334
+ Then the final TrainArgs will have
335
+
336
+ model:
337
+ dim: 64
338
+ n_layers: 4
339
+
340
+ Plus all the default values in EvalArgs dataclass.
341
+ """
342
+ cli_args = OmegaConf.from_cli()
343
+ file_cfg = OmegaConf.load(cli_args.config)
344
+ # We remove 'config' attribute from config as the underlying DataClass does not have it
345
+ del cli_args.config
346
+
347
+ default_cfg = OmegaConf.structured(EvalArgs())
348
+ cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
349
+ cfg = OmegaConf.to_object(cfg)
350
+ launch_eval(cfg)
351
+
352
+
353
+ if __name__ == "__main__":
354
+ main()
apps/main/generate.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import time
4
+ from dataclasses import dataclass, field
5
+ from pathlib import Path
6
+ from typing import List, Optional
7
+
8
+ import torch
9
+ from lingua.args import dataclass_from_dict
10
+ from lingua.tokenizers.abstract_tokenizer import Tokenizer
11
+ from lingua.tokenizers.build_tokenizer import build_tokenizer
12
+ from omegaconf import OmegaConf
13
+ from torch import nn
14
+ from torch.nn import functional as F
15
+ from torch.nn.attention.flex_attention import create_block_mask
16
+ from tqdm import tqdm
17
+
18
+ from bytelatent.base_transformer import (
19
+ Attention,
20
+ causal_mask,
21
+ generate_doc_mask_mod,
22
+ lengths_to_local_ids,
23
+ lengths_to_start_ids,
24
+ )
25
+ from bytelatent.checkpoint import CONSOLIDATE_NAME
26
+ from bytelatent.transformer import LMTransformer, LMTransformerArgs
27
+
28
+
29
+ def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
30
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
31
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
32
+ mask = probs_sum - probs_sort > p
33
+ probs_sort[mask] = 0.0
34
+ next_token = torch.multinomial(probs_sort, num_samples=1)
35
+ next_token = torch.gather(probs_idx, -1, next_token)
36
+ return next_token
37
+
38
+
39
+ def sample_top_k(probs, k):
40
+ topk_value, _ = torch.topk(probs, k) # batch_sz x topk
41
+ min_value_top_k = topk_value[:, [-1]]
42
+ probs[probs < min_value_top_k] = 0.0
43
+ probs.div_(probs.sum(dim=-1, keepdim=True))
44
+ next_token = torch.multinomial(probs, num_samples=1)
45
+ return next_token
46
+
47
+
48
+ def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None):
49
+ shape = logits.shape
50
+ logits = logits.flatten(end_dim=-2)
51
+ if temperature > 0.0:
52
+ probs = torch.softmax(logits / temperature, dim=-1)
53
+
54
+ if top_p is not None:
55
+ next_token = sample_top_p(probs, top_p)
56
+ elif top_k is not None:
57
+ next_token = sample_top_k(probs, top_k)
58
+ else:
59
+ next_token = torch.multinomial(probs, num_samples=1)
60
+ else:
61
+ next_token = torch.argmax(logits, dim=-1)
62
+ return next_token.view(shape[:-1])
63
+
64
+
65
+ def pack_prompts(prompts: List[int]):
66
+ res = []
67
+ lengths = []
68
+ for i, p in enumerate(prompts):
69
+ p = torch.tensor(p, dtype=torch.long)
70
+ l = p.size(0)
71
+ res.append(p)
72
+ lengths.append(l)
73
+ lengths = torch.tensor(lengths, dtype=torch.long)
74
+ res = torch.cat(res)
75
+ return res, lengths
76
+
77
+
78
+ def batch_prompts(prompts, max_elements, lengths=None):
79
+ batches = []
80
+ current_batch = []
81
+ current_count = 0
82
+
83
+ for i in range(len(prompts)):
84
+ prt = prompts[i]
85
+ prompt_size = len(prt) if lengths is None else lengths[i]
86
+ if current_count + prompt_size <= max_elements:
87
+ current_batch.append(prt)
88
+ current_count += prompt_size
89
+ else:
90
+ if current_batch: # Add the current batch to batches
91
+ batches.append(current_batch)
92
+ # Start a new batch with the current prompt
93
+ current_batch = [prt]
94
+ current_count = prompt_size
95
+
96
+ # Add the last batch if it contains any prompts
97
+ if current_batch:
98
+ batches.append(current_batch)
99
+
100
+ return batches
101
+
102
+
103
+ class KVCache(nn.Module):
104
+ def __init__(self, bsz, seqlen, n_heads, head_dim, dtype, device):
105
+ super().__init__()
106
+ shape = (bsz, seqlen, n_heads, head_dim)
107
+ self.register_buffer("k_cache", torch.zeros(shape, dtype=dtype, device=device))
108
+ self.register_buffer("v_cache", torch.zeros(shape, dtype=dtype, device=device))
109
+ self.offset = 0
110
+
111
+ def reset(self):
112
+ self.k_cache.zero_()
113
+ self.v_cache.zero_()
114
+ self.offset = 0
115
+
116
+ def update(self, k_val, v_val, tok_idx):
117
+ # input_pos: [B], k_val: [B, S, H, D]
118
+ self.k_cache.index_copy_(1, self.offset + tok_idx, k_val)
119
+ self.v_cache.index_copy_(1, self.offset + tok_idx, v_val)
120
+ return self.k_cache, self.v_cache
121
+
122
+
123
+ @dataclass
124
+ class PackedCausalTransformerGeneratorArgs:
125
+ temperature: float = 0.0
126
+ top_p: Optional[float] = None
127
+ top_k: Optional[float] = None
128
+ max_gen_len: int = 512 # Maximum number of tokens to generate
129
+ max_tokens: int = 1024 # Maximum number of tokens that can go through the model
130
+ max_prompt_len: Optional[int] = None
131
+ until: List[str] = field(default_factory=list)
132
+ compile_prefilling: bool = False
133
+ reduce_generation_overhead: bool = False
134
+ show_progress: bool = False
135
+ dtype: Optional[str] = "bf16"
136
+ device: Optional[str] = "cuda"
137
+
138
+
139
+ class PackedCausalTransformerGenerator:
140
+ def __init__(
141
+ self,
142
+ cfg: PackedCausalTransformerGeneratorArgs,
143
+ model: nn.Module,
144
+ tokenizer: Tokenizer,
145
+ ):
146
+ """
147
+ This class wraps a causal transformer model with its corresponding tokenizer
148
+ and provides an efficient way to pack prompts together and do generation on
149
+ the packed sequence.
150
+
151
+ For example, if we had the prompts "Hello, I am a " and "Initiating calibration "
152
+ Then this class will concatenate those sequence (pack them together)
153
+ "Hello, I am a Initiating calibration"
154
+ And make the necessary attention masks such that a sequence only attends to itself
155
+ during prefilling and generation.
156
+
157
+ This class creates a fixed size cache of size max_tokens or sum of prompt sizes
158
+ + the max number of generated tokens per sequence.
159
+ """
160
+ self.model = model
161
+ self.tokenizer = tokenizer
162
+ self.temperature = cfg.temperature
163
+ self.top_p = cfg.top_p
164
+ self.top_k = cfg.top_k
165
+
166
+ self.max_gen_len = cfg.max_gen_len
167
+ self.max_tokens = cfg.max_tokens
168
+ self.max_prompt_len = cfg.max_prompt_len
169
+ self.until = cfg.until
170
+ self.max_until_size = max([len(e) for e in self.until]) if self.until else 1
171
+ self.device = cfg.device
172
+
173
+ # Compile if necessary
174
+ self.prefill = torch.compile(self.prefill, disable=not cfg.compile_prefilling)
175
+ self.generate_next_token = torch.compile(
176
+ self.generate_next_token,
177
+ mode="reduce-overhead",
178
+ disable=not cfg.reduce_generation_overhead,
179
+ )
180
+
181
+ self.show_progress = cfg.show_progress
182
+ self.dtype = dict(fp32=torch.float32, bf16=torch.bfloat16)[cfg.dtype]
183
+
184
+ self.prefill_doc_id, self.prefill_tok_id = None, None
185
+ self.padded_doc_id, self.padded_tok_id = None, None
186
+ self.current_doc_id, self.current_tok_id = None, None
187
+ self.padded_doc_start = None
188
+ self.prefill_mask = None
189
+
190
+ def clear_cache(self, offset):
191
+ for module in self.model.modules():
192
+ if isinstance(module, Attention):
193
+ if not hasattr(module, "kv_cache"):
194
+ module.kv_cache = KVCache(
195
+ 1,
196
+ self.max_tokens,
197
+ module.n_kv_heads,
198
+ module.head_dim,
199
+ self.dtype,
200
+ self.device,
201
+ )
202
+ module.kv_cache.offset = offset
203
+
204
+ @torch.compiler.disable
205
+ def setup_prefilling(self, lengths: torch.Tensor):
206
+ # The KV cache is a fixed size tensor of size max_tokens that we need
207
+ # to update in order to do correct autoregressive generation.
208
+
209
+ # Here we will generate token by token but on multiple sequences
210
+ # at once. To do so, we need to have an attention mask that makes
211
+ # each sequence independent.
212
+
213
+ # Each sequence will write to its allocated space in the KV Cache.
214
+ # We allocate len(seq) + max_gen_len to each sequence in the cache.
215
+
216
+ # We will generate max_gen_len for each document
217
+ padded_lengths = lengths + self.max_gen_len
218
+ max_tokens = self.max_tokens or padded_lengths.sum().item()
219
+ # The last document might have more padding to fill up to max_tokens
220
+ padded_lengths[-1] += max_tokens - padded_lengths.sum()
221
+
222
+ # This is the start index in the cache for each document
223
+ self.padded_doc_start = lengths_to_start_ids(padded_lengths)
224
+ # For example with ab--123--cdef--
225
+ # this would be 0, 4, 9 if max_gen_len is 2
226
+
227
+ # We repeat interleave to align with tokens for prefilling
228
+ # Ex: ab--123--cdef--
229
+ # 000044444999999
230
+ prefill_offset = torch.repeat_interleave(self.padded_doc_start, lengths)
231
+ # This offset will make sure the tokens are written to the
232
+ # correct positions in the cache during prefilling
233
+
234
+ # We either init the cache or clear it by resetting the offset to prefill_offset
235
+ self.clear_cache(prefill_offset)
236
+
237
+ # The prefilling mask looks like the following for
238
+ # the two packed sequences ab and 123 : ab123
239
+ # Where spaces are empty cache positions
240
+ # keys
241
+ # ab---123---
242
+ # queries a 10000000000
243
+ # b 11000000000
244
+ # 1 00000100000
245
+ # 2 00000110000
246
+ # 3 00000111000
247
+ # We make sure to skip the empty cache positions
248
+ # and only attend to positions within the same sequence
249
+ doc_mask_mod = generate_doc_mask_mod(causal_mask, lengths, padded_lengths)
250
+ self.prefill_mask = create_block_mask(
251
+ doc_mask_mod, 1, None, lengths.sum(), max_tokens
252
+ )
253
+
254
+ # This creates the prefilling token ids which look like
255
+ # the following for the packed sequence abcdefg1234
256
+ # abcdefg1234
257
+ # 01234560123
258
+ # The token id gives us the position within each sequence
259
+ # This is used to compute ROPE and to update the cache
260
+ # At each forward pass the current tokens are written to
261
+ # offset + tok_id
262
+ self.prefill_doc_id, self.prefill_tok_id = lengths_to_local_ids(lengths)
263
+
264
+ # This creates the padded token and document ids
265
+ # which look like the following for the packed sequence ab123
266
+ # ab---123--- ab---123---
267
+ # padded_doc_id 00000111111 padded_tok_id 01234012345
268
+ # This will later be useful for the attention mask at generation
269
+ self.padded_doc_id, self.padded_tok_id = lengths_to_local_ids(padded_lengths)
270
+
271
+ @torch.compiler.disable
272
+ def setup_generation(self, lengths):
273
+ # KV Cache offset is set to the start of the padded documents
274
+ for module in self.model.modules():
275
+ if isinstance(module, Attention):
276
+ module.kv_cache.offset = self.padded_doc_start
277
+ # The token ids during generations correspond to the lengths of each doc
278
+ # current_tok_id will be incremented during generation
279
+ self.current_tok_id = lengths.clone()
280
+ # Since we're generating one token per document
281
+ # the document id is just an arange
282
+ self.current_doc_id = torch.arange(lengths.size(0), device=lengths.device)
283
+
284
+ # From here on some methods for generation
285
+ def prefill(self, tokens: torch.Tensor, lengths: torch.Tensor):
286
+ # Prefilling is done by taking multiple packed sequences and
287
+ # doing block diagonal attention on them so they remain independent
288
+ self.setup_prefilling(lengths=lengths)
289
+ prefill_out = self.model.forward(
290
+ tokens,
291
+ tok_idx=self.prefill_tok_id,
292
+ mask=self.prefill_mask,
293
+ attn_impl="flex_attention",
294
+ )
295
+ self.setup_generation(lengths=lengths)
296
+ return prefill_out
297
+
298
+ def generate_next_token(self, current_token):
299
+ # Since we're doing generation with multiple sequences at once
300
+ # we need to ignore tokens and cache entries from other sequences
301
+ # or in the future.
302
+ # Example mask :
303
+ # keys
304
+ # abc--1234--
305
+ # queries c 11100000000
306
+ # 4 00000111100
307
+
308
+ # mask shape : (n_seqs, cache_size)
309
+ doc_mask = self.current_doc_id.unsqueeze(1) == self.padded_doc_id.unsqueeze(0)
310
+ caus_mask = self.current_tok_id.unsqueeze(1) >= self.padded_tok_id.unsqueeze(0)
311
+ mask = doc_mask & caus_mask
312
+ out = self.model.forward(
313
+ current_token,
314
+ tok_idx=self.current_tok_id, # n_seqs
315
+ mask=mask,
316
+ attn_impl="sdpa",
317
+ )
318
+ self.current_tok_id += 1
319
+ return out
320
+
321
+ @torch.inference_mode()
322
+ def generate(self, prompts):
323
+ # Tokenize
324
+ prompts = [
325
+ self.tokenizer.encode(p, add_bos=True, add_eos=False) for p in prompts
326
+ ]
327
+ # Truncate
328
+ max_seqlen = (
329
+ self.max_tokens
330
+ if not hasattr(self.model, "max_seqlen")
331
+ else self.model.max_seqlen
332
+ )
333
+ max_prompt_len = self.max_prompt_len or min(
334
+ max_seqlen - self.max_gen_len, self.max_tokens - self.max_gen_len
335
+ )
336
+ prompts = [p[-max_prompt_len:] for p in prompts]
337
+ # Account for the generation in lengths
338
+ padded_lengths = [len(p) + self.max_gen_len for p in prompts]
339
+ generation = []
340
+ loglikelihood = []
341
+ greedy = []
342
+ it = batch_prompts(prompts, self.max_tokens, lengths=padded_lengths)
343
+ if self.show_progress:
344
+ it = tqdm(it)
345
+ for batch in it:
346
+ n_seqs = len(batch)
347
+ generated_tokens = [[] for _ in range(n_seqs)]
348
+ is_done = [False for _ in range(n_seqs)]
349
+ packed_batch, lengths = pack_prompts(batch)
350
+ packed_batch, lengths = packed_batch.cuda(), lengths.cuda()
351
+ n_seqs = lengths.size(0)
352
+
353
+ # Prefilling cache
354
+ prompt_logits = self.prefill(packed_batch.unsqueeze(0), lengths)
355
+ # Selecting last token in each prompt
356
+ all_tokens = sample_tokens(
357
+ prompt_logits, self.temperature, self.top_p, self.top_k
358
+ )
359
+ start_token = all_tokens[:, lengths.cumsum(0) - 1]
360
+
361
+ for seq_id, tok in enumerate(start_token.squeeze(0).tolist()):
362
+ generated_tokens[seq_id].append(tok)
363
+
364
+ current_token = start_token
365
+ for i in range(1, self.max_gen_len):
366
+
367
+ next_logits = self.generate_next_token(current_token)
368
+ next_token = sample_tokens(
369
+ next_logits.clone(), self.temperature, self.top_p, self.top_k
370
+ )
371
+
372
+ for seq_id, tok in enumerate(next_token.squeeze(0).tolist()):
373
+ if not is_done[seq_id]:
374
+ generated_tokens[seq_id].append(tok)
375
+ current_end_str = self.tokenizer.decode(
376
+ generated_tokens[seq_id][-self.max_until_size :]
377
+ )
378
+ contains_end_string = any(
379
+ [e in current_end_str for e in self.until]
380
+ )
381
+ is_done[seq_id] = (
382
+ contains_end_string or tok == self.tokenizer.eos_id
383
+ )
384
+ if all(is_done):
385
+ break
386
+
387
+ current_token = next_token
388
+
389
+ generation.extend([self.tokenizer.decode(g) for g in generated_tokens])
390
+
391
+ for p, logit in zip(
392
+ batch, prompt_logits.squeeze(0).split(lengths.tolist())
393
+ ):
394
+ x = logit[:-1]
395
+ y = torch.tensor(p[1:], device=x.device)
396
+ loglikelihood.append(-F.cross_entropy(x, y, reduction="none").cpu())
397
+ greedy.append((x.argmax(dim=-1) == y).cpu())
398
+
399
+ return generation, loglikelihood, greedy
400
+
401
+
402
+ def load_consolidated_model_and_tokenizer(
403
+ consolidated_path,
404
+ model_cls=LMTransformer,
405
+ model_args_cls=LMTransformerArgs,
406
+ ):
407
+ ckpt_path = Path(consolidated_path)
408
+ config = ckpt_path / "params.json"
409
+ config = OmegaConf.load(config)
410
+
411
+ param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[
412
+ config.distributed.model_dtype
413
+ ]
414
+ model_args = dataclass_from_dict(model_args_cls, config.model, strict=False)
415
+ tokenizer = build_tokenizer(config.data.tokenizer.name, config.data.tokenizer.path)
416
+ model = model_cls(model_args)
417
+ st_dict = torch.load(ckpt_path / CONSOLIDATE_NAME, weights_only=True)
418
+ model.load_state_dict(st_dict["model"])
419
+ model = model.cuda().eval()
420
+ for param in model.parameters():
421
+ param.data = param.data.to(dtype=param_dtype)
422
+ return model, tokenizer, config
423
+
424
+
425
+ def main():
426
+ # Load CLI arguments (overrides) and combine with a YAML config
427
+ cfg = OmegaConf.from_cli()
428
+ gen_cfg = dataclass_from_dict(
429
+ PackedCausalTransformerGeneratorArgs, cfg, strict=False
430
+ )
431
+ print(cfg)
432
+
433
+ model, tokenizer, _ = load_consolidated_model_and_tokenizer(cfg.ckpt)
434
+
435
+ generator = PackedCausalTransformerGenerator(gen_cfg, model, tokenizer)
436
+
437
+ # Allow multiple prompts
438
+ prompts = []
439
+ while True:
440
+ prompt = input("Enter a prompt (or press enter to finish): ")
441
+ if not prompt:
442
+ break
443
+ prompts.append(prompt)
444
+
445
+ # Start generation
446
+ start_time = time.time()
447
+ generation, loglikelihood, greedy = generator.generate(prompts)
448
+ end_time = time.time()
449
+
450
+ # Calculate tokens per second
451
+ total_tokens = sum(len(tokenizer.encode(gen, False, False)) for gen in generation)
452
+ tokens_per_second = total_tokens / (end_time - start_time)
453
+
454
+ # Display the results
455
+ for i, gen in enumerate(generation):
456
+ print(f"\nPrompt {i+1}: {prompts[i]}")
457
+ print(f"Generated Text: {gen}")
458
+
459
+ print(f"\nTokens per second: {tokens_per_second:.2f}")
460
+
461
+
462
+ if __name__ == "__main__":
463
+ main()
apps/main/lingua_train.py ADDED
@@ -0,0 +1,654 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3
+
4
+ import gc
5
+ import logging
6
+ import os
7
+ import sys
8
+ from contextlib import ExitStack
9
+ from copy import deepcopy
10
+ from dataclasses import asdict, dataclass, field
11
+ from pathlib import Path
12
+ from timeit import default_timer as timer
13
+ from typing import Any, Dict, Optional
14
+
15
+ import torch
16
+ import torch.distributed
17
+ import wandb
18
+ import xformers.profiler
19
+ from lingua.args import dataclass_from_dict, dump_config, flatten_dict
20
+ from lingua.data import (
21
+ DataArgs,
22
+ PackTokensState,
23
+ build_dataloader_from_args,
24
+ init_dataloader_state_from_args,
25
+ )
26
+ from lingua.tokenizers.build_tokenizer import TokenizerArgs
27
+ from omegaconf import OmegaConf
28
+ from pydantic import BaseModel
29
+ from torch.distributed._tensor import DTensor
30
+ from torch.distributed.checkpoint.stateful import Stateful
31
+ from torch.optim import lr_scheduler
32
+
33
+ from bytelatent.checkpoint import (
34
+ CheckpointArgs,
35
+ CheckpointManager,
36
+ load_from_checkpoint,
37
+ )
38
+ from bytelatent.distributed import (
39
+ DistributedArgs,
40
+ EnvironmentArgs,
41
+ check_model_value_range,
42
+ clean_env,
43
+ dist_mean_dict,
44
+ get_device_mesh,
45
+ get_is_master,
46
+ get_world_size,
47
+ init_signal_handler,
48
+ parallelize_model,
49
+ requeue_slurm_job,
50
+ setup_env,
51
+ setup_torch_distributed,
52
+ )
53
+ from bytelatent.logger import init_logger
54
+ from bytelatent.metrics import (
55
+ GPUMemoryMonitor,
56
+ LoggingArgs,
57
+ MetricLogger,
58
+ get_num_params,
59
+ )
60
+ from bytelatent.optim import OptimArgs, build_optimizer
61
+ from bytelatent.probe import AutoProbeD
62
+ from bytelatent.profiling import ProfilerArgs, maybe_run_profiler
63
+ from bytelatent.stool import StoolArgs, launch_job
64
+ from bytelatent.transformer import (
65
+ LMTransformer,
66
+ LMTransformerArgs,
67
+ build_fsdp_grouping_plan,
68
+ get_no_recompute_ops,
69
+ get_num_flop_per_token,
70
+ tp_parallelize,
71
+ )
72
+
73
+ logger = logging.getLogger()
74
+
75
+
76
+ class TrainArgs(BaseModel):
77
+ name: str = "lingua"
78
+ dump_dir: str = ""
79
+
80
+ seed: int = 42
81
+
82
+ # Number of gradient accumulation steps
83
+ # Total batch size is batch_size*grad_acc_steps
84
+ grad_acc_steps: int = 1
85
+
86
+ gc_collect_freq: int = 1000
87
+ probe_freq: int | None = None
88
+
89
+ # Nb optimizer steps to take
90
+ steps: int = 1000
91
+
92
+ data: DataArgs
93
+ optim: OptimArgs
94
+ model: LMTransformerArgs
95
+ distributed: DistributedArgs
96
+ env: EnvironmentArgs
97
+
98
+ checkpoint: CheckpointArgs
99
+ profiling: ProfilerArgs
100
+ logging: LoggingArgs
101
+
102
+ # If set to None, eval is run locally otherwise it launches a new job with the given number of gpus
103
+ async_eval_gpus: int | None = None
104
+ eval: Any | None = None
105
+
106
+
107
+ @dataclass
108
+ class TrainState(Stateful):
109
+ step: int # Nb of steps taken by the optimizer
110
+ acc_step: int # Nb of accumulation steps done since last optimizer step
111
+ scheduler: lr_scheduler.LambdaLR
112
+ data_loader_state: PackTokensState
113
+
114
+ def state_dict(self) -> Dict[str, Any]:
115
+ return {
116
+ "step": self.step,
117
+ "acc_step": self.acc_step,
118
+ "data_loader_state": self.data_loader_state,
119
+ "scheduler": self.scheduler.state_dict(),
120
+ }
121
+
122
+ def load_state_dict(self, state_dict):
123
+ self.step = state_dict["step"]
124
+ self.acc_step = state_dict["acc_step"]
125
+ self.data_loader_state = PackTokensState(**state_dict["data_loader_state"])
126
+ self.scheduler.load_state_dict(state_dict["scheduler"])
127
+
128
+
129
+ def validate_train_args(args: TrainArgs, output_size: int):
130
+ if args.model.vocab_size < 0:
131
+ logger.info(f"Setting model output size to {args.model.vocab_size}")
132
+ args.model.vocab_size = output_size
133
+ assert (
134
+ args.model.vocab_size == output_size
135
+ ), "Vocab size should be the same as output size"
136
+
137
+ assert args.dump_dir, "Dump dir not set"
138
+
139
+ if args.checkpoint.path is None:
140
+ logger.info(f"Setting checkpoint path to {args.checkpoint.path}")
141
+ args.checkpoint.path = str(Path(args.dump_dir) / "checkpoints")
142
+
143
+ for source in args.data.sources:
144
+ data_path = os.path.join(args.data.root_dir, source)
145
+ assert os.path.exists(data_path), f"{data_path} doesn't exist"
146
+
147
+ if (
148
+ args.distributed.dp_replicate
149
+ * args.distributed.dp_shard
150
+ * args.distributed.tp_size
151
+ != get_world_size()
152
+ ):
153
+ assert get_world_size() % args.distributed.dp_shard == 0
154
+ args.distributed.dp_replicate = get_world_size() // args.distributed.dp_shard
155
+
156
+ assert args.distributed.dp_replicate % args.distributed.tp_size == 0
157
+ args.distributed.dp_replicate = (
158
+ args.distributed.dp_replicate // args.distributed.tp_size
159
+ )
160
+
161
+ logger.warning(
162
+ f"Setting Data Parallel size to {args.distributed.dp_replicate * args.distributed.dp_shard}"
163
+ )
164
+ assert (
165
+ args.distributed.dp_replicate
166
+ * args.distributed.dp_shard
167
+ * args.distributed.tp_size
168
+ == get_world_size()
169
+ )
170
+
171
+ if args.distributed.fsdp_type == "no_shard":
172
+ assert (
173
+ args.distributed.dp_shard == 1
174
+ and args.distributed.dp_replicate == get_world_size()
175
+ )
176
+
177
+ args.model.max_seqlen = args.data.seq_len
178
+
179
+ if args.distributed.tp_size == 1:
180
+ logger.warning(
181
+ "Tensor parallelism has not been tested for a while, use at your own risk"
182
+ )
183
+
184
+ assert (
185
+ args.probe_freq != args.profiling.mem_steps
186
+ ), "Don't profile during probe step"
187
+ assert (
188
+ args.probe_freq != args.profiling.profile_steps
189
+ ), "Don't profile during probe step"
190
+ if args.logging.wandb is not None:
191
+ args.logging.wandb.name = args.name
192
+
193
+ if args.probe_freq is not None:
194
+ assert (
195
+ args.distributed.tp_size == 1
196
+ ), "Probing not supported with tensor parallelism"
197
+ assert (
198
+ args.distributed.selective_activation_checkpointing is False
199
+ ), "Probing not supported with selective activation checkpointing"
200
+
201
+
202
+ preemption_flag = dict(flag=False)
203
+
204
+
205
+ def set_preemption_flag(signum, frame):
206
+ logger.warning("Signal handler called with signal " + str(signum))
207
+ logger.warning("Preemption ! checkpointing asap and exiting.")
208
+ preemption_flag["flag"] = True
209
+
210
+
211
+ def every_n_steps(train_state, freq, acc_step=None, acc_freq=None):
212
+ test = train_state.step % freq == 0
213
+ if acc_step is not None:
214
+ test = test and (train_state.acc_step == acc_step)
215
+ elif acc_freq is not None:
216
+ test = test and ((train_state.acc_step % acc_freq) == 0)
217
+ return test
218
+
219
+
220
+ def train(args: TrainArgs):
221
+ with ExitStack() as context_stack:
222
+ tokenizer_args = TokenizerArgs(
223
+ name=args.data.name,
224
+ init_kwargs=args.data.tokenizer.init_kwargs,
225
+ )
226
+ tokenizer = tokenizer_args.build()
227
+ validate_train_args(
228
+ args,
229
+ tokenizer.n_words,
230
+ )
231
+ if get_is_master():
232
+ os.makedirs(args.dump_dir, exist_ok=True)
233
+ dump_config(args, Path(args.dump_dir) / "config.yaml")
234
+ init_logger(Path(args.dump_dir) / "train.log")
235
+ init_signal_handler(set_preemption_flag) # For handling preemption signals.
236
+ setup_env(args.env)
237
+ setup_torch_distributed(args.distributed)
238
+ world_mesh = get_device_mesh(args.distributed)
239
+ logger.info(f"Starting job: {args.name}")
240
+
241
+ # build dataloader
242
+ # need dp world size and rank
243
+ dp_mesh = world_mesh["dp_replicate"]
244
+ dp_degree = dp_mesh.size()
245
+ dp_rank = dp_mesh.get_local_rank()
246
+ if args.distributed.dp_shard > 1:
247
+ dp_rank = dp_rank * dp_degree + world_mesh["dp_shard"].get_local_rank()
248
+ dp_degree *= world_mesh["dp_shard"].size()
249
+
250
+ logger.info(f"Running on dp rank : {dp_rank}")
251
+ logger.info(f"Running on dp size : {dp_degree}")
252
+
253
+ torch.manual_seed(args.seed)
254
+ logger.info("Building model")
255
+
256
+ # Initializing Model in meta device allows us to initialize models much bigger than 1 gpu's memory
257
+ with torch.device("meta"):
258
+ model = LMTransformer(args.model)
259
+ logger.info("Model is built !")
260
+
261
+ model_param_count = get_num_params(model)
262
+
263
+ model = parallelize_model(
264
+ model,
265
+ world_mesh,
266
+ args.model,
267
+ args.distributed,
268
+ fsdp_grouping_plan=build_fsdp_grouping_plan(args.model),
269
+ tp_parallelize=tp_parallelize,
270
+ no_recompute_ops=get_no_recompute_ops(),
271
+ )
272
+
273
+ # Once we shard the model on different gpus we can actually initialize the model
274
+ # First we create empty tensors of the correct shapes
275
+ model = model.to_empty(device="cuda")
276
+ # Then we init the model. Please make sure this function initializes *ALL* parameters
277
+ # and buffers, otherwise you will have random values in the unitialized tensors
278
+ # which will silently fail (give nan gradients for example)
279
+
280
+ if args.checkpoint.init_ckpt_path:
281
+ logger.info(f"Loading initial model from {args.checkpoint.init_ckpt_path}")
282
+ load_from_checkpoint(
283
+ args.checkpoint.init_ckpt_path, model, model_key="model"
284
+ ) # Put model_key="" if its directly the model checkpoint
285
+ model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded
286
+ else:
287
+ with torch.random.fork_rng(devices=[torch.cuda.current_device()]):
288
+ torch.manual_seed(args.model.seed)
289
+ model.init_weights()
290
+ check_model_value_range(model, range=10.0, std=1.0)
291
+
292
+ # log model size
293
+
294
+ logger.info(f"Model size: {model_param_count:,} total parameters")
295
+
296
+ gpu_memory_monitor = GPUMemoryMonitor("cuda")
297
+ logger.info(
298
+ f"GPU capacity: {gpu_memory_monitor.device_name} ({gpu_memory_monitor.device_index}) "
299
+ f"with {gpu_memory_monitor.device_capacity_gib:.2f}GiB memory"
300
+ )
301
+ logger.info(f"GPU memory usage: {gpu_memory_monitor}")
302
+
303
+ # build optimizer after apply parallelisms to the model
304
+ optimizer, scheduler = build_optimizer(model, args.optim, args.steps)
305
+ data_loader_state = init_dataloader_state_from_args(
306
+ args.data, dp_rank, dp_degree
307
+ )
308
+
309
+ train_state = TrainState(
310
+ step=0,
311
+ acc_step=0,
312
+ data_loader_state=data_loader_state,
313
+ scheduler=scheduler,
314
+ )
315
+
316
+ checkpoint = CheckpointManager.instantiate_and_make_dir(args.checkpoint)
317
+ checkpoint.load(model, optimizer, train_state, world_mesh)
318
+ # Either load from latest checkpoint or start from scratch
319
+ if args.probe_freq is not None:
320
+ if get_is_master():
321
+ os.makedirs(Path(args.dump_dir) / "probe", exist_ok=True)
322
+ torch.distributed.barrier()
323
+ probe = AutoProbeD(
324
+ model,
325
+ (
326
+ Path(args.dump_dir) / "probe" / f"probe.{dp_rank}.jsonl"
327
+ if (dp_rank % 128 == 0)
328
+ else None
329
+ ),
330
+ )
331
+ probe_mod = model._orig_mod if args.distributed.compile else model
332
+
333
+ gc.disable()
334
+
335
+ # train loop
336
+ model.train()
337
+ metric_logger = context_stack.enter_context(
338
+ MetricLogger(Path(args.dump_dir) / "metrics.jsonl", args)
339
+ )
340
+ data_loader = context_stack.enter_context(
341
+ build_dataloader_from_args(
342
+ args.data,
343
+ state=train_state.data_loader_state,
344
+ )
345
+ )
346
+ torch_profiler = context_stack.enter_context(
347
+ maybe_run_profiler(args.dump_dir, model, args.profiling)
348
+ )
349
+
350
+ nwords_since_last_log = 0
351
+ time_last_log = timer()
352
+ gc.collect()
353
+ while train_state.step < args.steps:
354
+ # We constrain train_state.acc_step to be in range 0 to args.grad_acc_steps - 1
355
+ train_state.acc_step += 1
356
+ train_state.acc_step = train_state.acc_step % args.grad_acc_steps
357
+
358
+ # get batch
359
+ curr_lr = float(optimizer.param_groups[0]["lr"])
360
+ data_load_start = timer()
361
+ batch, train_state.data_loader_state = next(data_loader)
362
+ batch = torch.tensor(
363
+ batch,
364
+ dtype=torch.long,
365
+ )
366
+
367
+ if every_n_steps(train_state, args.gc_collect_freq, acc_step=0):
368
+ logger.info("garbage collection")
369
+ # we do garbage collection manually otherwise different processes
370
+ # run the GC at different times so they slow down the whole pipeline
371
+ gc.collect()
372
+
373
+ input_ids = batch[:, :, 0].cuda()
374
+ labels = batch[:, :, 1].cuda()
375
+ data_load_time = round(timer() - data_load_start, 4)
376
+ nwords_since_last_log += input_ids.numel()
377
+
378
+ bsz, seqlen = labels.shape
379
+
380
+ # forward
381
+ start_timer = torch.cuda.Event(enable_timing=True)
382
+ end_timer = torch.cuda.Event(enable_timing=True)
383
+ start_timer.record()
384
+
385
+ # This is an automatic probe that will compute statistics
386
+ # of all linears' inputs, weights and outputs
387
+ # along with attention logits and entropy
388
+ # both in forward and backward pass
389
+ if (args.probe_freq is not None) and every_n_steps(
390
+ train_state, args.probe_freq, acc_step=1 % args.grad_acc_steps
391
+ ):
392
+ # Here we do a fake forward and backward pass on a smaller
393
+ # batch size to avoid OOM
394
+ # This assumes the model has no stateful layers (batch norm..)
395
+ assert (
396
+ next(probe_mod.parameters()).grad is None
397
+ ), "Can't probe model if grads are not reset"
398
+
399
+ with probe:
400
+ probe.metadata = {
401
+ "it": train_state.step,
402
+ "global_step": train_state.step,
403
+ "loop": "lingua",
404
+ }
405
+ # Non compiled model uses roughly 2x memory in our exps
406
+ # So we divide bsz by 2 or seqlen by 2
407
+ probe_bsz = max(1, bsz // 2)
408
+ probe_seq = seqlen if (bsz // 2 >= 1) else (seqlen // 2)
409
+ probe_loss = probe_mod(
410
+ input_ids[:probe_bsz, :probe_seq],
411
+ labels[:probe_bsz, :probe_seq],
412
+ )
413
+ probe_loss.backward()
414
+ # We zero grads to cancel this fake step
415
+ optimizer.zero_grad()
416
+
417
+ assert (
418
+ next(probe_mod.parameters()).grad is None
419
+ ), "Probe model shouldn't have grads at this point"
420
+ loss = model(input_ids, labels)
421
+
422
+ # We scale loss with grad_acc_steps so the gradient is the same
423
+ # regardless of grad_acc_steps
424
+ loss = loss / args.grad_acc_steps
425
+ # backward on scaled loss to create scaled gradients
426
+ loss.backward()
427
+ # For logging we undo that scaling
428
+ loss = loss.detach() * args.grad_acc_steps
429
+
430
+ grad_norm = torch.nn.utils.clip_grad_norm_(
431
+ model.parameters(), max_norm=args.optim.clip, foreach=True
432
+ )
433
+
434
+ grad_norm = (
435
+ grad_norm.full_tensor() if isinstance(grad_norm, DTensor) else grad_norm
436
+ ).item()
437
+
438
+ # optimizer step
439
+ if train_state.acc_step == 0:
440
+ optimizer.step()
441
+ scheduler.step()
442
+ optimizer.zero_grad()
443
+ train_state.step += 1
444
+
445
+ # updates the scale for next iteration
446
+ # training iteration complete
447
+ end_timer.record()
448
+
449
+ torch.cuda.synchronize()
450
+
451
+ curr_iter_time = round(start_timer.elapsed_time(end_timer) * 1e-3, 4)
452
+
453
+ # if profiler is active
454
+ if torch_profiler:
455
+ xformers.profiler.step()
456
+
457
+ # log metrics
458
+ if every_n_steps(
459
+ train_state,
460
+ args.logging.freq,
461
+ acc_step=None if args.logging.acc_freq else 0,
462
+ acc_freq=args.logging.acc_freq,
463
+ ):
464
+ time_delta = timer() - time_last_log
465
+ wps = nwords_since_last_log / (time_delta * args.distributed.tp_size)
466
+
467
+ gpu_mem_stats = gpu_memory_monitor.get_peak_stats()
468
+
469
+ total_acc_steps = (
470
+ args.grad_acc_steps * train_state.step + train_state.acc_step
471
+ )
472
+ tokens_per_gpu = (
473
+ total_acc_steps * args.data.batch_size * args.data.seq_len
474
+ )
475
+ total_tokens = dp_degree * tokens_per_gpu
476
+ # This is an estimate and the correct values may change
477
+ # if you change the architecture
478
+ # Use xformer's analyze profile trace to get actual measurement
479
+ FLOPS = (
480
+ get_num_flop_per_token(
481
+ model_param_count - args.model.vocab_size * args.model.dim,
482
+ args.model.n_layers,
483
+ args.model.dim,
484
+ args.data.seq_len,
485
+ )
486
+ * wps
487
+ )
488
+ metrics = flatten_dict(
489
+ {
490
+ "global_step": train_state.step,
491
+ "acc_step": train_state.acc_step,
492
+ "speed": {
493
+ "wps": wps,
494
+ "FLOPS": FLOPS,
495
+ "curr_iter_time": curr_iter_time,
496
+ "data_load_time": data_load_time,
497
+ },
498
+ "optim": {
499
+ "grad_norm": grad_norm,
500
+ "lr": curr_lr,
501
+ "total_tokens": total_tokens,
502
+ },
503
+ "memory": gpu_mem_stats._asdict(),
504
+ },
505
+ sep="/",
506
+ )
507
+
508
+ to_sync = {}
509
+ to_sync["loss/out"] = loss.item()
510
+ metrics.update(dist_mean_dict(to_sync))
511
+
512
+ if get_is_master():
513
+ metric_logger.log(metrics)
514
+
515
+ gpu_memory_monitor.reset_peak_stats()
516
+ nwords_since_last_log = 0
517
+ time_last_log = timer()
518
+ logger.info(
519
+ f"step: {train_state.step}"
520
+ f" acc: {train_state.acc_step}"
521
+ f" loss: {round(loss.item(),4):>7}"
522
+ f" grad: {grad_norm:.2e}"
523
+ f" flops: {FLOPS:.2e}"
524
+ f" wps: {wps:.2e}"
525
+ f" iter: {curr_iter_time:>7}"
526
+ f" data: {data_load_time:>5}"
527
+ f" lr: {curr_lr:.2e}"
528
+ f" mem: {gpu_mem_stats.max_active_pct:.0f}%"
529
+ f" pow: {gpu_mem_stats.power_draw/1000} W"
530
+ )
531
+
532
+ saved = False
533
+ if every_n_steps(
534
+ train_state, args.checkpoint.dump.every, acc_step=0
535
+ ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):
536
+ saved = checkpoint.save(
537
+ model,
538
+ optimizer,
539
+ train_state,
540
+ args,
541
+ device_mesh=world_mesh,
542
+ )
543
+
544
+ if args.eval is not None and every_n_steps(
545
+ train_state, args.checkpoint.eval.every, acc_step=0
546
+ ):
547
+ from apps.main.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval
548
+
549
+ eval_args = dataclass_from_dict(EvalArgs, args.eval)
550
+
551
+ eval_args.global_step = train_state.step
552
+ eval_args.ckpt_dir = str(checkpoint.existing_saves[-1])
553
+ eval_args.dump_dir = str(
554
+ os.path.join(
555
+ args.dump_dir,
556
+ "evals",
557
+ EVAL_FOLDER_NAME.format(train_state.step),
558
+ )
559
+ )
560
+ eval_args.metric_log_dir = args.dump_dir
561
+ if args.async_eval_gpus is None:
562
+ launch_eval(eval_args)
563
+ elif get_is_master():
564
+ if wandb.run is not None and args.logging.wandb is not None:
565
+ eval_args.wandb = deepcopy(args.logging.wandb)
566
+ assert args.async_eval_gpus > 0
567
+ logger.info(f"Launching evals on {args.async_eval_gpus} gpus")
568
+ with clean_env():
569
+ launch_job(
570
+ StoolArgs(
571
+ asdict(eval_args),
572
+ script="apps.main.eval",
573
+ copy_code=False,
574
+ nodes=args.async_eval_gpus // 8,
575
+ qos="lowest",
576
+ )
577
+ )
578
+
579
+ if preemption_flag["flag"]:
580
+ if not saved:
581
+ checkpoint.save(
582
+ model,
583
+ optimizer,
584
+ train_state,
585
+ args,
586
+ device_mesh=world_mesh,
587
+ )
588
+ requeue_slurm_job()
589
+ sys.exit(0)
590
+
591
+ if not saved:
592
+ checkpoint.save(
593
+ model,
594
+ optimizer,
595
+ train_state,
596
+ args,
597
+ device_mesh=world_mesh,
598
+ )
599
+ gc.collect()
600
+
601
+
602
+ def main():
603
+ """
604
+ The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments
605
+ This accepts arguments as a dot list
606
+ So if the dataclass looks like
607
+
608
+ @dataclass
609
+ class DummyArgs:
610
+ name: str
611
+ model: LMTransformerArgsgs
612
+
613
+ @dataclass
614
+ class LMTransformerArgsgs:
615
+ dim: int
616
+
617
+ Then you can pass model.dim=32 to change values in LMTransformerArgsgs
618
+ or just name=tictac for top level attributes.
619
+
620
+ The behavior here is as follows:
621
+ 1. We instantiate TrainArgs with its default values
622
+ 2. We override those default values with the ones in the provided config file
623
+ 3. We override the result with the additional arguments provided through command line
624
+
625
+ For example, if the config is the following
626
+
627
+ model:
628
+ dim: 128
629
+ n_layers: 4
630
+
631
+ and you call train.py with train.py model.dim=64
632
+
633
+ Then the final TrainArgs will have
634
+
635
+ model:
636
+ dim: 64
637
+ n_layers: 4
638
+
639
+ Plus all the default values in TrainArgs dataclass.
640
+ """
641
+ cli_args = OmegaConf.from_cli()
642
+ file_cfg = OmegaConf.load(cli_args.config)
643
+ # We remove 'config' attribute from config as the underlying DataClass does not have it
644
+ del cli_args.config
645
+
646
+ default_cfg = OmegaConf.structured(TrainArgs())
647
+ cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
648
+ cfg = OmegaConf.to_object(cfg)
649
+
650
+ train(cfg)
651
+
652
+
653
+ if __name__ == "__main__":
654
+ main()
blt-figure.jpg ADDED
blt-figure.pdf ADDED
Binary file (62.5 kB). View file
 
bytelatent/.DS_Store ADDED
Binary file (6.15 kB). View file
 
bytelatent/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ class ByteLatentError(Exception):
3
+ pass
bytelatent/args.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ import logging
3
+ import os
4
+ from typing import Any
5
+
6
+ import numpy as np
7
+ import yaml
8
+ from pydantic import BaseModel, ConfigDict
9
+
10
+ from bytelatent.checkpoint import CheckpointArgs
11
+ from bytelatent.data.data_types import Batch
12
+ from bytelatent.data.iterators.abstract_iterator import StatefulIterator
13
+ from bytelatent.data.iterators.arrow_iterator import (
14
+ ArrowFileIterator,
15
+ find_and_sanitize_chunks,
16
+ )
17
+ from bytelatent.data.iterators.looping_iterator import LoopingIterator
18
+ from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator
19
+ from bytelatent.data.iterators.packing_iterator import PackingArgs, PackingIterator
20
+ from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator
21
+ from bytelatent.data.iterators.sampling_iterator import SamplingIterator
22
+ from bytelatent.data.iterators.sequence_iterator import (
23
+ SequenceIterator,
24
+ SequencePackingArgs,
25
+ )
26
+ from bytelatent.data.patcher import PatcherArgs
27
+ from bytelatent.distributed import DistributedArgs, EnvironmentArgs
28
+ from bytelatent.metrics import LoggingArgs
29
+ from bytelatent.model.blt import ByteLatentTransformerArgs
30
+ from bytelatent.optim import OptimArgs
31
+ from bytelatent.profiling import ProfilerArgs
32
+ from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
33
+
34
+ logger = logging.getLogger()
35
+
36
+
37
+ def get_rng_state(seed: int, rank: int, world_size: int) -> dict[str, Any]:
38
+ return np.random.default_rng((seed, rank, world_size)).bit_generator.state
39
+
40
+
41
+ def distribute_data_to_rank(
42
+ *,
43
+ dataset_path: str,
44
+ preprocess_dir: str,
45
+ entropy_model_name: str | None,
46
+ arrow_batch_size: int,
47
+ rank: int,
48
+ world_size: int,
49
+ ) -> ArrowFileIterator:
50
+ dataset_chunks = find_and_sanitize_chunks(dataset_path, world_size)
51
+ n_workers_per_chunk = world_size // len(dataset_chunks)
52
+ rank_to_arrow_iterator_params = []
53
+ for chunk_path in dataset_chunks:
54
+ for worker_id in range(n_workers_per_chunk):
55
+ rank_to_arrow_iterator_params.append(
56
+ ArrowFileIterator(
57
+ file_path=chunk_path,
58
+ worker_id=worker_id,
59
+ num_workers=n_workers_per_chunk,
60
+ preprocess_dir=preprocess_dir,
61
+ dataset_files=None,
62
+ entropy_model_name=entropy_model_name,
63
+ arrow_batch_size=arrow_batch_size,
64
+ )
65
+ )
66
+ return rank_to_arrow_iterator_params[rank]
67
+
68
+
69
+ class DataloaderArgs(BaseModel):
70
+ model_config = ConfigDict(extra="forbid")
71
+ root_dir: str | None = None
72
+ sources: dict[str, float] = {}
73
+ batch_size: int = 2
74
+ seq_len: int = 2048
75
+ seed: int = 42
76
+ add_bos: bool = True
77
+ add_eos: bool = True
78
+ load_async: bool = True
79
+ prefetch_size: int = 64
80
+ preprocess_dir: str | None = None
81
+ dataset_files: list[str] | None = None
82
+ entropy_model_name: str | None = "transformer_100m"
83
+ arrow_batch_size: int = 100
84
+ buffer_size: int = 64
85
+
86
+ pad_to_max_length: bool = True
87
+ max_encoder_seq_length: int = 12288
88
+ enable_byte_ngrams: bool = False
89
+
90
+ tokenizer_args: TokenizerArgs = TokenizerArgs()
91
+ patcher_args: PatcherArgs = PatcherArgs()
92
+
93
+ def _create_sequence_iterators(
94
+ self, rank: int, world_size: int
95
+ ) -> dict[str, SequenceIterator]:
96
+ sequence_packing_args = SequencePackingArgs(
97
+ output_seq_len=self.seq_len,
98
+ buffer_size=self.buffer_size,
99
+ )
100
+ source_to_sequence_iterator: dict[str, SequenceIterator] = {}
101
+ for dataset_path in self.sources:
102
+ shuffle_rng_state = get_rng_state(self.seed + 1, rank, world_size)
103
+ arrow_iterator = distribute_data_to_rank(
104
+ dataset_path=os.path.join(self.root_dir, dataset_path),
105
+ preprocess_dir=self.preprocess_dir,
106
+ entropy_model_name=self.entropy_model_name,
107
+ arrow_batch_size=self.arrow_batch_size,
108
+ rank=rank,
109
+ world_size=world_size,
110
+ )
111
+ looping_iterator = LoopingIterator(arrow_iterator)
112
+ preprocess_iterator = PreprocessIterator(
113
+ looping_iterator,
114
+ patcher_args=self.patcher_args,
115
+ tokenizer_args=self.tokenizer_args,
116
+ )
117
+ sequence_iterator = SequenceIterator(
118
+ preprocess_iterator,
119
+ sequence_packing_args=sequence_packing_args,
120
+ rng_state=shuffle_rng_state,
121
+ )
122
+
123
+ source_to_sequence_iterator[dataset_path] = sequence_iterator
124
+ return source_to_sequence_iterator
125
+
126
+ def build_from_rank(
127
+ self, rank: int, world_size: int
128
+ ) -> StatefulIterator[Batch, Any]:
129
+ source_to_sequence_iterators = self._create_sequence_iterators(rank, world_size)
130
+ weight_rng_state = get_rng_state(self.seed + 1, rank, world_size)
131
+ sampling_iterator = SamplingIterator(
132
+ rng_state=weight_rng_state,
133
+ source_to_weight=self.sources,
134
+ source_to_iterator=source_to_sequence_iterators,
135
+ )
136
+ tokenizer = self.tokenizer_args.build()
137
+ packing_args = PackingArgs(
138
+ batch_size=self.batch_size,
139
+ seq_len=self.seq_len,
140
+ pad_id=tokenizer.boe_id,
141
+ max_length=self.max_encoder_seq_length,
142
+ pad_to_max_length=self.pad_to_max_length,
143
+ enable_byte_ngrams=self.enable_byte_ngrams,
144
+ )
145
+ packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args)
146
+ mp_iterator = MultiprocessIterator(
147
+ packing_iterator, n_batches_to_prefetch=self.prefetch_size
148
+ )
149
+
150
+ return mp_iterator
151
+
152
+
153
+ class TrainArgs(BaseModel):
154
+ model_config = ConfigDict(extra="forbid")
155
+ name: str = "lingua"
156
+ dump_dir: str = ""
157
+
158
+ seed: int = 42
159
+
160
+ # Number of gradient accumulation steps
161
+ # Total batch size is batch_size*grad_acc_steps
162
+ grad_acc_steps: int = 1
163
+
164
+ gc_collect_freq: int = 1000
165
+ probe_freq: int | None = None
166
+
167
+ # Nb optimizer steps to take
168
+ steps: int = 1000
169
+
170
+ data: DataloaderArgs = DataloaderArgs()
171
+ optim: OptimArgs = OptimArgs()
172
+ model: ByteLatentTransformerArgs = ByteLatentTransformerArgs()
173
+ distributed: DistributedArgs = DistributedArgs()
174
+ env: EnvironmentArgs = EnvironmentArgs()
175
+
176
+ checkpoint: CheckpointArgs = CheckpointArgs()
177
+ profiling: ProfilerArgs = ProfilerArgs()
178
+ logging: LoggingArgs = LoggingArgs()
179
+
180
+ # If set to None, eval is run locally otherwise it launches a new job with the given number of gpus
181
+ async_eval_gpus: int | None = None
182
+ eval: Any | None = None
183
+ eval_on_gpus: int | None = None
184
+
185
+ def dump_to_yaml_file(
186
+ self, path: str, log_config: bool = True, sort_keys: bool = True
187
+ ):
188
+ model_dict = self.model_dump(mode="json")
189
+ yaml_str = yaml.dump(
190
+ model_dict,
191
+ allow_unicode=True,
192
+ sort_keys=sort_keys,
193
+ default_flow_style=False,
194
+ )
195
+ with open(path, "w") as f:
196
+ if log_config:
197
+ logger.info("Using the following config for this run:")
198
+ logger.info(yaml_str)
199
+ f.write(yaml_str)
bytelatent/base_transformer.py ADDED
@@ -0,0 +1,585 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from enum import Enum
4
+ from typing import Optional, Tuple, Union
5
+
6
+ import torch
7
+ from pydantic import BaseModel
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+ from torch.nn.attention.flex_attention import (
11
+ BlockMask,
12
+ _mask_mod_signature,
13
+ flex_attention,
14
+ )
15
+ from xformers.ops import AttentionBias, fmha
16
+
17
+ from bytelatent import probe
18
+
19
+ flex_attention_comp = torch.compile(flex_attention)
20
+
21
+
22
+ class InitStdFactor(Enum):
23
+ DISABLED = "disabled" # Init std is divided by 1.0
24
+ GLOBAL_DEPTH = "global_depth" # Init std is divided by sqrt(2*n_layers)
25
+ CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth)
26
+ DIM_RATIO = "dim_ratio" # Init std is divided by model_dim/4096
27
+
28
+
29
+ class BaseTransformerArgs(BaseModel):
30
+ dim: int = 512
31
+ n_layers: int = 8
32
+ head_dim: Optional[int] = None
33
+ n_heads: Optional[int] = None
34
+ n_kv_heads: Optional[int] = None
35
+
36
+ ffn_dim_multiplier: Optional[float] = None
37
+
38
+ multiple_of: int = 256
39
+
40
+ norm_eps: float = 1e-5
41
+
42
+ rope_theta: float = 10000.0
43
+
44
+ init_base_std: Optional[float] = None
45
+ init_std_factor: InitStdFactor = InitStdFactor.DISABLED
46
+
47
+ max_seqlen: int = 1024
48
+
49
+
50
+ def cross_entropy(pred, target, **kwargs):
51
+ return F.nll_loss(
52
+ F.log_softmax(pred.flatten(end_dim=-2).float(), -1),
53
+ target.flatten(end_dim=-1),
54
+ **kwargs,
55
+ )
56
+
57
+
58
+ def repeat_kv(x: torch.Tensor, n_rep: int, dim: int) -> torch.Tensor:
59
+ """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
60
+ assert dim == 2, "Only dim=2 is supported. Check the implementation for other dims."
61
+ bs, slen, n_kv_heads, head_dim = x.shape
62
+ if n_rep == 1:
63
+ return x
64
+ return (
65
+ x[:, :, :, None, :]
66
+ .expand(bs, slen, n_kv_heads, n_rep, head_dim)
67
+ .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
68
+ )
69
+
70
+
71
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
72
+ """
73
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
74
+
75
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
76
+ and the end index 'end'. The 'theta' parameter scales the frequencies.
77
+ The returned tensor contains complex values in complex64 data type.
78
+
79
+ Args:
80
+ dim (int): Dimension of the frequency tensor.
81
+ end (int): End index for precomputing frequencies.
82
+ theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
83
+
84
+ Returns:
85
+ torch.Tensor: Precomputed frequency tensor with complex exponentials.
86
+ """
87
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
88
+ t = torch.arange(end, device=freqs.device)
89
+ freqs = torch.outer(t, freqs).float()
90
+
91
+ cos, sin = freqs.cos(), freqs.sin()
92
+
93
+ return torch.stack((cos, -sin, sin, cos), dim=-1).view(*freqs.size(), 2, 2)
94
+
95
+
96
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor, seq_dim: int):
97
+ """
98
+ Reshape frequency tensor for broadcasting it with another tensor.
99
+
100
+ This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
101
+ for the purpose of broadcasting the frequency tensor during element-wise operations.
102
+
103
+ Args:
104
+ freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
105
+ x (torch.Tensor): Target tensor for broadcasting compatibility.
106
+ seq_dim (int): Sequence dimension index.
107
+
108
+ Returns:
109
+ torch.Tensor: Reshaped frequency tensor.
110
+ """
111
+ ndim = x.ndim
112
+ assert 0 <= seq_dim < ndim
113
+ assert freqs_cis.shape == (
114
+ x.shape[seq_dim],
115
+ x.shape[-3],
116
+ 2,
117
+ 2,
118
+ ), f"freqs_cis vs x: {(freqs_cis.shape, x.shape)}"
119
+ shape = [
120
+ d if i == seq_dim or i == ndim - 3 else 1 for i, d in enumerate(x.shape[:-2])
121
+ ] + [2, 2]
122
+ return freqs_cis.view(*shape)
123
+
124
+
125
+ def apply_rotary_emb(
126
+ xq: torch.Tensor,
127
+ xk: torch.Tensor,
128
+ seq_dim: int,
129
+ freqs_cis: torch.Tensor,
130
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
131
+ xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2
132
+ xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2
133
+ freqs_cis = reshape_for_broadcast(
134
+ freqs_cis, xq_, seq_dim
135
+ ).float() # S D/2 2 2 -> 1 S 1 D/2 2 2
136
+ xq_out = (xq_ * freqs_cis).sum(5).flatten(3)
137
+ xk_out = (xk_ * freqs_cis).sum(5).flatten(3)
138
+ return xq_out.type_as(xq), xk_out.type_as(xk)
139
+
140
+
141
+ def causal_mask(b, h, q_idx, kv_idx):
142
+ return q_idx >= kv_idx
143
+
144
+
145
+ def lengths_to_start_ids(lengths):
146
+ doc_start = lengths.cumsum(0)
147
+ doc_start = doc_start.roll(1)
148
+ doc_start[0] = 0
149
+ return doc_start
150
+
151
+
152
+ def lengths_to_local_ids(lengths):
153
+ assert lengths.ndim == 1
154
+ nb_seqs = lengths.size(0)
155
+ total_seqlen = lengths.sum()
156
+ # This gives the document id of each token
157
+ doc_id = torch.repeat_interleave(lengths)
158
+ # Compute document start for each document
159
+ doc_start = lengths_to_start_ids(lengths)
160
+ # Compute document start for each token
161
+ doc_start = doc_start[doc_id]
162
+ # Compute the position of each token within each document
163
+ tok_id = torch.arange(total_seqlen, device=lengths.device) - doc_start
164
+
165
+ return doc_id, tok_id
166
+
167
+
168
+ def generate_doc_mask_mod(
169
+ mask_mod: _mask_mod_signature,
170
+ lengths: torch.Tensor,
171
+ kv_lengths: Optional[torch.Tensor] = None,
172
+ ) -> _mask_mod_signature:
173
+ """Generates mask mods that apply to inputs to flex attention in the sequence stacked
174
+ format.
175
+
176
+ Args:
177
+ mask_mod: The mask mod to apply to the documents
178
+ lengths: Lengths of each document
179
+
180
+ Note:
181
+ What is the sequence stacked format? When assembling batches of inputs, we
182
+ take multiple sequences and stack them together to form 1 large sequence. We then
183
+ use masking to ensure that the attention scores are only applied to tokens within
184
+ the same document.
185
+
186
+ Example:
187
+
188
+ - Square mask
189
+ doc_mask lengths
190
+ a a b b b c c 2 3 2
191
+ a 1 0 0 0 0 0 0
192
+ a 1 1 0 0 0 0 0
193
+ b 0 0 1 0 0 0 0
194
+ b 0 0 1 1 0 0 0
195
+ b 0 0 1 1 1 0 0
196
+ c 0 0 0 0 0 1 0
197
+ c 0 0 0 0 0 1 1
198
+
199
+ """
200
+ kv_lengths = kv_lengths if kv_lengths is not None else lengths
201
+ q_document_id, q_token_id = lengths_to_local_ids(lengths)
202
+ kv_document_id, kv_token_id = lengths_to_local_ids(kv_lengths)
203
+ q_max_idx = lengths.sum() - 1
204
+ kv_max_idx = kv_lengths.sum() - 1
205
+
206
+ def doc_mask_mod(b, h, q_idx, kv_idx):
207
+ q_idx_cap = torch.minimum(q_max_idx, q_idx)
208
+ kv_idx_cap = torch.minimum(kv_max_idx, kv_idx)
209
+ valid_idx = (q_idx <= q_max_idx) & (kv_idx <= kv_max_idx)
210
+ same_doc = q_document_id[q_idx_cap] == kv_document_id[kv_idx_cap]
211
+ q_logical = q_token_id[q_idx_cap]
212
+ kv_logical = kv_token_id[kv_idx_cap]
213
+ inner_mask = mask_mod(b, h, q_logical, kv_logical)
214
+ return same_doc & inner_mask & valid_idx
215
+
216
+ return doc_mask_mod
217
+
218
+
219
+ # Rotary embedding as in xformer, see if torchtrain implementation is not better. Also might be usefull to make it work with batch*seqlen collapsed.
220
+ class RotaryEmbedding(torch.nn.Module):
221
+ """
222
+ RotaryEmbedding Module
223
+ """
224
+
225
+ def __init__(self, theta: float, head_dim: int, max_seqlen: int = 1024):
226
+ super().__init__()
227
+
228
+ self.theta = theta
229
+ self.head_dim = head_dim
230
+ self.max_seqlen = max_seqlen
231
+
232
+ self.register_buffer(
233
+ "freqs_cis",
234
+ precompute_freqs_cis(dim=head_dim, end=max_seqlen, theta=theta),
235
+ persistent=False,
236
+ )
237
+
238
+ def reset_parameters(self):
239
+ self.freqs_cis[...] = precompute_freqs_cis(
240
+ dim=self.head_dim, end=self.max_seqlen, theta=self.theta
241
+ )
242
+
243
+ def forward(
244
+ self, seqlen: Optional[int] = None, tok_idx: Optional[torch.Tensor] = None
245
+ ):
246
+ """
247
+ Return freqs_cis corresponding to consecutive seqlen positions or the corresponding tok_idx positions
248
+ Args:
249
+ seqlen (int): Contiguous sequence length
250
+ tok_idx (torch.Tensor[int]): Position indices of each token this overrides seqlen
251
+
252
+ Returns:
253
+ Tuple(torch.Tensor, torch.Tensor): Embedded input tensor and freqs_cis
254
+ """
255
+ test = (seqlen is not None) or (tok_idx is not None)
256
+ assert test, "Should provide atleast seqlen or tok_idx"
257
+ if tok_idx is not None:
258
+ return self.freqs_cis[tok_idx]
259
+ elif seqlen is not None:
260
+ return self.freqs_cis[0:seqlen]
261
+
262
+
263
+ class RMSNorm(nn.Module):
264
+ """
265
+ Initialize the RMSNorm normalization layer.
266
+
267
+ Args:
268
+ dim (int): The dimension of the input tensor.
269
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
270
+
271
+ Attributes:
272
+ eps (float): A small value added to the denominator for numerical stability.
273
+ weight (nn.Parameter): Learnable scaling parameter.
274
+
275
+ """
276
+
277
+ def __init__(self, dim: int, eps: float = 1e-6):
278
+ super().__init__()
279
+ self.eps = eps
280
+ self.weight = nn.Parameter(torch.ones(dim))
281
+
282
+ def _norm(self, x: torch.Tensor):
283
+ return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps)
284
+
285
+ def forward(self, x: torch.Tensor):
286
+ x = probe.log_stats(x, "resid")
287
+ output = self._norm(x.float())
288
+ return (output * self.weight.float()).type_as(x)
289
+
290
+ def reset_parameters(self):
291
+ torch.nn.init.ones_(self.weight) # type: ignore
292
+
293
+
294
+ class Attention(nn.Module):
295
+ def __init__(
296
+ self,
297
+ dim: int,
298
+ head_dim: int,
299
+ n_heads: int,
300
+ n_kv_heads: int,
301
+ rope_theta: float,
302
+ ):
303
+ super().__init__()
304
+
305
+ self.dim = dim
306
+ self.head_dim = head_dim
307
+ self.rope_theta = rope_theta
308
+
309
+ self.n_heads = n_heads
310
+ self.n_kv_heads = n_kv_heads
311
+ self.heads_per_group = self.n_heads // self.n_kv_heads
312
+
313
+ self.wq = nn.Linear(
314
+ dim,
315
+ n_heads * head_dim,
316
+ bias=False,
317
+ )
318
+ self.wk = nn.Linear(
319
+ dim,
320
+ n_kv_heads * head_dim,
321
+ bias=False,
322
+ )
323
+ self.wv = nn.Linear(
324
+ dim,
325
+ n_kv_heads * head_dim,
326
+ bias=False,
327
+ )
328
+
329
+ self.wo = nn.Linear(
330
+ n_heads * head_dim,
331
+ dim,
332
+ bias=False,
333
+ )
334
+
335
+ def forward(
336
+ self,
337
+ x: torch.Tensor,
338
+ freq_cis: torch.Tensor,
339
+ tok_idx: Optional[torch.Tensor] = None,
340
+ mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
341
+ attn_impl: str = "sdpa",
342
+ ) -> torch.Tensor:
343
+ # B S D
344
+ bsz, seq_len, dim = x.shape
345
+ xq = self.wq(x.view_as(x))
346
+ xk = self.wk(x.view_as(x))
347
+ xv = self.wv(x.view_as(x))
348
+
349
+ output_shape = xq.shape
350
+ # B S D -> B S H D
351
+ xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim)
352
+ xk = xk.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
353
+ xv = xv.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
354
+
355
+ xq, xk = apply_rotary_emb(xq, xk, 1, freq_cis[0:seq_len])
356
+
357
+ # This condition helps us be easily compatible
358
+ # with inference by adding a pluggable KVCache
359
+ if hasattr(self, "kv_cache"):
360
+ xk, xv = self.kv_cache.update(xk, xv, tok_idx)
361
+
362
+ xk = repeat_kv(xk, self.heads_per_group, dim=2)
363
+ xv = repeat_kv(xv, self.heads_per_group, dim=2)
364
+
365
+ if attn_impl == "flex_attention":
366
+ assert mask is None or isinstance(mask, BlockMask)
367
+ xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv))
368
+ output = flex_attention_comp(xq, xk, xv, block_mask=mask)
369
+ output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
370
+
371
+ elif attn_impl == "fmha":
372
+ assert mask is None or isinstance(mask, AttentionBias)
373
+ output = fmha.memory_efficient_attention(xq, xk, xv, attn_bias=mask)
374
+ # This uses B S H D instead of B H S D of pytorch
375
+
376
+ elif attn_impl == "sdpa":
377
+ xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv))
378
+ assert mask is None or isinstance(mask, (str, torch.Tensor))
379
+ is_causal = (mask == "causal") if isinstance(mask, str) else False
380
+ mask = mask if isinstance(mask, torch.Tensor) else None
381
+ output = F.scaled_dot_product_attention(
382
+ xq,
383
+ xk,
384
+ xv,
385
+ is_causal=is_causal,
386
+ attn_mask=mask,
387
+ )
388
+ output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
389
+ else:
390
+ raise NotImplementedError(
391
+ f"Attention implementation {attn_impl} not supported"
392
+ )
393
+
394
+ output = self.wo(output.reshape(output_shape))
395
+
396
+ return output
397
+
398
+ def reset_parameters(self, init_std=None, factor=1.0):
399
+ init_std = init_std or (self.dim ** (-0.5))
400
+
401
+ for w in [self.wq, self.wk, self.wv]:
402
+ nn.init.trunc_normal_(
403
+ w.weight,
404
+ mean=0.0,
405
+ std=init_std,
406
+ a=-3 * init_std,
407
+ b=3 * init_std,
408
+ )
409
+
410
+ nn.init.trunc_normal_(
411
+ self.wo.weight,
412
+ mean=0.0,
413
+ std=init_std / factor,
414
+ a=-3 * init_std,
415
+ b=3 * init_std,
416
+ )
417
+
418
+
419
+ class FeedForward(nn.Module):
420
+ def __init__(
421
+ self,
422
+ dim: int,
423
+ hidden_dim: int,
424
+ multiple_of: int,
425
+ ffn_dim_multiplier: Optional[float],
426
+ mp_size: int = 1,
427
+ ):
428
+ super().__init__()
429
+
430
+ hidden_dim = int(2 * hidden_dim / 3)
431
+ if ffn_dim_multiplier is not None:
432
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
433
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
434
+ assert hidden_dim % mp_size == 0
435
+
436
+ self.dim = dim
437
+ self.hidden_dim = hidden_dim
438
+
439
+ self.w1 = nn.Linear(
440
+ dim,
441
+ hidden_dim,
442
+ bias=False,
443
+ )
444
+ self.w3 = nn.Linear(
445
+ dim,
446
+ hidden_dim,
447
+ bias=False,
448
+ )
449
+ self.w2 = nn.Linear(
450
+ hidden_dim,
451
+ dim,
452
+ bias=False,
453
+ )
454
+
455
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
456
+ # B S D
457
+ x1 = self.w1(x.view_as(x))
458
+ x3 = self.w3(x.view_as(x))
459
+ output = self.w2(F.silu(x1) * x3)
460
+ return output
461
+
462
+ def reset_parameters(self, init_std=None, factor=1.0):
463
+ in_init_std = init_std or (self.dim ** (-0.5))
464
+ out_init_std = init_std or (self.hidden_dim ** (-0.5))
465
+ in_init_std = in_init_std
466
+ out_init_std = out_init_std / factor
467
+ for w in [self.w1, self.w3]:
468
+ nn.init.trunc_normal_(
469
+ w.weight,
470
+ mean=0.0,
471
+ std=in_init_std,
472
+ a=-3 * in_init_std,
473
+ b=3 * in_init_std,
474
+ )
475
+ nn.init.trunc_normal_(
476
+ self.w2.weight,
477
+ mean=0.0,
478
+ std=out_init_std,
479
+ a=-3 * out_init_std,
480
+ b=3 * out_init_std,
481
+ )
482
+
483
+
484
+ class TransformerBlock(nn.Module):
485
+ def __init__(self, args: BaseTransformerArgs):
486
+ super().__init__()
487
+
488
+ assert (args.head_dim is not None) or (
489
+ args.n_heads is not None
490
+ ), "Should specify at least head_dim or n_heads"
491
+ self.head_dim = args.head_dim or args.dim // args.n_heads
492
+ self.n_heads = args.n_heads or args.dim // args.head_dim
493
+ self.n_kv_heads = args.n_kv_heads or self.n_heads
494
+
495
+ assert args.n_heads % self.n_kv_heads == 0
496
+ assert args.dim % args.n_heads == 0
497
+
498
+ self.attention = Attention(
499
+ dim=args.dim,
500
+ head_dim=self.head_dim,
501
+ n_heads=self.n_heads,
502
+ n_kv_heads=self.n_kv_heads,
503
+ rope_theta=args.rope_theta,
504
+ )
505
+ self.feed_forward = FeedForward(
506
+ dim=args.dim,
507
+ hidden_dim=4 * args.dim,
508
+ multiple_of=args.multiple_of,
509
+ ffn_dim_multiplier=args.ffn_dim_multiplier,
510
+ )
511
+ self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
512
+ self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
513
+
514
+ def forward(
515
+ self,
516
+ x: torch.Tensor,
517
+ freq_cis: torch.Tensor,
518
+ tok_idx: Optional[torch.Tensor] = None,
519
+ mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
520
+ attn_impl: str = "sdpa",
521
+ ) -> torch.Tensor:
522
+ h = x + self.attention(
523
+ self.attention_norm(x),
524
+ freq_cis,
525
+ tok_idx=tok_idx,
526
+ mask=mask,
527
+ attn_impl=attn_impl,
528
+ )
529
+ out = h + self.feed_forward(self.ffn_norm(h))
530
+ return out
531
+
532
+ def init_weights(self, init_std=None, factor=1.0):
533
+ self.attention.reset_parameters(init_std, factor)
534
+ self.attention_norm.reset_parameters()
535
+
536
+ self.feed_forward.reset_parameters(init_std, factor)
537
+ self.ffn_norm.reset_parameters()
538
+
539
+
540
+ class BaseTransformer(nn.Module):
541
+ def __init__(self, args: BaseTransformerArgs):
542
+ super().__init__()
543
+ self.dim = args.dim
544
+ self.init_base_std = args.init_base_std
545
+ self.init_std_factor = InitStdFactor(args.init_std_factor)
546
+ self.max_seqlen = args.max_seqlen
547
+ self.rope_embeddings = RotaryEmbedding(
548
+ theta=args.rope_theta,
549
+ head_dim=args.head_dim or args.dim // args.n_heads,
550
+ max_seqlen=args.max_seqlen,
551
+ )
552
+
553
+ self.layers = nn.ModuleList()
554
+ for _ in range(args.n_layers):
555
+ self.layers.append(TransformerBlock(args))
556
+
557
+ def forward(
558
+ self,
559
+ h,
560
+ tok_idx: Optional[torch.Tensor] = None,
561
+ mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
562
+ attn_impl: str = "sdpa",
563
+ ):
564
+
565
+ freq_cis = self.rope_embeddings(seqlen=self.max_seqlen, tok_idx=tok_idx)
566
+
567
+ for i, layer in enumerate(self.layers):
568
+ h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)
569
+ return h
570
+
571
+ def reset_parameters(self):
572
+ # Either use fixed base std or sqrt model dim
573
+ self.rope_embeddings.reset_parameters()
574
+
575
+ def init_weights(self):
576
+ self.reset_parameters()
577
+ for depth, layer in enumerate(self.layers):
578
+ factor = {
579
+ InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
580
+ InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5,
581
+ InitStdFactor.DIM_RATIO: self.dim / 4096,
582
+ InitStdFactor.DISABLED: 1.0,
583
+ }[self.init_std_factor]
584
+
585
+ layer.init_weights(self.init_base_std, factor)
bytelatent/checkpoint.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import json
4
+ import logging
5
+ import os
6
+ import re
7
+ from pathlib import Path
8
+ from typing import List, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.distributed as dist
12
+ import torch.distributed.checkpoint as dcp
13
+ import torch.nn as nn
14
+ import torch.optim.optimizer
15
+ from pydantic import BaseModel, ConfigDict
16
+ from torch.distributed._tensor import DeviceMesh
17
+ from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
18
+ from torch.distributed.checkpoint.state_dict import (
19
+ get_model_state_dict,
20
+ get_state_dict,
21
+ set_state_dict,
22
+ )
23
+
24
+ from bytelatent.distributed import get_is_master
25
+
26
+ logger = logging.getLogger("CHECKPOINT")
27
+
28
+ FOLDER_NAME = "{:010d}"
29
+ RE_FOLDER = r"\d{10}"
30
+
31
+ RE_CKPT = r"__\d_\d\.distcp"
32
+
33
+ CONSOLIDATE_FOLDER = "consolidated"
34
+ CONSOLIDATE_NAME = "consolidated.pth"
35
+
36
+ CONFIG_NAME = "params.json"
37
+ TRAIN_STATE_NAME = "train_state_{:05d}.json"
38
+ RE_DIGITS = re.compile(r"\d+")
39
+
40
+
41
+ class SaveEvery(BaseModel):
42
+ model_config = ConfigDict(extra="forbid")
43
+ every: int = 1000
44
+ keep: int = 0
45
+
46
+
47
+ class CheckpointArgs(BaseModel):
48
+ model_config = ConfigDict(extra="forbid")
49
+ dump: SaveEvery = SaveEvery()
50
+ eval: SaveEvery = SaveEvery()
51
+ path: str | None = None
52
+ init_ckpt_path: str | None = None
53
+ continue_training_from_init: bool = False
54
+
55
+
56
+ def _get_key_step(name: str):
57
+ return int(re.findall(RE_DIGITS, name)[-1])
58
+
59
+
60
+ def consolidate_checkpoints(ckpt_dir: str):
61
+ """
62
+ Consolidates all FSDP checkpoints in a directory to a single file
63
+ Consolidate checkpoint is saved in a subdirectory of ckpt_dir
64
+
65
+ Parameters:
66
+ ckpt_dir: str - path to the directory containing the checkpoints
67
+
68
+ Returns the path to the consolidated checkpoint
69
+ """
70
+ consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER
71
+ if not (consolidate_path / CONSOLIDATE_NAME).exists():
72
+ consolidate_path.mkdir(exist_ok=True)
73
+ logger.info(f"Consolidating to: {str(consolidate_path)}")
74
+ dcp_to_torch_save(ckpt_dir, str(consolidate_path / CONSOLIDATE_NAME))
75
+ (consolidate_path / CONFIG_NAME).write_text(
76
+ (Path(ckpt_dir) / CONFIG_NAME).read_text()
77
+ )
78
+ logger.info("Consolidated !")
79
+ return consolidate_path
80
+
81
+
82
+ def load_from_checkpoint(
83
+ ckpt_dir: str,
84
+ model: nn.Module,
85
+ optimizer: Optional[torch.optim.Optimizer] = None,
86
+ model_key: str = "model",
87
+ optim_key: str = "optim",
88
+ ):
89
+ if not (Path(ckpt_dir) / ".metadata").exists():
90
+ raise ValueError(
91
+ f"Please convert the checkpoint distcp format using `torch.distributed.checkpoint.format_utils.torch_save_to_dcp` before loading it"
92
+ )
93
+
94
+ state_dict = {}
95
+ if optimizer is not None:
96
+ state_dict[model_key], state_dict[optim_key] = get_state_dict(model, optimizer)
97
+ else:
98
+ state_dict[model_key] = get_model_state_dict(model)
99
+ if model_key == "": # If only loading a model directly, the key should be empty
100
+ state_dict = state_dict.pop(model_key)
101
+
102
+ dcp.load(state_dict, checkpoint_id=ckpt_dir)
103
+
104
+
105
+ class CheckpointManager:
106
+ def __init__(self, args: CheckpointArgs):
107
+ self.path = args.path
108
+ self.dump_every = args.dump
109
+ self.eval_every = args.eval
110
+ self.init_ckpt_path = args.init_ckpt_path
111
+ self.continue_training_from_init = args.continue_training_from_init
112
+
113
+ assert os.path.exists(
114
+ self.path
115
+ ), f"Path {self.path} does not exist and needs to be created before using CheckpointManager (use instantiate_and_make_dir)"
116
+
117
+ self.existing_saves = self.get_existing_saves()
118
+
119
+ def get_existing_saves(self) -> List[Path]:
120
+ folders = [
121
+ p
122
+ for p in Path(self.path).iterdir()
123
+ if p.is_dir() and re.match(RE_FOLDER, p.name)
124
+ ]
125
+ folders.sort(key=lambda p: _get_key_step(p.name))
126
+ return folders
127
+
128
+ def clean_up(self):
129
+ logger.info("Cleaning up checkpoints...")
130
+ dump_folders = []
131
+ eval_folders = []
132
+ other_folders = []
133
+ for p in self.existing_saves:
134
+ is_dump = _get_key_step(p.name) % self.dump_every.every == 0
135
+ is_eval = _get_key_step(p.name) % self.eval_every.every == 0
136
+ if is_dump:
137
+ dump_folders.append(p)
138
+ if is_eval:
139
+ eval_folders.append(p)
140
+ if not (is_dump or is_eval):
141
+ other_folders.append(p)
142
+
143
+ logger.info(f"Dump folders: {dump_folders}")
144
+ logger.info(f"Eval folders: {eval_folders}")
145
+ logger.info(f"Other folders: {other_folders}")
146
+
147
+ if self.dump_every.keep > 0:
148
+ dump_folders = dump_folders[-self.dump_every.keep :]
149
+ if self.eval_every.keep > 0:
150
+ eval_folders = eval_folders[-self.eval_every.keep :]
151
+
152
+ folder_to_keep = set(other_folders + dump_folders + eval_folders)
153
+ folder_to_remove = set(self.existing_saves) - folder_to_keep
154
+
155
+ logger.info(f"Removing folders: {folder_to_remove}")
156
+
157
+ if dist.get_rank() == 0:
158
+ for folder in folder_to_remove:
159
+ for file in folder.iterdir():
160
+ if file.is_file():
161
+ file.unlink()
162
+ elif file.is_dir():
163
+ assert file.name in [CONSOLIDATE_FOLDER]
164
+ for f in file.iterdir():
165
+ f.unlink()
166
+ file.rmdir()
167
+ folder.rmdir()
168
+
169
+ dist.barrier()
170
+
171
+ self.existing_saves = list(folder_to_keep)
172
+ self.existing_saves.sort(key=lambda p: _get_key_step(p.name))
173
+
174
+ def get_last_step_path(self, dp_rank: int = 0) -> Optional[Path]:
175
+ path = None
176
+ for p in reversed(self.existing_saves):
177
+ if (p / TRAIN_STATE_NAME.format(dp_rank)).is_file():
178
+ path = p
179
+ break
180
+ return path
181
+
182
+ def _create_folder(self, base_path: Path, folder_name: str) -> Path:
183
+ folder = base_path / folder_name
184
+ if get_is_master():
185
+ folder.mkdir(parents=False, exist_ok=True)
186
+ if dist.is_initialized():
187
+ dist.barrier()
188
+ return folder
189
+
190
+ def _get_dp_tp_mesh(
191
+ self, device_mesh: Optional[DeviceMesh] = None
192
+ ) -> Tuple[int, int]:
193
+ dp_rank = 0
194
+ tp_rank = 0
195
+ if device_mesh is not None:
196
+ if "dp_replicate" in device_mesh.mesh_dim_names:
197
+ dp_rank = device_mesh.get_local_rank("dp_replicate")
198
+ if "dp_shard" in device_mesh.mesh_dim_names:
199
+ dp_rank = dp_rank * device_mesh[
200
+ "dp_replicate"
201
+ ].size() + device_mesh.get_local_rank("dp_shard")
202
+ if "tp" in device_mesh.mesh_dim_names:
203
+ tp_rank = device_mesh.get_local_rank("tp")
204
+ return dp_rank, tp_rank
205
+
206
+ @torch.no_grad()
207
+ def get_state_dict(
208
+ self,
209
+ model,
210
+ optimizer,
211
+ ):
212
+ model_sd, optim_sd = get_state_dict(model, optimizer)
213
+ return {"model": model_sd, "optim": optim_sd}
214
+
215
+ def save(
216
+ self,
217
+ model,
218
+ optimizer,
219
+ train_state,
220
+ config,
221
+ device_mesh: Optional[DeviceMesh] = None,
222
+ ) -> bool:
223
+
224
+ # When creating directory check if only rank0 or is there other solution
225
+ path = Path(self.path)
226
+ curr_save_dir = self._create_folder(path, FOLDER_NAME.format(train_state.step))
227
+ logger.info(f"Saving to: {str(curr_save_dir)}")
228
+
229
+ if dist.is_initialized():
230
+ dist.barrier()
231
+
232
+ logger.info("Saving...")
233
+ state_dict = self.get_state_dict(model, optimizer)
234
+ dcp.save(state_dict, checkpoint_id=curr_save_dir)
235
+ logger.info("State dict saved!")
236
+
237
+ if dist.is_initialized():
238
+ dist.barrier()
239
+
240
+ if get_is_master():
241
+ config.dump_to_yaml_file(curr_save_dir / CONFIG_NAME)
242
+
243
+ # Add json dump here
244
+ dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh)
245
+ if tp_rank == 0:
246
+ train_state_name = TRAIN_STATE_NAME.format(dp_rank)
247
+ logger.info(
248
+ f"Saving train state to: {str(curr_save_dir / train_state_name)}"
249
+ )
250
+ with open(curr_save_dir / train_state_name, "w") as f:
251
+ json.dump(train_state.state_dict(), f)
252
+ logger.info("Train state saved !")
253
+
254
+ self.existing_saves.append(curr_save_dir)
255
+
256
+ self.clean_up()
257
+
258
+ if dist.is_initialized():
259
+ dist.barrier()
260
+ return True
261
+
262
+ @torch.no_grad()
263
+ def load(
264
+ self,
265
+ model: nn.Module,
266
+ optimizer,
267
+ train_state,
268
+ device_mesh: DeviceMesh,
269
+ path: Optional[Path] = None,
270
+ ):
271
+ dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh)
272
+ # Loading tries to load the provided path, if not available the last saved step and finally from the init path
273
+ path = path or self.get_last_step_path(dp_rank=dp_rank)
274
+ # If none of those are available don't do anything
275
+ if path is None:
276
+ # If no checkpoints exist do nothing
277
+ return
278
+
279
+ # Only load train state if it's provided, the files exist and we're not loading from init path
280
+ train_state_name = TRAIN_STATE_NAME.format(dp_rank)
281
+ logger.info("Reloading train state")
282
+ with open(path / train_state_name, "r") as f:
283
+ train_state_dict = json.load(f)
284
+ train_state.load_state_dict(train_state_dict)
285
+ logger.info("Train state reloaded")
286
+
287
+ logger.info(f"Loading from: {str(path)}")
288
+ state_dict = self.get_state_dict(
289
+ model=model,
290
+ optimizer=optimizer,
291
+ )
292
+ dcp.load(state_dict, checkpoint_id=path)
293
+ logger.info("State dict loaded.")
294
+
295
+ logger.info("Reloading model and optim")
296
+
297
+ set_state_dict(
298
+ model,
299
+ optimizer,
300
+ model_state_dict=state_dict["model"],
301
+ optim_state_dict=state_dict["optim"],
302
+ )
303
+ logger.info("Model and optim reloaded")
304
+
305
+ @classmethod
306
+ def instantiate_and_make_dir(cls, args: CheckpointArgs):
307
+ if get_is_master():
308
+ os.makedirs(args.path, exist_ok=True)
309
+ dist.barrier()
310
+
311
+ return cls(args)
bytelatent/configs/debug.yaml ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Template config, need to change dump_dir, data.root_dir and tokenizer.path
2
+ # Evals can be activated by uncommenting its config
3
+ # python -m launchers.stool config=apps/main/configs/debug.yaml nodes=8 account=fair_amaia_cw_codegen qos=lowest
4
+
5
+ dump_dir: /tmp/
6
+ name: "debug"
7
+ steps: 100_000
8
+ probe_freq: null
9
+ seed: 777
10
+ optim:
11
+ lr: 4e-04
12
+ warmup: 500
13
+ lr_min_ratio: 0.1
14
+ clip: 10.0
15
+
16
+ distributed:
17
+ fsdp_type: full_shard
18
+ compile: true
19
+ model_dtype: bf16
20
+ matmul_allow_tf32: false
21
+ selective_activation_checkpointing: false
22
+ tp_size: 1
23
+
24
+ model:
25
+ n_heads: 8
26
+ dim: 512
27
+ vocab_size: 260
28
+ dim_token: 256
29
+ patch_size: 6
30
+ tokenization_mode: "bytes"
31
+ patching_mode: "space"
32
+ tie_local_encoder_decoder_logits: false
33
+ data_loader_patching: true
34
+ max_encoder_seq_length: 12288
35
+ pad_to_max_length: true
36
+ patching_threshold: 3.1439168453216553
37
+ encoder_hash_byte_group_size: [4]
38
+ encoder_hash_byte_group_vocab: 50002
39
+ encoder_hash_byte_group_nb_functions: 3
40
+ encoder_enable_byte_ngrams: false
41
+ cross_attn_encoder: true # assuming cross_attention is true
42
+ cross_attn_decoder: true # assuming cross_attention is true
43
+ cross_attn_window_encoder: 512
44
+ cross_attn_window_decoder: 512
45
+ dim_local_encoder: 256
46
+ dim_local_decoder: 256
47
+ cross_attn_k: 8
48
+ cross_attn_nheads: 4
49
+ cross_attn_all_layers_decoder: true
50
+ cross_attn_all_layers_encoder: true
51
+ cross_attn_use_flex_attention: true
52
+ cross_attn_init_by_pooling: true
53
+ log_patch_lengths: true
54
+ non_linearity: "swiglu"
55
+ use_rope: true
56
+ recompute_fc1_out: false
57
+ recompute_fc3_out: false
58
+ recompute_attn: false
59
+ custom_bwd: false
60
+ layer_ckpt: "none"
61
+ efficient_attn: "sdpa"
62
+ patch_only_encoder: false
63
+ patch_only_decoder: false
64
+ use_local_encoder_transformer: true
65
+ init_use_gaussian: true
66
+ init_use_depth: "current"
67
+ attn_bias_type: "block_causal"
68
+ alpha_depth: "disabled"
69
+ max_length: 256
70
+ local_attention_window_len: 512
71
+ max_seqlen: 12288
72
+ downsampling_by_pooling: "max"
73
+
74
+ data:
75
+ root_dir: ???
76
+ sources:
77
+ dclm_baseline_1.0: 1.0
78
+ batch_size: 2
79
+ prefetch_size: 64
80
+ seq_len: 4096
81
+ load_async: true
82
+ preprocess_dir: ???
83
+ tokenizer_args:
84
+ name: blt
85
+ init_kwargs:
86
+ bpe_tokenizer_path: ???
87
+
88
+ profiling:
89
+ run: false
90
+
91
+ checkpoint:
92
+ dump:
93
+ every: 500
94
+ keep: 3
95
+ eval:
96
+ every: 1000
97
+ keep: -1
98
+
99
+ logging:
100
+ freq: 10
101
+
102
+ eval_on_gpus: 8
103
+ eval:
104
+ dataset_dir: /checkpoint/amaia/codegen/datasets/eval
105
+ tasks: boolq,hellaswag,nq,piqa,siqa,tqa,winogrande,obqa,arc_easy,arc_challenge,race.middle,race.high,gsm8k,math,bbh,copa,human_eval_plus,mbpp,mmlu
106
+ generator:
107
+ max_tokens: 65536
108
+ dtype: bf16
109
+
110
+ mp_size: 1
bytelatent/constants.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ import os
3
+ from pathlib import Path
4
+
5
+ BLT_DATA = Path(os.environ.get("BLT_DATA", "data"))
bytelatent/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
bytelatent/data/data_types.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ import json
3
+ from dataclasses import dataclass
4
+ from typing import Any, Iterator
5
+
6
+ import numpy as np
7
+ from pydantic import BaseModel, ConfigDict
8
+
9
+
10
+ class BltExample(BaseModel):
11
+ model_config = ConfigDict(extra="forbid")
12
+ sample_id: str
13
+ text: str
14
+ tokens: list[int] | None
15
+ entropies: list[float] | None
16
+ patch_lengths: list[int] | None
17
+ mask: list[bool] | None
18
+
19
+
20
+ class MultiChoiceState(BaseModel):
21
+ model_config = ConfigDict(extra="forbid")
22
+ root_dir: str
23
+ sources: dict[str, float]
24
+ source_to_state: dict[str, Any]
25
+ rng_state: dict[str, Any]
26
+
27
+
28
+ class PrefetchState(BaseModel):
29
+ model_config = ConfigDict(extra="forbid")
30
+ seq_idx: int
31
+ rng_state: dict[str, Any]
32
+ prefetch_size: int
33
+ batch_size: int
34
+
35
+
36
+ class BltPackTokensState(BaseModel):
37
+ model_config = ConfigDict(extra="forbid")
38
+ start_token: int
39
+ output_seq_len: int
40
+ n_views: int = 2
41
+
42
+
43
+ class DataLoaderState(BaseModel):
44
+ model_config = ConfigDict(extra="forbid")
45
+ multi_choice_state: MultiChoiceState
46
+ pack_tokens_state: BltPackTokensState
47
+ prefetch_state: PrefetchState
48
+
49
+
50
+ BltIterator = Iterator[tuple[BltExample, DataLoaderState]]
51
+
52
+
53
+ class BltSequence(BaseModel):
54
+ tokens: list[int]
55
+ mask: list[bool]
56
+ patch_lengths: list[int]
57
+
58
+
59
+ @dataclass
60
+ class Batch:
61
+ x: np.ndarray
62
+ y: np.ndarray
63
+ mask: np.ndarray | None = None
64
+ patch_lengths: np.ndarray | None = None
65
+ ngram_ids: np.ndarray | None = None
66
+ is_final: bool = False
67
+
68
+ def to_python_dict(self) -> dict:
69
+ x = self.x.tolist()
70
+ y = self.y.tolist()
71
+ if self.mask is None:
72
+ mask = None
73
+ else:
74
+ mask = self.mask.tolist()
75
+ if self.patch_lengths is None:
76
+ patch_lengths = None
77
+ else:
78
+ patch_lengths = self.patch_lengths.tolist()
79
+ if self.ngram_ids is None:
80
+ ngram_ids = None
81
+ else:
82
+ ngram_ids = self.ngram_ids.tolist()
83
+ return {
84
+ "x": x,
85
+ "y": y,
86
+ "mask": mask,
87
+ "patch_lengths": patch_lengths,
88
+ "ngram_ids": ngram_ids,
89
+ "is_final": self.is_final,
90
+ }
91
+
92
+ @classmethod
93
+ def from_python_dict(cls, data: dict) -> "Batch":
94
+ x = np.array(data["x"])
95
+ y = np.array(data["y"])
96
+ if data["mask"] is None:
97
+ mask = None
98
+ else:
99
+ mask = np.array(data["mask"])
100
+ if data["patch_lengths"] is None:
101
+ patch_lengths = None
102
+ else:
103
+ patch_lengths = np.array(data["patch_lengths"])
104
+ if data["ngram_ids"] is None:
105
+ ngram_ids = None
106
+ else:
107
+ ngram_ids = np.array(data["ngram_ids"])
108
+ return Batch(
109
+ x=x,
110
+ y=y,
111
+ mask=mask,
112
+ patch_lengths=patch_lengths,
113
+ ngram_ids=ngram_ids,
114
+ is_final=data["is_final"],
115
+ )
bytelatent/data/iterators/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
bytelatent/data/iterators/abstract_iterator.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ import abc
3
+ from typing import Any, Generator, Generic, TypeVar
4
+
5
+ T = TypeVar("T")
6
+ C = TypeVar("C")
7
+
8
+
9
+ class StatefulIterator(Generic[T, C], abc.ABC):
10
+
11
+ @abc.abstractmethod
12
+ def get_state(self) -> C:
13
+ pass
14
+
15
+ @abc.abstractmethod
16
+ def create_iter(self) -> Generator[T, Any, None]:
17
+ pass
18
+
19
+
20
+ class IteratorState(Generic[C]):
21
+ @abc.abstractmethod
22
+ def build(self) -> StatefulIterator[T, C]:
23
+ pass
bytelatent/data/iterators/arrow_iterator.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ import re
3
+ from logging import getLogger
4
+ from pathlib import Path
5
+ from typing import Any, Generator
6
+
7
+ import pyarrow as pa
8
+
9
+ # pyarrow needs the initialization from this import
10
+ import pyarrow.dataset # pyright: ignore
11
+ from pydantic import BaseModel, ConfigDict
12
+
13
+ from bytelatent import ByteLatentError
14
+ from bytelatent.data.data_types import BltExample
15
+ from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
16
+
17
+ logger = getLogger(__name__)
18
+
19
+
20
+ class ArrowFileIteratorState(BaseModel, IteratorState):
21
+ model_config = ConfigDict(extra="forbid")
22
+ file_path: str | None
23
+ row_num: int
24
+ num_workers: int
25
+ worker_id: int
26
+ preprocess_dir: str | None
27
+ dataset_files: list[str] | None
28
+ entropy_model_name: str | None
29
+ arrow_batch_size: int = 100
30
+
31
+ def build(self) -> "ArrowFileIterator":
32
+ arrow_file = ArrowFileIterator(
33
+ file_path=self.file_path,
34
+ worker_id=self.worker_id,
35
+ num_workers=self.num_workers,
36
+ preprocess_dir=self.preprocess_dir,
37
+ entropy_model_name=self.entropy_model_name,
38
+ arrow_batch_size=self.arrow_batch_size,
39
+ dataset_files=self.dataset_files,
40
+ )
41
+ if self.row_num != 0:
42
+ arrow_file._set_row_num(self.row_num)
43
+ return arrow_file
44
+
45
+
46
+ def shard_sort_key(file: str | Path):
47
+ match = re.search(r".+\.shard_([0-9]+)\.arrow", str(file))
48
+ shard_number = int(match.group(1))
49
+ return shard_number
50
+
51
+
52
+ class ArrowFileIterator(StatefulIterator):
53
+ def __init__(
54
+ self,
55
+ *,
56
+ file_path: str | None,
57
+ worker_id: int,
58
+ num_workers: int,
59
+ preprocess_dir: str | None,
60
+ entropy_model_name: str | None,
61
+ arrow_batch_size: int,
62
+ dataset_files: list[str] | None = None,
63
+ ):
64
+ assert 0 <= worker_id < num_workers, (worker_id, num_workers)
65
+ if file_path is None and dataset_files is None:
66
+ raise ByteLatentError("file_path and dataset_files cannot both be None")
67
+ self.row_num = 0
68
+ self.iter_id = 0
69
+ self.batch_iterator = None
70
+ self.batch_to_consume = None
71
+ self.dataset = None
72
+ self.file_path = file_path
73
+ self.worker_id = worker_id
74
+ self.num_workers = num_workers
75
+ self.preprocess_dir = preprocess_dir
76
+ self.entropy_model_name = entropy_model_name
77
+ self.arrow_batch_size = arrow_batch_size
78
+ if dataset_files is None:
79
+ # Prepare arrow shards
80
+ jsonl_file = Path(file_path)
81
+ parts = re.match(r"(.+)\.chunk\.[0-9]+\.jsonl", jsonl_file.name)
82
+ assert parts is not None
83
+ dataset = parts.group(1)
84
+ data_dir = Path(preprocess_dir) / dataset / entropy_model_name
85
+ shard_files = list(data_dir.glob(f"{jsonl_file.name}.shard_*.arrow"))
86
+ for s in shard_files:
87
+ if not (data_dir / f"{s.name}.complete").exists():
88
+ raise ValueError(f"Missing .complete for input file: {s}")
89
+
90
+ shard_files = sorted(shard_files, key=shard_sort_key)
91
+ if len(shard_files) == 0:
92
+ raise ByteLatentError(
93
+ f"Zero shard_files found corresponding to: {file_path} using preprocess_dir={preprocess_dir} and entropy_model_name={entropy_model_name}, so the search path is data_dir={data_dir} for matches to {jsonl_file.name}.shard_*.arrow"
94
+ )
95
+ self.dataset_files = [str(f) for f in shard_files]
96
+ else:
97
+ self.preprocess_dir = None
98
+ self.dataset_files = dataset_files
99
+
100
+ def get_state(self) -> ArrowFileIteratorState:
101
+ return ArrowFileIteratorState(
102
+ file_path=self.file_path,
103
+ row_num=self.row_num,
104
+ worker_id=self.worker_id,
105
+ num_workers=self.num_workers,
106
+ preprocess_dir=self.preprocess_dir,
107
+ entropy_model_name=self.entropy_model_name,
108
+ arrow_batch_size=self.arrow_batch_size,
109
+ dataset_files=self.dataset_files,
110
+ )
111
+
112
+ def create_iter(
113
+ self,
114
+ ) -> Generator[BltExample, Any, None]:
115
+ if self.dataset is None:
116
+ self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow")
117
+ self.batch_iterator = self.dataset.to_batches(
118
+ batch_size=self.arrow_batch_size
119
+ )
120
+ self.iter_id += 1
121
+ if self.batch_to_consume is not None:
122
+ batch_columns: dict[str, list] = self.batch_to_consume
123
+ self.batch_to_consume = None
124
+ sample_ids = batch_columns["sample_id"]
125
+ texts = batch_columns["text"]
126
+ entropies = batch_columns["entropies"]
127
+ for i in range(len(sample_ids)):
128
+ out = BltExample(
129
+ sample_id=sample_ids[i],
130
+ entropies=entropies[i],
131
+ text=texts[i],
132
+ tokens=None,
133
+ mask=None,
134
+ patch_lengths=None,
135
+ )
136
+ self.row_num += 1
137
+ if (self.row_num - 1) % self.num_workers == self.worker_id:
138
+ yield out
139
+
140
+ for batch in self.batch_iterator:
141
+ batch_columns = batch.to_pydict()
142
+ sample_ids = batch_columns["sample_id"]
143
+ texts = batch_columns["text"]
144
+ entropies = batch_columns["entropies"]
145
+ for i in range(len(sample_ids)):
146
+ out = BltExample(
147
+ sample_id=sample_ids[i],
148
+ entropies=entropies[i],
149
+ text=texts[i],
150
+ tokens=None,
151
+ mask=None,
152
+ patch_lengths=None,
153
+ )
154
+ self.row_num += 1
155
+ if (self.row_num - 1) % self.num_workers == self.worker_id:
156
+ yield out
157
+
158
+ def _set_row_num(self, target_row_num: int):
159
+ logger.info(
160
+ f"Setting arrow position to {target_row_num} for {self.dataset_files}"
161
+ )
162
+ if target_row_num is None or target_row_num == 0:
163
+ self.row_num = 0
164
+ self.dataset = None
165
+ self.batch_iterator = None
166
+ self.batch_to_consume = None
167
+ else:
168
+ self.dataset = pa.dataset.dataset(self.dataset_files, format="arrow")
169
+ self.batch_iterator = self.dataset.to_batches(
170
+ batch_size=self.arrow_batch_size
171
+ )
172
+ curr_remaining = target_row_num
173
+ for batch in self.batch_iterator:
174
+ if len(batch) > curr_remaining:
175
+ batch_columns: dict[str, list] = batch.to_pydict()
176
+ batch_columns["sample_id"] = batch_columns["sample_id"][
177
+ curr_remaining:
178
+ ]
179
+ batch_columns["entropies"] = batch_columns["entropies"][
180
+ curr_remaining:
181
+ ]
182
+ batch_columns["text"] = batch_columns["text"][curr_remaining:]
183
+ self.batch_to_consume = batch_columns
184
+ break
185
+ elif len(batch) == curr_remaining:
186
+ # We are exactly at the end of the batch,
187
+ # so the next batch is the right spot
188
+ break
189
+ else:
190
+ curr_remaining -= len(batch)
191
+ self.row_num = target_row_num
192
+ logger.info(
193
+ f"Finished setting arrow position to {target_row_num} for {self.dataset_files}"
194
+ )
195
+
196
+
197
+ TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl"
198
+
199
+
200
+ def find_and_sanitize_chunks(
201
+ dataset_path: str, world_size: int, file_pattern: str = TRAIN_DATA_FILE_PATTERN
202
+ ):
203
+ dataset_chunks = [str(p) for p in Path(dataset_path).glob(file_pattern)]
204
+ n_chunks = len(dataset_chunks)
205
+
206
+ if n_chunks > world_size:
207
+ n_discard = n_chunks - world_size
208
+ dataset_chunks = dataset_chunks[:world_size]
209
+ else:
210
+ assert (
211
+ world_size % n_chunks == 0
212
+ ), "World size should be a multiple of number of chunks"
213
+
214
+ assert n_chunks > 0, f"No valid chunks in {dataset_path}"
215
+
216
+ return dataset_chunks
bytelatent/data/iterators/looping_iterator.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ from pydantic import BaseModel
3
+
4
+ from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
5
+ from bytelatent.data.iterators.arrow_iterator import (
6
+ ArrowFileIterator,
7
+ ArrowFileIteratorState,
8
+ )
9
+
10
+
11
+ class LoopingIteratorState(BaseModel, IteratorState):
12
+ file_iterator_state: ArrowFileIteratorState
13
+ epoch: int
14
+
15
+ def build(self) -> "LoopingIterator":
16
+ return LoopingIterator(
17
+ file_iterator=self.file_iterator_state.build(),
18
+ epoch=self.epoch,
19
+ )
20
+
21
+
22
+ class LoopingIterator(StatefulIterator):
23
+ def __init__(self, file_iterator: ArrowFileIterator, epoch: int = -1):
24
+ self.file_iterator = file_iterator
25
+ self.epoch = epoch
26
+
27
+ def get_state(self):
28
+ return LoopingIteratorState(
29
+ file_iterator_state=self.file_iterator.get_state(), epoch=self.epoch
30
+ )
31
+
32
+ def create_iter(self):
33
+ while True:
34
+ self.epoch += 1
35
+ iterator = self.file_iterator.create_iter()
36
+ yield from iterator
bytelatent/data/iterators/multiprocess_iterator.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ import json
3
+ import logging
4
+ import multiprocessing as mp
5
+ from multiprocessing.synchronize import Event as EventClass
6
+ from queue import Empty, Full
7
+
8
+ import numpy as np
9
+ from pydantic import BaseModel, ConfigDict
10
+
11
+ from bytelatent.data.data_types import Batch
12
+ from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
13
+ from bytelatent.data.iterators.packing_iterator import PackingIteratorState
14
+
15
+ logger = logging.getLogger()
16
+
17
+
18
+ class MultiprocessIteratorState(BaseModel, IteratorState):
19
+ model_config = ConfigDict(extra="forbid")
20
+ base_iterator_state: PackingIteratorState
21
+ n_batches_to_prefetch: int
22
+ serialized_prefetch_buffer: str
23
+
24
+ def build(self):
25
+ base_iterator = self.base_iterator_state.build()
26
+ data = json.loads(self.serialized_prefetch_buffer)
27
+ prefetch_buffer = [Batch.from_python_dict(item) for item in data]
28
+ return MultiprocessIterator(
29
+ base_iterator,
30
+ n_batches_to_prefetch=self.n_batches_to_prefetch,
31
+ prefetch_buffer=prefetch_buffer,
32
+ )
33
+
34
+
35
+ def start_work_from_state(
36
+ batch_queue: mp.Queue,
37
+ state_queue: mp.Queue,
38
+ stop_event: EventClass,
39
+ state_dumped_event: EventClass,
40
+ state: IteratorState,
41
+ ):
42
+ logging.info("Worker thread: Starting base_iterator work")
43
+ stateful_iterator = state.build()
44
+ iterator = stateful_iterator.create_iter()
45
+ for item in iterator:
46
+ while not stop_event.is_set():
47
+ try:
48
+ # Attempt to put on queue or timeout to try again (maybe main thread is busy)
49
+ batch_queue.put(item, timeout=0.1)
50
+ # On success, stop trying
51
+ break
52
+ except Full:
53
+ pass
54
+ if stop_event.is_set():
55
+ # Signal the end of output, this ensures that even if the queue takes a while to
56
+ # buffer, that the main thread receives everything (and tosses this fake batch)
57
+ logging.info(
58
+ "Worker thread: Stop event detected, outputting is_final=True batch"
59
+ )
60
+ batch_queue.put(
61
+ Batch(
62
+ x=np.zeros((1, 1)),
63
+ y=np.zeros((1, 1)),
64
+ is_final=True,
65
+ mask=None,
66
+ patch_lengths=None,
67
+ ngram_ids=None,
68
+ )
69
+ )
70
+ break
71
+
72
+ try:
73
+ logging.info("Worker thread: outputting state")
74
+ state_queue.put(iterator.get_state(), timeout=1)
75
+ logging.info("Worker thread: state dump complete")
76
+ state_dumped_event.set()
77
+ logging.info("Worker thread: set state_dump_event")
78
+ except Full:
79
+ raise ValueError(
80
+ "Attempted to dump state into the state queue, but it was full"
81
+ )
82
+
83
+
84
+ class MultiprocessIterator(StatefulIterator):
85
+ """
86
+ Design sketch of the multiprocess iterator:
87
+
88
+ Given the base_iterator, the only thing we do with this is call get_state()
89
+ so that we can pass that through to the background worker process.
90
+
91
+ The background process will receive this, rebuild the iterator, then start yielding from it.
92
+
93
+ However, in order to implement MultiprocessIterator.get_state(), we need to be able to accurately get
94
+ (1) the state of the iterator in the worker process
95
+ (2) the currently buffered items in the Queue
96
+
97
+ To do this, we use:
98
+ - batch_queue: This is the prefetch buffer the worker yields to and the main loop yields from
99
+ - state_queue: This size 1 queue will be how the worker sends the iterator state once it has halted iterating.
100
+ It must hold the state in addition to the last batch, if the queue was full at the time the stop event is sent.
101
+ - stop_iterating_event: Once this is issued from the main loop, the worker will stop iterating and enter cleanup.
102
+ During cleanup, the iterator will send the state of the current iterator to the main loop,
103
+ in addition to possibly the last batch if the batch_queue was full at the time
104
+ - state_dumped_event: When the main loop issues the stop_iterating_event, it will wait until the state_dumped_event to attempt
105
+ to get state from the state_queue. It must do this since the worker may take some time to create and send the state.
106
+ Once received by the main loop, the main loop can safely store the Queue (plus maybe the last batch) as the prefetch buffer,
107
+ get the worker iterator's state, and terminate the background process + delete associated objects.
108
+
109
+ At this point, calling create_iter() again will bootstrap everything from the stored state and the old iterator will throw an error
110
+ since it will not iterate anymore (so the caller must call create_iter() again to get a python iterator).
111
+
112
+ """
113
+
114
+ def __init__(
115
+ self,
116
+ base_iterator: StatefulIterator,
117
+ *,
118
+ n_batches_to_prefetch: int,
119
+ prefetch_buffer: list | None = None
120
+ ):
121
+ self.base_iterator = base_iterator
122
+ self.n_batches_to_prefetch = n_batches_to_prefetch
123
+ if prefetch_buffer is None:
124
+ prefetch_buffer = []
125
+ self.prefetch_buffer = prefetch_buffer
126
+ self.batch_queue = None
127
+ self.state_queue = None
128
+ self.producer = None
129
+ self.stop_iterating_event = None
130
+ self.state_dumped_event = None
131
+
132
+ def get_state(self) -> MultiprocessIteratorState:
133
+ """
134
+ This is slightly unusual in effectively destroying the current iterator, its necessary
135
+ to halt the background process and allow it to write the state to the main loop
136
+ in order to not lose data
137
+ """
138
+ if self.producer is None:
139
+ serialized_prefetch_buffer = json.dumps(
140
+ [b.to_python_dict() for b in self.prefetch_buffer]
141
+ )
142
+ return MultiprocessIteratorState(
143
+ base_iterator_state=self.base_iterator.get_state(),
144
+ n_batches_to_prefetch=self.n_batches_to_prefetch,
145
+ serialized_prefetch_buffer=serialized_prefetch_buffer,
146
+ )
147
+ else:
148
+ logging.info("Main thread: Sending stop iteration event")
149
+ self.stop_iterating_event.set()
150
+ logging.info("Main thread: Waiting for state_dumped event")
151
+ self.state_dumped_event.wait()
152
+ self.prefetch_buffer = []
153
+ final_batch_received = False
154
+ while True:
155
+ try:
156
+ batch = self.batch_queue.get(timeout=1)
157
+ if batch.is_final:
158
+ final_batch_received = True
159
+ break
160
+ self.prefetch_buffer.append(batch)
161
+ except Empty:
162
+ logging.warning("Main thread: batch_queue is abnormally empty")
163
+ assert final_batch_received
164
+
165
+ try:
166
+ base_iterator_state = self.state_queue.get(timeout=1)
167
+ assert isinstance(base_iterator_state, IteratorState)
168
+ except Empty:
169
+ raise ValueError(
170
+ "Attempted to get the state, but it was unexpectantly missing"
171
+ )
172
+
173
+ self.base_iterator = base_iterator_state.build()
174
+ self.producer.close()
175
+ self.producer = None
176
+ self.batch_queue = None
177
+ self.state_queue = None
178
+ self.stop_iterating_event = None
179
+ self.state_dumped_event = None
180
+
181
+ return MultiprocessIteratorState(
182
+ base_iterator_state=self.base_iterator.get_state(),
183
+ n_batches_to_prefetch=self.n_batches_to_prefetch,
184
+ serialized_prefetch_buffer=json.dumps(
185
+ [b.to_python_dict() for b in self.prefetch_buffer]
186
+ ),
187
+ )
188
+
189
+ def create_iter(self):
190
+ logging.info("Main thread: Creating MP iterator")
191
+ # First yield from the stored prefetch buffer.
192
+ if self.prefetch_buffer is not None:
193
+ while len(self.prefetch_buffer) > 0:
194
+ item = self.prefetch_buffer.pop(0)
195
+ yield item
196
+ self.prefetch_buffer = None
197
+
198
+ assert (
199
+ self.producer is None
200
+ ), "Cannot create two parallel iterators at once, call get_state() then remake to have two."
201
+
202
+ # using mp context manager avoids excessive CPU loading
203
+ ctx = mp.get_context("forkserver")
204
+ self.batch_queue = ctx.Manager().Queue(maxsize=self.n_batches_to_prefetch)
205
+
206
+ # We should only ever one state, which is output at the detection of a stop event
207
+ self.state_queue = ctx.Manager().Queue(maxsize=1)
208
+
209
+ self.stop_iterating_event = ctx.Event()
210
+ self.state_dumped_event = ctx.Event()
211
+
212
+ self.producer = mp.Process(
213
+ name="blt_data_loader",
214
+ target=start_work_from_state,
215
+ args=(
216
+ self.batch_queue,
217
+ self.state_queue,
218
+ self.stop_iterating_event,
219
+ self.state_dumped_event,
220
+ self.base_iterator.get_state(),
221
+ ),
222
+ )
223
+ logger.info("Async dataloader started")
224
+ self.producer.start()
225
+
226
+ while True:
227
+ if self.producer.exitcode is not None:
228
+ raise RuntimeError(
229
+ "Data loader quit unexpectedly, real error has been raised previously"
230
+ )
231
+ try:
232
+ batch = self.batch_queue.get(timeout=0.1)
233
+ assert isinstance(batch, Batch)
234
+ assert (
235
+ not batch.is_final
236
+ ), "is_final should only be used during get_state() being called"
237
+ yield batch
238
+ except Empty:
239
+ pass
240
+ if self.producer is None:
241
+ raise ValueError(
242
+ "Attempted to call this iterator after calling get_state(). You must call create_iter() to make a new iterator instead."
243
+ )
bytelatent/data/iterators/packing_iterator.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ from typing import Any
3
+
4
+ import numpy as np
5
+ from pydantic import BaseModel, ConfigDict
6
+
7
+ from bytelatent.data.data_types import Batch, BltSequence
8
+ from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
9
+ from bytelatent.data.iterators.sampling_iterator import SamplingIteratorState
10
+
11
+
12
+ class PackingArgs(BaseModel):
13
+ model_config = ConfigDict(extra="forbid")
14
+ batch_size: int
15
+ seq_len: int
16
+ pad_id: int
17
+ max_length: int | None
18
+ pad_to_max_length: bool
19
+ enable_byte_ngrams: bool
20
+
21
+
22
+ class PackingIteratorState(BaseModel, IteratorState):
23
+ model_config = ConfigDict(extra="forbid")
24
+ sequence_iterator_state: SamplingIteratorState
25
+ packing_args: PackingArgs
26
+
27
+ def build(self) -> "PackingIterator":
28
+ return PackingIterator(
29
+ sequence_iterator=self.sequence_iterator_state.build(),
30
+ packing_args=self.packing_args,
31
+ )
32
+
33
+
34
+ def _merge_patch_seq_masks(bs, slen: int, mask_seqs: list[list[bool]]):
35
+ assert len(mask_seqs) == bs
36
+ lens = [len(m) for m in mask_seqs]
37
+ if all(all(m) for m in mask_seqs) and all(lens[0] == l for l in lens):
38
+ return None
39
+ assert slen == max(lens) - 1
40
+ mask = np.zeros((bs, slen), dtype=bool)
41
+ for i, m in enumerate(mask_seqs):
42
+ if m is None:
43
+ print(
44
+ "Did not implement None mask, the mask should be True for all toks, so we need to pass that to this function."
45
+ )
46
+ raise NotImplementedError
47
+ mask[i][: len(mask_seqs[i]) - 1] = mask_seqs[i][1:]
48
+ return mask
49
+
50
+
51
+ def truncate_batch(
52
+ batch: Batch,
53
+ max_length: int,
54
+ pad_id: int,
55
+ pad_to_max_length: bool = False,
56
+ *,
57
+ enable_byte_ngrams: bool,
58
+ ):
59
+ """
60
+ Truncate the x to a given size, making sure we remove the corresponding patch sizes in patch_lenghts
61
+ and fixing the batch.mask.
62
+
63
+ batch.patch_lengths has unchanged shape
64
+ x,y, and mask may reduce in size
65
+ """
66
+ if batch.patch_lengths is None:
67
+ return batch
68
+
69
+ seq_lengths = batch.patch_lengths.sum(axis=1)
70
+ max_length_adj = max_length + 1
71
+ if np.any(seq_lengths > max_length_adj):
72
+ for i in range(batch.x.shape[0]):
73
+ if seq_lengths[i] > max_length_adj:
74
+ # Find id of patch that tips over max_length + 1
75
+ count, j = 0, 0
76
+ while count + batch.patch_lengths[i, j] <= max_length_adj:
77
+ count += batch.patch_lengths[i, j]
78
+ j += 1
79
+ # Edit the batch
80
+ assert j < batch.patch_lengths.shape[1]
81
+ batch.x[i, max_length:] = pad_id
82
+ batch.y[i, max_length:] = pad_id
83
+ if batch.mask is not None:
84
+ batch.mask[i, max_length:] = False
85
+ batch.patch_lengths[i, j:] = 0
86
+ batch.patch_lengths[i, j] = max_length_adj - count
87
+
88
+ # Truncate if necessary.
89
+ if max_length < batch.x.shape[1]:
90
+ batch.x = batch.x[:, :max_length]
91
+ batch.y = batch.y[:, :max_length]
92
+ if batch.mask is not None:
93
+ batch.mask = batch.mask[:, :max_length]
94
+
95
+ # Right pad to max_length if necessary
96
+ elif pad_to_max_length:
97
+ if batch.x.shape[1] < max_length:
98
+ # NOTE: this has to be done on an actual patch.
99
+ non_zero_indices = (batch.patch_lengths != 0).sum(axis=1) - 1
100
+ non_zero_indices = np.maximum(0, non_zero_indices)
101
+ batch.patch_lengths[range(len(batch.patch_lengths)), non_zero_indices] += (
102
+ max_length - batch.x.shape[1]
103
+ )
104
+ # TODO: We could get rid of many of these complications by moving this funciton directly in the dataloader.
105
+ x = np.full((batch.x.shape[0], max_length), pad_id, dtype=batch.x.dtype)
106
+ x[:, : batch.x.shape[1]] = batch.x
107
+ batch.x = x
108
+ if batch.y.shape[1] < max_length:
109
+ y = np.full((batch.y.shape[0], max_length), pad_id, dtype=batch.y.dtype)
110
+ y[:, : batch.y.shape[1]] = batch.y
111
+ batch.y = y
112
+ if batch.mask is not None and batch.mask.shape[1] < max_length:
113
+ mask = np.full(
114
+ (batch.mask.shape[0], max_length), False, dtype=batch.mask.dtype
115
+ )
116
+ mask[:, : batch.mask.shape[1]] = batch.mask
117
+ batch.mask = mask
118
+
119
+ assert batch.x.shape[1] <= max_length
120
+ assert batch.y.shape[1] <= max_length
121
+ assert batch.mask is None or batch.mask.shape[1] <= max_length
122
+ assert np.all(max_length_adj - batch.patch_lengths.sum(axis=1) == 0)
123
+ if pad_to_max_length:
124
+ assert batch.x.shape[1] == max_length
125
+ assert batch.y.shape[1] == max_length
126
+ assert batch.mask is None or batch.mask.shape[1] == max_length
127
+ if enable_byte_ngrams:
128
+ raise NotImplementedError()
129
+ # (num_ngram, batch_size, seq_len)
130
+ ngram_ids = np.array(tokenizer.encode_token_ngrams(batch.x))
131
+ assert ngram_ids.shape[2] == batch.x.shape[1]
132
+ else:
133
+ ngram_ids = None
134
+ batch.ngram_ids = ngram_ids
135
+
136
+
137
+ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
138
+ def __init__(
139
+ self,
140
+ sequence_iterator: StatefulIterator[BltSequence, Any],
141
+ *,
142
+ packing_args: PackingArgs,
143
+ ):
144
+ self.sequence_iterator = sequence_iterator
145
+ self.packing_args = packing_args
146
+
147
+ def get_state(self):
148
+ return PackingIteratorState(
149
+ sequence_iterator_state=self.sequence_iterator.get_state(),
150
+ packing_args=self.packing_args,
151
+ )
152
+
153
+ def create_iter(self):
154
+ sequence_iter = self.sequence_iterator.create_iter()
155
+ batch_size = self.packing_args.batch_size
156
+ pad_id = self.packing_args.pad_id
157
+ seq_len = self.packing_args.seq_len
158
+ pad_to_max_length = self.packing_args.pad_to_max_length
159
+ enable_byte_ngrams = self.packing_args.enable_byte_ngrams
160
+ max_length = self.packing_args.max_length
161
+ while True:
162
+ tokens: list[list[int]] = []
163
+ masks: list[list[bool]] = []
164
+ patch_lengths: list[list[int]] = []
165
+
166
+ for _ in range(self.packing_args.batch_size):
167
+ sequence = next(sequence_iter)
168
+ _tokens = sequence.tokens
169
+ _mask = sequence.mask
170
+ _patch_lengths = sequence.patch_lengths
171
+ assert len(sequence.patch_lengths) == self.packing_args.seq_len
172
+ last_patch_length = 0
173
+ if _patch_lengths[0] > 1:
174
+ last_patch_length = _patch_lengths[-1]
175
+ _patch_lengths[0] -= 1
176
+ _patch_lengths = [1] + _patch_lengths[:-1]
177
+ tokens.append(_tokens[: len(_tokens) - last_patch_length])
178
+ masks.append(_mask[: len(_mask) - last_patch_length])
179
+ patch_lengths.append(_patch_lengths)
180
+
181
+ x_patch_lengths = np.array(patch_lengths)
182
+ # pad batch to same length
183
+ tok_seq_len = max([len(toks) for toks in tokens]) - 1
184
+ x = np.full((batch_size, tok_seq_len), fill_value=pad_id)
185
+ y = np.full((batch_size, tok_seq_len), fill_value=pad_id)
186
+
187
+ for i, tok_seq in enumerate(tokens):
188
+ x[i, : len(tok_seq) - 1] = tok_seq[:-1]
189
+ y[i, : len(tok_seq) - 1] = tok_seq[1:]
190
+ # Adjust patch lengths to match x
191
+ x_patch_lengths[i, -1] += tok_seq_len - (len(tok_seq) - 1)
192
+
193
+ assert x_patch_lengths.shape == (batch_size, seq_len)
194
+
195
+ if enable_byte_ngrams:
196
+ raise NotImplementedError()
197
+ else:
198
+ ngram_ids = None
199
+
200
+ batch = Batch(
201
+ x=x,
202
+ y=y,
203
+ patch_lengths=x_patch_lengths,
204
+ ngram_ids=ngram_ids,
205
+ mask=_merge_patch_seq_masks(batch_size, tok_seq_len, masks),
206
+ )
207
+ assert (
208
+ x_patch_lengths.sum() == x.size + batch_size
209
+ ), f"{x_patch_lengths.sum()} != {x.size + batch_size}"
210
+ assert (
211
+ batch.mask is None or np.sum(x != pad_id) == batch.mask.sum()
212
+ ), f"{np.sum(x != pad_id)} != {batch.mask.sum()}"
213
+ assert np.all(
214
+ x_patch_lengths[:, 0] == 1
215
+ ), f"first patch should always be 1, {x_patch_lengths[:, 0]}"
216
+ # cuda_gb_allocated = (torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024)
217
+ # cuda_gb_reserved = torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024
218
+ # print(f"dataloader cuda_gb_allocated: {cuda_gb_allocated}, cuda_gb_reserved: {cuda_gb_reserved}")
219
+ truncate_batch(
220
+ batch,
221
+ max_length=max_length,
222
+ pad_id=pad_id,
223
+ pad_to_max_length=pad_to_max_length,
224
+ enable_byte_ngrams=enable_byte_ngrams,
225
+ )
226
+ yield batch
bytelatent/data/iterators/preprocess_iterator.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ from typing import Any, Generator
3
+
4
+ import torch
5
+ from pydantic import BaseModel, ConfigDict
6
+
7
+ from bytelatent.data.data_types import BltExample
8
+ from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
9
+ from bytelatent.data.iterators.arrow_iterator import (
10
+ ArrowFileIterator,
11
+ ArrowFileIteratorState,
12
+ )
13
+ from bytelatent.data.iterators.looping_iterator import LoopingIteratorState
14
+ from bytelatent.data.patcher import Patcher, PatcherArgs, PatchingModeEnum
15
+ from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
16
+ from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
17
+
18
+
19
+ class PreprocessIteratorState(BaseModel, IteratorState):
20
+ model_config = ConfigDict(extra="forbid")
21
+ arrow_file_iterator_state: ArrowFileIteratorState | LoopingIteratorState
22
+ add_tokens: bool
23
+ add_patches: bool
24
+ tokenizer_args: TokenizerArgs
25
+ patcher_args: PatcherArgs
26
+
27
+ def build(self):
28
+ arrow_iterator = self.arrow_file_iterator_state.build()
29
+ return PreprocessIterator(
30
+ arrow_iterator,
31
+ patcher_args=self.patcher_args,
32
+ tokenizer_args=self.tokenizer_args,
33
+ add_tokens=self.add_tokens,
34
+ add_patches=self.add_patches,
35
+ )
36
+
37
+
38
+ class PreprocessIterator(StatefulIterator):
39
+ """
40
+ Take BltExamples with fields filled in only from ArrowFileIterator, and fill in fields that require
41
+ preprocessing like tokenization and patching
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ arrow_iterator: ArrowFileIterator,
47
+ *,
48
+ patcher_args: PatcherArgs,
49
+ tokenizer_args: TokenizerArgs,
50
+ add_tokens: bool = True,
51
+ add_patches: bool = True,
52
+ ):
53
+ self.arrow_iterator = arrow_iterator
54
+ self.tokenizer_args = tokenizer_args
55
+ self.patcher_args = patcher_args
56
+ self.add_tokens = add_tokens
57
+ self.add_patches = add_patches
58
+ self.tokenizer: BltTokenizer | None = None
59
+ self.patcher: Patcher | None = None
60
+
61
+ def get_state(self) -> PreprocessIteratorState:
62
+ """
63
+ The only state to maintain here is from arrow, there
64
+ isn't any internal state on this iterator.
65
+ """
66
+ return PreprocessIteratorState(
67
+ arrow_file_iterator_state=self.arrow_iterator.get_state(),
68
+ tokenizer_args=self.tokenizer_args,
69
+ patcher_args=self.patcher_args,
70
+ add_tokens=self.add_tokens,
71
+ add_patches=self.add_patches,
72
+ )
73
+
74
+ def create_iter(self) -> Generator[BltExample, Any, None]:
75
+ if self.tokenizer is None and self.add_tokens:
76
+ self.tokenizer = self.tokenizer_args.build()
77
+ if self.patcher is None and self.add_patches:
78
+ self.patcher = self.patcher_args.build()
79
+
80
+ example_iter = self.arrow_iterator.create_iter()
81
+ for example in example_iter:
82
+ if self.add_tokens:
83
+ tokens = self.tokenizer.encode(example.text)
84
+ else:
85
+ tokens = example.tokens
86
+ if (
87
+ self.patcher is not None
88
+ and self.patcher.patching_mode == PatchingModeEnum.entropy
89
+ ):
90
+ assert (
91
+ example.entropies is not None
92
+ ), "For patching, entropies cannot be None"
93
+ entropies = torch.tensor(example.entropies).unsqueeze(0)
94
+ else:
95
+ entropies = None
96
+ if self.patcher is None:
97
+ patch_lengths = None
98
+ else:
99
+ patch_lengths = self.patcher.patch(
100
+ torch.tensor(tokens).unsqueeze(0),
101
+ include_next_token=False,
102
+ entropies=entropies,
103
+ )[0][0].tolist()
104
+ yield BltExample(
105
+ sample_id=example.sample_id,
106
+ text=example.text,
107
+ tokens=tokens,
108
+ mask=[True] * len(tokens),
109
+ patch_lengths=patch_lengths,
110
+ entropies=example.entropies,
111
+ )
bytelatent/data/iterators/sampling_iterator.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ from typing import Any
3
+
4
+ import numpy as np
5
+ from pydantic import BaseModel, ConfigDict
6
+
7
+ from bytelatent.data.iterators.abstract_iterator import StatefulIterator
8
+ from bytelatent.data.iterators.sequence_iterator import SequenceIteratorState
9
+
10
+
11
+ class SamplingIteratorState(BaseModel):
12
+ model_config = ConfigDict(extra="forbid")
13
+ rng_state: dict[str, Any]
14
+ source_to_weight: dict[str, float]
15
+ source_to_iterator_state: dict[str, SequenceIteratorState]
16
+
17
+ def build(self) -> "SamplingIterator":
18
+ return SamplingIterator(
19
+ rng_state=self.rng_state,
20
+ source_to_weight=self.source_to_weight,
21
+ source_to_iterator={
22
+ source: state.build()
23
+ for source, state in self.source_to_iterator_state.items()
24
+ },
25
+ )
26
+
27
+
28
+ class SamplingIterator(StatefulIterator):
29
+ def __init__(
30
+ self,
31
+ *,
32
+ rng_state: dict[str, Any],
33
+ source_to_weight: dict[str, float],
34
+ source_to_iterator: dict[str, StatefulIterator],
35
+ ):
36
+ self.rng = np.random.default_rng()
37
+ self.rng.bit_generator.state = rng_state
38
+ self.source_to_weight = source_to_weight
39
+ self.source_to_iterator = source_to_iterator
40
+
41
+ def get_state(self) -> SamplingIteratorState:
42
+ return SamplingIteratorState(
43
+ rng_state=self.rng.bit_generator.state,
44
+ source_to_weight=self.source_to_weight,
45
+ source_to_iterator_state={
46
+ source: iterator.get_state()
47
+ for source, iterator in self.source_to_iterator.items()
48
+ },
49
+ )
50
+
51
+ def create_iter(self):
52
+ n_sources = len(self.source_to_weight)
53
+ possible_sources = []
54
+ weights = []
55
+ for source, w in self.source_to_weight.items():
56
+ possible_sources.append(source)
57
+ weights.append(w)
58
+
59
+ source_to_python_iter = {
60
+ source: self.source_to_iterator[source].create_iter()
61
+ for source in possible_sources
62
+ }
63
+ while True:
64
+ norm_weights = np.array(weights) / np.array(weights).sum()
65
+ source_choice = possible_sources[self.rng.choice(n_sources, p=norm_weights)]
66
+ yield next(source_to_python_iter[source_choice])
bytelatent/data/iterators/sequence_iterator.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ from logging import getLogger
3
+ from typing import Any
4
+
5
+ import numpy as np
6
+ from pydantic import BaseModel, ConfigDict
7
+
8
+ from bytelatent.data.data_types import BltSequence
9
+ from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
10
+ from bytelatent.data.iterators.preprocess_iterator import (
11
+ PreprocessIterator,
12
+ PreprocessIteratorState,
13
+ )
14
+
15
+ logger = getLogger()
16
+
17
+
18
+ class SequencePackingArgs(BaseModel):
19
+ model_config = ConfigDict(extra="forbid")
20
+ output_seq_len: int
21
+ buffer_size: int
22
+
23
+
24
+ class SequenceIteratorState(BaseModel, IteratorState):
25
+ model_config = ConfigDict(extra="forbid")
26
+ sequence_packing_args: SequencePackingArgs
27
+ preprocess_iterator_state: PreprocessIteratorState
28
+ rng_state: dict[str, Any]
29
+
30
+ def build(self):
31
+ preprocess_iterator = self.preprocess_iterator_state.build()
32
+ return SequenceIterator(
33
+ preprocess_iterator,
34
+ sequence_packing_args=self.sequence_packing_args,
35
+ rng_state=self.rng_state,
36
+ )
37
+
38
+
39
+ class SequenceIterator(StatefulIterator):
40
+ def __init__(
41
+ self,
42
+ preprocess_iterator: PreprocessIterator,
43
+ *,
44
+ rng_state: dict[str, Any],
45
+ sequence_packing_args: SequencePackingArgs,
46
+ ):
47
+ self.preprocess_iterator = preprocess_iterator
48
+ self.sequence_packing_args = sequence_packing_args
49
+ self.output_seq_len = sequence_packing_args.output_seq_len
50
+ self.buffer_size = sequence_packing_args.buffer_size
51
+ self.rng = np.random.default_rng()
52
+ self.rng.bit_generator.state = rng_state
53
+
54
+ def get_state(self):
55
+ # TODO: need to also perist the current shuffle buffer
56
+ return SequenceIteratorState(
57
+ sequence_packing_args=self.sequence_packing_args,
58
+ preprocess_iterator_state=self.preprocess_iterator.get_state(),
59
+ rng_state=self.rng.bit_generator.state,
60
+ )
61
+
62
+ def create_iter(self):
63
+ example_iter = self.preprocess_iterator.create_iter()
64
+ n_buffer_patches = self.buffer_size * self.output_seq_len
65
+
66
+ patch_lengths: list[int] = []
67
+ tokens: list[int] = []
68
+ mask: list[bool] = []
69
+ first = True
70
+ for example in example_iter:
71
+ assert example.tokens is not None
72
+ assert example.mask is not None
73
+ assert example.patch_lengths is not None
74
+ assert len(example.tokens) != 0
75
+ assert len(example.mask) != 0
76
+ assert len(example.tokens) == len(example.mask)
77
+ assert len(example.tokens) == sum(example.patch_lengths)
78
+
79
+ tokens.extend(example.tokens)
80
+ mask.extend(example.mask)
81
+ patch_lengths.extend(example.patch_lengths)
82
+
83
+ while len(patch_lengths) >= n_buffer_patches:
84
+ if first:
85
+ first = False
86
+ logger.info("First buffer complete")
87
+
88
+ x_patches = np.array(patch_lengths[:n_buffer_patches]).reshape(
89
+ self.buffer_size, self.output_seq_len
90
+ )
91
+ seq_tokens = []
92
+ seq_mask = []
93
+ start_id = 0
94
+ # We fix the number of patches and therefore global steps per batch
95
+ # so we have a variable number of tokens we need to account for
96
+ for num_tokens in x_patches.sum(axis=-1):
97
+ seq_tokens.append(tokens[start_id : start_id + num_tokens])
98
+ seq_mask.append(mask[start_id : start_id + num_tokens])
99
+ start_id += num_tokens
100
+
101
+ assert start_id == x_patches.sum()
102
+
103
+ # Remove what we just added from the buffer
104
+ patch_lengths = patch_lengths[n_buffer_patches:]
105
+ tokens = tokens[x_patches.sum() :]
106
+ mask = mask[x_patches.sum() :]
107
+
108
+ seq_patch_lengths: list[list[int]] = x_patches.tolist()
109
+ assert len(seq_patch_lengths) == self.buffer_size
110
+ for idx in self.rng.permutation(len(seq_patch_lengths)):
111
+ assert len(seq_patch_lengths[idx]) == self.output_seq_len
112
+ assert (
113
+ sum(seq_patch_lengths[idx])
114
+ == len(seq_tokens[idx])
115
+ == len(seq_mask[idx])
116
+ ), f"{sum(seq_patch_lengths[idx])}, {len(seq_tokens[idx])} {len(seq_mask[idx])}, idx={idx}"
117
+ assert seq_patch_lengths[idx][0] > 0, f"{seq_patch_lengths[idx]}"
118
+ yield BltSequence(
119
+ tokens=seq_tokens[idx],
120
+ mask=seq_mask[idx],
121
+ patch_lengths=seq_patch_lengths[idx],
122
+ )
bytelatent/data/iterators/test_arrow_iterator.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ import numpy as np
3
+ import pyarrow as pa
4
+
5
+ # pyarrow needs the initialization from this import
6
+ import pyarrow.dataset # pyright: ignore
7
+
8
+ from bytelatent.constants import BLT_DATA
9
+ from bytelatent.data.iterators.arrow_iterator import ArrowFileIteratorState
10
+
11
+ ENTROPY_MODEL = "transformer_100m"
12
+ ARROW_TEST_DATA_1 = str(BLT_DATA / "stackexchange.chunk.00.jsonl.shard_00.arrow")
13
+ ARROW_TEST_DATA_2 = str(BLT_DATA / "stackexchange.chunk.00.jsonl.shard_01.arrow")
14
+
15
+
16
+ def test_basic_arrow_file():
17
+ dataset = pa.dataset.dataset(ARROW_TEST_DATA_1, format="arrow")
18
+ n_head = 1000
19
+ head_df = dataset.head(n_head).to_pandas()
20
+
21
+ initial_state = ArrowFileIteratorState(
22
+ file_path=None,
23
+ num_workers=1,
24
+ worker_id=0,
25
+ preprocess_dir=None,
26
+ entropy_model_name=ENTROPY_MODEL,
27
+ dataset_files=[ARROW_TEST_DATA_1],
28
+ row_num=0,
29
+ arrow_batch_size=100,
30
+ )
31
+ arrow_file = initial_state.build()
32
+ start_state = arrow_file.get_state()
33
+ assert start_state.row_num == initial_state.row_num
34
+
35
+ sample_id = None
36
+ for example in arrow_file.create_iter():
37
+ sample_id = example.sample_id
38
+ assert head_df.iloc[0]["sample_id"] == sample_id
39
+ break
40
+
41
+ assert arrow_file.get_state().row_num == 1
42
+ arrow_file = initial_state.build()
43
+ for example in arrow_file.create_iter():
44
+ assert example.sample_id == sample_id
45
+ assert head_df.iloc[0]["sample_id"] == sample_id
46
+ break
47
+
48
+ # Test resume far enough in to be past the batch size of 100
49
+ resumed_state = ArrowFileIteratorState(
50
+ file_path=None,
51
+ num_workers=1,
52
+ worker_id=0,
53
+ preprocess_dir=None,
54
+ entropy_model_name=ENTROPY_MODEL,
55
+ dataset_files=[ARROW_TEST_DATA_1],
56
+ row_num=251,
57
+ arrow_batch_size=100,
58
+ )
59
+ arrow_file = resumed_state.build()
60
+ for example in arrow_file.create_iter():
61
+ assert example.sample_id == head_df.iloc[251]["sample_id"]
62
+ assert arrow_file.get_state().row_num == 252
63
+ break
64
+
65
+ world_rank = 1
66
+ world_size = 4
67
+ # Test World Size and Rank
68
+ rank_state = ArrowFileIteratorState(
69
+ file_path=None,
70
+ num_workers=world_size,
71
+ worker_id=world_rank,
72
+ preprocess_dir=None,
73
+ entropy_model_name=ENTROPY_MODEL,
74
+ dataset_files=[ARROW_TEST_DATA_1],
75
+ row_num=0,
76
+ arrow_batch_size=100,
77
+ )
78
+ arrow_file = rank_state.build()
79
+ expected_ids = []
80
+ for i in range(n_head):
81
+ if i % world_size == world_rank:
82
+ expected_ids.append(head_df.iloc[i]["sample_id"])
83
+ print(len(expected_ids))
84
+ i = 0
85
+ for example in arrow_file.create_iter():
86
+ assert example.sample_id == expected_ids[i]
87
+ i += 1
88
+ if i >= len(expected_ids):
89
+ break
bytelatent/data/iterators/test_iters.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ import pandas as pd
3
+ from pydantic import BaseModel
4
+
5
+ from bytelatent.constants import BLT_DATA
6
+ from bytelatent.data.data_types import BltExample
7
+ from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
8
+ from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator
9
+ from bytelatent.data.patcher import PatcherArgs, PatchingModeEnum
10
+ from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
11
+
12
+
13
+ class BltTestIteratorState(BaseModel, IteratorState):
14
+ position: int
15
+ total: int
16
+
17
+ def build(self):
18
+ blt_iter = BltTestIteratorState(total=self.total)
19
+ blt_iter.position = self.position
20
+ return blt_iter
21
+
22
+
23
+ class BltTestIterator(StatefulIterator):
24
+ def __init__(self, total: int):
25
+ self.position = 0
26
+ self.total = total
27
+
28
+ def get_state(self):
29
+ return BltTestIteratorState(position=self.position, total=self.total)
30
+
31
+ def create_iter(self):
32
+ for i in range(self.total):
33
+ self.position += 1
34
+ yield BltExample(
35
+ sample_id=f"test_{i}",
36
+ text=f"This is some test {i} text.",
37
+ tokens=None,
38
+ mask=None,
39
+ entropies=None,
40
+ patch_lengths=None,
41
+ )
42
+
43
+
44
+ class BltTestWithEntropiesIteratorState(BaseModel, IteratorState):
45
+ position: int
46
+ total: int
47
+
48
+ def build(self):
49
+ blt_iter = BltTestWithEntropiesIteratorState(total=self.total)
50
+ blt_iter.position = self.position
51
+ return blt_iter
52
+
53
+
54
+ class BltTestWithEntropiesIterator(StatefulIterator):
55
+ def __init__(self, total: int):
56
+ self.position = 0
57
+ self.total = total
58
+
59
+ def get_state(self):
60
+ return BltTestIteratorState(position=self.position, total=self.total)
61
+
62
+ def create_iter(self):
63
+ text = "Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin."
64
+ df = pd.read_json("fixtures/tokens_with_entropies.json")
65
+ tokens = df["token_ids"].tolist()
66
+ entropies = df["entropies"].tolist()
67
+ # BOS and EOS
68
+ assert len(tokens) == len(text) + 2
69
+ for i in range(self.total):
70
+ self.position += 1
71
+ yield BltExample(
72
+ sample_id=f"test_{i}",
73
+ text=text,
74
+ tokens=tokens,
75
+ mask=[True] * len(tokens),
76
+ entropies=entropies,
77
+ patch_lengths=None,
78
+ )
79
+
80
+
81
+ def test_preprocess_iter():
82
+ total = 3
83
+ tokenizer_args = TokenizerArgs(
84
+ name="blt",
85
+ init_kwargs={
86
+ "bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model"
87
+ },
88
+ )
89
+ for mode in [
90
+ PatchingModeEnum.bpe,
91
+ PatchingModeEnum.space,
92
+ ]:
93
+ data_it = BltTestIterator(total)
94
+ patcher_args = PatcherArgs(patching_mode=mode)
95
+ example_it = PreprocessIterator(
96
+ data_it, tokenizer_args=tokenizer_args, patcher_args=patcher_args
97
+ )
98
+ count = 0
99
+ for example in example_it.create_iter():
100
+ assert isinstance(example.tokens, list)
101
+ assert isinstance(example.tokens[0], int)
102
+ # BOS and EOS
103
+ assert len(example.tokens) == len(example.text) + 2
104
+ assert example.mask is not None
105
+ assert len(example.tokens) == len(example.mask)
106
+ count += 1
107
+
108
+ assert count == total
109
+
110
+
111
+ def test_non_entropy_patch_iter():
112
+ total = 3
113
+ tokenizer_args = TokenizerArgs(
114
+ name="blt",
115
+ init_kwargs={
116
+ "bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model"
117
+ },
118
+ )
119
+ for mode in [
120
+ PatchingModeEnum.bpe,
121
+ PatchingModeEnum.space,
122
+ ]:
123
+ patcher_args = PatcherArgs(patching_mode=mode)
124
+ data_it = BltTestIterator(total)
125
+ example_it = PreprocessIterator(
126
+ data_it, tokenizer_args=tokenizer_args, patcher_args=patcher_args
127
+ )
128
+
129
+ count = 0
130
+ for example in example_it.create_iter():
131
+ assert isinstance(example.patch_lengths, list)
132
+ assert isinstance(example.patch_lengths[0], int)
133
+ assert len(example.tokens) == sum(example.patch_lengths)
134
+ count += 1
135
+
136
+ assert count == total
137
+
138
+
139
+ def test_entropy_patch_iter():
140
+ total = 2
141
+ patcher_args = PatcherArgs(
142
+ patching_mode=PatchingModeEnum.entropy, threshold=1.335442066192627
143
+ )
144
+ tokenizer_args = TokenizerArgs(
145
+ name="blt",
146
+ init_kwargs={
147
+ "bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model"
148
+ },
149
+ )
150
+ data_it = BltTestWithEntropiesIterator(total)
151
+ example_it = PreprocessIterator(
152
+ data_it, tokenizer_args=tokenizer_args, patcher_args=patcher_args
153
+ )
154
+
155
+ count = 0
156
+ for example in example_it.create_iter():
157
+ assert isinstance(example.patch_lengths, list)
158
+ assert isinstance(example.patch_lengths[0], int)
159
+ assert len(example.tokens) == sum(example.patch_lengths)
160
+ count += 1
161
+
162
+ assert count == total
bytelatent/data/ngram_processor.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ import pickle
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+
7
+ from bytelatent import ByteLatentError
8
+
9
+ LOOKUP_OFFSET = 4
10
+
11
+
12
+ def apply_lookup_table_wrapper(ngram_to_idx: dict[tuple, int], lookup_offset=1):
13
+ """
14
+ Wrapper function for applying the lookup table to each n-gram.
15
+
16
+ :param ngram: Array of numbers representing an n-gram.
17
+ :param lookup_table: Dictionary where keys are tuples (n-grams) and values are the desired outputs.
18
+ :param lookup_offset: Offset to add to the lookup result.
19
+ :return: The value associated with the n-gram tuple in the dictionary, or None if not found.
20
+ """
21
+
22
+ def apply_lookup_table(ngram):
23
+ """
24
+ Function to apply to each n-gram: converts it to a tuple and looks it up in a dictionary.
25
+
26
+ :param ngram: Array of numbers representing an n-gram.
27
+ :return: The value associated with the n-gram tuple in the dictionary, or None if not found.
28
+ """
29
+ # Convert the n-gram to a tuple
30
+ ngram_tuple = tuple(ngram)
31
+
32
+ if ngram_tuple not in ngram_to_idx:
33
+ return 0
34
+ else:
35
+ return ngram_to_idx[ngram_tuple] + lookup_offset
36
+
37
+ return apply_lookup_table
38
+
39
+
40
+ def get_byte_ngrams_ids(
41
+ byte_array: np.ndarray, n: int, ngram_to_idx: dict[tuple, int], pad_value=0
42
+ ):
43
+ """
44
+ Generate n-grams from a 2D numpy array.
45
+
46
+ :param n: The length of each n-gram.
47
+ :param pad_value: The value used for padding of the byte values to maintain the same dimensions for the n-grams.
48
+ :return: A 2D numpy array where each element is the ID of an n-gram offset by LOOKUP_OFFSET.
49
+ """
50
+ num_rows, num_cols = byte_array.shape
51
+
52
+ # Create an array to hold the padded version of the original array
53
+ padded_array = np.pad(
54
+ byte_array, ((0, 0), (n - 1, 0)), mode="constant", constant_values=pad_value
55
+ )
56
+
57
+ # Use stride tricks to avoid explicit looping
58
+ strided = np.lib.stride_tricks.as_strided
59
+ shape = (num_rows, num_cols, n)
60
+ strides = padded_array.strides[:2] + (padded_array.strides[1],)
61
+ ngrams = strided(padded_array, shape=shape, strides=strides)
62
+
63
+ ngram_ids = np.apply_along_axis(
64
+ apply_lookup_table_wrapper(ngram_to_idx, lookup_offset=LOOKUP_OFFSET), 2, ngrams
65
+ )
66
+ assert ngram_ids.shape == byte_array.shape
67
+ return ngram_ids
68
+
69
+
70
+ def reload_tables(
71
+ ngram_table_dir: str, ngram_to_size: dict[int, int], offset: int = LOOKUP_OFFSET
72
+ ) -> tuple[dict[int, list], dict[tuple, int], dict[int, int]]:
73
+ """
74
+ Reload lookup tables from a directory. Reload only the ngrams in the dictionary and per ngram,
75
+ only load up to the max specified size. Return the actual number of ngrams taken per ngram size.
76
+ """
77
+ idx_to_ngram_tables = {}
78
+ ngram_to_idx_tables = {}
79
+ vocab_sizes = {}
80
+ for ngram, size in ngram_to_size.items():
81
+ with open(Path(ngram_table_dir) / f"ngram-{ngram}.pickle", "rb") as f:
82
+ # These are already sorted by count
83
+ # Value: tuple of: count, ngram, dataset
84
+ ngram_data: list[tuple[tuple, tuple[int, int, str]]] = pickle.load(f)[
85
+ "counts"
86
+ ]
87
+ table = [ngram for ngram, _ in ngram_data][:size]
88
+ if len(table) != size:
89
+ raise ValueError(
90
+ f"Ngram table for {ngram}-gram is not large enough to get {size} ngrams, max size is {len(ngram_data)}"
91
+ )
92
+ ngram_to_idx = {ngram: idx for idx, ngram in enumerate(table)}
93
+ actual_size = len(table)
94
+ idx_to_ngram_tables[ngram] = table
95
+ ngram_to_idx_tables[ngram] = ngram_to_idx
96
+ vocab_sizes[ngram] = actual_size + offset
97
+ return ngram_to_idx_tables, ngram_to_idx_tables, vocab_sizes
98
+
99
+
100
+ def parse_ngram_to_size(ngram_to_size_str: str | None) -> dict[int, int]:
101
+ if ngram_to_size_str is None:
102
+ return None
103
+ ngram_to_size = {}
104
+ for entry in ngram_to_size_str.split(","):
105
+ ngram, size = entry.split(":")
106
+ ngram = int(ngram)
107
+ size = int(size)
108
+ ngram_to_size[ngram] = size
109
+ return ngram_to_size
110
+
111
+
112
+ class NgramProcessor:
113
+ def __init__(
114
+ self,
115
+ ngram_table_dir: str | None = None,
116
+ ngram_to_size: dict[int, int] | None = None,
117
+ ):
118
+ if ngram_table_dir is None or ngram_to_size is None:
119
+ raise ByteLatentError(
120
+ "ngram_table_dir and ngram_to_size cannot be none if enable_byte_ngrams is True"
121
+ )
122
+ (
123
+ self.ngram_to_idx_tables,
124
+ self.idx_to_ngram_tables,
125
+ self.ngram_vocab_sizes,
126
+ ) = reload_tables(ngram_table_dir, ngram_to_size)
127
+ # Lowest to highest ngram
128
+ self.ngram_sizes = sorted(list(self.ngram_to_idx_tables.keys()))
129
+ # Although the model might not use all the ngrams, we need the tokenizer
130
+ # to produce ngram_ids such that index zero is the 2-gram, later on in
131
+ # src.model.megabyte.Megabyte.forward
132
+ assert self.ngram_sizes[0] == 2
133
+
134
+ def encode_single_ngram_table(self, data: np.ndarray, n: int):
135
+ """
136
+ Return the n-grams of the input data for a given n
137
+ numpy array with ids of shape data.shape
138
+ """
139
+ return get_byte_ngrams_ids(data, n, self.ngram_to_idx_tables[n], pad_value=0)
140
+
141
+ def encode_token_ngrams(self, data: np.ndarray):
142
+ """
143
+ Return the n-grams of the input data.
144
+ output shape: [ids with data.shape for n in self.ngram_sizes]
145
+ """
146
+ return [self.encode_single_ngram_table(data, n) for n in self.ngram_sizes]
bytelatent/data/patcher.py ADDED
@@ -0,0 +1,609 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ import math
3
+ import time
4
+ from collections import defaultdict
5
+ from enum import Enum
6
+
7
+ import torch
8
+ from pydantic import BaseModel
9
+ from torch.nn import functional as F
10
+
11
+ from bytelatent.distributed import get_local_rank
12
+ from bytelatent.entropy_model import load_entropy_model
13
+
14
+ # from src.slurm import get_local_rank
15
+ from bytelatent.tokenizers.blt_tokenizer import BPE_ID, OFFSET
16
+ from bytelatent.tokenizers.constants import BPE_ID, OFFSET
17
+
18
+
19
+ class PatchingModeEnum(str, Enum):
20
+ entropy = "entropy"
21
+ bpe = "bpe"
22
+ bpe_patcher = "bpe_patcher"
23
+ space = "space"
24
+
25
+
26
+ class PatcherArgs(BaseModel):
27
+ patching_mode: PatchingModeEnum = PatchingModeEnum.entropy
28
+ patching_device: str = "cuda"
29
+ entropy_model_checkpoint_dir: str | None = None
30
+ realtime_patching: bool = False
31
+ threshold: float = 1.335442066192627
32
+ threshold_add: float | None = None
33
+ max_patch_length: int | None = None
34
+ patch_size: float = 4.5
35
+ patching_batch_size: int = 1
36
+ data_loader_patching: bool = False
37
+ device: str = "cuda"
38
+ monotonicity: bool = False
39
+ log_time: bool = False
40
+
41
+ def build(self) -> "Patcher":
42
+ return Patcher(self)
43
+
44
+
45
+ def entropy(scores):
46
+ """
47
+ scores: [bs, seq_len, vocab]
48
+ returns [bs, seq_len]
49
+
50
+ Computes the entropy for each token in the batch.
51
+ Note: uses natural log.
52
+ """
53
+ log_probs = F.log_softmax(scores, dim=-1)
54
+ probs = torch.exp(log_probs)
55
+ p_log_p = log_probs * probs
56
+ entropy = -p_log_p.sum(dim=-1)
57
+ return entropy
58
+
59
+
60
+ def calculate_entropies(
61
+ tokens: torch.tensor, entropy_model, patching_batch_size, device: str | None = None
62
+ ):
63
+ """
64
+ tokens: 2D tensor of shape [batch_size, seq_len]
65
+ Return 2D tensor of shape [batch_size, seq_len] with entropies for each token.
66
+
67
+ Splits the tokens into chunks of size max_length and calculates entropies for each chunk.
68
+ Entropy model can be executed on cpu or gpu, specify either 'cuda' or 'cpu' in the device argument.
69
+ """
70
+ with torch.no_grad():
71
+ entropies = []
72
+ max_length = getattr(entropy_model, "max_length", 8192)
73
+ batch_numel = max_length * patching_batch_size
74
+ splits = torch.split(tokens.flatten(), batch_numel)
75
+ for split in splits:
76
+ pad_size = (max_length - (split.numel() % max_length)) % max_length
77
+ pad = torch.zeros(
78
+ pad_size, dtype=split.dtype, device=split.device, requires_grad=False
79
+ )
80
+ split = torch.cat((split, pad), dim=0)
81
+ split = split.reshape(-1, max_length)
82
+ if device is not None:
83
+ split = split.to(device)
84
+ assert torch.all(split >= 0) and torch.all(split < 260)
85
+ pred, _ = entropy_model(split)
86
+ pred = pred.reshape(-1, pred.shape[-1])[
87
+ : split.numel() - pad_size, :
88
+ ] # [batch_size * seq_len, vocab]
89
+ pred_entropies = entropy(pred)
90
+ entropies.append(pred_entropies)
91
+
92
+ entropies = torch.cat(entropies, dim=0)
93
+ entropies = entropies.reshape(tokens.shape)
94
+ return entropies
95
+
96
+
97
+ def patch_start_mask_from_entropy_with_monotonicity(entropies, t):
98
+ """
99
+ entropies: [bs, seq_len] torch tensor of entropies
100
+ t: threshold
101
+ returns [bs, seq_len] mask where True indicates the start of a patch
102
+ """
103
+ bs, seq_len = entropies.shape
104
+ mask = torch.zeros_like(entropies, dtype=torch.bool)
105
+ mask[:, 0] = True
106
+
107
+ # Calculate differences between consecutive elements along the sequence length
108
+ differences = entropies[:, 1:] - entropies[:, :-1]
109
+
110
+ # Calculate conditions for all elements except the first one in each sequence
111
+ condition = differences > t
112
+
113
+ # Update the mask based on the condition
114
+ mask[:, 1:] = condition
115
+
116
+ return mask
117
+
118
+
119
+ def patch_start_mask_global_and_monotonicity(entropies, t, t_add=0):
120
+ """
121
+ entropies: [bs, seq_len] torch tensor of entropies
122
+ t: threshold
123
+ returns [bs, seq_len] mask where True indicates the start of a patch
124
+ """
125
+ bs, seq_len = entropies.shape
126
+ mask = torch.zeros_like(entropies, dtype=torch.bool)
127
+ mask[:, 0] = True
128
+
129
+ # Calculate differences between consecutive elements along the sequence length
130
+ differences = entropies[:, 1:] - entropies[:, :-1]
131
+
132
+ # Calculate conditions for all elements except the first one in each sequence
133
+ condition = (differences > t_add) & (entropies[:, 1:] > t) & (~mask[:, :-1])
134
+
135
+ # Update the mask based on the condition
136
+ mask[:, 1:] = condition
137
+
138
+ return mask
139
+
140
+
141
+ def patch_start_ids_from_patch_start_mask(patch_start_mask):
142
+ bs, trunc_seq_len = patch_start_mask.shape
143
+ max_patches = patch_start_mask.sum(dim=1).max()
144
+ if max_patches == 0:
145
+ patch_start_ids = torch.full(
146
+ (bs, trunc_seq_len),
147
+ trunc_seq_len,
148
+ dtype=torch.long,
149
+ device=patch_start_mask.device,
150
+ )
151
+ else:
152
+ patch_ids = (
153
+ torch.arange(trunc_seq_len, device=patch_start_mask.device)
154
+ .unsqueeze(0)
155
+ .repeat(bs, 1)
156
+ )
157
+ extra_patch_ids = torch.full(
158
+ (bs, trunc_seq_len),
159
+ trunc_seq_len,
160
+ dtype=torch.long,
161
+ device=patch_start_mask.device,
162
+ )
163
+ all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1)
164
+ patch_start_mask_padded = torch.cat(
165
+ (patch_start_mask, ~patch_start_mask), dim=1
166
+ )
167
+ patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape(
168
+ bs, trunc_seq_len
169
+ )[:, :max_patches]
170
+ return patch_start_ids
171
+
172
+
173
+ def check_non_zero_after_zero(tensor):
174
+ zero_mask = tensor == 0
175
+ shifted_mask = torch.cat(
176
+ [
177
+ torch.zeros(tensor.shape[0], 1, dtype=torch.bool, device=tensor.device),
178
+ zero_mask[:, :-1],
179
+ ],
180
+ dim=1,
181
+ )
182
+ non_zero_after_zero = (tensor != 0) & shifted_mask
183
+ return non_zero_after_zero.any()
184
+
185
+
186
+ def patch_lengths_from_start_ids(patch_start_ids, seq_len):
187
+ """
188
+ Calculate patch lengths from start ids.
189
+ start ids: ex: [0, 1, 7, 7, 7, 7, 7], it has the start ids of the patches (here 0, 1), and then
190
+ the rest are filled to the seq len.
191
+ seq_len: ex: 7 length of the sequence
192
+
193
+ returns the patch lengths:
194
+ [1, 6] for the above example.
195
+ """
196
+ last_ids = torch.full_like(patch_start_ids[:, :1], seq_len - 1)
197
+ patch_end_ids = torch.cat((patch_start_ids[:, 1:] - 1, last_ids), dim=1)
198
+ patch_lengths = patch_end_ids - patch_start_ids + 1
199
+ assert torch.all(patch_lengths >= 0), f"{patch_lengths}"
200
+ assert not check_non_zero_after_zero(patch_lengths), f"{patch_lengths}"
201
+ return patch_lengths
202
+
203
+
204
+ def find_space_patch_start_ids(tokens):
205
+ bs, seq_len = tokens.shape
206
+ tokens_no_offset = tokens - OFFSET
207
+ patch_end_mask = (
208
+ (tokens_no_offset < ord("0"))
209
+ | ((ord("9") < tokens_no_offset) & (tokens_no_offset < ord("A")))
210
+ | ((ord("Z") < tokens_no_offset) & (tokens_no_offset < ord("a")))
211
+ | ((ord("z") < tokens_no_offset) & (tokens_no_offset < 0b1000_0000))
212
+ | (0b1100_0000 <= tokens_no_offset)
213
+ )
214
+ patch_end_mask[:, 1:] &= patch_end_mask[:, :-1].bitwise_not()
215
+ patch_end_mask |= tokens < OFFSET
216
+
217
+ patch_start_mask = torch.cat(
218
+ [
219
+ torch.tensor([1, 1], device=tokens.device, dtype=torch.bool)
220
+ .unsqueeze(0)
221
+ .repeat(bs, 1),
222
+ patch_end_mask[:, 1:],
223
+ ],
224
+ dim=1,
225
+ )
226
+ max_patches = patch_start_mask.sum(dim=1).max()
227
+
228
+ patch_ids = (
229
+ torch.arange(seq_len + 1, device=tokens.device).unsqueeze(0).repeat(bs, 1)
230
+ )
231
+ extra_patch_ids = torch.full(
232
+ (bs, seq_len + 1), seq_len + 1, dtype=torch.long, device=tokens.device
233
+ )
234
+ all_patch_ids = torch.cat((patch_ids, extra_patch_ids), dim=1)
235
+ patch_start_mask_padded = torch.cat((patch_start_mask, ~patch_start_mask), dim=1)
236
+
237
+ patch_start_ids = all_patch_ids[patch_start_mask_padded].reshape(bs, -1)[
238
+ :, :max_patches
239
+ ]
240
+ return patch_start_ids
241
+
242
+
243
+ def to_device(entropy_model, device=None):
244
+ if device == "cuda":
245
+ rank = get_local_rank()
246
+ device = f"cuda:{rank}"
247
+ entropy_model = entropy_model.to(device)
248
+ return entropy_model, device
249
+
250
+
251
+ def model_pred_to_bpe_patching_pred(pred):
252
+ _, indices = torch.max(pred, dim=1)
253
+ return indices == BPE_ID
254
+
255
+
256
+ def apply_bpe_patcher(tokens, bpe_patcher, patching_batch_size, device=None):
257
+ assert tokens.device == torch.device(
258
+ "cpu"
259
+ ), f"{tokens.device} != cpu expects tokens to be on cpu"
260
+ with torch.no_grad():
261
+ bpe_patcher_device, device = to_device(
262
+ bpe_patcher, device
263
+ ) # Get entropy model to right rank device.
264
+ bpe_patching_mask = []
265
+ max_length = getattr(bpe_patcher, "max_length", 8192)
266
+ batch_numel = max_length * patching_batch_size
267
+ splits = torch.split(tokens.flatten(), batch_numel)
268
+ for split in splits:
269
+ pad_size = (max_length - (split.numel() % max_length)) % max_length
270
+ pad = torch.zeros(
271
+ pad_size, dtype=split.dtype, device=split.device, requires_grad=False
272
+ )
273
+ split = torch.cat((split, pad), dim=0)
274
+ split = split.reshape(-1, max_length).to(device)
275
+ assert torch.all(split >= 0) and torch.all(split < 260)
276
+ pred = bpe_patcher_device(split)
277
+ pred_cpu = pred[0].cpu()
278
+ pred_cpu = pred_cpu.reshape(-1, pred_cpu.shape[-1])[
279
+ : split.numel() - pad_size, :
280
+ ] # [batch_size * seq_len, vocab]
281
+ bpe_patching_pred = model_pred_to_bpe_patching_pred(pred_cpu)
282
+ bpe_patching_mask.append(bpe_patching_pred)
283
+ bpe_patching_mask = torch.cat(bpe_patching_mask, dim=0)
284
+ bpe_patching_mask = bpe_patching_mask.reshape(tokens.shape)
285
+ return bpe_patching_mask
286
+
287
+
288
+ def find_bpe_patcher_patch_start_ids(
289
+ tokens, bpe_patcher, patching_batch_size, device=None, include_next_token=True
290
+ ):
291
+ bs, seq_len = tokens.shape
292
+
293
+ first_ids = (
294
+ torch.tensor([0, 1], dtype=torch.long, device=tokens.device)
295
+ .unsqueeze(0)
296
+ .repeat(bs, 1)
297
+ )
298
+ preds_truncation_len = first_ids.shape[1]
299
+ token_input = tokens[:, 1:] if include_next_token else tokens[:, 1:-1]
300
+ if token_input.shape[1] >= 1:
301
+ patch_start_mask = apply_bpe_patcher(
302
+ token_input, bpe_patcher, patching_batch_size, device
303
+ )
304
+ assert (
305
+ patch_start_mask.shape[1]
306
+ == tokens.shape[1] + include_next_token - preds_truncation_len
307
+ ), f"{patch_start_mask.shape[1]} != {tokens.shape[1] + include_next_token - preds_truncation_len}"
308
+ patch_start_ids = patch_start_ids_from_patch_start_mask(patch_start_mask)
309
+ patch_start_ids = torch.cat(
310
+ (first_ids, patch_start_ids + preds_truncation_len), dim=1
311
+ )
312
+ else:
313
+ patch_start_ids = first_ids
314
+ return patch_start_ids
315
+
316
+
317
+ def find_entropy_patch_start_ids(
318
+ entropies,
319
+ patch_size=None,
320
+ threshold=None,
321
+ threshold_add=None,
322
+ monotonicity=False,
323
+ include_next_token=True,
324
+ ):
325
+ """
326
+ Use entropies to find the start ids of each patch.
327
+ Use patch_size or threshold to figure out the total number of patches to allocate.
328
+
329
+ When threshold is not None the number of patches is not constant between
330
+ different sequences, but patches can be identified incrementally rather than
331
+ decided globally using the entire sequence.
332
+ """
333
+ bs, seq_len = entropies.shape[:2]
334
+
335
+ first_ids = (
336
+ torch.tensor([0, 1], dtype=torch.long, device=entropies.device)
337
+ .unsqueeze(0)
338
+ .repeat(bs, 1)
339
+ )
340
+ preds_truncation_len = first_ids.shape[
341
+ 1
342
+ ] # remove the first preds because they will be start of patches.
343
+ entropies = entropies[:, 1:]
344
+ if threshold is None:
345
+ num_patches = seq_len // patch_size
346
+ patch_start_ids = entropies.topk(num_patches - 2, dim=1).indices
347
+ patch_start_ids = patch_start_ids.sort(dim=1).values
348
+ else:
349
+ # Assumes that there is at least one token going over the threshold
350
+ if monotonicity:
351
+ patch_start_mask = patch_start_mask_from_entropy_with_monotonicity(
352
+ entropies, threshold
353
+ )
354
+ elif threshold_add is not None and threshold is not None:
355
+ patch_start_mask = patch_start_mask_global_and_monotonicity(
356
+ entropies, threshold, threshold_add
357
+ )
358
+ else:
359
+ patch_start_mask = entropies > threshold
360
+ if not include_next_token:
361
+ patch_start_mask = patch_start_mask[:, :-1]
362
+ # patch_start_mask[1:] |= tokens[:-1] < OFFSET
363
+ patch_start_ids = patch_start_ids_from_patch_start_mask(patch_start_mask)
364
+
365
+ patch_start_ids = torch.cat(
366
+ (first_ids, patch_start_ids + preds_truncation_len), dim=1
367
+ )
368
+ return patch_start_ids
369
+
370
+
371
+ def rightpad(seq, pad_id, max_len):
372
+ return seq + [pad_id] * (max_len - len(seq))
373
+
374
+
375
+ def find_bpe_delim_patch_start_ids(tokens, delim):
376
+ ids = (tokens[:, :-1] == delim).nonzero(as_tuple=False)
377
+ out = [[0, 1] for _ in range(tokens.shape[0])]
378
+ for x, y in ids:
379
+ # start is at delim + 1, delim should be the last element in the patch.
380
+ out[x.item()].append(y.item() + 1)
381
+ max_len = max([len(elt) for elt in out])
382
+ out = [rightpad(elt, tokens.shape[1], max_len) for elt in out]
383
+ patch_start_ids = torch.tensor(out, dtype=tokens.dtype, device=tokens.device)
384
+ return patch_start_ids
385
+
386
+
387
+ def find_lookup_table_start_mask(
388
+ tokens: torch.Tensor, lookup_table: torch.Tensor, include_next_token=True
389
+ ):
390
+ window_size = lookup_table.ndim
391
+ # Unfold the tensor to get sliding windows
392
+ unfolded = tokens.unfold(1, window_size, 1)
393
+ # Gather indices for each dimension
394
+ indices = [unfolded[..., i] for i in range(window_size)]
395
+ # Access the lookup table using the gathered indices
396
+ result = lookup_table[indices]
397
+ return result
398
+
399
+
400
+ def find_lookup_table_patch_start_ids(
401
+ tokens: torch.Tensor, lookup_table: torch.Tensor, include_next_token=True
402
+ ):
403
+ bs, seq_len = tokens.shape
404
+
405
+ first_ids = (
406
+ torch.tensor([0, 1], dtype=torch.long, device=tokens.device)
407
+ .unsqueeze(0)
408
+ .repeat(bs, 1)
409
+ )
410
+ preds_truncation_len = first_ids.shape[1]
411
+ window_size = lookup_table.ndim
412
+ assert window_size == 2, f"{window_size} != 2"
413
+ # output dimensions: token_input shape - window_size + 1 --> we want first ids + this = tokens shape + 1 if next token otherwise just token shape
414
+ token_input = (
415
+ tokens if include_next_token else tokens[:, : -preds_truncation_len + 1]
416
+ )
417
+ if token_input.shape[1] >= window_size:
418
+ patch_start_mask = find_lookup_table_start_mask(
419
+ token_input, lookup_table, include_next_token
420
+ )
421
+ assert (
422
+ patch_start_mask.shape[1]
423
+ == tokens.shape[1] + include_next_token - preds_truncation_len
424
+ ), f"{patch_start_mask.shape[1]} != {tokens.shape[1] + include_next_token - preds_truncation_len}"
425
+ patch_start_ids = patch_start_ids_from_patch_start_mask(patch_start_mask)
426
+ patch_start_ids = torch.cat(
427
+ (first_ids, patch_start_ids + preds_truncation_len), dim=1
428
+ )
429
+ else:
430
+ patch_start_ids = first_ids
431
+ return patch_start_ids
432
+
433
+
434
+ def split_large_numbers(lst, m):
435
+ new_lst = []
436
+ for i in lst:
437
+ if i > m:
438
+ while i > m:
439
+ new_lst.append(m)
440
+ i -= m
441
+ new_lst.append(i)
442
+ else:
443
+ new_lst.append(i)
444
+ assert sum(new_lst) == sum(lst), f"{sum(new_lst)} != {sum(lst)}"
445
+ return new_lst
446
+
447
+
448
+ class Patcher:
449
+ def __init__(self, patcher_args: PatcherArgs):
450
+ self.patcher_args = patcher_args
451
+ self.patching_mode = patcher_args.patching_mode
452
+ self.realtime_patching = patcher_args.realtime_patching
453
+ if self.realtime_patching:
454
+ assert (
455
+ patcher_args.entropy_model_checkpoint_dir is not None
456
+ ), "Cannot require realtime patching without an entropy model checkpoint"
457
+ entropy_model = load_entropy_model(
458
+ patcher_args.entropy_model_checkpoint_dir
459
+ )
460
+ entropy_model, _ = to_device(entropy_model, patcher_args.patching_device)
461
+ self.entropy_model = entropy_model
462
+ else:
463
+ self.entropy_model = None
464
+ self.threshold = patcher_args.threshold
465
+ self.threshold_add = patcher_args.threshold_add
466
+ self.max_patch_length = patcher_args.max_patch_length
467
+ self.patch_size = patcher_args.patch_size
468
+ self.patching_batch_size = patcher_args.patching_batch_size
469
+ self.data_loader_patching = patcher_args.data_loader_patching
470
+ self.device = patcher_args.device
471
+ self.monotonicity = patcher_args.monotonicity
472
+ self.log_time = patcher_args.log_time
473
+ if self.log_time:
474
+ self.log = defaultdict(float)
475
+
476
+ def patch(
477
+ self,
478
+ tokens: torch.Tensor,
479
+ include_next_token: bool = False,
480
+ preds: torch.Tensor | None = None,
481
+ entropies: torch.Tensor | None = None,
482
+ threshold: float = None,
483
+ ) -> torch.Tensor:
484
+ """
485
+ tokens: 2D tensor of shape [batch_size, seq_len] that needs to be patched
486
+ Returns patch lengths and optionally scores associated with the tokens (i.e. entropies, logprobs etc.)
487
+ -> output tensor: [batch_size, max_num_patches]
488
+ each tensor is processed independently and gets right padded with zeros.
489
+
490
+ Patching with the following modes:
491
+ 1. patching_mode = None: static patch size
492
+ 2. patching_mode = "entropy":
493
+ calculate entropy of each token, allocate patches so that the total
494
+ number of patches is the same as static patching but choose to begin
495
+ patches on tokens where the model is most uncertain (highest entropy).
496
+
497
+ When threshold is provided, it uses the threshold to decide when to
498
+ start a new patch.
499
+ 3. patching_mode = "space":
500
+ use space like tokens to define the patches.
501
+ 4. patching_mode = "bpe":
502
+ use bpe delim tokens to define the patches.
503
+
504
+ To correctly patch the last token, it may be necessary to include the next token in the patch
505
+ lengths calculations. This is controlled by the include_next_token argument.
506
+ """
507
+ bs, seq_len = tokens.shape
508
+ seq_len_next_tok = seq_len + 1 if include_next_token else seq_len
509
+ scores = None
510
+ # STATIC
511
+ if self.patching_mode is None:
512
+ patch_lengths = torch.zeros(
513
+ (bs, math.ceil(seq_len_next_tok / self.patch_size)),
514
+ dtype=tokens.dtype,
515
+ device=tokens.device,
516
+ ).fill_(self.patch_size)
517
+ if seq_len_next_tok % self.patch_size != 0:
518
+ patch_lengths[:, -1] = seq_len_next_tok % self.patch_size
519
+ # ENTROPY
520
+ elif self.patching_mode == PatchingModeEnum.entropy:
521
+ if self.log_time:
522
+ s = time.time()
523
+ if entropies is not None:
524
+ scores = torch.tensor(entropies, dtype=torch.float32)
525
+ elif preds is not None:
526
+ scores = entropy(preds)
527
+ else:
528
+ start_entropies = time.time()
529
+ scores = calculate_entropies(
530
+ tokens,
531
+ self.entropy_model,
532
+ self.patching_batch_size,
533
+ self.device,
534
+ )
535
+ if self.log_time:
536
+ self.log["calculate_entropies"] += time.time() - s
537
+ s = time.time()
538
+ patch_start_ids = find_entropy_patch_start_ids(
539
+ scores,
540
+ self.patch_size,
541
+ include_next_token=include_next_token,
542
+ threshold=threshold if threshold is not None else self.threshold,
543
+ threshold_add=self.threshold_add,
544
+ monotonicity=self.monotonicity,
545
+ )
546
+ if self.log_time:
547
+ self.log["find_entropy_patch_start_ids"] += time.time() - s
548
+ s = time.time()
549
+ patch_lengths = patch_lengths_from_start_ids(
550
+ patch_start_ids, seq_len_next_tok
551
+ )
552
+ if self.log_time:
553
+ self.log["patch_lengths_from_start_ids"] += time.time() - s
554
+ s = time.time()
555
+ # BPE
556
+ elif self.patching_mode == PatchingModeEnum.bpe:
557
+ patch_start_ids = find_bpe_delim_patch_start_ids(tokens, delim=BPE_ID)
558
+ patch_lengths = patch_lengths_from_start_ids(
559
+ patch_start_ids, seq_len_next_tok
560
+ )
561
+ elif self.patching_mode == PatchingModeEnum.bpe_patcher:
562
+ patch_start_ids = find_bpe_patcher_patch_start_ids(
563
+ tokens,
564
+ self.entropy_model,
565
+ self.patching_batch_size,
566
+ self.device,
567
+ include_next_token,
568
+ )
569
+ patch_lengths = patch_lengths_from_start_ids(
570
+ patch_start_ids, seq_len_next_tok
571
+ )
572
+ # SPACE
573
+ elif self.patching_mode == PatchingModeEnum.space:
574
+ patch_start_ids = find_space_patch_start_ids(tokens)
575
+ patch_lengths = patch_lengths_from_start_ids(
576
+ patch_start_ids, seq_len_next_tok
577
+ )
578
+ else:
579
+ raise NotImplementedError(f"self.patching_mode {self.patching_mode}")
580
+
581
+ # Apply any processing to patch lengths
582
+ if self.max_patch_length is not None:
583
+ # TODO: avoid going back to a list here.
584
+ patch_lengths = [
585
+ split_large_numbers(pl, self.max_patch_length)
586
+ for pl in patch_lengths.tolist()
587
+ ]
588
+ max_len = max([len(pl) for pl in patch_lengths])
589
+ patch_lengths = [rightpad(pl, 0, max_len=max_len) for pl in patch_lengths]
590
+ patch_lengths = torch.tensor(
591
+ patch_lengths, dtype=tokens.dtype, device=tokens.device
592
+ )
593
+ assert not check_non_zero_after_zero(patch_lengths)
594
+ # Find the last non-zero column index using argmax on a reversed version of the tensor
595
+ last_non_zero_col_reversed = (
596
+ (patch_lengths != 0).flip(dims=[1]).int().argmax(dim=1).min()
597
+ )
598
+ # Slice the tensor up to the last non-zero column
599
+ patch_lengths = patch_lengths[
600
+ :, : patch_lengths.shape[1] - last_non_zero_col_reversed
601
+ ]
602
+ assert (
603
+ torch.sum(patch_lengths)
604
+ == tokens.numel() + include_next_token * tokens.shape[0]
605
+ ), f"{torch.sum(patch_lengths)} != {tokens.numel() + include_next_token * tokens.shape[0]}"
606
+ if self.log_time:
607
+ self.log["postprocessing_patch_lengths"] += time.time() - s
608
+ self.log["tokens"] += patch_lengths.sum().item()
609
+ return patch_lengths, scores
bytelatent/distributed.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import atexit
4
+ import contextlib
5
+ import logging
6
+ import multiprocessing as mp
7
+ import os
8
+ import random
9
+ import shutil
10
+ import signal
11
+ import socket
12
+ import subprocess
13
+ import sys
14
+ import tempfile
15
+ from dataclasses import asdict, dataclass
16
+ from functools import lru_cache, partial, reduce
17
+ from itertools import chain
18
+ from typing import List, Optional, Tuple, Union
19
+
20
+ import torch
21
+
22
+ # for no recompute ops
23
+ import xformers.ops
24
+ from pydantic import BaseModel, ConfigDict
25
+ from torch import distributed as dist
26
+ from torch.distributed import ReduceOp
27
+ from torch.distributed._composable.fsdp import MixedPrecisionPolicy, fully_shard
28
+ from torch.distributed._tensor import DTensor
29
+ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
30
+ checkpoint_wrapper,
31
+ )
32
+ from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
33
+ from torch.nn.parallel import DistributedDataParallel as DDP
34
+ from torch.utils.checkpoint import (
35
+ CheckpointPolicy,
36
+ create_selective_checkpoint_contexts,
37
+ )
38
+
39
+ from bytelatent.float8 import convert_linears_to_fp8
40
+
41
+ logger = logging.getLogger()
42
+
43
+ # for selective AC
44
+ default_no_recompute_ops = {
45
+ torch.ops.aten.mm.default,
46
+ torch.ops.aten._scaled_mm.default,
47
+ torch.ops.aten._scaled_dot_product_efficient_attention.default,
48
+ torch.ops.aten._scaled_dot_product_flash_attention.default,
49
+ torch.ops.c10d_functional.reduce_scatter_tensor.default,
50
+ torch.ops.xformers_flash.flash_fwd.default,
51
+ torch.ops.xformers.efficient_attention_forward_cutlass.default,
52
+ }
53
+
54
+
55
+ class DistributedArgs(BaseModel):
56
+ model_config = ConfigDict(extra="forbid")
57
+ dp_shard: int = (
58
+ 1 # In how many shard to split the model weight. Typically number gpu in a node.
59
+ )
60
+ dp_replicate: int = (
61
+ 1 # How many times to replicate the model weight. Typically number of nodes.
62
+ )
63
+ tp_size: int = 1
64
+ selective_activation_checkpointing: bool = False
65
+ compile: bool = False
66
+ fsdp_type: str = "no_shard"
67
+ model_dtype: str = "bf16"
68
+ float8_recipe: str | None = None
69
+ float8_filter: str = r"layers\.[0-9]+\."
70
+
71
+ matmul_allow_tf32: bool = False
72
+ allow_bf16_reduced_precision_reduction: bool = True
73
+ detect_anomaly: bool = False
74
+
75
+ compile_cache_size_limit: int = 8
76
+
77
+ spawn_method: str = "forkserver"
78
+
79
+
80
+ class EnvironmentArgs(BaseModel):
81
+ model_config = ConfigDict(extra="forbid")
82
+ # Use GNU openMP (GOMP) instead of Intel OpenMP [Intel Math Kernel Library (MKL)]
83
+ MKL_SERVICE_FORCE_INTEL: str = "GNU"
84
+ OMP_NUM_THREADS: str = "1"
85
+ MKL_NUM_THREADS: str = "1"
86
+ # faster intra-node collectives, seems to be a cluster specific flag
87
+ ENABLE_INTRA_NODE_COMM: str = "1"
88
+ # avoids OOMs with long context
89
+ TORCH_NCCL_AVOID_RECORD_STREAMS: str = "1"
90
+ # increasing NCCL timeout time before having some NCCL error 22 should give a 16s timeout
91
+ NCCL_IB_TIMEOUT: str = "22"
92
+ NCCL_DEBUG: str = "INFO"
93
+ TORCH_NCCL_ASYNC_ERROR_HANDLING: str = "1"
94
+
95
+
96
+ def get_device_mesh(distributed_args: DistributedArgs):
97
+ tp_size = distributed_args.tp_size
98
+ dp_replicate = distributed_args.dp_replicate
99
+ dp_shard = distributed_args.dp_shard
100
+
101
+ assert (
102
+ dp_replicate * dp_shard * tp_size == get_world_size()
103
+ ), f"dp_replicate * dp_shard * tp_size ({dp_replicate} * {dp_shard} * {tp_size}) != world_size ({get_world_size()})"
104
+
105
+ dims = []
106
+ names = []
107
+ if dp_replicate >= 1:
108
+ dims.append(dp_replicate)
109
+ names.append("dp_replicate")
110
+ if dp_shard > 1 or distributed_args.fsdp_type == "no_shard":
111
+ dims.append(dp_shard)
112
+ names.append("dp_shard")
113
+ if tp_size > 1:
114
+ dims.append(tp_size)
115
+ names.append("tp")
116
+ dims = tuple(dims)
117
+ names = tuple(names)
118
+
119
+ return init_device_mesh("cuda", mesh_shape=dims, mesh_dim_names=names)
120
+
121
+
122
+ def dist_max(x: Union[int, float], mesh: DeviceMesh = None):
123
+ tensor = torch.tensor(x).cuda()
124
+ dist.all_reduce(tensor, op=ReduceOp.MAX, group=mesh.get_group() if mesh else None)
125
+ return tensor
126
+
127
+
128
+ def dist_mean(x: Union[int, float], mesh: DeviceMesh = None):
129
+ tensor = torch.tensor(x).cuda()
130
+ dist.all_reduce(tensor, op=ReduceOp.AVG, group=mesh.get_group() if mesh else None)
131
+ return tensor
132
+
133
+
134
+ def dist_mean_dict(x):
135
+ r = dict()
136
+ for k in x:
137
+ r[k] = dist_mean(x[k])
138
+ r[k] = r[k].item() if (r[k].dim() == 0) else r[k].tolist()
139
+ return r
140
+
141
+
142
+ @lru_cache()
143
+ def get_is_torch_run() -> bool:
144
+ return os.environ.get("LOCAL_RANK") is not None
145
+
146
+
147
+ @lru_cache()
148
+ def get_is_slurm_job() -> bool:
149
+ return "SLURM_JOB_ID" in os.environ and not get_is_torch_run()
150
+
151
+
152
+ @lru_cache()
153
+ def get_global_rank() -> int:
154
+ if get_is_torch_run():
155
+ return int(os.environ["RANK"])
156
+ elif get_is_slurm_job():
157
+ return int(os.environ["SLURM_PROCID"])
158
+ else:
159
+ return 0
160
+
161
+
162
+ @lru_cache()
163
+ def get_local_rank() -> int:
164
+ if get_is_torch_run():
165
+ return int(os.environ["LOCAL_RANK"])
166
+ elif get_is_slurm_job():
167
+ return int(os.environ["SLURM_LOCALID"])
168
+ else:
169
+ return 0
170
+
171
+
172
+ @lru_cache()
173
+ def get_world_size() -> int:
174
+ if get_is_torch_run():
175
+ return int(os.environ["WORLD_SIZE"])
176
+ elif get_is_slurm_job():
177
+ return int(os.environ["SLURM_NTASKS"])
178
+ else:
179
+ return 1
180
+
181
+
182
+ @lru_cache()
183
+ def get_is_master() -> bool:
184
+ return get_global_rank() == 0
185
+
186
+
187
+ @lru_cache()
188
+ def get_master_port(job_id: int) -> int:
189
+ if get_is_torch_run():
190
+ return int(os.environ["MASTER_PORT"])
191
+ else:
192
+ MIN_MASTER_PORT, MAX_MASTER_PORT = (20000, 60000)
193
+ rng = random.Random(job_id)
194
+ return rng.randint(MIN_MASTER_PORT, MAX_MASTER_PORT)
195
+
196
+
197
+ @lru_cache()
198
+ def get_master_addr() -> str:
199
+ if get_is_torch_run():
200
+ return os.environ["MASTER_ADDR"]
201
+ elif get_is_slurm_job():
202
+ hostnames = subprocess.check_output(
203
+ ["scontrol", "show", "hostnames", os.environ["SLURM_JOB_NODELIST"]]
204
+ )
205
+ return hostnames.split()[0].decode("utf-8")
206
+ else:
207
+ return "127.0.0.1"
208
+
209
+
210
+ def setup_env(env_args: EnvironmentArgs):
211
+ env_vars = env_args.model_dump()
212
+
213
+ # When using Triton, it attempts to locate prebuilt kernels in a cache
214
+ # located at ~/.triton/cache, but when that's backed by NFS this can fail
215
+ # with a "OSError: [Errno 116] Stale file handle" error. If we were to set
216
+ # it to a local directory it would belong to the first user who created it
217
+ # and it would fail for the job of any other successive user assigned to
218
+ # that machine. To avoid all this mess we use a temporary per-process cache.
219
+ triton_cache_dir = tempfile.mkdtemp()
220
+ atexit.register(shutil.rmtree, triton_cache_dir, ignore_errors=True)
221
+ env_vars["TRITON_CACHE_DIR"] = triton_cache_dir
222
+
223
+ # We change the tmp dir to /scratch in case it's slurm job
224
+ # This avoids filling up the host's usually limited tmpfs
225
+ # A full tmpfs leads to very slow creation of processes and weird bugs
226
+ if get_is_slurm_job():
227
+ new_tmp = f"/scratch/slurm_tmpdir/{os.environ['SLURM_JOB_ID']}"
228
+ if os.path.exists(new_tmp):
229
+ env_vars["TMP_DIR"] = new_tmp
230
+
231
+ for name, value in env_vars.items():
232
+ if os.environ.get(name) != str(value):
233
+ os.environ[name] = str(value)
234
+ logger.warning(f"WARNING: Setting {name} to {value}")
235
+
236
+
237
+ def setup_torch_distributed(dist_args):
238
+ """
239
+ Handle single and multi-GPU / multi-node / SLURM jobs.
240
+ Initialize the following variables:
241
+ - global_rank
242
+ - world_size
243
+ """
244
+ mp.set_start_method(dist_args.spawn_method)
245
+ with mp.Manager():
246
+ pass
247
+
248
+ local_rank = get_local_rank()
249
+
250
+ os.environ["RANK"] = str(get_global_rank())
251
+ os.environ["WORLD_SIZE"] = str(get_world_size())
252
+ os.environ["MASTER_ADDR"] = get_master_addr()
253
+ os.environ["MASTER_PORT"] = str(
254
+ get_master_port(job_id=int(os.environ.get("SLURM_JOB_ID", -1)))
255
+ )
256
+
257
+ if get_is_torch_run():
258
+ logger.info(f"Run launched with torchrun, local rank: {local_rank}")
259
+ elif get_is_slurm_job():
260
+ logger.info(f"Run launched with slurm, local rank: {local_rank}")
261
+ else:
262
+ logger.info("Single GPU job")
263
+
264
+ logger.info(f"ENV: {os.environ}")
265
+
266
+ # set GPU device
267
+ assert 0 <= local_rank < 8
268
+ if dist_args.matmul_allow_tf32:
269
+ torch.backends.cuda.matmul.allow_tf32 = True
270
+ logger.warning(
271
+ f"WARNING: Setting torch.backends.matmul.allow_tf32 to True. This is faster but less accurate."
272
+ )
273
+ torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = (
274
+ dist_args.allow_bf16_reduced_precision_reduction
275
+ )
276
+ if torch.cuda.device_count() > 1:
277
+ torch.cuda.set_device(local_rank)
278
+ torch.distributed.init_process_group(init_method="env://", backend="nccl")
279
+ torch.autograd.set_detect_anomaly(dist_args.detect_anomaly)
280
+
281
+
282
+ def get_module(module, access_string):
283
+ names = access_string.split(sep=".")
284
+ return reduce(getattr, names, module)
285
+
286
+
287
+ def set_module(module, access_string, value):
288
+ names = access_string.split(sep=".")
289
+ parent = reduce(getattr, names[:-1], module)
290
+ setattr(parent, names[-1], value)
291
+
292
+
293
+ def default_fsdp_grouping_plan(n_layers: int) -> List[Tuple[str, bool]]:
294
+ return [(f"layers.{i}", i < n_layers - 1) for i in range(n_layers)]
295
+
296
+
297
+ def get_default_policy(no_recompute_ops=None):
298
+ no_recompute_ops = no_recompute_ops or default_no_recompute_ops
299
+
300
+ def default_policy(ctx, func, *args, **kwargs):
301
+ return (
302
+ CheckpointPolicy.MUST_SAVE
303
+ if func in no_recompute_ops
304
+ else CheckpointPolicy.PREFER_RECOMPUTE
305
+ )
306
+
307
+ return default_policy
308
+
309
+
310
+ @torch.no_grad()
311
+ def check_model_value_range(
312
+ model: torch.nn.Module, range: float = 1e3, std: float = 1e3
313
+ ):
314
+ for name, param in chain(model.named_parameters(), model.named_buffers()):
315
+ if isinstance(param, DTensor):
316
+ param = param.to_local()
317
+
318
+ if param.numel() == 0:
319
+ logger.warning(
320
+ f"Model parameter {name} is empty, probably because of FSDP sharding"
321
+ )
322
+ continue
323
+
324
+ if torch.isnan(param).any() or torch.isinf(param).any():
325
+ logger.warning(f"Model parameter {name} contains NaN or Inf")
326
+
327
+ param_range = param.max() - param.min()
328
+ param_std = param.std()
329
+ if param_range > range:
330
+ logger.warning(
331
+ f"Model parameter {name} has a suspiciously large range ({param_range}): please check initialization and init_weights is defined and called"
332
+ )
333
+ if param_std > std:
334
+ logger.warning(
335
+ f"Model parameter {name} has a suspiciously large standard deviation ({param_std}): please check initialization and init_weights is defined and called"
336
+ )
337
+ if (param == 0).all():
338
+ logger.warning(
339
+ f"Model parameter {name} is all zeros: it might be because of a missing initialization"
340
+ )
341
+
342
+
343
+ def init_signal_handler(callable):
344
+ """
345
+ Handle signals sent by SLURM for time limit / pre-emption.
346
+ """
347
+ signal.signal(signal.SIGUSR2, callable)
348
+ logger.warning("Signal handler installed.")
349
+
350
+
351
+ def requeue_slurm_job():
352
+ prod_id = int(os.environ["SLURM_PROCID"])
353
+ logger.warning("Host: %s - Global rank: %i" % (socket.gethostname(), prod_id))
354
+ if prod_id == 0 and os.environ.get("LAUNCH_WITH", "") != "DORA":
355
+ logger.warning("Requeuing job " + os.environ["SLURM_JOB_ID"])
356
+ os.system("scontrol requeue " + os.environ["SLURM_JOB_ID"])
357
+ else:
358
+ logger.warning("Not the master process, no need to requeue.")
359
+ sys.exit(0)
360
+
361
+
362
+ @contextlib.contextmanager
363
+ def clean_env():
364
+ distrib_names = (
365
+ "MASTER_ADDR",
366
+ "MASTER_PORT",
367
+ "RANK",
368
+ "WORLD_SIZE",
369
+ "LOCAL_RANK",
370
+ "LOCAL_WORLD_SIZE",
371
+ "TORCHELASTIC_RUN_ID",
372
+ "DORA_FORCE_DISTRIB",
373
+ )
374
+ cluster_env = {
375
+ x: os.environ.pop(x)
376
+ for x in os.environ
377
+ if x.startswith(
378
+ ("SLURM_", "SLURMD_", "SRUN_", "SBATCH_", "SUBMITIT_", "WANDB_")
379
+ )
380
+ or x in distrib_names
381
+ }
382
+ try:
383
+ yield
384
+ finally:
385
+ os.environ.update(cluster_env)
386
+
387
+
388
+ def parallelize_model(
389
+ model,
390
+ device_mesh,
391
+ model_args,
392
+ distributed_args: DistributedArgs,
393
+ fsdp_grouping_plan: Optional[List[Tuple[str, bool]]] = None,
394
+ tp_parallelize=None,
395
+ no_recompute_ops=None,
396
+ ):
397
+ if distributed_args.tp_size > 1:
398
+ assert (
399
+ distributed_args.fsdp_type == "full_shard"
400
+ ), "Only full shard is supported for TP parallelism"
401
+ assert tp_parallelize is not None, "TP plan is required for TP parallelism"
402
+ assert (
403
+ distributed_args.compile == False
404
+ ), "Compile is not supported for TP parallelism"
405
+
406
+ tp_parallelize(model, device_mesh["tp"], model_args, distributed_args)
407
+
408
+ if distributed_args.float8_recipe is not None:
409
+ if distributed_args.tp_size > 1:
410
+ raise RuntimeError("float8 is incompatible with tensor-parallelism for now")
411
+ model = convert_linears_to_fp8(
412
+ model, distributed_args.float8_recipe, distributed_args.float8_filter
413
+ )
414
+
415
+ param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[
416
+ distributed_args.model_dtype
417
+ ]
418
+ if (
419
+ distributed_args.fsdp_type == "full_shard"
420
+ or distributed_args.fsdp_type == "no_shard"
421
+ ):
422
+ if distributed_args.fsdp_type == "no_shard":
423
+ assert (
424
+ distributed_args.dp_shard == 1
425
+ ), "dp_shard must be 1 for no_shard fsdp_type"
426
+ assert (
427
+ device_mesh["dp_shard"].size() == 1
428
+ ), "dp_shard must be 1 for no_shard fsdp_type"
429
+
430
+ fsdp_config = dict(
431
+ mp_policy=(
432
+ MixedPrecisionPolicy(
433
+ param_dtype=param_dtype,
434
+ reduce_dtype=torch.float32,
435
+ )
436
+ ),
437
+ mesh=(
438
+ device_mesh["dp_replicate", "dp_shard"]
439
+ if distributed_args.dp_shard > 1
440
+ or distributed_args.fsdp_type == "no_shard"
441
+ else device_mesh["dp_replicate"]
442
+ ),
443
+ )
444
+
445
+ if fsdp_grouping_plan is None:
446
+ # Assume that the model has list of layers and group around it
447
+ fsdp_grouping_plan = default_fsdp_grouping_plan(len(model.layers))
448
+
449
+ for path, reshard_after_forward in fsdp_grouping_plan:
450
+ module = get_module(model, path)
451
+ set_module(
452
+ model,
453
+ path,
454
+ fully_shard(
455
+ module, **fsdp_config, reshard_after_forward=reshard_after_forward
456
+ ),
457
+ )
458
+
459
+ model = fully_shard(model, **fsdp_config, reshard_after_forward=True)
460
+ else:
461
+ raise ValueError(f"Invalid fsdp_type: {distributed_args.fsdp_type}")
462
+
463
+ if distributed_args.selective_activation_checkpointing:
464
+ model = checkpoint_wrapper(
465
+ model,
466
+ context_fn=partial(
467
+ create_selective_checkpoint_contexts,
468
+ get_default_policy(no_recompute_ops),
469
+ ),
470
+ )
471
+
472
+ if distributed_args.compile:
473
+ torch._dynamo.config.cache_size_limit = (
474
+ distributed_args.compile_cache_size_limit
475
+ )
476
+ model = torch.compile(model)
477
+
478
+ return model
bytelatent/entropy_model.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ import json
3
+ import os
4
+ import re
5
+
6
+ import torch
7
+
8
+ from bytelatent.transformer import LMTransformer, LMTransformerArgs
9
+
10
+
11
+ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cpu"):
12
+ with open(os.path.join(entropy_model_checkpoint_dir, "params.json")) as fr:
13
+ reloaded = json.loads(fr.read())
14
+
15
+ torch.set_default_dtype(torch.bfloat16)
16
+ model_params = reloaded["model"]
17
+ entropy_model = LMTransformer(
18
+ LMTransformerArgs(
19
+ dim=model_params["dim"],
20
+ n_layers=model_params["n_layers"],
21
+ n_heads=model_params["n_heads"],
22
+ max_seqlen=model_params["max_length"],
23
+ ffn_dim_multiplier=model_params["ffn_dim_multiplier"],
24
+ vocab_size=model_params["vocab_size"],
25
+ )
26
+ )
27
+
28
+ entropy_model.load_state_dict(
29
+ torch.load(state_dict_path, map_location=device), strict=False
30
+ )
31
+ entropy_model.to(device)
32
+ entropy_model = entropy_model.eval()
33
+ # no grads for the model:
34
+ for param in entropy_model.parameters():
35
+ param.requires_grad = False
36
+ return entropy_model
bytelatent/float8.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import re
4
+ import warnings
5
+ from typing import Callable
6
+
7
+ import torch
8
+
9
+ # avoid division by zero when calculating scale
10
+ EPS = 1e-12
11
+
12
+
13
+ def scale(t, amax_t, dtype_t):
14
+ min_v, max_v = torch.finfo(dtype_t).min, torch.finfo(dtype_t).max
15
+ scale_t = torch.clamp(amax_t.float(), min=EPS) / max_v
16
+ t_fp8 = (t / scale_t).clamp(min=min_v, max=max_v).to(dtype_t)
17
+ return t_fp8, scale_t
18
+
19
+
20
+ def matmul(
21
+ first, amax_first, dtype_first, second_t, amax_second_t, dtype_second_t, bias
22
+ ):
23
+ first_fp8, scale_first = scale(first, amax_first, dtype_first)
24
+ second_t_fp8, scale_second_t = scale(second_t, amax_second_t, dtype_second_t)
25
+ output = torch._scaled_mm(
26
+ first_fp8,
27
+ second_t_fp8.t(),
28
+ scale_a=scale_first,
29
+ scale_b=scale_second_t.t(),
30
+ bias=bias,
31
+ out_dtype=torch.bfloat16,
32
+ use_fast_accum=True,
33
+ )
34
+ return output
35
+
36
+
37
+ @torch._dynamo.allow_in_graph
38
+ class Fp8LinearFn(torch.autograd.Function):
39
+ @staticmethod
40
+ def forward(ctx, a, b_t, bias):
41
+ amax_a = a.abs().amax(dim=-1, keepdim=True)
42
+ amax_b_t = b_t.abs().amax(dim=-1, keepdim=True)
43
+ out = matmul(
44
+ a, amax_a, torch.float8_e4m3fn, b_t, amax_b_t, torch.float8_e4m3fn, bias
45
+ )
46
+
47
+ ctx.a_requires_grad = a.requires_grad
48
+ ctx.b_requires_grad = b_t.requires_grad
49
+ ctx.bias_requires_grad = bias.requires_grad if bias is not None else False
50
+
51
+ ctx.save_for_backward(a, b_t, amax_b_t.max())
52
+
53
+ return out
54
+
55
+ @staticmethod
56
+ def backward(ctx, grad_out):
57
+ a, b_t, amax_b = ctx.saved_tensors
58
+
59
+ if ctx.a_requires_grad:
60
+ b = b_t.t().contiguous()
61
+ amax_grad_out = grad_out.abs().amax(dim=-1, keepdim=True)
62
+ amax_b = amax_b.repeat(b.shape[0], 1)
63
+ grad_a = matmul(
64
+ grad_out,
65
+ amax_grad_out,
66
+ torch.float8_e4m3fn,
67
+ b,
68
+ amax_b,
69
+ torch.float8_e4m3fn,
70
+ None,
71
+ )
72
+ else:
73
+ grad_a = None
74
+ if ctx.b_requires_grad:
75
+ grad_b = grad_out.t() @ a
76
+ else:
77
+ grad_b = None
78
+ if ctx.bias_requires_grad:
79
+ grad_bias = grad_out.sum(dim=0)
80
+ else:
81
+ grad_bias = None
82
+
83
+ return grad_a, grad_b, grad_bias
84
+
85
+
86
+ class Fp8Linear(torch.nn.Linear):
87
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
88
+ out = Fp8LinearFn.apply(input.flatten(end_dim=-2), self.weight, self.bias)
89
+ out = out.unflatten(0, input.shape[:-1])
90
+ return out
91
+
92
+
93
+ def named_replace(
94
+ fn: Callable[[torch.nn.Module, str], torch.nn.Module],
95
+ module: torch.nn.Module,
96
+ name="",
97
+ ) -> torch.nn.Module:
98
+ for child_name, child_module in list(module.named_children()):
99
+ full_name = f"{name}.{child_name}" if name else child_name
100
+ new_child_module = named_replace(fn, child_module, full_name)
101
+ setattr(module, child_name, new_child_module)
102
+ module = fn(module, name)
103
+ return module
104
+
105
+
106
+ def convert_linears_to_fp8(
107
+ root_module: torch.nn.Module, recipe: str, filter: str
108
+ ) -> torch.nn.Module:
109
+ if recipe not in ["rowwise"]:
110
+ raise RuntimeError(f"Unknown float8 recipe {recipe!r}")
111
+
112
+ if recipe == "rowwise" and torch.__version__ < "2.5":
113
+ # We need https://github.com/pytorch/pytorch/pull/134781.
114
+ warnings.warn("Float8 row-wise scaling is slow in PyTorch prior to v2.5.0")
115
+
116
+ # Multi-kernel makes Inductor auto-tune between a regular "streaming"-based
117
+ # reduction kernel and a "persistent" reduction kernel. Since fp8 has some
118
+ # multi-pass steps (e.g., first get amax, then scale), persistent kernels
119
+ # should perform better.
120
+ torch._inductor.config.triton.multi_kernel = 1
121
+
122
+ filter_re = re.compile(filter)
123
+
124
+ def replace(module: torch.nn.Module, name: str) -> torch.nn.Module:
125
+ if not isinstance(module, torch.nn.Linear) or not filter_re.search(name):
126
+ return module
127
+ if type(module) == torch.nn.Linear:
128
+ if recipe == "rowwise":
129
+ new_module = Fp8Linear(
130
+ in_features=module.in_features,
131
+ out_features=module.out_features,
132
+ bias=module.bias is not None,
133
+ dtype=module.weight.dtype,
134
+ device=module.weight.device,
135
+ )
136
+ new_module.weight = module.weight
137
+ new_module.bias = module.bias
138
+ else:
139
+ assert False, recipe
140
+ else:
141
+ assert False, str(type(module))
142
+ return new_module
143
+
144
+ out = named_replace(replace, root_module)
145
+
146
+ # Force re-compile everything
147
+ torch._dynamo.reset_code_caches()
148
+ from torch._inductor.cudagraph_trees import reset_cudagraph_trees
149
+
150
+ reset_cudagraph_trees()
151
+
152
+ return out
bytelatent/logger.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import logging
4
+ import math
5
+ import sys
6
+ import time
7
+ from datetime import timedelta
8
+
9
+ from bytelatent.distributed import get_global_rank, get_is_slurm_job
10
+
11
+
12
+ class LogFormatter(logging.Formatter):
13
+ """
14
+ Custom logger for distributed jobs, displaying rank
15
+ and preserving indent from the custom prefix format.
16
+ """
17
+
18
+ def __init__(self):
19
+ self.start_time = time.time()
20
+ self.rank = get_global_rank()
21
+ self.show_rank = not get_is_slurm_job() # srun has --label
22
+
23
+ def formatTime(self, record):
24
+ subsecond, seconds = math.modf(record.created)
25
+ curr_date = (
26
+ time.strftime("%y-%m-%d %H:%M:%S", time.localtime(seconds))
27
+ + f".{int(subsecond * 1_000_000):06d}"
28
+ )
29
+ delta = timedelta(seconds=round(record.created - self.start_time))
30
+ return f"{curr_date} - {delta}"
31
+
32
+ def formatPrefix(self, record):
33
+ fmt_time = self.formatTime(record)
34
+ if self.show_rank:
35
+ return f"{self.rank}: {record.levelname:<7} {fmt_time} - "
36
+ else:
37
+ return f"{record.levelname:<7} {fmt_time} - "
38
+
39
+ def formatMessage(self, record, indent: str):
40
+ content = record.getMessage()
41
+ content = content.replace("\n", "\n" + indent)
42
+ # Exception handling as in the default formatter, albeit with indenting
43
+ # according to our custom prefix
44
+ if record.exc_info:
45
+ # Cache the traceback text to avoid converting it multiple times
46
+ # (it's constant anyway)
47
+ if not record.exc_text:
48
+ record.exc_text = self.formatException(record.exc_info)
49
+ if record.exc_text:
50
+ if content[-1:] != "\n":
51
+ content = content + "\n" + indent
52
+ content = content + indent.join(
53
+ [l + "\n" for l in record.exc_text.splitlines()]
54
+ )
55
+ if content[-1:] == "\n":
56
+ content = content[:-1]
57
+ if record.stack_info:
58
+ if content[-1:] != "\n":
59
+ content = content + "\n" + indent
60
+ stack_text = self.formatStack(record.stack_info)
61
+ content = content + indent.join([l + "\n" for l in stack_text.splitlines()])
62
+ if content[-1:] == "\n":
63
+ content = content[:-1]
64
+
65
+ return content
66
+
67
+ def format(self, record):
68
+ prefix = self.formatPrefix(record)
69
+ indent = " " * len(prefix)
70
+ content = self.formatMessage(record, indent)
71
+ return prefix + content
72
+
73
+
74
+ def set_root_log_level(log_level: str):
75
+ logger = logging.getLogger()
76
+ level: int | str = log_level.upper()
77
+ try:
78
+ level = int(log_level)
79
+ except ValueError:
80
+ pass
81
+ try:
82
+ logger.setLevel(level) # type: ignore
83
+ except Exception:
84
+ logger.warning(
85
+ f"Failed to set logging level to {log_level}, using default 'NOTSET'"
86
+ )
87
+ logger.setLevel(logging.NOTSET)
88
+
89
+
90
+ def init_logger(
91
+ log_file: str | None = None,
92
+ *,
93
+ name: str | None = None,
94
+ level: str = "NOTSET",
95
+ ):
96
+ """
97
+ Setup logging.
98
+
99
+ Args:
100
+ log_file: A file name to save file logs to.
101
+ name: The name of the logger to configure, by default the root logger.
102
+ level: The logging level to use.
103
+ """
104
+ set_root_log_level(level)
105
+ logger = logging.getLogger(name)
106
+
107
+ # stdout: everything
108
+ stdout_handler = logging.StreamHandler(sys.stdout)
109
+ stdout_handler.setLevel(logging.NOTSET)
110
+ stdout_handler.setFormatter(LogFormatter())
111
+
112
+ # stderr: warnings / errors and above
113
+ stderr_handler = logging.StreamHandler(sys.stderr)
114
+ stderr_handler.setLevel(logging.WARNING)
115
+ stderr_handler.setFormatter(LogFormatter())
116
+
117
+ # set stream handlers
118
+ logger.handlers.clear()
119
+ logger.handlers.append(stdout_handler)
120
+ logger.handlers.append(stderr_handler)
121
+
122
+ if log_file is not None and get_global_rank() == 0:
123
+ # build file handler
124
+ file_handler = logging.FileHandler(log_file, "a")
125
+ file_handler.setLevel(logging.NOTSET)
126
+ file_handler.setFormatter(LogFormatter())
127
+ # update logger
128
+ logger = logging.getLogger()
129
+ logger.addHandler(file_handler)
bytelatent/metrics.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3
+
4
+ import json
5
+ import logging
6
+ from collections import namedtuple
7
+ from dataclasses import asdict
8
+ from datetime import datetime, timezone
9
+ from pathlib import Path
10
+ from typing import Any, Union
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import wandb
15
+ from pydantic import BaseModel, ConfigDict
16
+
17
+ from bytelatent.distributed import get_is_master
18
+
19
+ logger = logging.getLogger()
20
+
21
+
22
+ class WandbArgs(BaseModel):
23
+ model_config = ConfigDict(extra="forbid")
24
+ job_type: str | None = None
25
+ dir: str | None = None
26
+ project: str | None = None
27
+ entity: str | None = None
28
+ tags: list | None = None
29
+ group: str | None = None
30
+ name: str | None = None
31
+ notes: str | None = None
32
+ config_exclude_keys: list[str] | None = None
33
+ config_include_keys: list[str] | None = None
34
+ anonymous: str | None = None
35
+ mode: str | None = None
36
+ allow_val_change: bool | None = None
37
+ resume: Union[bool, str] | None = None
38
+ force: bool | None = None
39
+ tensorboard: bool | None = None
40
+ sync_tensorboard: bool | None = None
41
+ monitor_gym: bool | None = None
42
+ save_code: bool | None = None
43
+ id: str | None = None
44
+ fork_from: str | None = None
45
+ resume_from: str | None = None
46
+
47
+
48
+ class LoggingArgs(BaseModel):
49
+ model_config = ConfigDict(extra="forbid")
50
+ freq: int = 10 # Log every freq optimizer steps
51
+ acc_freq: int | None = None # Log every acc_freq gradient accumulation steps
52
+
53
+ wandb: WandbArgs | None = None
54
+
55
+
56
+ class MetricLogger:
57
+ def __init__(self, outdir: Path, args: Any | None = None):
58
+ self.outdir = outdir
59
+ self.jsonl_writer = None
60
+ self.args = args
61
+
62
+ def open(self):
63
+ if self.jsonl_writer is None:
64
+ self.jsonl_writer = open(self.outdir, "a")
65
+ if (
66
+ self.args is not None
67
+ and self.args.logging.wandb is not None
68
+ and get_is_master()
69
+ ):
70
+ run = wandb.init(
71
+ config=asdict(self.args),
72
+ **asdict(self.args.logging.wandb),
73
+ )
74
+
75
+ def log(self, metrics: dict[str, Any]):
76
+ if (
77
+ self.args is not None
78
+ and self.args.logging.wandb is not None
79
+ and (wandb.run is not None)
80
+ ):
81
+ wandb.log(metrics, step=metrics["global_step"])
82
+
83
+ metrics.update({"created_at": datetime.now(timezone.utc).isoformat()})
84
+ print(json.dumps(metrics), file=self.jsonl_writer, flush=True)
85
+
86
+ def close(self):
87
+ if self.jsonl_writer is not None:
88
+ self.jsonl_writer.close()
89
+ self.jsonl_writer = None
90
+
91
+ def __enter__(self):
92
+ self.open()
93
+ return self
94
+
95
+ def __exit__(self, exc_type, exc_value, traceback):
96
+ self.close()
97
+
98
+ def __del__(self):
99
+ self.close()
100
+
101
+
102
+ GPUMemStats = namedtuple(
103
+ "GPUMemStats",
104
+ [
105
+ "max_active_gib",
106
+ "max_active_pct",
107
+ "max_reserved_gib",
108
+ "max_reserved_pct",
109
+ "num_alloc_retries",
110
+ "num_ooms",
111
+ "power_draw",
112
+ ],
113
+ )
114
+
115
+
116
+ class GPUMemoryMonitor:
117
+ """
118
+ Class to monitor GPU memory usage
119
+ """
120
+
121
+ def __init__(self, device: str = "cuda:0"):
122
+ self.device = torch.device(device) # device object
123
+ self.device_name = torch.cuda.get_device_name(self.device)
124
+ self.device_index = torch.cuda.current_device()
125
+ self.device_capacity = torch.cuda.get_device_properties(
126
+ self.device
127
+ ).total_memory
128
+ self.device_capacity_gib = self._to_gib(self.device_capacity)
129
+
130
+ # reset stats, clear cache
131
+ torch.cuda.reset_peak_memory_stats()
132
+ torch.cuda.empty_cache()
133
+
134
+ def _to_gib(self, memory_in_bytes):
135
+ # NOTE: GiB (gibibyte) is 1024, vs GB is 1000
136
+ _gib_in_bytes = 1024 * 1024 * 1024
137
+ memory_in_gib = memory_in_bytes / _gib_in_bytes
138
+ return memory_in_gib
139
+
140
+ def _to_pct(self, memory):
141
+ return 100 * memory / self.device_capacity
142
+
143
+ def get_peak_stats(self):
144
+ cuda_info = torch.cuda.memory_stats(self.device)
145
+
146
+ max_active = cuda_info["active_bytes.all.peak"]
147
+ max_active_gib = self._to_gib(max_active)
148
+ max_active_pct = self._to_pct(max_active)
149
+
150
+ max_reserved = cuda_info["reserved_bytes.all.peak"]
151
+ max_reserved_gib = self._to_gib(max_reserved)
152
+ max_reserved_pct = self._to_pct(max_reserved)
153
+
154
+ num_retries = cuda_info["num_alloc_retries"]
155
+ num_ooms = cuda_info["num_ooms"]
156
+ power_draw = torch.cuda.power_draw()
157
+
158
+ if num_retries > 0:
159
+ logger.warning(f"{num_retries} CUDA memory allocation retries.")
160
+ if num_ooms > 0:
161
+ logger.warning(f"{num_ooms} CUDA OOM errors thrown.")
162
+
163
+ return GPUMemStats(
164
+ max_active_gib,
165
+ max_active_pct,
166
+ max_reserved_gib,
167
+ max_reserved_pct,
168
+ num_retries,
169
+ num_ooms,
170
+ power_draw,
171
+ )
172
+
173
+ def reset_peak_stats(self):
174
+ torch.cuda.reset_peak_memory_stats()
175
+ torch.cuda.reset_accumulated_memory_stats()
176
+
177
+ def __str__(self):
178
+ mem_stats = self.get_peak_stats()
179
+ display_str = f"{self.device_name} ({self.device_index}): {self.device_capacity_gib} GiB capacity, "
180
+ display_str += (
181
+ f"{mem_stats.max_reserved_gib} GiB peak, {mem_stats.max_reserved_pct}% peak"
182
+ )
183
+ return f"{display_str}"
184
+
185
+
186
+ def upload_train_to_wandb(
187
+ ckpt_dir, project="lingua", entity="codegen-team", train=True, eval=True
188
+ ):
189
+ import json
190
+ from pathlib import Path
191
+
192
+ import wandb
193
+ from omegaconf import OmegaConf
194
+
195
+ cfg = OmegaConf.load(Path(ckpt_dir) / "config.yaml")
196
+ cfg = OmegaConf.to_container(cfg)
197
+
198
+ if train:
199
+ wandb.init(config=cfg, name=cfg["name"], project=project, entity=entity)
200
+
201
+ with open(Path(ckpt_dir) / "metrics.jsonl") as f:
202
+ for l in f:
203
+ m = json.loads(l)
204
+ wandb.log(m, step=m["global_step"])
205
+
206
+ wandb.finish()
207
+
208
+ if eval:
209
+ wandb.init(config=cfg, name=cfg["name"], project=project, entity=entity)
210
+
211
+ with open(Path(ckpt_dir) / "metrics.eval.jsonl") as f:
212
+ for l in f:
213
+ m = json.loads(l)
214
+ wandb.log(
215
+ {
216
+ f"evals/{name.replace('/','.')}": value
217
+ for name, value in m.items()
218
+ if "/" in name
219
+ },
220
+ step=m["global_step"],
221
+ )
222
+
223
+ wandb.finish()
224
+
225
+
226
+ def get_num_params(model: nn.Module) -> int:
227
+ """
228
+ Get the total model params
229
+ Args : only_trainable: whether to only count trainable params
230
+ """
231
+ numel = {n: p.numel() for n, p in model.named_parameters()}
232
+ return sum(numel.values())
bytelatent/model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
bytelatent/model/blt.py ADDED
@@ -0,0 +1,1064 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from enum import Enum, auto
4
+ from typing import Any, Optional
5
+
6
+ import torch
7
+ from pydantic import ConfigDict, model_validator
8
+ from torch import nn
9
+ from torch.nn.attention.flex_attention import create_block_mask
10
+ from typing_extensions import Self
11
+
12
+ from bytelatent.base_transformer import (
13
+ BaseTransformerArgs,
14
+ InitStdFactor,
15
+ TransformerBlock,
16
+ )
17
+ from bytelatent.data.patcher import Patcher, PatcherArgs
18
+ from bytelatent.model.local_models import LocalDecoder, LocalEncoder
19
+ from bytelatent.model.transformer import GlobalTransformer
20
+ from bytelatent.model.utils import downsample
21
+ from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID
22
+
23
+
24
+ def attention_flops_per_token(n_layers, seq_len, dim, causal):
25
+ # Formula from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py#L27-L30
26
+ return 3.5 * (4 * n_layers * seq_len * dim // (2 if causal else 1))
27
+
28
+
29
+ def get_num_flop_per_token(
30
+ num_non_embed_params: int, n_layers: int, dim: int, seq_len: int
31
+ ) -> int:
32
+ return 6 * num_non_embed_params + attention_flops_per_token(
33
+ n_layers, seq_len, dim, True
34
+ )
35
+
36
+
37
+ def causal_mask(b, h, q_idx, kv_idx):
38
+ return q_idx >= kv_idx
39
+
40
+
41
+ def setattrs(_self, **kwargs):
42
+ for k, v in kwargs.items():
43
+ setattr(_self, k, v)
44
+
45
+
46
+ def get_encoder_dim_token_emb(args):
47
+ if args.dim_token is not None:
48
+ dim_token_emb = args.dim_token
49
+ elif args.use_local_encoder_transformer:
50
+ dim_token_emb = args.dim_local_encoder
51
+ else:
52
+ dim_token_emb = args.dim_global // args.patch_size
53
+ return dim_token_emb
54
+
55
+
56
+ def get_encoder_dim_patch_emb(args):
57
+ dim_patch_emb = None
58
+ if args.cross_attn_encoder:
59
+ if args.cross_attn_init_by_pooling:
60
+ dim_patch_emb = args.dim_local_encoder
61
+ else:
62
+ dim_patch_emb = args.dim_global
63
+ return dim_patch_emb
64
+
65
+
66
+ def get_global_dim_patch_emb(args):
67
+ dim_token_emb = get_encoder_dim_token_emb(args)
68
+ if args.cross_attn_encoder:
69
+ dim_patch_emb = dim_token_emb * args.cross_attn_k
70
+ elif (
71
+ args.downsampling_by_pooling is None
72
+ or not args.downsampling_by_pooling
73
+ or len(args.downsampling_by_pooling) == 0
74
+ ):
75
+ dim_patch_emb = dim_token_emb * args.patch_size
76
+ else:
77
+ dim_patch_emb = dim_token_emb * sum(
78
+ [
79
+ pooling in args.downsampling_by_pooling
80
+ for pooling in ["avg", "min", "max"]
81
+ ]
82
+ )
83
+ return dim_patch_emb
84
+
85
+
86
+ def get_decoder_dim_token_emb(args):
87
+ if args.share_encoder_decoder_emb:
88
+ dim_token_emb = get_encoder_dim_token_emb(args)
89
+ elif args.dim_token is not None:
90
+ dim_token_emb = args.dim_token
91
+ else:
92
+ dim_token_emb = args.dim_local_decoder
93
+ return dim_token_emb
94
+
95
+
96
+ def parse_ngram_to_size(ngram_to_size_str: str | None) -> dict[int, int]:
97
+ if ngram_to_size_str is None:
98
+ return None
99
+ ngram_to_size = {}
100
+ for entry in ngram_to_size_str.split(","):
101
+ ngram, size = entry.split(":")
102
+ ngram = int(ngram)
103
+ size = int(size)
104
+ ngram_to_size[ngram] = size
105
+ return ngram_to_size
106
+
107
+
108
+ def fill_tokens(tokens, patch_size, fill_id):
109
+ batch_size, seq_len = tokens.shape
110
+ if seq_len % patch_size == 0:
111
+ return tokens
112
+ else:
113
+ remaining = patch_size - seq_len % patch_size
114
+ final_padding = tokens.new(batch_size, remaining).fill_(fill_id)
115
+ return torch.cat((tokens, final_padding), dim=1)
116
+
117
+
118
+ def decoder_patch_ids_from_lengths(patch_lengths, nb_boe, seq_len):
119
+ first_patch_length = patch_lengths[0, 0]
120
+ assert torch.all(
121
+ first_patch_length == patch_lengths[:, 0]
122
+ ), "first patch should always be the same size (1 for dynamic, patch_size for static)."
123
+ assert (
124
+ first_patch_length - nb_boe == 1
125
+ ), f"First patch (patch length: {first_patch_length}) should have one non-boe token (boe toks: {nb_boe})"
126
+ # Remove first patch from patch_ids for local decoder inputs and shift the last patch.
127
+ # decoder_patch_lengths = patch_lengths[:, 1:].clone()
128
+ # decoder_patch_lengths = add_to_last_nonzero_patch(decoder_patch_lengths, 1)
129
+ decoder_patch_lengths = patch_lengths[:, 1:]
130
+ assert (
131
+ decoder_patch_lengths.sum() + (nb_boe + 1) * patch_lengths.shape[0]
132
+ == patch_lengths.sum()
133
+ ), f"{decoder_patch_lengths.sum() + (nb_boe + 1) * patch_lengths.shape[0]} != {patch_lengths.sum()}"
134
+ assert torch.all(decoder_patch_lengths >= 0), f"{decoder_patch_lengths}"
135
+ decoder_patch_ids = patch_ids_from_lengths(
136
+ patch_lengths=decoder_patch_lengths, seq_len=seq_len
137
+ )
138
+ return decoder_patch_ids
139
+
140
+
141
+ primes = [
142
+ 1000000007,
143
+ 5915587277,
144
+ 1500450271,
145
+ 3267000013,
146
+ 5754853343,
147
+ 4093082899,
148
+ 9576890767,
149
+ 3628273133,
150
+ 2860486313,
151
+ 5463458053,
152
+ 3367900313,
153
+ ]
154
+
155
+
156
+ def rolling_polynomial_hash(t, hash_func_nb: int = 0):
157
+ prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=t.device)
158
+ prime_powers = torch.stack([prime**i for i in range(t.shape[-1])])
159
+ return torch.sum(t * prime_powers, dim=-1)
160
+
161
+
162
+ def get_rolling_polynomial_hash_fn(hash_func_nb: int = 0, group_size: int = 2):
163
+ prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64)
164
+ prime_powers = torch.stack([prime**i for i in range(group_size)])
165
+
166
+ def rolling_polynomial_hash_fn(t):
167
+ return torch.sum(t * prime_powers, dim=-1)
168
+
169
+ return rolling_polynomial_hash_fn
170
+
171
+
172
+ def byte_group_hash_function(
173
+ x: torch.Tensor, group_size: int = 2, hash_func_nb: int = 0, max_hash: int = 30000
174
+ ):
175
+ """
176
+ Returns a hash of the input x and maps it to a value in the range [0, max_hash].
177
+
178
+ expects: x of shape (batch_size, seq_len) with values as ids in the token vocab.
179
+ returns a tensor of shape (batch_size, seq_len) with values in the range [0, max_hash].
180
+
181
+ Note: max hash can make a big difference on the number of collisions.
182
+ """
183
+ with torch.no_grad():
184
+ bs, seq_len = x.shape
185
+ # x_numpy = x.numpy()
186
+ # hash_values = torch.zeros(bs, seq_len, dtype=torch.int64, requires_grad=False)
187
+ # for i in range(bs):
188
+ # for j in range(seq_len):
189
+ # start = max(j, j-group_size+1)
190
+ # end = j+1
191
+ # hash_values[i, j] = hash_array(x_numpy[i, start:end], max_hash)
192
+
193
+ prefix = torch.zeros(bs, group_size - 1, dtype=torch.int64, device=x.device)
194
+ x = torch.cat([prefix, x], dim=1)
195
+ windows = x.unfold(1, group_size, 1)
196
+ # hashes = get_rolling_polynomial_hash_fn(hash_func_nb, group_size)(windows)
197
+ hashes = rolling_polynomial_hash(windows, hash_func_nb)
198
+ hash_values_range = hashes % max_hash
199
+ hash_values_range.requires_grad = False
200
+ return hash_values_range
201
+
202
+
203
+ def create_patch_mask_from_ids(
204
+ patch_ids, num_patches, window=None, patches_as_queries=False
205
+ ):
206
+ """
207
+ Creates a tensor of shape [bs, seq_len, num_patches] where each element at position (i, j, k)
208
+ is True if the patch id at position (i, j) is less than or equal to k.
209
+ Args:
210
+ patch_ids (torch.Tensor): Tensor of shape [bs, seq_len] containing patch ids.
211
+ num_patches (int): Total number of patches.
212
+ window (int): If not None, only considers patches within a window of size window.
213
+ patches_as_queries (bool): If True, the patches are used as queries
214
+ Returns:
215
+ torch.Tensor: Tensor of shape [bs, q_len, kv_len] with the desired mask.
216
+ """
217
+ bs, seq_len = patch_ids.shape
218
+ if not patches_as_queries:
219
+ q_ids = patch_ids.unsqueeze(-1).expand(bs, seq_len, num_patches)
220
+ kv_ids = (
221
+ torch.arange(num_patches, device=patch_ids.device)
222
+ .unsqueeze(0)
223
+ .unsqueeze(0)
224
+ .expand(bs, seq_len, num_patches)
225
+ )
226
+ else:
227
+ kv_ids = patch_ids.unsqueeze(1).expand(bs, num_patches, seq_len)
228
+ q_ids = (
229
+ torch.arange(num_patches, device=patch_ids.device)
230
+ .unsqueeze(0)
231
+ .unsqueeze(-1)
232
+ .expand(bs, num_patches, seq_len)
233
+ )
234
+ if window is None:
235
+ mask = q_ids == kv_ids
236
+ else:
237
+ mask = (kv_ids <= q_ids) & (q_ids < kv_ids + window)
238
+ return mask
239
+
240
+
241
+ def cross_attn_mask(
242
+ patch_ids,
243
+ patch_lengths,
244
+ N,
245
+ patches_as_queries=False,
246
+ cross_attn_k=1,
247
+ window=None,
248
+ block_mask=True,
249
+ ):
250
+ bs = patch_ids.shape[0]
251
+ with torch.no_grad():
252
+ # Create the patch mask
253
+ cross_mask = create_patch_mask_from_ids(
254
+ patch_ids,
255
+ patch_lengths.shape[1],
256
+ window=window,
257
+ patches_as_queries=patches_as_queries,
258
+ ).repeat_interleave(cross_attn_k, dim=1 if patches_as_queries else -1)
259
+ q_len = patch_lengths.shape[1] * cross_attn_k if patches_as_queries else N
260
+ kv_len = N if patches_as_queries else patch_lengths.shape[1] * cross_attn_k
261
+ assert cross_mask.shape == (
262
+ bs,
263
+ q_len,
264
+ kv_len,
265
+ ), f"{cross_mask.shape} != {(bs, q_len, kv_len)}"
266
+ if block_mask:
267
+
268
+ def patch_mask(b, h, q_idx, kv_idx):
269
+ return cross_mask[b, q_idx, kv_idx]
270
+
271
+ block_mask = create_block_mask(
272
+ patch_mask,
273
+ B=bs,
274
+ H=None,
275
+ Q_LEN=q_len,
276
+ KV_LEN=kv_len,
277
+ _compile=True,
278
+ )
279
+ return block_mask
280
+ else:
281
+ return torch.where(
282
+ cross_mask, torch.tensor(0.0), torch.tensor(float("-inf"))
283
+ ).unsqueeze(
284
+ 1
285
+ ) # [bs, 1, q_len, kv_len]
286
+
287
+
288
+ def get_blt_input(
289
+ tokens: torch.Tensor,
290
+ enforce_patch_size_multiple: bool,
291
+ nb_boe: torch.Tensor,
292
+ patch_size: int,
293
+ boe_id: int,
294
+ ):
295
+ """
296
+ This function returns X_et, X_gt and X_dt, the encoder, global, and decoder
297
+ tokens respectively.
298
+
299
+ Consider the input and target sequences:
300
+ X=[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12,13]
301
+ Y=[4,5,6,7,eos,bos,8,9,10,eos,bos,11,12,13,14]
302
+ with patch_size=4
303
+
304
+ Note 1: that there will be no special tokens introduced at the patch level.
305
+ Note 2: X_e needs to be trimmed to be passed to Global
306
+
307
+ Current without boe:
308
+ X_et = [[boe,boe,boe,boe] [3,4,5,6], [7,eos,bos,8], [9,10,eos,bos] [11,12,13, pad]]
309
+ X_g = [[boe,boe,boe,boe] [3,4,5,6], [7,eos,bos,8], [9,10,eos,bos] [11,12,13, pad]] # remove last glob patch
310
+ X_dt = [[3,4,5,6] [7,eos,bos,8], [9,10,eos,bos], [11,12,13]]
311
+ Y = [[4,5,6,7] [eos,bos,8,9], [10,eos,bos,11], [12,13,14]]
312
+
313
+ --> lag fix:
314
+ X_et = [[boe,boe,boe,3] [4,5,6,7], [eos,bos,8,9], [10,eos,bos,11] [12,13,pad,pad]]
315
+ X_g = [[boe,boe,boe,3] [4,5,6,7], [eos,bos,8,9], [10,eos,bos,11]]
316
+ X_dt = [[3,4,5,6] [7,eos,bos,8], [9,10,eos,bos], [11,12,13]]
317
+ Y = [[4,5,6,7] [eos,bos,8,9], [10,eos,bos,11], [12,13,14]]
318
+
319
+ Dynamic (current):
320
+ X = [3,4,5,6,7,eos,bos,8,9,10,eos,bos]
321
+ Y = [4,5,6,7,eos,bos,8,9,10,eos,bos,11]
322
+
323
+ entropy patching:
324
+ input: 7, bos, 9, 10
325
+ pred (high entropy): eos, 8, 10, eos
326
+
327
+ X_et = [[boe,3,4,5,6,7,eos,bos,8,9,10,eos,bos]
328
+ X_g = [[boe], [3,4,5,6], [7,eos],[bos,8],[9], [10,eos]]
329
+ X_dt = [[3,4,5,6], [7,eos], [bos,8],[9], [10,eos],[bos]]
330
+ Y = [4,5,6,7,eos,bos,8,9,10,eos,bos,11]
331
+
332
+ --> lag fix no boe (force single byte first patch):
333
+ X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12]
334
+ X_g = [[3], [4,5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # remove last global patch
335
+ X_dt = [[3,4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11,12]]
336
+ Y = [4,5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12,13]
337
+
338
+ input: 4, 7, bos, 9, 10
339
+ pred (high entropy): 5, eos, 8, 10, eos
340
+
341
+ X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12]
342
+ X_g = [[3], [4] , [5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # remove last global patch
343
+ X_dt = [[3] [4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11,12]]
344
+ Y = [4,] [5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12,13]
345
+
346
+ Handle the last byte properly.
347
+ patch_lengths = [1, 1, 3, 2, 2 1 2 2 1]
348
+ X_et = [[3,4,5,6,7,eos,bos,8,9,10,eos,bos,11,12]
349
+ X_g = [[3], [4] , [5,6,7], [eos,bos],[8,9], [10], [eos,bos], [11,12]] # do not remove last global patch
350
+ X_dt = [[3] [4,5,6], [7,eos], [bos,8], [9], [10,eos], [bos,11] [12]]
351
+ Y = [4,] [5,6,7, eos,bos, 8,9, 10, eos,bos, 11,12, 13]]
352
+
353
+
354
+ bpe delim
355
+ X_et = [[3,4,5,6,7,<d>,eos,bos,<d>,8,9,<d>,10,<d>,eos,bos,11,12]
356
+ X_g = [[3], [4,5,6,7,<d>], [eos,bos,<d>], ..
357
+ X_dt = [[3,4,5,6,7], [<d>,eos,bos], [<d>,bos,8], ..
358
+ Y = [4,5,6,7,<d>, eos,bos,<d> 8,9,<d>, ..
359
+
360
+
361
+ Note 1: that there will be no special tokens introduced at the patch level.
362
+ Note 2: X_e needs to be trimmed to be passed to Global
363
+ """
364
+ batch_size, seq_len = tokens.shape
365
+ local_encoder_tokens = tokens
366
+ local_decoder_tokens = tokens
367
+
368
+ if nb_boe > 0:
369
+ padded_patch = tokens.new(batch_size, nb_boe).fill_(boe_id)
370
+ local_encoder_tokens = torch.cat((padded_patch, local_encoder_tokens), dim=1)
371
+ # global_tokens = tokens.new(batch_size, ((seq_len-1) // patch_size)+1).fill_(boe_id)
372
+
373
+ # create global tokens, contains boe tokens and eos
374
+ # padded_local_encoder_tokens = fill_tokens(local_encoder_tokens, patch_size, boe_id)
375
+ # patches = padded_local_encoder_tokens.view(batch_size, -1, patch_size)
376
+ # global_tokens = (patches.eq(eos_id).any(dim=2).int() * eos_id)[:, 1:]
377
+ # global_tokens += global_tokens.eq(0).int() * boe_id
378
+ # TODO: fix this when we want to use block causal in the global.
379
+
380
+ if enforce_patch_size_multiple and local_encoder_tokens.shape[-1] % patch_size != 0:
381
+ local_encoder_tokens = fill_tokens(local_encoder_tokens, patch_size, boe_id)
382
+
383
+ return local_encoder_tokens, None, local_decoder_tokens
384
+
385
+
386
+ def patch_ids_from_lengths(patch_lengths, seq_len):
387
+ bs, num_patches = patch_lengths.shape
388
+ # Create a tensor of cumulative sums of the patch lengths
389
+ cum_d = torch.cat(
390
+ [
391
+ torch.zeros(bs, 1, dtype=patch_lengths.dtype, device=patch_lengths.device),
392
+ patch_lengths.cumsum(dim=-1),
393
+ ],
394
+ dim=-1,
395
+ )
396
+ patch_ids = (cum_d.unsqueeze(-1) <= torch.arange(seq_len, device=cum_d.device)).sum(
397
+ dim=-2
398
+ ) - 1
399
+ assert not (
400
+ torch.max(patch_ids) > patch_lengths.shape[-1] or torch.min(patch_ids) < 0
401
+ ), f"{torch.max(patch_ids)} > {patch_lengths.shape[-1]} or {torch.min(patch_ids)} < 0"
402
+ return patch_ids
403
+
404
+
405
+ class ByteLatentTransformerArgs(BaseTransformerArgs):
406
+ model_config = ConfigDict(extra="forbid")
407
+ # Basic model configuration
408
+ seed: int = 42
409
+ vocab_size: int = -1
410
+ dim: int = 512
411
+ n_layers: int = 8
412
+ n_heads: int = 8
413
+ # TODO: What is the purpose of this parameter?
414
+ weight_tying: bool = False
415
+ sliding_window: Optional[int] = None
416
+
417
+ # Architecture and dimensions
418
+ dim_token: int = 256
419
+ dim_global: int = 512
420
+ dim_local_decoder: int = 512
421
+ dim_local_encoder: int = 512
422
+ n_layers_global: int = 8
423
+ n_layers_local_decoder: int = 8
424
+ n_layers_local_encoder: int = 8
425
+
426
+ # Tokenization and patching
427
+ tokenization_mode: str = "bpe"
428
+ patch_size: float | None = None
429
+ patching_mode: str | None = None
430
+ patching_threshold: float | None = None
431
+ patching_threshold_add: float | None = None
432
+ monotonicity: bool = False
433
+ patching_batch_size: int = 1
434
+ patching_device: str = "cuda"
435
+ data_loader_patching: bool = False
436
+ max_patch_length: int | None = None
437
+
438
+ # Encoder/Decoder configuration
439
+ tie_local_encoder_decoder_logits: bool = False
440
+ use_local_encoder_transformer: bool = False
441
+ encoder_lm_loss: bool = False
442
+ max_encoder_seq_length: int | None = None
443
+ pad_to_max_length: bool = False
444
+ encoder_enable_byte_ngrams: bool = False
445
+ encoder_enable_byte_group_hash: bool = False
446
+ ngram_vocab_sizes: int | None = None
447
+
448
+ # Cross attention configurations
449
+ cross_attn_encoder: bool = False
450
+ cross_attn_decoder: bool = False
451
+ cross_attn_window_encoder: int | None = None
452
+ cross_attn_window_decoder: int | None = None
453
+ cross_attn_k: int | None = None
454
+ cross_attn_nheads: int | None = None
455
+ cross_attn_all_layers_decoder: bool = False
456
+ cross_attn_all_layers_encoder: bool = False
457
+ cross_attn_use_flex_attention: bool = True
458
+ cross_attn_init_by_pooling: bool = False
459
+
460
+ # Encoder hash configurations
461
+ encoder_hash_byte_group_size: Any | None = None
462
+ encoder_hash_byte_group_vocab: int = 30000
463
+ encoder_hash_byte_group_nb_functions: int = 3
464
+
465
+ # Model behavior and optimization
466
+ log_patch_lengths: bool = False
467
+ non_linearity: str = "swiglu"
468
+ use_rope: bool = True
469
+ recompute_fc1_out: bool = False
470
+ recompute_fc3_out: bool = False
471
+ recompute_attn: bool = True
472
+ custom_bwd: bool = False
473
+ layer_ckpt: str = "all"
474
+ efficient_attn: str | None = None
475
+
476
+ # Architecture options
477
+ patch_only_encoder: bool = False
478
+ patch_only_decoder: bool = False
479
+
480
+ # Initialization and attention
481
+ init_use_gaussian: bool = True
482
+ init_use_depth: str = "current"
483
+ attn_bias_type: str = "causal"
484
+ alpha_depth: str = "disabled"
485
+ max_length: int = 2048
486
+
487
+ # Norm configuration
488
+ norm_eps: float = 1e-5
489
+ norm_affine: bool = True
490
+ pre_norm: bool = True
491
+ norm_type: str = "rmsnorm"
492
+
493
+ # Additional configurations
494
+ multiple_of: int = 256
495
+ ffn_dim_multiplier: float = 1.0
496
+ dropout: float = 0
497
+ output_size: int = -1
498
+
499
+ # Additional parameters from ModelArgs
500
+ architecture: str = "vanilla"
501
+ share_encoder_decoder_emb: bool = True
502
+ global_local_decoder_residual_layer: str | None = None
503
+
504
+ tokenize_with_bpe_delimiter: bool = False
505
+ patching_thresholds_str: str | None = None
506
+ tie_local_encoder_decoder: bool = False
507
+ encoder_preds_low_entropy_toks: float | None = None
508
+ encoder_preds_random_toks: float | None = None
509
+ dim_token_emb: int | None = None
510
+ dim_patch_emb: int | None = None
511
+
512
+ encoder_ngram_table_dir: str | None = None
513
+ encoder_ngram_to_size_str: str | None = None
514
+
515
+ # Model architecture params
516
+ entropy_model_checkpoint_dir: str | None = None
517
+ entropy_model_is_ngram_model: bool = False
518
+ downsampling_by_pooling: str | None = None
519
+ n_heads_global: int = 8
520
+ n_heads_local_decoder: int = 8
521
+ n_heads_local_encoder: int = 8
522
+ n_kv_heads: int | None = None
523
+ n_kv_heads_global: int | None = None
524
+ conv_kernel_size: int | None = None
525
+ local_attention_window_len: int | None = None
526
+
527
+ # Performance optimization
528
+ sequence_parallel: bool = False
529
+ loss_parallel: bool = False
530
+ fuse_sequence_parallel: bool = False
531
+ use_fsdp: bool = True
532
+ attn_to_keep: str = "all"
533
+
534
+ # RoPE parameters
535
+ rope_theta: float = 10000.0
536
+ rope_use_fp32_in_outer_product: bool = False
537
+
538
+ # Parameter mixing
539
+ pm_size: int = 0
540
+
541
+ # Logging
542
+ full_logging_n_layers: int = 4
543
+
544
+ # Special token config
545
+ eos_id: int | None = None
546
+
547
+ @model_validator(mode="after")
548
+ def check_hash_byte_sizes(self) -> Self:
549
+ if (
550
+ self.encoder_hash_byte_group_size is not None
551
+ and type(self.encoder_hash_byte_group_size) == str
552
+ ):
553
+ self.encoder_hash_byte_group_size = [
554
+ int(x)
555
+ for x in self.encoder_hash_byte_group_size.split(",")
556
+ if len(x) > 0
557
+ ]
558
+ return self
559
+
560
+
561
+ class LocalEncoderArgs(ByteLatentTransformerArgs):
562
+ # Local encoder specific dimensions
563
+ n_heads_local_encoder: int = 8
564
+ dim_token_emb: int | None = None
565
+ dim_patch_emb: int | None = None
566
+
567
+ def __post_init__(self):
568
+ # Override base args with local encoder specific values
569
+ self.dim = self.dim_local_encoder
570
+ self.n_layers = self.n_layers_local_encoder
571
+ self.n_heads = self.n_heads_local_encoder
572
+ self.cross_attn_decoder = False
573
+ self.cross_attn_k = self.cross_attn_k if self.cross_attn_encoder else None
574
+ self.attn_bias_type = "local_block_causal"
575
+
576
+
577
+ class GlobalTransformerArgs(ByteLatentTransformerArgs):
578
+ # Global encoder specific dimensions
579
+ dim_token_emb: int | None = None
580
+ dim_patch_emb: int | None = None
581
+
582
+ def __post_init__(self):
583
+ # Override base args with global encoder specific values
584
+ self.dim = self.dim_global
585
+ self.n_layers = self.n_layers_global
586
+ self.n_heads = self.n_heads_global
587
+ self.n_kv_heads = self.n_kv_heads_global
588
+ self.local_attention_window_len = None
589
+ self.cross_attn_encoder = False
590
+ self.cross_attn_decoder = False
591
+
592
+
593
+ class LocalDecoderArgs(ByteLatentTransformerArgs):
594
+ # Local decoder specific dimensions
595
+ dim_token_emb: int | None = None
596
+ dim_patch_emb: int | None = None
597
+
598
+ def __post_init__(self):
599
+ # Override base args with local decoder specific values
600
+ self.dim = self.dim_local_decoder
601
+ self.n_layers = self.n_layers_local_decoder
602
+ self.n_heads = self.n_heads_local_decoder
603
+ self.cross_attn_encoder = False
604
+ self.cross_attn_init_by_pooling = False
605
+ self.attn_bias_type = "local_block_causal"
606
+
607
+
608
+ def create_global_transformer(args: ByteLatentTransformerArgs) -> GlobalTransformer:
609
+ global_args = args.model_copy(
610
+ deep=True,
611
+ update=dict(
612
+ dim=args.dim_global,
613
+ n_layers=args.n_layers_global,
614
+ n_heads=args.n_heads_global,
615
+ n_kv_heads=args.n_kv_heads_global,
616
+ local_attention_window_len=None,
617
+ dim_token_emb=get_global_dim_patch_emb(args),
618
+ dim_patch_emb=None,
619
+ cross_attn_encoder=False,
620
+ cross_attn_decoder=False,
621
+ ),
622
+ )
623
+
624
+ return GlobalTransformer(global_args)
625
+
626
+
627
+ def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder:
628
+ # First deep copy the original args
629
+ # Replace with local encoder specific values
630
+ local_encoder_args = args.model_copy(
631
+ deep=True,
632
+ update=dict(
633
+ dim=args.dim_local_encoder,
634
+ n_layers=args.n_layers_local_encoder,
635
+ n_heads=args.n_heads_local_encoder,
636
+ dim_token_emb=get_encoder_dim_token_emb(args),
637
+ dim_patch_emb=get_encoder_dim_patch_emb(args),
638
+ cross_attn_decoder=False,
639
+ cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None,
640
+ attn_bias_type="local_block_causal",
641
+ ),
642
+ )
643
+
644
+ return LocalEncoder(local_encoder_args)
645
+
646
+
647
+ def create_local_decoder(args: ByteLatentTransformerArgs) -> LocalDecoder:
648
+ # First deep copy the original args
649
+ local_decoder_args = args.model_copy(
650
+ deep=True,
651
+ update=dict(
652
+ dim=args.dim_local_decoder,
653
+ n_layers=args.n_layers_local_decoder,
654
+ n_heads=args.n_heads_local_decoder,
655
+ cross_attn_encoder=False,
656
+ cross_attn_init_by_pooling=False, # states are already defined
657
+ dim_token_emb=get_decoder_dim_token_emb(args),
658
+ dim_patch_emb=args.dim_global,
659
+ cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None,
660
+ ),
661
+ )
662
+
663
+ return LocalDecoder(local_decoder_args)
664
+
665
+
666
+ class EmbeddingType(Enum):
667
+ HASH_TOK = auto()
668
+ NGRAM = auto()
669
+
670
+
671
+ def init_embeddings(
672
+ args,
673
+ embedding_type: EmbeddingType,
674
+ local_encoder_dim: int,
675
+ encoder_hash_byte_group_size: list = None,
676
+ ):
677
+ if (
678
+ embedding_type == EmbeddingType.HASH_TOK
679
+ and args.encoder_hash_byte_group_size is None
680
+ ):
681
+ return None
682
+ if embedding_type == EmbeddingType.NGRAM and args.encoder_ngram_to_size_str is None:
683
+ return None
684
+
685
+ embeddings = []
686
+
687
+ if embedding_type == EmbeddingType.HASH_TOK:
688
+ emb_dim = local_encoder_dim
689
+ encoder_hash_byte_group_vocab = args.encoder_hash_byte_group_vocab
690
+ for _ in range(args.encoder_hash_byte_group_nb_functions):
691
+ for _ in encoder_hash_byte_group_size:
692
+ embeddings.append(
693
+ nn.Embedding(
694
+ encoder_hash_byte_group_vocab,
695
+ emb_dim,
696
+ )
697
+ )
698
+
699
+ elif embedding_type == EmbeddingType.NGRAM:
700
+ encoder_ngram_to_size = parse_ngram_to_size(args.encoder_ngram_to_size_str)
701
+ emb_dim = local_encoder_dim
702
+ OFFSET = 4 # This should be passed as parameter if it's variable
703
+ for ngram_vocab_size in encoder_ngram_to_size.values():
704
+ embeddings.append(nn.Embedding(ngram_vocab_size + OFFSET, emb_dim))
705
+
706
+ return nn.ModuleList(embeddings)
707
+
708
+
709
+ def compute_hash_embeddings(
710
+ local_encoder_tokens: torch.Tensor,
711
+ local_encoder,
712
+ encoder_hash_tok_embedding: nn.ModuleList,
713
+ encoder_hash_byte_group_nb_functions: int,
714
+ encoder_hash_byte_group_size: list,
715
+ encoder_hash_byte_group_vocab: int,
716
+ ) -> torch.Tensor:
717
+ """
718
+ Compute embeddings using hash token embeddings.
719
+
720
+ Args:
721
+ local_encoder_tokens: Input tokens tensor
722
+ local_encoder: Encoder object with tok_embeddings method
723
+ encoder_hash_tok_embedding: ModuleList of hash token embeddings
724
+ encoder_hash_byte_group_nb_functions: Number of hash functions
725
+ encoder_hash_byte_group_size: List of byte group sizes
726
+ encoder_hash_byte_group_vocab: Vocabulary size for hash embeddings
727
+
728
+ Returns:
729
+ torch.Tensor: Combined embeddings
730
+ """
731
+ if encoder_hash_tok_embedding is None:
732
+ return None
733
+
734
+ local_encoder_embeds = local_encoder.tok_embeddings(local_encoder_tokens)
735
+
736
+ i = 0
737
+ for func_nb in range(encoder_hash_byte_group_nb_functions):
738
+ for byte_group_size in encoder_hash_byte_group_size:
739
+ hash_ids = byte_group_hash_function(
740
+ local_encoder_tokens,
741
+ byte_group_size,
742
+ hash_func_nb=func_nb,
743
+ max_hash=encoder_hash_byte_group_vocab,
744
+ )
745
+ hash_tok_embedding = encoder_hash_tok_embedding[i]
746
+ local_encoder_embeds = local_encoder_embeds + hash_tok_embedding(hash_ids)
747
+ i += 1
748
+
749
+ assert i == len(encoder_hash_tok_embedding)
750
+ return local_encoder_embeds
751
+
752
+
753
+ class ByteLatentTransformer(nn.Module):
754
+ """
755
+ The ByteLatentTransformer (BLT) is a byte-level language model architecture that processes byte sequences
756
+ by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers,
757
+ and local decoders to efficiently encode and decode byte sequences, leveraging patch-based processing for
758
+ improved performance and inference efficiency.
759
+ """
760
+
761
+ def __init__(self, args: ByteLatentTransformerArgs):
762
+ super().__init__()
763
+
764
+ # General configuration
765
+ self.weight_tying = args.weight_tying
766
+ self.sliding_window = args.sliding_window
767
+ self.patch_size = args.patch_size
768
+ self.patching_mode = args.patching_mode
769
+ self.boe_id, self.bos_id, self.pad_id, self.eos_id = (
770
+ BOE_ID,
771
+ BOS_ID,
772
+ PAD_ID,
773
+ EOS_ID,
774
+ )
775
+ self.downsampling_by_pooling = args.downsampling_by_pooling
776
+ self.patching_threshold = args.patching_threshold
777
+ self.dim = args.dim
778
+ self.init_base_std = args.init_base_std
779
+ self.init_std_factor = InitStdFactor(args.init_std_factor)
780
+ self.max_seqlen = args.max_seqlen
781
+
782
+ # Cross attention configuration
783
+ self.cross_attn_encoder = args.cross_attn_encoder
784
+ self.cross_attn_decoder = args.cross_attn_decoder
785
+ self.cross_attn_k = args.cross_attn_k
786
+ self.cross_attn_window_encoder = args.cross_attn_window_encoder
787
+ self.cross_attn_window_decoder = args.cross_attn_window_decoder
788
+ self.cross_attn_use_flex_attention = args.cross_attn_use_flex_attention
789
+
790
+ # Encoder hash configuration
791
+ self.encoder_hash_byte_group_size = args.encoder_hash_byte_group_size
792
+ self.encoder_hash_byte_group_vocab = args.encoder_hash_byte_group_vocab
793
+ self.encoder_hash_byte_group_nb_functions = (
794
+ args.encoder_hash_byte_group_nb_functions
795
+ )
796
+
797
+ # ByteLatent modules
798
+ self.local_encoder = create_local_encoder(args)
799
+ self.global_transformer = create_global_transformer(args)
800
+ self.local_decoder = create_local_decoder(args)
801
+ self.encoder_hash_tok_embedding = init_embeddings(
802
+ args,
803
+ EmbeddingType.HASH_TOK,
804
+ local_encoder_dim=self.local_encoder.dim,
805
+ encoder_hash_byte_group_size=self.encoder_hash_byte_group_size,
806
+ )
807
+ self.encoder_ngram_embedding = init_embeddings(
808
+ args,
809
+ EmbeddingType.NGRAM,
810
+ local_encoder_dim=self.local_encoder.dim,
811
+ encoder_hash_byte_group_size=None,
812
+ )
813
+ self.tok_embeddings = torch.nn.Embedding(args.vocab_size, args.dim)
814
+
815
+ # Transformer layers
816
+ self.layers = nn.ModuleList(
817
+ [TransformerBlock(args) for _ in range(args.n_layers)]
818
+ )
819
+
820
+ # Encoder ngram embedding tables
821
+ self.encoder_ngram_embedding = None
822
+ if args.encoder_enable_byte_ngrams:
823
+ self.encoder_ngram_embedding = nn.ModuleList()
824
+ assert args.ngram_vocab_sizes is not None
825
+ self.encoder_ngram_to_size = parse_ngram_to_size(
826
+ args.encoder_ngram_to_size_str
827
+ )
828
+ ngram_emb_dim = self.local_encoder.dim
829
+ for ngram_vocab_size in self.encoder_ngram_to_size.values():
830
+ self.encoder_ngram_embedding.append(
831
+ nn.Embedding(ngram_vocab_size + OFFSET, ngram_emb_dim)
832
+ )
833
+
834
+ # Output layer
835
+ assert args.vocab_size > 0, "vocab_size must be greater than 0"
836
+ self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
837
+ if args.weight_tying:
838
+ self.output.weight = self.tok_embeddings.weight
839
+
840
+ # Patcher module
841
+ if not args.data_loader_patching:
842
+ self.patcher = Patcher(
843
+ PatcherArgs(
844
+ patch_size=args.patch_size,
845
+ patching_mode=args.patching_mode,
846
+ patching_threshold=args.patching_threshold,
847
+ patching_threshold_add=args.patching_threshold_add,
848
+ monotonicity=args.monotonicity,
849
+ max_patch_length=args.max_patch_length,
850
+ )
851
+ )
852
+
853
+ def forward(
854
+ self,
855
+ tokens: torch.Tensor,
856
+ patch_lengths: Optional[torch.Tensor] = None,
857
+ ngram_ids: Optional[torch.Tensor] = None,
858
+ ):
859
+ # Ensure ngram_ids is either a tensor or None
860
+ assert (
861
+ isinstance(ngram_ids, torch.Tensor) or ngram_ids is None
862
+ ), f"ngram_ids must be a tensor or None, but was: {type(ngram_ids)}"
863
+
864
+ bs, N = tokens.shape # Batch size and sequence length
865
+
866
+ # Get megabyte inputs
867
+ nb_boe = int(0 if self.patching_mode != "" else self.patch_size - 1)
868
+ local_encoder_tokens, _, local_decoder_tokens = get_blt_input(
869
+ tokens=tokens,
870
+ enforce_patch_size_multiple=False,
871
+ nb_boe=nb_boe,
872
+ patch_size=self.patch_size,
873
+ boe_id=self.boe_id,
874
+ )
875
+
876
+ # Patching
877
+ if patch_lengths is None:
878
+ assert (
879
+ getattr(self, "patcher", None) is not None
880
+ ), "Patcher not defined and no patch_lengths passed."
881
+ patch_lengths, tok_scores = self.patcher.patch(
882
+ local_encoder_tokens,
883
+ include_next_token=True,
884
+ threshold=self.patcher.threshold,
885
+ )
886
+ else:
887
+ if nb_boe > 0:
888
+ patch_lengths[:, 0] += nb_boe
889
+
890
+ assert torch.min(patch_lengths) >= 0
891
+
892
+ # Generate patch IDs from patch_lengths
893
+ patch_ids = patch_ids_from_lengths(
894
+ patch_lengths, local_encoder_tokens.shape[-1]
895
+ )
896
+ assert torch.max(patch_ids) + 1 <= torch.max(
897
+ (patch_lengths != 0).sum(dim=-1)
898
+ ), f"{torch.max(patch_ids) + 1} > {torch.max((patch_lengths != 0).sum(dim=-1))}"
899
+
900
+ cross_attn_mask_enc = None
901
+ # Cross-attention encoder
902
+ if self.cross_attn_encoder:
903
+ cross_attn_mask_enc = cross_attn_mask(
904
+ patch_ids,
905
+ patch_lengths,
906
+ N,
907
+ patches_as_queries=True,
908
+ cross_attn_k=self.cross_attn_k,
909
+ window=self.cross_attn_window_encoder,
910
+ block_mask=self.cross_attn_use_flex_attention,
911
+ )
912
+
913
+ # Hashing and embedding
914
+ local_encoder_embeds = compute_hash_embeddings(
915
+ local_encoder_tokens=local_encoder_tokens,
916
+ local_encoder=self.local_encoder,
917
+ encoder_hash_tok_embedding=self.encoder_hash_tok_embedding,
918
+ encoder_hash_byte_group_nb_functions=self.encoder_hash_byte_group_nb_functions,
919
+ encoder_hash_byte_group_size=self.encoder_hash_byte_group_size,
920
+ encoder_hash_byte_group_vocab=self.encoder_hash_byte_group_vocab,
921
+ )
922
+
923
+ # N-gram table embeddings
924
+ if self.encoder_ngram_embedding is not None:
925
+ assert ngram_ids is not None, "ngram_ids must be provided"
926
+ if local_encoder_embeds is None:
927
+ local_encoder_embeds = self.local_encoder.tok_embeddings(
928
+ local_encoder_tokens
929
+ )
930
+ assert len(ngram_ids) == len(
931
+ self.encoder_ngram_embedding
932
+ ), f"ngram_ids.shape[0]={ngram_ids.shape[0]} versus len(encoder_ngram_embedding)={len(self.encoder_ngram_embedding)}, ngram_ids.shape={ngram_ids.shape}"
933
+ for i in range(ngram_ids.shape[0]):
934
+ ngram_embedding = self.encoder_ngram_embedding[i]
935
+ ngram_embeds = ngram_embedding(ngram_ids[i])
936
+ assert (
937
+ local_encoder_embeds.shape == ngram_embeds.shape
938
+ ), f"Shape mismatch: {local_encoder_embeds.shape} vs {ngram_embeds.shape}, ngram_ids.shape={ngram_ids.shape}"
939
+ local_encoder_embeds = local_encoder_embeds + ngram_embeds
940
+
941
+ # Local encoder
942
+ h_cross = None
943
+ (h_encoder, h_cross), cache_encoder = self.local_encoder(
944
+ tokens=local_encoder_tokens,
945
+ embeds=local_encoder_embeds,
946
+ patch_embeds=h_cross if self.cross_attn_encoder else None,
947
+ cross_mask=cross_attn_mask_enc,
948
+ num_patches=patch_lengths.shape[1],
949
+ patch_ids=patch_ids,
950
+ )
951
+
952
+ # Downsampling
953
+ if not self.cross_attn_encoder:
954
+ assert (
955
+ patch_ids.shape[1] == h_encoder.shape[1]
956
+ ), f"{patch_ids.shape[1]} != {h_encoder.shape[1]}"
957
+ h = downsample(
958
+ h_encoder,
959
+ patch_lengths.shape[1],
960
+ patch_lengths,
961
+ patch_ids,
962
+ downsampling_by_pooling=self.downsampling_by_pooling,
963
+ patch_size=self.patch_size,
964
+ )
965
+ else:
966
+ # Reshape h_cross
967
+ h = h_cross.view(bs, patch_lengths.shape[1], -1)
968
+
969
+ # Global transformer
970
+ global_tokens = tokens.new(h.shape[0], h.shape[1]).fill_(self.boe_id)
971
+ rows, cols = torch.where(local_encoder_tokens == self.eos_id)
972
+ eos_patch_ids = patch_ids[rows, cols]
973
+ global_tokens[rows, eos_patch_ids] = self.eos_id
974
+
975
+ h, _ = self.global_transformer(
976
+ embeds=h,
977
+ tokens=global_tokens,
978
+ )
979
+
980
+ # Unpatching
981
+ dec_embeds = h_encoder[:, nb_boe : nb_boe + N, :]
982
+
983
+ # Generate decoder patch IDs
984
+ decoder_patch_ids = decoder_patch_ids_from_lengths(
985
+ patch_lengths, nb_boe, local_decoder_tokens.shape[-1]
986
+ )
987
+ assert (
988
+ torch.max(decoder_patch_ids) + 1 <= h.shape[1]
989
+ ), f"{torch.max(decoder_patch_ids) + 1} > {h.shape[1]}"
990
+ assert (
991
+ decoder_patch_ids.shape[1] == dec_embeds.shape[1]
992
+ ), f"{decoder_patch_ids.shape[1]} != {dec_embeds.shape[1]}"
993
+
994
+ # Cross-attention decoder
995
+ if not self.cross_attn_decoder:
996
+ h = torch.gather(
997
+ h, 1, decoder_patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1])
998
+ )
999
+ cross_attn_mask_dec = None
1000
+ assert local_decoder_tokens.shape == h.shape[:-1]
1001
+ else:
1002
+ cross_attn_mask_dec = cross_attn_mask(
1003
+ decoder_patch_ids,
1004
+ patch_lengths,
1005
+ N,
1006
+ patches_as_queries=False,
1007
+ cross_attn_k=self.cross_attn_k,
1008
+ window=self.cross_attn_window_decoder,
1009
+ block_mask=self.cross_attn_use_flex_attention,
1010
+ )
1011
+
1012
+ # Local decoder
1013
+ output, _ = self.local_decoder(
1014
+ embeds=dec_embeds,
1015
+ patch_embeds=h,
1016
+ tokens=local_decoder_tokens,
1017
+ cross_mask=cross_attn_mask_dec,
1018
+ )
1019
+ return output
1020
+
1021
+ def reset_parameters(self, init_std=None):
1022
+ # Either use fixed base std or sqrt model dim
1023
+ init_std = init_std or (self.dim ** (-0.5))
1024
+ nn.init.trunc_normal_(
1025
+ self.tok_embeddings.weight,
1026
+ mean=0.0,
1027
+ std=init_std,
1028
+ a=-3 * init_std,
1029
+ b=3 * init_std,
1030
+ )
1031
+ if not self.weight_tying:
1032
+ nn.init.trunc_normal_(
1033
+ self.output.weight,
1034
+ mean=0.0,
1035
+ std=init_std,
1036
+ a=-3 * init_std,
1037
+ b=3 * init_std,
1038
+ )
1039
+
1040
+ def init_weights(self):
1041
+ self.reset_parameters()
1042
+ self.init_base_std = self.init_base_std or (self.dim ** (-0.5))
1043
+ for depth, layer in enumerate(self.layers):
1044
+ factor = {
1045
+ InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
1046
+ InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5,
1047
+ InitStdFactor.DIM_RATIO: self.dim / 4096,
1048
+ InitStdFactor.DISABLED: 1.0,
1049
+ }[self.init_std_factor]
1050
+
1051
+ layer.init_weights(self.init_base_std, factor)
1052
+
1053
+ self.local_decoder.init_weights(self.init_base_std)
1054
+ self.global_transformer.init_weights(self.init_base_std)
1055
+ self.local_encoder.init_weights(self.init_base_std)
1056
+
1057
+ for emb in self.encoder_hash_tok_embedding:
1058
+ nn.init.trunc_normal_(
1059
+ emb.weight,
1060
+ mean=0.0,
1061
+ std=self.init_base_std,
1062
+ a=-3 * self.init_base_std,
1063
+ b=3 * self.init_base_std,
1064
+ )
bytelatent/model/local_models.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import logging
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import torch
7
+ import torch.nn
8
+ import torch.nn as nn
9
+ from torch.nn import functional as F
10
+ from torch.nn.attention.flex_attention import BlockMask
11
+ from xformers.ops import AttentionBias
12
+
13
+ from bytelatent.base_transformer import (
14
+ InitStdFactor,
15
+ RMSNorm,
16
+ RotaryEmbedding,
17
+ TransformerBlock,
18
+ )
19
+ from bytelatent.model.transformer import CrossAttention
20
+ from bytelatent.model.utils import create_causal_mask, downsample
21
+ from bytelatent.tokenizers.blt_tokenizer import BOE_ID
22
+
23
+ logger = logging.getLogger()
24
+
25
+
26
+ class LocalModelBase(nn.Module):
27
+ def __init__(self, args):
28
+ super().__init__()
29
+
30
+ self.dim = args.dim
31
+ self.dropout = args.dropout
32
+ self.vocab_size = args.vocab_size + args.pm_size
33
+ self.patch_size = args.patch_size
34
+
35
+ self.efficient_attn = args.efficient_attn
36
+ self.sliding_window = args.sliding_window
37
+ self.use_rope = args.use_rope
38
+ self.init_std_factor = args.init_std_factor
39
+ self.cross_attn_encoder = getattr(args, "cross_attn_encoder", None)
40
+ self.cross_attn_decoder = getattr(args, "cross_attn_decoder", None)
41
+ self.cross_attn_k = getattr(args, "cross_attn_k", None)
42
+
43
+ self.boe_id = BOE_ID
44
+
45
+ self.norm = RMSNorm(args.dim, eps=args.norm_eps)
46
+ self.layers = nn.ModuleList(
47
+ [TransformerBlock(args) for _ in range(args.n_layers)]
48
+ )
49
+
50
+ self.tok_embeddings = nn.Embedding(self.vocab_size, args.dim)
51
+ if not self.use_rope:
52
+ self.pos_embeddings = nn.Embedding(args.max_length, args.dim)
53
+ else:
54
+ self.rope = RotaryEmbedding(
55
+ theta=args.rope_theta,
56
+ head_dim=args.head_dim or args.dim // args.n_heads,
57
+ max_seqlen=getattr(args, "max_encoder_seq_length", args.max_length),
58
+ )
59
+ self.pos_embeddings = None
60
+
61
+ self.token_embedding_projection = (
62
+ nn.Linear(args.dim_token_emb, args.dim, bias=False)
63
+ if hasattr(args, "dim_token_emb") and args.dim_token_emb != self.dim
64
+ else None
65
+ )
66
+
67
+ self.patch_embedding_projection = self._create_patch_projection(args)
68
+
69
+ def _should_create_patch_projection(self, args):
70
+ dimension_mismatch = (
71
+ getattr(args, "dim_patch_emb") and args.dim_patch_emb != self.dim
72
+ )
73
+
74
+ # Check cross attention conditions
75
+ cross_attn_conditions = (
76
+ hasattr(args, "cross_attn_encoder")
77
+ and args.cross_attn_encoder
78
+ and getattr(args, "cross_attn_init_by_pooling")
79
+ ) or (
80
+ hasattr(args, "cross_attn_decoder")
81
+ and args.cross_attn_decoder
82
+ and getattr(args, "cross_attn_init_by_pooling")
83
+ )
84
+
85
+ return dimension_mismatch or cross_attn_conditions
86
+
87
+ def _create_patch_projection(self, args):
88
+ if not self._should_create_patch_projection(args):
89
+ return None
90
+
91
+ output_dim = args.dim_token_emb * (self.cross_attn_k or 1)
92
+
93
+ return nn.Linear(
94
+ in_features=args.dim_patch_emb,
95
+ out_features=output_dim,
96
+ bias=False,
97
+ )
98
+
99
+ def apply_embedding(self, tokens, embeds):
100
+ if embeds is not None:
101
+ return embeds
102
+ else:
103
+ return self.tok_embeddings(tokens)
104
+
105
+ def init_weights(self, init_std=None):
106
+ self.rope.reset_parameters()
107
+
108
+ init_std = init_std or (self.dim ** (-0.5))
109
+ nn.init.trunc_normal_(
110
+ self.tok_embeddings.weight,
111
+ mean=0.0,
112
+ std=init_std,
113
+ a=-3 * init_std,
114
+ b=3 * init_std,
115
+ )
116
+ if self.pos_embeddings is not None:
117
+ nn.init.trunc_normal_(
118
+ self.pos_embeddings.weight,
119
+ mean=0.0,
120
+ std=init_std,
121
+ a=-3 * init_std,
122
+ b=3 * init_std,
123
+ )
124
+
125
+ for depth, layer in enumerate(self.layers):
126
+ factor = {
127
+ InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
128
+ InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5,
129
+ InitStdFactor.DIM_RATIO: self.dim / 4096,
130
+ InitStdFactor.DISABLED: 1.0,
131
+ }[self.init_std_factor]
132
+
133
+ layer.init_weights(init_std, factor)
134
+
135
+ if self.token_embedding_projection is not None:
136
+ nn.init.trunc_normal_(
137
+ self.token_embedding_projection.weight,
138
+ mean=0.0,
139
+ std=init_std,
140
+ a=-3 * init_std,
141
+ b=3 * init_std,
142
+ )
143
+
144
+ if self.patch_embedding_projection is not None:
145
+ nn.init.trunc_normal_(
146
+ self.patch_embedding_projection.weight,
147
+ mean=0.0,
148
+ std=init_std,
149
+ a=-3 * init_std,
150
+ b=3 * init_std,
151
+ )
152
+
153
+ if hasattr(self, "output"):
154
+ nn.init.trunc_normal_(
155
+ self.output.weight,
156
+ mean=0.0,
157
+ std=init_std,
158
+ a=-3 * init_std,
159
+ b=3 * init_std,
160
+ )
161
+
162
+ if self.cross_attn_layers is not None:
163
+ for depth, layer in enumerate(self.cross_attn_layers):
164
+ factor = {
165
+ InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
166
+ InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5,
167
+ InitStdFactor.DIM_RATIO: self.dim / 4096,
168
+ InitStdFactor.DISABLED: 1.0,
169
+ }[self.init_std_factor]
170
+
171
+ layer.init_weights(init_std, factor)
172
+
173
+
174
+ class LocalEncoder(LocalModelBase):
175
+ def __init__(self, args):
176
+ super().__init__(args)
177
+ self.output_proj = (
178
+ args.patching_mode in ["entropy", "probmax"]
179
+ ) and args.entropy_model_checkpoint_dir is None
180
+
181
+ self.apply_transformer = args.use_local_encoder_transformer
182
+ self.downsampling_by_pooling = args.downsampling_by_pooling
183
+ self.patch_only = args.patch_only_encoder
184
+ self.expects_hash_embeddings = args.encoder_hash_byte_group_size is not None
185
+ self.cross_attn_encoder = args.cross_attn_encoder
186
+ self.cross_attn_all_layers_encoder = args.cross_attn_all_layers_encoder
187
+ self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling
188
+ self.cross_attn_nheads = args.cross_attn_nheads
189
+
190
+ if self.cross_attn_encoder:
191
+ self.cross_attn_layers = torch.nn.ModuleList()
192
+ layers_to_add = args.n_layers if self.cross_attn_all_layers_encoder else 1
193
+ for _ in range(layers_to_add):
194
+ self.cross_attn_layers.append(
195
+ CrossAttention(
196
+ dim=self.dim,
197
+ head_dim=self.dim // self.cross_attn_nheads,
198
+ n_heads=self.cross_attn_nheads,
199
+ n_kv_heads=self.cross_attn_nheads,
200
+ norm_eps=args.norm_eps,
201
+ )
202
+ )
203
+
204
+ def apply_embedding(self, tokens, embeds):
205
+ if embeds is not None:
206
+ assert (
207
+ self.expects_hash_embeddings
208
+ ), "Not expecting embeddings to be passed."
209
+ return embeds
210
+ else:
211
+ return self.tok_embeddings(tokens)
212
+
213
+ def forward(
214
+ self,
215
+ tokens: torch.Tensor,
216
+ embeds: Optional[torch.Tensor] = None,
217
+ patch_embeds: Optional[torch.Tensor] = None,
218
+ mask: Optional[Union["BlockMask", "AttentionBias", torch.Tensor, str]] = None,
219
+ cross_mask: Optional[torch.Tensor] = None,
220
+ num_patches: Optional[int] = None,
221
+ patch_ids: Optional[torch.Tensor] = None,
222
+ cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
223
+ ):
224
+ """ """
225
+ bs, seqlen = tokens.shape
226
+ if mask is None:
227
+ mask = create_causal_mask(seqlen, self.efficient_attn, self.sliding_window)
228
+
229
+ h = self.apply_embedding(tokens, embeds)
230
+ freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None
231
+
232
+ h = F.dropout(h, p=self.dropout, training=self.training)
233
+
234
+ for i, layer in enumerate(self.layers):
235
+ h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.efficient_attn)
236
+ # check if cross attention should be applied to either all layer or only the last layer
237
+ if self.cross_attn_encoder and (
238
+ i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder
239
+ ):
240
+ patch_embeds = self.apply_cross_attention(
241
+ h, patch_embeds, i, bs, num_patches, patch_ids, cross_mask
242
+ )
243
+
244
+ h_residual = patch_embeds if self.cross_attn_encoder else None
245
+ return (h, h_residual), cache
246
+
247
+ def apply_cross_attention(
248
+ self, h, patch_embeds, layer_idx, bs, num_patches, patch_ids, cross_mask
249
+ ):
250
+ # apply pooling and project
251
+ if self.cross_attn_init_by_pooling and patch_embeds is None:
252
+ patch_embeds = downsample(
253
+ h,
254
+ num_patches,
255
+ patch_ids=patch_ids,
256
+ downsampling_by_pooling=self.downsampling_by_pooling,
257
+ patch_size=self.patch_size,
258
+ )
259
+ if self.patch_embedding_projection is not None:
260
+ patch_embeds = self.patch_embedding_projection(patch_embeds)
261
+ patch_embeds = patch_embeds.reshape(
262
+ bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim
263
+ )
264
+
265
+ layer_idx = layer_idx if self.cross_attn_all_layers_encoder else 0
266
+ patch_embeds_cross = self.cross_attn_layers[layer_idx](
267
+ x=patch_embeds,
268
+ kv=h,
269
+ mask=cross_mask,
270
+ )
271
+ patch_embeds += patch_embeds_cross
272
+ return patch_embeds
273
+
274
+
275
+ class LocalDecoder(LocalModelBase):
276
+ def __init__(self, args):
277
+ super().__init__(args)
278
+
279
+ # Model configuration flags
280
+ self.patch_only = args.patch_only_decoder
281
+ self.expects_embeddings = args.share_encoder_decoder_emb
282
+ self.cross_attn_decoder = args.cross_attn_decoder
283
+ self.cross_attn_all_layers_decoder = args.cross_attn_all_layers_decoder
284
+ self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling
285
+ self.cross_attn_nheads = args.cross_attn_nheads
286
+
287
+ if self.cross_attn_decoder:
288
+ self.cross_attn_layers = torch.nn.ModuleList()
289
+ layers_to_add = args.n_layers if self.cross_attn_all_layers_decoder else 1
290
+ for _ in range(layers_to_add):
291
+ self.cross_attn_layers.append(
292
+ CrossAttention(
293
+ dim=self.dim,
294
+ head_dim=self.dim // self.cross_attn_nheads,
295
+ n_heads=self.cross_attn_nheads,
296
+ n_kv_heads=self.cross_attn_nheads,
297
+ norm_eps=args.norm_eps,
298
+ )
299
+ )
300
+
301
+ self.output = nn.Linear(
302
+ self.dim,
303
+ args.vocab_size,
304
+ bias=False,
305
+ )
306
+
307
+ def forward(
308
+ self,
309
+ tokens: torch.Tensor,
310
+ embeds: Optional[torch.Tensor],
311
+ patch_embeds: Optional[torch.Tensor] = None,
312
+ mask: Optional[Union["BlockMask", "AttentionBias", torch.Tensor, str]] = None,
313
+ cross_mask: Optional[torch.Tensor] = None,
314
+ cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
315
+ ):
316
+ bs, seqlen = tokens.shape
317
+ assert embeds is not None, "Embeddings must be provided"
318
+
319
+ if mask is None:
320
+ mask = create_causal_mask(seqlen, self.efficient_attn, self.sliding_window)
321
+
322
+ h = embeds
323
+
324
+ if self.patch_embedding_projection is not None:
325
+ assert patch_embeds is not None, "Patch embeddings must be passed."
326
+ patch_embeds = self.patch_embedding_projection(patch_embeds)
327
+ if self.cross_attn_k is not None:
328
+ patch_embeds = patch_embeds.reshape(
329
+ bs, patch_embeds.shape[1] * self.cross_attn_k, self.dim
330
+ )
331
+
332
+ if patch_embeds is not None and not self.cross_attn_decoder:
333
+ h = h + patch_embeds
334
+
335
+ freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None
336
+
337
+ h = F.dropout(h, p=self.dropout, training=self.training)
338
+ for i, layer in enumerate(self.layers):
339
+ if self.cross_attn_decoder and (
340
+ i == 0 or self.cross_attn_all_layers_decoder
341
+ ):
342
+ # Use cross attention to extract info from patch_embeds into h
343
+ h_cross = self.cross_attn_layers[i](
344
+ x=h,
345
+ kv=patch_embeds,
346
+ mask=cross_mask,
347
+ )
348
+ h = h + h_cross
349
+
350
+ h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.efficient_attn)
351
+
352
+ h_preds = self.norm(h)
353
+ h_preds = F.dropout(h_preds, p=self.dropout, training=self.training)
354
+ h_preds = self.output(h_preds)
355
+ h_preds = h_preds.float()
356
+ return h_preds, cache
bytelatent/model/transformer.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ import logging
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn
7
+ import torch.nn as nn
8
+ from torch.nn import functional as F
9
+ from torch.nn.attention.flex_attention import BlockMask
10
+ from xformers.ops import AttentionBias
11
+
12
+ from bytelatent.base_transformer import (
13
+ BaseTransformer,
14
+ RMSNorm,
15
+ flex_attention_comp,
16
+ repeat_kv,
17
+ )
18
+ from bytelatent.model.utils import create_causal_mask
19
+
20
+ logger = logging.getLogger()
21
+
22
+
23
+ class CrossAttention(nn.Module):
24
+ """
25
+ CrossAttention block to attend to the encoder states from the decoder.
26
+ Rope is not supported.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ dim: int,
32
+ head_dim: int,
33
+ n_heads: int,
34
+ n_kv_heads: int,
35
+ norm_eps: float,
36
+ ):
37
+ super().__init__()
38
+
39
+ self.dim = dim
40
+ self.head_dim = head_dim
41
+
42
+ self.n_heads = n_heads
43
+ self.n_kv_heads = n_kv_heads
44
+ self.heads_per_group = self.n_heads // self.n_kv_heads
45
+
46
+ self.cross_attn_norm_q = RMSNorm(dim, eps=norm_eps)
47
+ self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps)
48
+
49
+ self.wq = nn.Linear(
50
+ dim,
51
+ n_heads * head_dim,
52
+ bias=False,
53
+ )
54
+ self.wk = nn.Linear(
55
+ dim,
56
+ n_kv_heads * head_dim,
57
+ bias=False,
58
+ )
59
+ self.wv = nn.Linear(
60
+ dim,
61
+ n_kv_heads * head_dim,
62
+ bias=False,
63
+ )
64
+
65
+ self.wo = nn.Linear(
66
+ n_heads * head_dim,
67
+ dim,
68
+ bias=False,
69
+ )
70
+
71
+ def forward(
72
+ self,
73
+ x: torch.Tensor,
74
+ kv: torch.Tensor,
75
+ mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
76
+ ) -> torch.Tensor:
77
+ # B S D
78
+ bsz, seq_len, _ = x.shape
79
+ _, slen_kv, _ = kv.shape
80
+ x = self.cross_attn_norm_q(x)
81
+ kv = self.cross_attn_norm_kv(kv)
82
+
83
+ xq = self.wq(x)
84
+ xk = self.wk(kv)
85
+ xv = self.wv(kv)
86
+
87
+ output_shape = xq.shape
88
+ # B S D -> B S H D
89
+ xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim)
90
+ xk = xk.view(bsz, slen_kv, self.n_kv_heads, self.head_dim)
91
+ xv = xv.view(bsz, slen_kv, self.n_kv_heads, self.head_dim)
92
+
93
+ xk = repeat_kv(xk, self.heads_per_group, dim=2)
94
+ xv = repeat_kv(xv, self.heads_per_group, dim=2)
95
+
96
+ assert mask is None or isinstance(mask, BlockMask)
97
+ xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv))
98
+ output = flex_attention_comp(xq, xk, xv, block_mask=mask)
99
+ output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
100
+
101
+ output = self.wo(output.reshape(output_shape))
102
+
103
+ return x + output
104
+
105
+ def init_weights(self, base_std: float, factor: float = 1.0):
106
+ std = base_std * factor
107
+
108
+ nn.init.trunc_normal_(
109
+ self.wq.weight,
110
+ mean=0.0,
111
+ std=std,
112
+ a=-3 * std,
113
+ b=3 * std,
114
+ )
115
+
116
+ nn.init.trunc_normal_(
117
+ self.wk.weight,
118
+ mean=0.0,
119
+ std=std,
120
+ a=-3 * std,
121
+ b=3 * std,
122
+ )
123
+
124
+ nn.init.trunc_normal_(
125
+ self.wv.weight,
126
+ mean=0.0,
127
+ std=std,
128
+ a=-3 * std,
129
+ b=3 * std,
130
+ )
131
+
132
+ output_std = std / (2**0.5)
133
+ nn.init.trunc_normal_(
134
+ self.wo.weight,
135
+ mean=0.0,
136
+ std=output_std,
137
+ a=-3 * output_std,
138
+ b=3 * output_std,
139
+ )
140
+ self.cross_attn_norm_q.reset_parameters()
141
+ self.cross_attn_norm_kv.reset_parameters()
142
+
143
+
144
+ class GlobalTransformer(BaseTransformer):
145
+ def __init__(self, args):
146
+ super().__init__(args)
147
+ self.dropout = args.dropout
148
+ self.sliding_window = args.sliding_window
149
+ self.efficient_attn = args.efficient_attn
150
+
151
+ self.token_embedding_projection = None
152
+ if args.dim_token_emb is not None and args.dim_token_emb != self.dim:
153
+ self.token_embedding_projection = nn.Linear(
154
+ args.dim_token_emb,
155
+ args.dim,
156
+ bias=False,
157
+ )
158
+
159
+ def forward(
160
+ self,
161
+ tokens: torch.Tensor,
162
+ tok_idx: Optional[torch.Tensor] = None,
163
+ embeds: Optional[torch.Tensor] = None,
164
+ mask: Optional[Union[BlockMask, AttentionBias, torch.Tensor, str]] = None,
165
+ cache: Optional[List[Tuple[torch.Tensor, torch.Tensor, int]]] = None,
166
+ ):
167
+ """
168
+ Similar to BaseTransformer.forward, but with an additional embeds argument
169
+ and projection to the token space.
170
+ """
171
+ bs, seqlen = tokens.shape
172
+ attn_impl = self.efficient_attn
173
+
174
+ h = embeds
175
+
176
+ mask = (
177
+ mask
178
+ if mask is not None
179
+ else create_causal_mask(seqlen, attn_impl, self.sliding_window)
180
+ )
181
+
182
+ if self.token_embedding_projection is not None and h.shape[-1] != self.dim:
183
+ h = self.token_embedding_projection(h)
184
+
185
+ h = F.dropout(h, p=self.dropout, training=self.training)
186
+
187
+ h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)
188
+ return h, cache
189
+
190
+ def init_weights(self, init_base_std: float):
191
+ super().init_weights()
192
+ if self.token_embedding_projection is not None:
193
+ nn.init.trunc_normal_(
194
+ self.token_embedding_projection.weight,
195
+ mean=0.0,
196
+ std=init_base_std,
197
+ a=-3 * init_base_std,
198
+ b=3 * init_base_std,
199
+ )
bytelatent/model/utils.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ import torch
3
+ from torch.nn.attention.flex_attention import create_block_mask
4
+ from xformers.ops import fmha
5
+
6
+
7
+ def patch_reduce(h, max_num_patches, reduction, patch_ids):
8
+ """
9
+ Reduce variable length patches to single embedding per patch
10
+ Note: this works with variable number of patches for different sequences in the batch
11
+ It handles variable length patches by assuming that patch_lengths will be 0 for any
12
+ extra patches on the *right*. Since there can be a variable number of patches
13
+ this function also return the number of patches for each sequence in the batch.
14
+ Any embeddings on the right that are not allocated to a patch
15
+ (i.e. if the sum(patch_lengths[i]) < seq_len for any i)
16
+ will be sent to a dummy patch, which is trimmed before returning.
17
+ """
18
+ bs, seq_len, emb_dim = h.shape
19
+
20
+ patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1])
21
+
22
+ reduced_embs = torch.zeros(
23
+ (bs, max_num_patches, emb_dim), dtype=h.dtype, device=h.device
24
+ )
25
+ reduced_embs = reduced_embs.scatter_reduce(
26
+ src=h,
27
+ dim=1,
28
+ index=patch_ids,
29
+ reduce=reduction,
30
+ include_self=False,
31
+ )
32
+ reduced_embs = reduced_embs[:, :max_num_patches, :]
33
+
34
+ return reduced_embs
35
+
36
+
37
+ def concat_downsample(h, patch_lengths, patch_size):
38
+ # The assumption in this function is that seq_len = patch_size * num_patches.
39
+ bs, seq_len, emb_dim = h.shape
40
+ patch_end_ids = torch.cumsum(patch_lengths, dim=1)
41
+ patch_ids = patch_end_ids.unsqueeze(-1) - torch.arange(patch_size, 0, -1).to(
42
+ patch_end_ids.device
43
+ )
44
+ # Is clamp ok here?
45
+ patch_ids = patch_ids.clamp(min=0).unsqueeze(-1).expand(-1, -1, -1, h.shape[-1])
46
+ patch_ids = patch_ids.view(bs, -1, emb_dim)
47
+ # after gather h.shape = [batch_size, seq_len, dim]
48
+ h = torch.gather(h, 1, patch_ids)
49
+ h = h.reshape(bs, patch_lengths.shape[1], patch_size * h.size(-1))
50
+ return h
51
+
52
+
53
+ def pooling_downsample(h, max_num_patches, pooling_mode, patch_ids):
54
+ cat = []
55
+ if "avg" in pooling_mode or "mean" in pooling_mode:
56
+ cat.append(patch_reduce(h, max_num_patches, "mean", patch_ids))
57
+ if "min" in pooling_mode:
58
+ cat.append(patch_reduce(h, max_num_patches, "amin", patch_ids))
59
+ if "max" in pooling_mode:
60
+ cat.append(patch_reduce(h, max_num_patches, "amax", patch_ids))
61
+ assert len(cat) > 0
62
+ h = torch.cat(cat, dim=-1)
63
+ return h
64
+
65
+
66
+ def downsample(
67
+ h,
68
+ num_patches,
69
+ patch_lengths=None,
70
+ patch_ids=None,
71
+ downsampling_by_pooling=None,
72
+ patch_size=4,
73
+ ):
74
+ """
75
+ Downsampling:
76
+ a. concatenating embeddings in the patch
77
+ Note: with dynamic patching, patch the last patch_size tokens.
78
+ b. pooling embeddings in the patch
79
+ """
80
+ # input: h.shape = [batch_size, seq_len, dim]
81
+ # input: pool h.shape = [batch_size, seq_len / patch_size, dim]
82
+ # if we don't use the cros_attn, we pool so that we convert bytes rep to patch rep
83
+ if downsampling_by_pooling is not None and len(downsampling_by_pooling) > 0:
84
+ # By pooling
85
+ max_num_patches = num_patches
86
+ assert patch_ids is not None
87
+ h = pooling_downsample(h, max_num_patches, downsampling_by_pooling, patch_ids)
88
+ else:
89
+ # TODO: remove this condition
90
+ # By concatenating (fixed lengths patching)
91
+ assert patch_lengths is not None
92
+ h = concat_downsample(h, patch_lengths, patch_size)
93
+ return h
94
+
95
+
96
+ def causal_mask(b, h, q_idx, kv_idx):
97
+ return q_idx >= kv_idx
98
+
99
+
100
+ def create_causal_mask(seqlen, attn_impl, sliding_window):
101
+ if sliding_window is not None and attn_impl == "xformers":
102
+ return fmha.attn_bias.LocalAttentionFromBottomRightMask(
103
+ window_left=sliding_window - 1, window_right=0
104
+ )
105
+ elif attn_impl == "xformers":
106
+ return fmha.attn_bias.LowerTriangularMask()
107
+ elif attn_impl == "sdpa":
108
+ return "causal"
109
+ elif attn_impl == "flex_attention":
110
+ return create_block_mask(causal_mask, None, None, seqlen, seqlen)
111
+ elif attn_impl == "fmha":
112
+ return None
113
+ else:
114
+ raise NotImplementedError(
115
+ f"Attention {attn_impl} with {sliding_window} sliding window not implemented"
116
+ )