letitiaaa commited on
Commit
e66e8cc
·
1 Parent(s): d136541

Initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .girattributes +2 -0
  2. .github/workflows/ci.yml +121 -0
  3. .github/workflows/clear-cache.yml +29 -0
  4. .github/workflows/python-publish.yml +37 -0
  5. .gitignore +153 -0
  6. CITATION.cff +33 -0
  7. HISTORY.md +223 -0
  8. LICENSE +23 -0
  9. MANIFEST.in +3 -0
  10. README.md +618 -0
  11. models.txt +2 -0
  12. pytest.ini +3 -0
  13. requirements.txt +8 -0
  14. src/open_clip/__init__.py +18 -0
  15. src/open_clip/coca_model.py +582 -0
  16. src/open_clip/constants.py +11 -0
  17. src/open_clip/convert.py +206 -0
  18. src/open_clip/factory.py +586 -0
  19. src/open_clip/hf_configs.py +67 -0
  20. src/open_clip/hf_model.py +193 -0
  21. src/open_clip/loss.py +447 -0
  22. src/open_clip/model.py +919 -0
  23. src/open_clip/model_configs/EVA01-g-14-plus.json +18 -0
  24. src/open_clip/model_configs/EVA01-g-14.json +18 -0
  25. src/open_clip/model_configs/EVA02-B-16.json +18 -0
  26. src/open_clip/model_configs/EVA02-E-14-plus.json +18 -0
  27. src/open_clip/model_configs/EVA02-E-14.json +18 -0
  28. src/open_clip/model_configs/EVA02-L-14-336.json +18 -0
  29. src/open_clip/model_configs/EVA02-L-14.json +18 -0
  30. src/open_clip/model_configs/MobileCLIP-B.json +21 -0
  31. src/open_clip/model_configs/MobileCLIP-S1.json +21 -0
  32. src/open_clip/model_configs/MobileCLIP-S2.json +21 -0
  33. src/open_clip/model_configs/RN101-quickgelu.json +22 -0
  34. src/open_clip/model_configs/RN101.json +21 -0
  35. src/open_clip/model_configs/RN50-quickgelu.json +22 -0
  36. src/open_clip/model_configs/RN50.json +21 -0
  37. src/open_clip/model_configs/RN50x16-quickgelu.json +22 -0
  38. src/open_clip/model_configs/RN50x16.json +21 -0
  39. src/open_clip/model_configs/RN50x4-quickgelu.json +22 -0
  40. src/open_clip/model_configs/RN50x4.json +21 -0
  41. src/open_clip/model_configs/RN50x64-quickgelu.json +22 -0
  42. src/open_clip/model_configs/RN50x64.json +21 -0
  43. src/open_clip/model_configs/ViT-B-16-SigLIP-256.json +29 -0
  44. src/open_clip/model_configs/ViT-B-16-SigLIP-384.json +29 -0
  45. src/open_clip/model_configs/ViT-B-16-SigLIP-512.json +29 -0
  46. src/open_clip/model_configs/ViT-B-16-SigLIP-i18n-256.json +29 -0
  47. src/open_clip/model_configs/ViT-B-16-SigLIP.json +29 -0
  48. src/open_clip/model_configs/ViT-B-16-SigLIP2-256.json +32 -0
  49. src/open_clip/model_configs/ViT-B-16-SigLIP2-384.json +32 -0
  50. src/open_clip/model_configs/ViT-B-16-SigLIP2-512.json +32 -0
.girattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.py linguist-language=python
2
+ *.ipynb linguist-documentation
.github/workflows/ci.yml ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Continuous integration
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ paths-ignore:
8
+ - '**.md'
9
+ - 'CITATION.cff'
10
+ - 'LICENSE'
11
+ - '.gitignore'
12
+ - 'docs/**'
13
+ pull_request:
14
+ branches:
15
+ - main
16
+ paths-ignore:
17
+ - '**.md'
18
+ - 'CITATION.cff'
19
+ - 'LICENSE'
20
+ - '.gitignore'
21
+ - 'docs/**'
22
+ workflow_dispatch:
23
+ inputs:
24
+ manual_revision_reference:
25
+ required: false
26
+ type: string
27
+ manual_revision_test:
28
+ required: false
29
+ type: string
30
+
31
+ env:
32
+ REVISION_REFERENCE: v2.8.2
33
+ #9d31b2ec4df6d8228f370ff20c8267ec6ba39383 earliest compatible v2.7.0 + pretrained_hf param
34
+
35
+ jobs:
36
+ Tests:
37
+ strategy:
38
+ matrix:
39
+ os: [ ubuntu-latest ] #, macos-latest ]
40
+ python: [ 3.8 ]
41
+ job_num: [ 4 ]
42
+ job: [ 1, 2, 3, 4 ]
43
+ runs-on: ${{ matrix.os }}
44
+ steps:
45
+ - uses: actions/checkout@v3
46
+ with:
47
+ fetch-depth: 0
48
+ ref: ${{ inputs.manual_revision_test }}
49
+ - name: Set up Python ${{ matrix.python }}
50
+ id: pythonsetup
51
+ uses: actions/setup-python@v4
52
+ with:
53
+ python-version: ${{ matrix.python }}
54
+ - name: Venv cache
55
+ id: venv-cache
56
+ uses: actions/cache@v3
57
+ with:
58
+ path: .env
59
+ key: venv-${{ matrix.os }}-${{ steps.pythonsetup.outputs.python-version }}-${{ hashFiles('requirements*') }}
60
+ - name: Pytest durations cache
61
+ uses: actions/cache@v3
62
+ with:
63
+ path: .test_durations
64
+ key: test_durations-${{ matrix.os }}-${{ steps.pythonsetup.outputs.python-version }}-${{ matrix.job }}-${{ github.run_id }}
65
+ restore-keys: test_durations-0-
66
+ - name: Setup
67
+ if: steps.venv-cache.outputs.cache-hit != 'true'
68
+ run: |
69
+ python3 -m venv .env
70
+ source .env/bin/activate
71
+ pip install -e .[test]
72
+ - name: Prepare test data
73
+ run: |
74
+ source .env/bin/activate
75
+ python -m pytest \
76
+ --quiet --co \
77
+ --splitting-algorithm least_duration \
78
+ --splits ${{ matrix.job_num }} \
79
+ --group ${{ matrix.job }} \
80
+ -m regression_test \
81
+ tests \
82
+ | head -n -2 | grep -Po 'test_inference_with_data\[\K[^]]*(?=-False]|-True])' \
83
+ > models_gh_runner.txt
84
+ if [ -n "${{ inputs.manual_revision_reference }}" ]; then
85
+ REVISION_REFERENCE=${{ inputs.manual_revision_reference }}
86
+ fi
87
+ python tests/util_test.py \
88
+ --save_model_list models_gh_runner.txt \
89
+ --model_list models_gh_runner.txt \
90
+ --git_revision $REVISION_REFERENCE
91
+ - name: Unit tests
92
+ run: |
93
+ source .env/bin/activate
94
+ if [[ -f .test_durations ]]
95
+ then
96
+ cp .test_durations durations_1
97
+ mv .test_durations durations_2
98
+ fi
99
+ python -m pytest \
100
+ -x -s -v \
101
+ --splitting-algorithm least_duration \
102
+ --splits ${{ matrix.job_num }} \
103
+ --group ${{ matrix.job }} \
104
+ --store-durations \
105
+ --durations-path durations_1 \
106
+ --clean-durations \
107
+ -m "not regression_test" \
108
+ tests
109
+ OPEN_CLIP_TEST_REG_MODELS=models_gh_runner.txt python -m pytest \
110
+ -x -s -v \
111
+ --store-durations \
112
+ --durations-path durations_2 \
113
+ --clean-durations \
114
+ -m "regression_test" \
115
+ tests
116
+ jq -s -S 'add' durations_* > .test_durations
117
+ - name: Collect pytest durations
118
+ uses: actions/upload-artifact@v4
119
+ with:
120
+ name: pytest_durations_${{ matrix.os }}-${{ matrix.python }}-${{ matrix.job }}
121
+ path: .test_durations
.github/workflows/clear-cache.yml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Clear cache
2
+
3
+ on:
4
+ workflow_dispatch:
5
+
6
+ permissions:
7
+ actions: write
8
+
9
+ jobs:
10
+ clear-cache:
11
+ runs-on: ubuntu-latest
12
+ steps:
13
+ - name: Clear cache
14
+ uses: actions/github-script@v6
15
+ with:
16
+ script: |
17
+ const caches = await github.rest.actions.getActionsCacheList({
18
+ owner: context.repo.owner,
19
+ repo: context.repo.repo,
20
+ })
21
+ for (const cache of caches.data.actions_caches) {
22
+ console.log(cache)
23
+ await github.rest.actions.deleteActionsCacheById({
24
+ owner: context.repo.owner,
25
+ repo: context.repo.repo,
26
+ cache_id: cache.id,
27
+ })
28
+ }
29
+
.github/workflows/python-publish.yml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Release
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ jobs:
8
+ deploy:
9
+ runs-on: ubuntu-latest
10
+ steps:
11
+ - uses: actions/checkout@v2
12
+ - uses: actions-ecosystem/action-regex-match@v2
13
+ id: regex-match
14
+ with:
15
+ text: ${{ github.event.head_commit.message }}
16
+ regex: '^Release ([^ ]+)'
17
+ - name: Set up Python
18
+ uses: actions/setup-python@v2
19
+ with:
20
+ python-version: '3.8'
21
+ - name: Install dependencies
22
+ run: |
23
+ python -m pip install --upgrade pip
24
+ pip install setuptools wheel twine build
25
+ - name: Release
26
+ if: ${{ steps.regex-match.outputs.match != '' }}
27
+ uses: softprops/action-gh-release@v1
28
+ with:
29
+ tag_name: v${{ steps.regex-match.outputs.group1 }}
30
+ - name: Build and publish
31
+ if: ${{ steps.regex-match.outputs.match != '' }}
32
+ env:
33
+ TWINE_USERNAME: __token__
34
+ TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
35
+ run: |
36
+ python -m build
37
+ twine upload dist/*
.gitignore ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ **/logs/
2
+ **/wandb/
3
+ models/
4
+ features/
5
+ results/
6
+
7
+ tests/data/
8
+ *.pt
9
+
10
+ # Byte-compiled / optimized / DLL files
11
+ __pycache__/
12
+ *.py[cod]
13
+ *$py.class
14
+
15
+ # C extensions
16
+ *.so
17
+
18
+ # Distribution / packaging
19
+ .Python
20
+ build/
21
+ develop-eggs/
22
+ dist/
23
+ downloads/
24
+ eggs/
25
+ .eggs/
26
+ lib/
27
+ lib64/
28
+ parts/
29
+ sdist/
30
+ var/
31
+ wheels/
32
+ pip-wheel-metadata/
33
+ share/python-wheels/
34
+ *.egg-info/
35
+ .installed.cfg
36
+ *.egg
37
+ MANIFEST
38
+
39
+ # PyInstaller
40
+ # Usually these files are written by a python script from a template
41
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
42
+ *.manifest
43
+ *.spec
44
+
45
+ # Installer logs
46
+ pip-log.txt
47
+ pip-delete-this-directory.txt
48
+
49
+ # Unit test / coverage reports
50
+ htmlcov/
51
+ .tox/
52
+ .nox/
53
+ .coverage
54
+ .coverage.*
55
+ .cache
56
+ nosetests.xml
57
+ coverage.xml
58
+ *.cover
59
+ *.py,cover
60
+ .hypothesis/
61
+ .pytest_cache/
62
+
63
+ # Translations
64
+ *.mo
65
+ *.pot
66
+
67
+ # Django stuff:
68
+ *.log
69
+ local_settings.py
70
+ db.sqlite3
71
+ db.sqlite3-journal
72
+
73
+ # Flask stuff:
74
+ instance/
75
+ .webassets-cache
76
+
77
+ # Scrapy stuff:
78
+ .scrapy
79
+
80
+ # Sphinx documentation
81
+ docs/_build/
82
+
83
+ # PyBuilder
84
+ target/
85
+
86
+ # Jupyter Notebook
87
+ .ipynb_checkpoints
88
+
89
+ # IPython
90
+ profile_default/
91
+ ipython_config.py
92
+
93
+ # pyenv
94
+ .python-version
95
+
96
+ # pipenv
97
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
98
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
99
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
100
+ # install all needed dependencies.
101
+ #Pipfile.lock
102
+
103
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
104
+ __pypackages__/
105
+
106
+ # Celery stuff
107
+ celerybeat-schedule
108
+ celerybeat.pid
109
+
110
+ # SageMath parsed files
111
+ *.sage.py
112
+
113
+ # Environments
114
+ .env
115
+ .venv
116
+ env/
117
+ venv/
118
+ ENV/
119
+ env.bak/
120
+ venv.bak/
121
+
122
+ # Spyder project settings
123
+ .spyderproject
124
+ .spyproject
125
+
126
+ # Rope project settings
127
+ .ropeproject
128
+
129
+ # mkdocs documentation
130
+ /site
131
+
132
+ # mypy
133
+ .mypy_cache/
134
+ .dmypy.json
135
+ dmypy.json
136
+
137
+ # Pyre type checker
138
+ .pyre/
139
+ sync.sh
140
+ gpu1sync.sh
141
+ .idea
142
+ *.pdf
143
+ **/._*
144
+ **/*DS_*
145
+ **.jsonl
146
+ src/sbatch
147
+ src/misc
148
+ .vscode
149
+ src/debug
150
+ core.*
151
+
152
+ # Allow
153
+ !src/evaluation/misc/results_dbs/*
CITATION.cff ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cff-version: 1.1.0
2
+ message: If you use this software, please cite it as below.
3
+ authors:
4
+ - family-names: Ilharco
5
+ given-names: Gabriel
6
+ - family-names: Wortsman
7
+ given-names: Mitchell
8
+ - family-names: Wightman
9
+ given-names: Ross
10
+ - family-names: Gordon
11
+ given-names: Cade
12
+ - family-names: Carlini
13
+ given-names: Nicholas
14
+ - family-names: Taori
15
+ given-names: Rohan
16
+ - family-names: Dave
17
+ given-names: Achal
18
+ - family-names: Shankar
19
+ given-names: Vaishaal
20
+ - family-names: Namkoong
21
+ given-names: Hongseok
22
+ - family-names: Miller
23
+ given-names: John
24
+ - family-names: Hajishirzi
25
+ given-names: Hannaneh
26
+ - family-names: Farhadi
27
+ given-names: Ali
28
+ - family-names: Schmidt
29
+ given-names: Ludwig
30
+ title: OpenCLIP
31
+ version: v0.1
32
+ doi: 10.5281/zenodo.5143773
33
+ date-released: 2021-07-28
HISTORY.md ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## 2.24.0
2
+
3
+ * Fix missing space in error message
4
+ * use model flag for normalizing embeddings
5
+ * init logit_bias for non siglip pretrained models
6
+ * Fix logit_bias load_checkpoint addition
7
+ * Make CoCa model match CLIP models for logit scale/bias init
8
+ * Fix missing return of "logit_bias" in CoCa.forward
9
+ * Add NLLB-CLIP with SigLIP models
10
+ * Add get_logits method and NLLB tokenizer
11
+ * Remove the empty file src/open_clip/generation_utils.py
12
+ * Update params.py: "BatchNorm" -> "LayerNorm" in the description string for "--lock-text-freeze-layer-norm"
13
+
14
+ ## 2.23.0
15
+
16
+ * Add CLIPA-v2 models
17
+ * Add SigLIP models
18
+ * Add MetaCLIP models
19
+ * Add NLLB-CLIP models
20
+ * CLIPA train code
21
+ * Minor changes/fixes
22
+ * Remove protobuf version limit
23
+ * Stop checking model name when loading CoCa models
24
+ * Log native wandb step
25
+ * Use bool instead of long masks
26
+
27
+ ## 2.21.0
28
+
29
+ * Add SigLIP loss + training support
30
+ * Add more DataComp models (B/16, B/32 and B/32@256)
31
+ * Update default num workers
32
+ * Update CoCa generation for `transformers>=4.31`
33
+ * PyTorch 2.0 `state_dict()` compatibility fix for compiled models
34
+ * Fix padding in `ResizeMaxSize`
35
+ * Convert JIT model on state dict load for `pretrained='filename…'`
36
+ * Other minor changes and fixes (typos, README, dependencies, CI)
37
+
38
+ ## 2.20.0
39
+
40
+ * Add EVA models
41
+ * Support serial worker training
42
+ * Fix Python 3.7 compatibility
43
+
44
+ ## 2.19.0
45
+
46
+ * Add DataComp models
47
+
48
+ ## 2.18.0
49
+
50
+ * Enable int8 inference without `.weight` attribute
51
+
52
+ ## 2.17.2
53
+
54
+ * Update push_to_hf_hub
55
+
56
+ ## 2.17.0
57
+
58
+ * Add int8 support
59
+ * Update notebook demo
60
+ * Refactor zero-shot classification code
61
+
62
+ ## 2.16.2
63
+
64
+ * Fixes for context_length and vocab_size attributes
65
+
66
+ ## 2.16.1
67
+
68
+ * Fixes for context_length and vocab_size attributes
69
+ * Fix --train-num-samples logic
70
+ * Add HF BERT configs for PubMed CLIP model
71
+
72
+ ## 2.16.0
73
+
74
+ * Add improved g-14 weights
75
+ * Update protobuf version
76
+
77
+ ## 2.15.0
78
+
79
+ * Add convnext_xxlarge weights
80
+ * Fixed import in readme
81
+ * Add samples per second per gpu logging
82
+ * Fix slurm example
83
+
84
+ ## 2.14.0
85
+
86
+ * Move dataset mixtures logic to shard level
87
+ * Fix CoCa accum-grad training
88
+ * Safer transformers import guard
89
+ * get_labels refactoring
90
+
91
+ ## 2.13.0
92
+
93
+ * Add support for dataset mixtures with different sampling weights
94
+ * Make transformers optional again
95
+
96
+ ## 2.12.0
97
+
98
+ * Updated convnext configs for consistency
99
+ * Added input_patchnorm option
100
+ * Clean and improve CoCa generation
101
+ * Support model distillation
102
+ * Add ConvNeXt-Large 320x320 fine-tune weights
103
+
104
+ ## 2.11.1
105
+
106
+ * Make transformers optional
107
+ * Add MSCOCO CoCa finetunes to pretrained models
108
+
109
+ ## 2.11.0
110
+
111
+ * coca support and weights
112
+ * ConvNeXt-Large weights
113
+
114
+ ## 2.10.1
115
+
116
+ * `hf-hub:org/model_id` support for loading models w/ config and weights in Hugging Face Hub
117
+
118
+ ## 2.10.0
119
+
120
+ * Added a ViT-bigG-14 model.
121
+ * Added an up-to-date example slurm script for large training jobs.
122
+ * Added a option to sync logs and checkpoints to S3 during training.
123
+ * New options for LR schedulers, constant and constant with cooldown
124
+ * Fix wandb autoresuming when resume is not set
125
+ * ConvNeXt `base` & `base_w` pretrained models added
126
+ * `timm-` model prefix removed from configs
127
+ * `timm` augmentation + regularization (dropout / drop-path) supported
128
+
129
+ ## 2.9.3
130
+
131
+ * Fix wandb collapsing multiple parallel runs into a single one
132
+
133
+ ## 2.9.2
134
+
135
+ * Fix braceexpand memory explosion for complex webdataset urls
136
+
137
+ ## 2.9.1
138
+
139
+ * Fix release
140
+
141
+ ## 2.9.0
142
+
143
+ * Add training feature to auto-resume from the latest checkpoint on restart via `--resume latest`
144
+ * Allow webp in webdataset
145
+ * Fix logging for number of samples when using gradient accumulation
146
+ * Add model configs for convnext xxlarge
147
+
148
+ ## 2.8.2
149
+
150
+ * wrapped patchdropout in a torch.nn.Module
151
+
152
+ ## 2.8.1
153
+
154
+ * relax protobuf dependency
155
+ * override the default patch dropout value in 'vision_cfg'
156
+
157
+ ## 2.8.0
158
+
159
+ * better support for HF models
160
+ * add support for gradient accumulation
161
+ * CI fixes
162
+ * add support for patch dropout
163
+ * add convnext configs
164
+
165
+
166
+ ## 2.7.0
167
+
168
+ * add multilingual H/14 xlm roberta large
169
+
170
+ ## 2.6.1
171
+
172
+ * fix setup.py _read_reqs
173
+
174
+ ## 2.6.0
175
+
176
+ * Make openclip training usable from pypi.
177
+ * Add xlm roberta large vit h 14 config.
178
+
179
+ ## 2.5.0
180
+
181
+ * pretrained B/32 xlm roberta base: first multilingual clip trained on laion5B
182
+ * pretrained B/32 roberta base: first clip trained using an HF text encoder
183
+
184
+ ## 2.4.1
185
+
186
+ * Add missing hf_tokenizer_name in CLIPTextCfg.
187
+
188
+ ## 2.4.0
189
+
190
+ * Fix #211, missing RN50x64 config. Fix type of dropout param for ResNet models
191
+ * Bring back LayerNorm impl that casts to input for non bf16/fp16
192
+ * zero_shot.py: set correct tokenizer based on args
193
+ * training/params.py: remove hf params and get them from model config
194
+
195
+ ## 2.3.1
196
+
197
+ * Implement grad checkpointing for hf model.
198
+ * custom_text: True if hf_model_name is set
199
+ * Disable hf tokenizer parallelism
200
+
201
+ ## 2.3.0
202
+
203
+ * Generalizable Text Transformer with HuggingFace Models (@iejMac)
204
+
205
+ ## 2.2.0
206
+
207
+ * Support for custom text tower
208
+ * Add checksum verification for pretrained model weights
209
+
210
+ ## 2.1.0
211
+
212
+ * lot including sota models, bfloat16 option, better loading, better metrics
213
+
214
+ ## 1.2.0
215
+
216
+ * ViT-B/32 trained on Laion2B-en
217
+ * add missing openai RN50x64 model
218
+
219
+ ## 1.1.1
220
+
221
+ * ViT-B/16+
222
+ * Add grad checkpointing support
223
+ * more robust data loader
LICENSE ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2012-2021 Gabriel Ilharco, Mitchell Wortsman,
2
+ Nicholas Carlini, Rohan Taori, Achal Dave, Vaishaal Shankar,
3
+ John Miller, Hongseok Namkoong, Hannaneh Hajishirzi, Ali Farhadi,
4
+ Ludwig Schmidt
5
+
6
+ Permission is hereby granted, free of charge, to any person obtaining
7
+ a copy of this software and associated documentation files (the
8
+ "Software"), to deal in the Software without restriction, including
9
+ without limitation the rights to use, copy, modify, merge, publish,
10
+ distribute, sublicense, and/or sell copies of the Software, and to
11
+ permit persons to whom the Software is furnished to do so, subject to
12
+ the following conditions:
13
+
14
+ The above copyright notice and this permission notice shall be
15
+ included in all copies or substantial portions of the Software.
16
+
17
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
18
+ EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
19
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
20
+ NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
21
+ LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
22
+ OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
23
+ WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
MANIFEST.in ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ include src/open_clip/bpe_simple_vocab_16e6.txt.gz
2
+ include src/open_clip/model_configs/*.json
3
+
README.md ADDED
@@ -0,0 +1,618 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # OpenCLIP
2
+
3
+ [[Paper]](https://arxiv.org/abs/2212.07143) [[Citations]](#citing) [[Clip Colab]](https://colab.research.google.com/github/mlfoundations/open_clip/blob/master/docs/Interacting_with_open_clip.ipynb) [[Coca Colab]](https://colab.research.google.com/github/mlfoundations/open_clip/blob/master/docs/Interacting_with_open_coca.ipynb)
4
+ [![pypi](https://img.shields.io/pypi/v/open_clip_torch.svg)](https://pypi.python.org/pypi/open_clip_torch)
5
+
6
+ Welcome to an open source implementation of OpenAI's [CLIP](https://arxiv.org/abs/2103.00020) (Contrastive Language-Image Pre-training).
7
+
8
+ Using this codebase, we have trained several models on a variety of data sources and compute budgets, ranging from [small-scale experiments](docs/LOW_ACC.md) to larger runs including models trained on datasets such as [LAION-400M](https://arxiv.org/abs/2111.02114), [LAION-2B](https://arxiv.org/abs/2210.08402) and [DataComp-1B](https://arxiv.org/abs/2304.14108).
9
+ Many of our models and their scaling properties are studied in detail in the paper [reproducible scaling laws for contrastive language-image learning](https://arxiv.org/abs/2212.07143).
10
+ Some of the best models we've trained and their zero-shot ImageNet-1k accuracy are shown below, along with the ViT-L model trained by OpenAI and other state-of-the-art open source alternatives (all can be loaded via OpenCLIP).
11
+ We provide more details about our full collection of pretrained models [here](docs/PRETRAINED.md), and zero-shot results for 38 datasets [here](docs/openclip_results.csv).
12
+
13
+
14
+
15
+ | Model | Training data | Resolution | # of samples seen | ImageNet zero-shot acc. |
16
+ | -------- | ------- | ------- | ------- | ------- |
17
+ | ConvNext-Base | LAION-2B | 256px | 13B | 71.5% |
18
+ | ConvNext-Large | LAION-2B | 320px | 29B | 76.9% |
19
+ | ConvNext-XXLarge | LAION-2B | 256px | 34B | 79.5% |
20
+ | ViT-B/32 | DataComp-1B | 256px | 34B | 72.8% |
21
+ | ViT-B/16 | DataComp-1B | 224px | 13B | 73.5% |
22
+ | ViT-L/14 | LAION-2B | 224px | 32B | 75.3% |
23
+ | ViT-H/14 | LAION-2B | 224px | 32B | 78.0% |
24
+ | ViT-L/14 | DataComp-1B | 224px | 13B | 79.2% |
25
+ | ViT-G/14 | LAION-2B | 224px | 34B | 80.1% |
26
+ | | | | | |
27
+ | ViT-L/14-quickgelu [(Original CLIP)](https://arxiv.org/abs/2103.00020) | WIT | 224px | 13B | 75.5% |
28
+ | ViT-SO400M/14 [(SigLIP)](https://arxiv.org/abs/2303.15343) | WebLI | 224px | 45B | 82.0% |
29
+ | ViT-L/14 [(DFN)](https://arxiv.org/abs/2309.17425) | DFN-2B | 224px | 39B | 82.2% |
30
+ | ViT-SO400M-14-SigLIP-384 [(SigLIP)](https://arxiv.org/abs/2303.15343) | WebLI | 384px | 45B | 83.1% |
31
+ | ViT-H/14-quickgelu [(DFN)](https://arxiv.org/abs/2309.17425) | DFN-5B | 224px | 39B | 83.4% |
32
+ | ViT-H-14-378-quickgelu [(DFN)](https://arxiv.org/abs/2309.17425) | DFN-5B | 378px | 44B | 84.4% |
33
+
34
+ Model cards with additional model specific details can be found on the Hugging Face Hub under the OpenCLIP library tag: https://huggingface.co/models?library=open_clip.
35
+
36
+ If you found this repository useful, please consider [citing](#citing).
37
+ We welcome anyone to submit an issue or send an email if you have any other requests or suggestions.
38
+
39
+ Note that portions of `src/open_clip/` modelling and tokenizer code are adaptations of OpenAI's official [repository](https://github.com/openai/CLIP).
40
+
41
+ ## Approach
42
+
43
+ | ![CLIP](https://raw.githubusercontent.com/mlfoundations/open_clip/main/docs/CLIP.png) |
44
+ |:--:|
45
+ | Image Credit: https://github.com/openai/CLIP |
46
+
47
+ ## Usage
48
+
49
+ ```
50
+ pip install open_clip_torch
51
+ ```
52
+
53
+ ```python
54
+ import torch
55
+ from PIL import Image
56
+ import open_clip
57
+
58
+ model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
59
+ model.eval() # model in train mode by default, impacts some models with BatchNorm or stochastic depth active
60
+ tokenizer = open_clip.get_tokenizer('ViT-B-32')
61
+
62
+ image = preprocess(Image.open("docs/CLIP.png")).unsqueeze(0)
63
+ text = tokenizer(["a diagram", "a dog", "a cat"])
64
+
65
+ with torch.no_grad(), torch.autocast("cuda"):
66
+ image_features = model.encode_image(image)
67
+ text_features = model.encode_text(text)
68
+ image_features /= image_features.norm(dim=-1, keepdim=True)
69
+ text_features /= text_features.norm(dim=-1, keepdim=True)
70
+
71
+ text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
72
+
73
+ print("Label probs:", text_probs) # prints: [[1., 0., 0.]]
74
+ ```
75
+
76
+ If model uses `timm` image encoders (convnext, siglip, eva, etc) ensure the latest timm is installed. Upgrade `timm` if you see 'Unknown model' errors for the image encoder.
77
+
78
+ If model uses transformers tokenizers, ensure `transformers` is installed.
79
+
80
+ See also this [[Clip Colab]](https://colab.research.google.com/github/mlfoundations/open_clip/blob/master/docs/Interacting_with_open_clip.ipynb).
81
+
82
+ To compute billions of embeddings efficiently, you can use [clip-retrieval](https://github.com/rom1504/clip-retrieval) which has openclip support.
83
+
84
+ ### Pretrained models
85
+
86
+ We offer a simple model interface to instantiate both pre-trained and untrained models.
87
+ To see which pretrained models are available, use the following code snippet.
88
+ More details about our pretrained models are available [here](docs/PRETRAINED.md).
89
+
90
+ ```python
91
+ >>> import open_clip
92
+ >>> open_clip.list_pretrained()
93
+ ```
94
+
95
+ You can find more about the models we support (e.g. number of parameters, FLOPs) in [this table](docs/model_profile.csv).
96
+
97
+ NOTE: Many existing checkpoints use the QuickGELU activation from the original OpenAI models. This activation is actually less efficient than native torch.nn.GELU in recent versions of PyTorch. The model defaults are now nn.GELU, so one should use model definitions with `-quickgelu` postfix for the OpenCLIP pretrained weights. All OpenAI pretrained weights will always default to QuickGELU. One can also use the non `-quickgelu` model definitions with pretrained weights using QuickGELU but there will be an accuracy drop, for fine-tune that will likely vanish for longer runs.
98
+ Future trained models will use nn.GELU.
99
+
100
+ ### Loading models
101
+
102
+ Models can be loaded with `open_clip.create_model_and_transforms`, as shown in the example below. The model name and corresponding `pretrained` keys are compatible with the outputs of `open_clip.list_pretrained()`.
103
+
104
+ The `pretrained` argument also accepts local paths, for example `/path/to/my/b32.pt`.
105
+ You can also load checkpoints from huggingface this way. To do so, download the `open_clip_pytorch_model.bin` file (for example, [https://huggingface.co/laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K/tree/main](https://huggingface.co/laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K/blob/main/open_clip_pytorch_model.bin)), and use `pretrained=/path/to/open_clip_pytorch_model.bin`.
106
+
107
+ ```python
108
+ # pretrained also accepts local paths
109
+ model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
110
+ ```
111
+
112
+ ## Fine-tuning on classification tasks
113
+
114
+ This repository is focused on training CLIP models. To fine-tune a *trained* zero-shot model on a downstream classification task such as ImageNet, please see [our other repository: WiSE-FT](https://github.com/mlfoundations/wise-ft). The [WiSE-FT repository](https://github.com/mlfoundations/wise-ft) contains code for our paper on [Robust Fine-tuning of Zero-shot Models](https://arxiv.org/abs/2109.01903), in which we introduce a technique for fine-tuning zero-shot models while preserving robustness under distribution shift.
115
+
116
+ ## Data
117
+
118
+ To download datasets as webdataset, we recommend [img2dataset](https://github.com/rom1504/img2dataset).
119
+
120
+ ### Conceptual Captions
121
+
122
+ See [cc3m img2dataset example](https://github.com/rom1504/img2dataset/blob/main/dataset_examples/cc3m.md).
123
+
124
+ ### YFCC and other datasets
125
+
126
+ In addition to specifying the training data via CSV files as mentioned above, our codebase also supports [webdataset](https://github.com/webdataset/webdataset), which is recommended for larger scale datasets. The expected format is a series of `.tar` files. Each of these `.tar` files should contain two files for each training example, one for the image and one for the corresponding text. Both files should have the same name but different extensions. For instance, `shard_001.tar` could contain files such as `abc.jpg` and `abc.txt`. You can learn more about `webdataset` at [https://github.com/webdataset/webdataset](https://github.com/webdataset/webdataset). We use `.tar` files with 1,000 data points each, which we create using [tarp](https://github.com/webdataset/tarp).
127
+
128
+ You can download the YFCC dataset from [Multimedia Commons](http://mmcommons.org/).
129
+ Similar to OpenAI, we used a subset of YFCC to reach the aforementioned accuracy numbers.
130
+ The indices of images in this subset are in [OpenAI's CLIP repository](https://github.com/openai/CLIP/blob/main/data/yfcc100m.md).
131
+
132
+
133
+ ## Training CLIP
134
+
135
+ ### Install
136
+
137
+ We advise you first create a virtual environment with:
138
+
139
+ ```
140
+ python3 -m venv .env
141
+ source .env/bin/activate
142
+ pip install -U pip
143
+ ```
144
+
145
+ You can then install openclip for training with `pip install 'open_clip_torch[training]'`.
146
+
147
+ #### Development
148
+
149
+ If you want to make changes to contribute code, you can clone openclip then run `make install` in openclip folder (after creating a virtualenv)
150
+
151
+ Install pip PyTorch as per https://pytorch.org/get-started/locally/
152
+
153
+ You may run `make install-training` to install training deps
154
+
155
+ #### Testing
156
+
157
+ Test can be run with `make install-test` then `make test`
158
+
159
+ `python -m pytest -x -s -v tests -k "training"` to run a specific test
160
+
161
+ Running regression tests against a specific git revision or tag:
162
+ 1. Generate testing data
163
+ ```sh
164
+ python tests/util_test.py --model RN50 RN101 --save_model_list models.txt --git_revision 9d31b2ec4df6d8228f370ff20c8267ec6ba39383
165
+ ```
166
+ **_WARNING_: This will invoke git and modify your working tree, but will reset it to the current state after data has been generated! \
167
+ Don't modify your working tree while test data is being generated this way.**
168
+
169
+ 2. Run regression tests
170
+ ```sh
171
+ OPEN_CLIP_TEST_REG_MODELS=models.txt python -m pytest -x -s -v -m regression_test
172
+ ```
173
+
174
+ ### Sample single-process running code:
175
+
176
+ ```bash
177
+ python -m open_clip_train.main \
178
+ --save-frequency 1 \
179
+ --zeroshot-frequency 1 \
180
+ --report-to tensorboard \
181
+ --train-data="/path/to/train_data.csv" \
182
+ --val-data="/path/to/validation_data.csv" \
183
+ --csv-img-key filepath \
184
+ --csv-caption-key title \
185
+ --imagenet-val=/path/to/imagenet/root/val/ \
186
+ --warmup 10000 \
187
+ --batch-size=128 \
188
+ --lr=1e-3 \
189
+ --wd=0.1 \
190
+ --epochs=30 \
191
+ --workers=8 \
192
+ --model RN50
193
+ ```
194
+
195
+ Note: `imagenet-val` is the path to the *validation* set of ImageNet for zero-shot evaluation, not the training set!
196
+ You can remove this argument if you do not want to perform zero-shot evaluation on ImageNet throughout training. Note that the `val` folder should contain subfolders. If it does not, please use [this script](https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh).
197
+
198
+ ### Multi-GPU and Beyond
199
+
200
+ This code has been battle tested up to 1024 A100s and offers a variety of solutions
201
+ for distributed training. We include native support for SLURM clusters.
202
+
203
+ As the number of devices used to train increases, so does the space complexity of
204
+ the the logit matrix. Using a naïve all-gather scheme, space complexity will be
205
+ `O(n^2)`. Instead, complexity may become effectively linear if the flags
206
+ `--gather-with-grad` and `--local-loss` are used. This alteration results in one-to-one
207
+ numerical results as the naïve method.
208
+
209
+ #### Epochs
210
+
211
+ For larger datasets (eg Laion2B), we recommend setting `--train-num-samples` to a lower value than the full epoch, for example `--train-num-samples 135646078` to 1/16 of an epoch in conjunction with `--dataset-resampled` to do sampling with replacement. This allows having frequent checkpoints to evaluate more often.
212
+
213
+ #### Patch Dropout
214
+
215
+ <a href="https://arxiv.org/abs/2212.00794">Recent research</a> has shown that one can dropout half to three-quarters of the visual tokens, leading to up to 2-3x training speeds without loss of accuracy.
216
+
217
+ You can set this on your visual transformer config with the key `patch_dropout`.
218
+
219
+ In the paper, they also finetuned without the patch dropout at the end. You can do this with the command-line argument `--force-patch-dropout 0.`
220
+
221
+ #### Multiple data sources
222
+
223
+ OpenCLIP supports using multiple data sources, by separating different data paths with `::`.
224
+ For instance, to train on CC12M and on LAION, one might use `--train-data "/data/cc12m/cc12m-train-{0000..2175}.tar::/data/LAION-400M/{00000..41455}.tar"`.
225
+ Using `--dataset-resampled` is recommended for these cases.
226
+
227
+ By default, on expectation the amount of times the model will see a sample from each source is proportional to the size of the source.
228
+ For instance, when training on one data source with size 400M and one with size 10M, samples from the first source are 40x more likely to be seen in expectation.
229
+
230
+ We also support different weighting of the data sources, by using the `--train-data-upsampling-factors` flag.
231
+ For instance, using `--train-data-upsampling-factors=1::1` in the above scenario is equivalent to not using the flag, and `--train-data-upsampling-factors=1::2` is equivalent to upsampling the second data source twice.
232
+ If you want to sample from data sources with the same frequency, the upsampling factors should be inversely proportional to the sizes of the data sources.
233
+ For instance, if dataset `A` has 1000 samples and dataset `B` has 100 samples, you can use `--train-data-upsampling-factors=0.001::0.01` (or analogously, `--train-data-upsampling-factors=1::10`).
234
+
235
+ #### Single-Node
236
+
237
+ We make use of `torchrun` to launch distributed jobs. The following launches a
238
+ a job on a node of 4 GPUs:
239
+
240
+ ```bash
241
+ cd open_clip/src
242
+ torchrun --nproc_per_node 4 -m open_clip_train.main \
243
+ --train-data '/data/cc12m/cc12m-train-{0000..2175}.tar' \
244
+ --train-num-samples 10968539 \
245
+ --dataset-type webdataset \
246
+ --batch-size 320 \
247
+ --precision amp \
248
+ --workers 4 \
249
+ --imagenet-val /data/imagenet/validation/
250
+ ```
251
+
252
+ #### Multi-Node
253
+
254
+ The same script above works, so long as users include information about the number
255
+ of nodes and host node.
256
+
257
+ ```bash
258
+ cd open_clip/src
259
+ torchrun --nproc_per_node=4 \
260
+ --rdzv_endpoint=$HOSTE_NODE_ADDR \
261
+ -m open_clip_train.main \
262
+ --train-data '/data/cc12m/cc12m-train-{0000..2175}.tar' \
263
+ --train-num-samples 10968539 \
264
+ --dataset-type webdataset \
265
+ --batch-size 320 \
266
+ --precision amp \
267
+ --workers 4 \
268
+ --imagenet-val /data/imagenet/validation/
269
+ ```
270
+
271
+ #### SLURM
272
+
273
+ This is likely the easiest solution to utilize. The following script was used to
274
+ train our largest models:
275
+
276
+ ```bash
277
+ #!/bin/bash -x
278
+ #SBATCH --nodes=32
279
+ #SBATCH --gres=gpu:4
280
+ #SBATCH --ntasks-per-node=4
281
+ #SBATCH --cpus-per-task=6
282
+ #SBATCH --wait-all-nodes=1
283
+ #SBATCH --job-name=open_clip
284
+ #SBATCH --account=ACCOUNT_NAME
285
+ #SBATCH --partition PARTITION_NAME
286
+
287
+ eval "$(/path/to/conda/bin/conda shell.bash hook)" # init conda
288
+ conda activate open_clip
289
+ export CUDA_VISIBLE_DEVICES=0,1,2,3
290
+ export MASTER_PORT=12802
291
+
292
+ master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
293
+ export MASTER_ADDR=$master_addr
294
+
295
+ cd /shared/open_clip
296
+ export PYTHONPATH="$PYTHONPATH:$PWD/src"
297
+ srun --cpu_bind=v --accel-bind=gn python -u src/open_clip_train/main.py \
298
+ --save-frequency 1 \
299
+ --report-to tensorboard \
300
+ --train-data="/data/LAION-400M/{00000..41455}.tar" \
301
+ --warmup 2000 \
302
+ --batch-size=256 \
303
+ --epochs=32 \
304
+ --workers=8 \
305
+ --model ViT-B-32 \
306
+ --name "ViT-B-32-Vanilla" \
307
+ --seed 0 \
308
+ --local-loss \
309
+ --gather-with-grad
310
+ ```
311
+
312
+ ### Resuming from a checkpoint:
313
+
314
+ ```bash
315
+ python -m open_clip_train.main \
316
+ --train-data="/path/to/train_data.csv" \
317
+ --val-data="/path/to/validation_data.csv" \
318
+ --resume /path/to/checkpoints/epoch_K.pt
319
+ ```
320
+
321
+ ### Training CoCa:
322
+ Training [CoCa](https://arxiv.org/abs/2205.01917) models is enabled through specifying a CoCa config using the ```--model``` parameter of the training script. Currently available configs are "coca_base", "coca_ViT-B-32", and "coca_roberta-ViT-B-32" (which uses RoBERTa as the text encoder). CoCa configs are different from CLIP configs because they have an additional "multimodal_cfg" component which specifies parameters for the multimodal text decoder. Here's an example from the coca_ViT-B-32 config:
323
+ ```json
324
+ "multimodal_cfg": {
325
+ "context_length": 76,
326
+ "vocab_size": 49408,
327
+ "width": 512,
328
+ "heads": 8,
329
+ "layers": 12,
330
+ "latent_dim": 512,
331
+ "attn_pooler_heads": 8
332
+ }
333
+ ```
334
+ Credit to [lucidrains](https://github.com/lucidrains) for [initial code](https://github.com/lucidrains/CoCa-pytorch), [gpucce](https://github.com/gpucce) for adapting the code to open_clip, and [iejMac](https://github.com/iejMac) for training the models.
335
+
336
+ ### Generating text with CoCa
337
+
338
+ ```python
339
+ import open_clip
340
+ import torch
341
+ from PIL import Image
342
+
343
+ model, _, transform = open_clip.create_model_and_transforms(
344
+ model_name="coca_ViT-L-14",
345
+ pretrained="mscoco_finetuned_laion2B-s13B-b90k"
346
+ )
347
+
348
+ im = Image.open("cat.jpg").convert("RGB")
349
+ im = transform(im).unsqueeze(0)
350
+
351
+ with torch.no_grad(), torch.cuda.amp.autocast():
352
+ generated = model.generate(im)
353
+
354
+ print(open_clip.decode(generated[0]).split("<end_of_text>")[0].replace("<start_of_text>", ""))
355
+ ```
356
+
357
+ See also this [[Coca Colab]](https://colab.research.google.com/github/mlfoundations/open_clip/blob/master/docs/Interacting_with_open_coca.ipynb)
358
+
359
+ ### Fine Tuning CoCa
360
+
361
+ To fine-tune coca on mscoco, first create the dataset, one way is using a csvdataset and perhaps the simplest way to do it is using [CLIP_benchmark](https://github.com/LAION-AI/CLIP_benchmark) which in turn uses [pycocotools](https://github.com/cocodataset/cocoapi) (that can be used also by itself).
362
+
363
+ ```python
364
+ from clip_benchmark.datasets.builder import build_dataset
365
+ import pandas as pd
366
+ import os
367
+
368
+ root_path = "path/to/data/dir" # set this to smth meaningful
369
+ ds = build_dataset("mscoco_captions", root=root_path, split="train", task="captioning") # this downloads the dataset if it is not there already
370
+ coco = ds.coco
371
+ imgs = coco.loadImgs(coco.getImgIds())
372
+ future_df = {"filepath":[], "title":[]}
373
+ for img in imgs:
374
+ caps = coco.imgToAnns[img["id"]]
375
+ for cap in caps:
376
+ future_df["filepath"].append(img["file_name"])
377
+ future_df["title"].append(cap["caption"])
378
+ pd.DataFrame.from_dict(future_df).to_csv(
379
+ os.path.join(root_path, "train2014.csv"), index=False, sep="\t"
380
+ )
381
+ ```
382
+ This should create a csv dataset that one can use to fine-tune coca with open_clip
383
+ ```bash
384
+ python -m open_clip_train.main \
385
+ --dataset-type "csv" \
386
+ --train-data "path/to/data/dir/train2014.csv" \
387
+ --warmup 1000 \
388
+ --batch-size 128 \
389
+ --lr 1e-5 \
390
+ --wd 0.1 \
391
+ --epochs 1 \
392
+ --workers 3 \
393
+ --model "coca_ViT-L-14" \
394
+ --report-to "wandb" \
395
+ --coca-contrastive-loss-weight 0 \
396
+ --coca-caption-loss-weight 1 \
397
+ --log-every-n-steps 100
398
+ ```
399
+
400
+ This is a general setting, open_clip has very parameters that can be set, ```python -m open_clip_train.main --help``` should show them. The only relevant change compared to pre-training are the two arguments
401
+
402
+ ```bash
403
+ --coca-contrastive-loss-weight 0
404
+ --coca-caption-loss-weight 1
405
+ ```
406
+ which make the model only train the generative side.
407
+
408
+ ### Training with pre-trained language models as text encoder:
409
+
410
+ If you wish to use different language models as the text encoder for CLIP you can do so by using one of the Hugging Face model configs in ```src/open_clip/model_configs``` and passing in it's tokenizer as the ```--model``` and ```--hf-tokenizer-name``` parameters respectively. Currently we only support RoBERTa ("test-roberta" config), however adding new models should be trivial. You can also determine how many layers, from the end, to leave unfrozen with the ```--lock-text-unlocked-layers``` parameter. Here's an example command to train CLIP with the RoBERTa LM that has it's last 10 layers unfrozen:
411
+ ```bash
412
+ python -m open_clip_train.main \
413
+ --train-data="pipe:aws s3 cp s3://s-mas/cc3m/{00000..00329}.tar -" \
414
+ --train-num-samples 3000000 \
415
+ --val-data="pipe:aws s3 cp s3://s-mas/cc3m/{00330..00331}.tar -" \
416
+ --val-num-samples 10000 \
417
+ --dataset-type webdataset \
418
+ --batch-size 256 \
419
+ --warmup 2000 \
420
+ --epochs 10 \
421
+ --lr 5e-4 \
422
+ --precision amp \
423
+ --workers 6 \
424
+ --model "roberta-ViT-B-32" \
425
+ --lock-text \
426
+ --lock-text-unlocked-layers 10 \
427
+ --name "10_unfrozen" \
428
+ --report-to "tensorboard" \
429
+ ```
430
+
431
+ ### Loss Curves
432
+
433
+ When run on a machine with 8 GPUs the command should produce the following training curve for Conceptual Captions:
434
+
435
+ ![CLIP zero shot training curve](https://raw.githubusercontent.com/mlfoundations/open_clip/main/docs/clip_zeroshot.png)
436
+
437
+ More detailed curves for Conceptual Captions are given at [/docs/clip_conceptual_captions.md](/docs/clip_conceptual_captions.md).
438
+
439
+ When training a RN50 on YFCC the same hyperparameters as above are used, with the exception of `lr=5e-4` and `epochs=32`.
440
+
441
+ Note that to use another model, like `ViT-B/32` or `RN50x4` or `RN50x16` or `ViT-B/16`, specify with `--model RN50x4`.
442
+
443
+ ### Logging
444
+
445
+ For tensorboard logging, run:
446
+ ```bash
447
+ tensorboard --logdir=logs/tensorboard/ --port=7777
448
+ ```
449
+
450
+ For wandb logging, we recommend looking at the `step` variable instead of `Step`, since the later was not properly set in earlier versions of this codebase.
451
+ For older runs with models trained before https://github.com/mlfoundations/open_clip/pull/613, the `Step` variable should be ignored.
452
+ For newer runs, after that PR, the two variables are the same.
453
+
454
+ ## Evaluation / Zero-Shot
455
+
456
+ We recommend https://github.com/LAION-AI/CLIP_benchmark#how-to-use for systematic evaluation on 40 datasets.
457
+
458
+ ### Evaluating local checkpoint:
459
+
460
+ ```bash
461
+ python -m open_clip_train.main \
462
+ --val-data="/path/to/validation_data.csv" \
463
+ --model RN101 \
464
+ --pretrained /path/to/checkpoints/epoch_K.pt
465
+ ```
466
+
467
+ ### Evaluating hosted pretrained checkpoint on ImageNet zero-shot prediction:
468
+
469
+ ```bash
470
+ python -m open_clip_train.main \
471
+ --imagenet-val /path/to/imagenet/validation \
472
+ --model ViT-B-32-quickgelu \
473
+ --pretrained laion400m_e32
474
+ ```
475
+
476
+ ### Model distillation
477
+
478
+ You can distill from a pre-trained by using `--distill-model` and `--distill-pretrained` to specify the model you'd like to distill from.
479
+ For instance, to distill from OpenAI ViT-L/14 use `--distill-model ViT-L-14 --distill-pretrained openai`.
480
+
481
+ ### Gradient accumulation
482
+
483
+ To simulate larger batches use `--accum-freq k`. If per gpu batch size, `--batch-size`, is `m`, then the effective batch size will be `k * m * num_gpus`.
484
+
485
+ When increasing `--accum-freq` from its default of 1, samples/s will remain approximately constant (batch size will double, as will time-per-batch). It is recommended to use other features to reduce batch size such as `--grad-checkpointing --local-loss --gather-with-grad` before increasing `--accum-freq`. `--accum-freq` can be used in addition to these features.
486
+
487
+ Instead of 1 forward pass per example, there are now 2 forward passes per-example. However, the first is done with `torch.no_grad`.
488
+
489
+ There is some additional GPU memory required --- the features and data from all `m` batches are stored in memory.
490
+
491
+ There are also `m` loss computations instead of the usual 1.
492
+
493
+ For more information see Cui et al. (https://arxiv.org/abs/2112.09331) or Pham et al. (https://arxiv.org/abs/2111.10050).
494
+
495
+ ### Int8 Support
496
+
497
+ We have beta support for int8 training and inference.
498
+ You can enable int8 training with `--use-bnb-linear SwitchBackLinearGlobal` or `--use-bnb-linear SwitchBackLinearGlobalMemEfficient`.
499
+ Please see the bitsandbytes library for definitions for these layers.
500
+ For CLIP VIT-Huge this should currently correspond to a 10% training speedup with no accuracy loss.
501
+ More speedups comin when the attention layer is refactored so that linear layers man be replaced there, too.
502
+
503
+ See the tutorial https://github.com/mlfoundations/open_clip/blob/main/tutorials/int8_tutorial.ipynb or [paper](https://arxiv.org/abs/2304.13013).
504
+
505
+ ### Support for remote loading/training
506
+
507
+ It is always possible to resume directly from a remote file, e.g., a file in an s3 bucket. Just set `--resume s3://<path-to-checkpoint> `.
508
+ This will work with any filesystem supported by `fsspec`.
509
+
510
+ It is also possible to train `open_clip` models while continuously backing up to s3. This can help to avoid slow local file systems.
511
+
512
+ Say that your node has a local ssd `/scratch`, an s3 bucket `s3://<path-to-bucket>`.
513
+
514
+ In that case, set `--logs /scratch` and `--remote-sync s3://<path-to-bucket>`. Then, a background process will sync `/scratch/<run-name>` to `s3://<path-to-bucket>/<run-name>`. After syncing, the background process will sleep for `--remote-sync-frequency` seconds, which defaults to 5 minutes.
515
+
516
+ There is also experimental support for syncing to other remote file systems, not just s3. To do so, specify `--remote-sync-protocol fsspec`. However, this is currently very slow and not recommended.
517
+
518
+ Also, to optionally avoid saving too many checkpoints locally when using these features, you can use `--delete-previous-checkpoint` which deletes the previous checkpoint after saving a new one.
519
+
520
+ Note: if you are using this feature with `--resume latest`, there are a few warnings. First, use with `--save-most-recent` is not supported. Second, only `s3` is supported. Finally, since the sync happens in the background, it is possible that the most recent checkpoint may not be finished syncing to the remote.
521
+
522
+ ### Pushing Models to Hugging Face Hub
523
+
524
+ The module `open_clip.push_to_hf_hub` includes helpers for pushing models /w weights and config to the HF Hub.
525
+
526
+ The tool can be run from command line, ex:
527
+ `python -m open_clip.push_to_hf_hub --model convnext_large_d_320 --pretrained /train/checkpoints/epoch_12.pt --repo-id laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft`
528
+
529
+
530
+
531
+ ## Acknowledgments
532
+
533
+ We gratefully acknowledge the Gauss Centre for Supercomputing e.V. (www.gauss-centre.eu) for funding this part of work by providing computing time through the John von Neumann Institute for Computing (NIC) on the GCS Supercomputer JUWELS Booster at Jülich Supercomputing Centre (JSC).
534
+
535
+ ## The Team
536
+
537
+ Current development of this repository is led by [Ross Wightman](https://rwightman.com/), [Romain Beaumont](https://github.com/rom1504), [Cade Gordon](http://cadegordon.io/), and [Vaishaal Shankar](http://vaishaal.com/).
538
+
539
+ The original version of this repository is from a group of researchers at UW, Google, Stanford, Amazon, Columbia, and Berkeley.
540
+
541
+ [Gabriel Ilharco*](http://gabrielilharco.com/), [Mitchell Wortsman*](https://mitchellnw.github.io/), [Nicholas Carlini](https://nicholas.carlini.com/), [Rohan Taori](https://www.rohantaori.com/), [Achal Dave](http://www.achaldave.com/), [Vaishaal Shankar](http://vaishaal.com/), [John Miller](https://people.eecs.berkeley.edu/~miller_john/), [Hongseok Namkoong](https://hsnamkoong.github.io/), [Hannaneh Hajishirzi](https://homes.cs.washington.edu/~hannaneh/), [Ali Farhadi](https://homes.cs.washington.edu/~ali/), [Ludwig Schmidt](https://people.csail.mit.edu/ludwigs/)
542
+
543
+ Special thanks to [Jong Wook Kim](https://jongwook.kim/) and [Alec Radford](https://github.com/Newmu) for help with reproducing CLIP!
544
+
545
+ ## Citing
546
+
547
+ If you found this repository useful, please consider citing:
548
+ ```bibtex
549
+ @software{ilharco_gabriel_2021_5143773,
550
+ author = {Ilharco, Gabriel and
551
+ Wortsman, Mitchell and
552
+ Wightman, Ross and
553
+ Gordon, Cade and
554
+ Carlini, Nicholas and
555
+ Taori, Rohan and
556
+ Dave, Achal and
557
+ Shankar, Vaishaal and
558
+ Namkoong, Hongseok and
559
+ Miller, John and
560
+ Hajishirzi, Hannaneh and
561
+ Farhadi, Ali and
562
+ Schmidt, Ludwig},
563
+ title = {OpenCLIP},
564
+ month = jul,
565
+ year = 2021,
566
+ note = {If you use this software, please cite it as below.},
567
+ publisher = {Zenodo},
568
+ version = {0.1},
569
+ doi = {10.5281/zenodo.5143773},
570
+ url = {https://doi.org/10.5281/zenodo.5143773}
571
+ }
572
+ ```
573
+
574
+ ```bibtex
575
+ @inproceedings{cherti2023reproducible,
576
+ title={Reproducible scaling laws for contrastive language-image learning},
577
+ author={Cherti, Mehdi and Beaumont, Romain and Wightman, Ross and Wortsman, Mitchell and Ilharco, Gabriel and Gordon, Cade and Schuhmann, Christoph and Schmidt, Ludwig and Jitsev, Jenia},
578
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
579
+ pages={2818--2829},
580
+ year={2023}
581
+ }
582
+ ```
583
+
584
+ ```bibtex
585
+ @inproceedings{Radford2021LearningTV,
586
+ title={Learning Transferable Visual Models From Natural Language Supervision},
587
+ author={Alec Radford and Jong Wook Kim and Chris Hallacy and A. Ramesh and Gabriel Goh and Sandhini Agarwal and Girish Sastry and Amanda Askell and Pamela Mishkin and Jack Clark and Gretchen Krueger and Ilya Sutskever},
588
+ booktitle={ICML},
589
+ year={2021}
590
+ }
591
+ ```
592
+
593
+ ```bibtex
594
+ @inproceedings{schuhmann2022laionb,
595
+ title={{LAION}-5B: An open large-scale dataset for training next generation image-text models},
596
+ author={Christoph Schuhmann and
597
+ Romain Beaumont and
598
+ Richard Vencu and
599
+ Cade W Gordon and
600
+ Ross Wightman and
601
+ Mehdi Cherti and
602
+ Theo Coombes and
603
+ Aarush Katta and
604
+ Clayton Mullis and
605
+ Mitchell Wortsman and
606
+ Patrick Schramowski and
607
+ Srivatsa R Kundurthy and
608
+ Katherine Crowson and
609
+ Ludwig Schmidt and
610
+ Robert Kaczmarczyk and
611
+ Jenia Jitsev},
612
+ booktitle={Thirty-sixth Conference on Neural Information Processing Systems Datasets and Benchmarks Track},
613
+ year={2022},
614
+ url={https://openreview.net/forum?id=M3Y74vmsMcY}
615
+ }
616
+ ```
617
+
618
+ [![DOI](https://zenodo.org/badge/390536799.svg)](https://zenodo.org/badge/latestdoi/390536799)
models.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ RN101
2
+ RN50
pytest.ini ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [pytest]
2
+ markers =
3
+ regression_test
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch>=1.9.0
2
+ torchvision
3
+ regex
4
+ ftfy
5
+ tqdm
6
+ huggingface_hub
7
+ safetensors
8
+ timm
src/open_clip/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .version import __version__
2
+
3
+ from .coca_model import CoCa
4
+ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
5
+ from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss
6
+ from .factory import list_models, add_model_config, get_model_config, load_checkpoint
7
+ from .loss import ClipLoss, DistillClipLoss, CoCaLoss
8
+ from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \
9
+ convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype, \
10
+ get_model_tokenize_cfg, get_model_preprocess_cfg, set_model_preprocess_cfg
11
+ from .openai import load_openai_model, list_openai_models
12
+ from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
13
+ get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
14
+ from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub
15
+ from .tokenizer import SimpleTokenizer, tokenize, decode
16
+ from .transform import image_transform, AugmentationCfg
17
+ from .zero_shot_classifier import build_zero_shot_classifier, build_zero_shot_classifier_legacy
18
+ from .zero_shot_metadata import OPENAI_IMAGENET_TEMPLATES, SIMPLE_IMAGENET_TEMPLATES, IMAGENET_CLASSNAMES
src/open_clip/coca_model.py ADDED
@@ -0,0 +1,582 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Union
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ import numpy as np
7
+ from dataclasses import dataclass
8
+
9
+ from .transformer import (
10
+ LayerNormFp32,
11
+ LayerNorm,
12
+ QuickGELU,
13
+ MultimodalTransformer,
14
+ )
15
+ from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower
16
+
17
+ try:
18
+ from transformers import (
19
+ BeamSearchScorer,
20
+ LogitsProcessorList,
21
+ TopPLogitsWarper,
22
+ TopKLogitsWarper,
23
+ RepetitionPenaltyLogitsProcessor,
24
+ MinLengthLogitsProcessor,
25
+ MaxLengthCriteria,
26
+ StopStringCriteria,
27
+ EosTokenCriteria,
28
+ StoppingCriteriaList
29
+ )
30
+
31
+ GENERATION_TYPES = {
32
+ "top_k": TopKLogitsWarper,
33
+ "top_p": TopPLogitsWarper,
34
+ "beam_search": "beam_search"
35
+ }
36
+ _has_transformers = True
37
+ except ImportError as e:
38
+ GENERATION_TYPES = {
39
+ "top_k": None,
40
+ "top_p": None,
41
+ "beam_search": "beam_search"
42
+ }
43
+ _has_transformers = False
44
+
45
+
46
+ @dataclass
47
+ class MultimodalCfg(CLIPTextCfg):
48
+ mlp_ratio: int = 4
49
+ dim_head: int = 64
50
+ heads: int = 8
51
+ n_queries: int = 256
52
+ attn_pooler_heads: int = 8
53
+
54
+
55
+ def _build_text_decoder_tower(
56
+ embed_dim,
57
+ multimodal_cfg,
58
+ quick_gelu: bool = False,
59
+ cast_dtype: Optional[torch.dtype] = None,
60
+ ):
61
+ multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
62
+ act_layer = QuickGELU if quick_gelu else nn.GELU
63
+ norm_layer = (
64
+ LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
65
+ )
66
+
67
+ decoder = MultimodalTransformer(
68
+ context_length=multimodal_cfg.context_length,
69
+ width=multimodal_cfg.width,
70
+ heads=multimodal_cfg.heads,
71
+ layers=multimodal_cfg.layers,
72
+ ls_init_value=multimodal_cfg.ls_init_value,
73
+ output_dim=embed_dim,
74
+ act_layer=act_layer,
75
+ norm_layer=norm_layer,
76
+ )
77
+
78
+ return decoder
79
+
80
+
81
+ def _token_to_tensor(token_id, device: str = "cpu") -> torch.Tensor:
82
+ if not isinstance(token_id, torch.Tensor):
83
+ if isinstance(token_id, int):
84
+ token_id = [token_id]
85
+ token_id = torch.tensor(token_id, device=device)
86
+ return token_id
87
+
88
+
89
+ class CoCa(nn.Module):
90
+ def __init__(
91
+ self,
92
+ embed_dim,
93
+ multimodal_cfg: MultimodalCfg,
94
+ text_cfg: CLIPTextCfg,
95
+ vision_cfg: CLIPVisionCfg,
96
+ quick_gelu: bool = False,
97
+ init_logit_scale: float = np.log(1 / 0.07),
98
+ init_logit_bias: Optional[float] = None,
99
+ nonscalar_logit_scale: bool = False,
100
+ cast_dtype: Optional[torch.dtype] = None,
101
+ pad_id: int = 0,
102
+ ):
103
+ super().__init__()
104
+ multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
105
+ text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg
106
+ vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg
107
+
108
+ self.text = _build_text_tower(
109
+ embed_dim=embed_dim,
110
+ text_cfg=text_cfg,
111
+ quick_gelu=quick_gelu,
112
+ cast_dtype=cast_dtype,
113
+ )
114
+
115
+ vocab_size = (
116
+ text_cfg.vocab_size # for hf models
117
+ if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None
118
+ else text_cfg.vocab_size
119
+ )
120
+
121
+ self.visual = _build_vision_tower(
122
+ embed_dim=embed_dim,
123
+ vision_cfg=vision_cfg,
124
+ quick_gelu=quick_gelu,
125
+ cast_dtype=cast_dtype,
126
+ )
127
+
128
+ self.text_decoder = _build_text_decoder_tower(
129
+ vocab_size,
130
+ multimodal_cfg=multimodal_cfg,
131
+ quick_gelu=quick_gelu,
132
+ cast_dtype=cast_dtype,
133
+ )
134
+
135
+ lshape = [1] if nonscalar_logit_scale else []
136
+ self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale)
137
+ if init_logit_bias is not None:
138
+ self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias)
139
+ else:
140
+ self.logit_bias = None
141
+ self.pad_id = pad_id
142
+
143
+ self.context_length = multimodal_cfg.context_length
144
+
145
+ @torch.jit.ignore
146
+ def set_grad_checkpointing(self, enable: bool = True):
147
+ self.visual.set_grad_checkpointing(enable)
148
+ self.text.set_grad_checkpointing(enable)
149
+ self.text_decoder.set_grad_checkpointing(enable)
150
+
151
+ def _encode_image(self, images, normalize: bool = True):
152
+ image_latent, tokens_embs = self.visual(images)
153
+ image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent
154
+ return image_latent, tokens_embs
155
+
156
+ def _encode_text(self, text, normalize: bool = True):
157
+ text_latent, token_emb = self.text(text)
158
+ text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent
159
+ return text_latent, token_emb
160
+
161
+ def encode_image(self, images, normalize: bool = True):
162
+ image_latent, _ = self._encode_image(images, normalize=normalize)
163
+ return image_latent
164
+
165
+ def encode_text(self, text, normalize: bool = True):
166
+ text_latent, _ = self._encode_text(text, normalize=normalize)
167
+ return text_latent
168
+
169
+ def forward_intermediates(
170
+ self,
171
+ image: Optional[torch.Tensor] = None,
172
+ text: Optional[torch.Tensor] = None,
173
+ image_indices: Optional[Union[int, List[int]]] = None,
174
+ text_indices: Optional[Union[int, List[int]]] = None,
175
+ stop_early: bool = False,
176
+ normalize: bool = True,
177
+ normalize_intermediates: bool = False,
178
+ intermediates_only: bool = False,
179
+ image_output_fmt: str = 'NCHW',
180
+ image_output_extra_tokens: bool = False,
181
+ text_output_fmt: str = 'NLC',
182
+ text_output_extra_tokens: bool = False,
183
+ output_logits: bool = False,
184
+ output_logit_scale_bias: bool = False,
185
+ ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
186
+ """ Forward features that returns intermediates.
187
+
188
+ Args:
189
+ image: Input image tensor
190
+ text: Input text tensor
191
+ image_indices: For image tower, Take last n blocks if int, all if None, select matching indices if sequence
192
+ text_indices: Take last n blocks if int, all if None, select matching indices if sequence
193
+ stop_early: Stop iterating over blocks when last desired intermediate hit
194
+ normalize: L2 Normalize final image and text features (if present)
195
+ normalize_intermediates: Apply final encoder norm layer to all intermediates (if possible)
196
+ intermediates_only: Only return intermediate features, do not return final features
197
+ image_output_fmt: Shape of intermediate image feature outputs
198
+ image_output_extra_tokens: Return both prefix and spatial intermediate tokens
199
+ text_output_fmt: Shape of intermediate text feature outputs
200
+ text_output_extra_tokens: Return both prefix and spatial intermediate tokens
201
+ output_logits: Include logits in output
202
+ output_logit_scale_bias: Include the logit scale bias in the output
203
+ Returns:
204
+
205
+ """
206
+ output = {}
207
+ if intermediates_only:
208
+ # intermediates only disables final feature normalization, and include logits
209
+ normalize = False
210
+ output_logits = False
211
+ if output_logits:
212
+ assert False, 'FIXME, needs implementing'
213
+
214
+ if image is not None:
215
+ image_output = self.visual.forward_intermediates(
216
+ image,
217
+ indices=image_indices,
218
+ stop_early=stop_early,
219
+ normalize_intermediates=normalize_intermediates,
220
+ intermediates_only=intermediates_only,
221
+ output_fmt=image_output_fmt,
222
+ output_extra_tokens=image_output_extra_tokens,
223
+ )
224
+ if normalize and "image_features" in image_output:
225
+ image_output["image_features"] = F.normalize(image_output["image_features"], dim=-1)
226
+ output.update(image_output)
227
+
228
+ if text is not None:
229
+ text_output = self.text.forward_intermediates(
230
+ text,
231
+ indices=text_indices,
232
+ stop_early=stop_early,
233
+ normalize_intermediates=normalize_intermediates,
234
+ intermediates_only=intermediates_only,
235
+ output_fmt=text_output_fmt,
236
+ output_extra_tokens=text_output_extra_tokens,
237
+ )
238
+ if normalize and "text_features" in text_output:
239
+ text_output["text_features"] = F.normalize(text_output["text_features"], dim=-1)
240
+ output.update(text_output)
241
+
242
+ # FIXME text decoder
243
+ logit_scale_exp = self.logit_scale.exp() if output_logits or output_logit_scale_bias else None
244
+ if output_logit_scale_bias:
245
+ output["logit_scale"] = logit_scale_exp
246
+ if self.logit_bias is not None:
247
+ output['logit_bias'] = self.logit_bias
248
+
249
+ return output
250
+
251
+ def forward(
252
+ self,
253
+ image,
254
+ text: Optional[torch.Tensor] = None,
255
+ image_latent: Optional[torch.Tensor] = None,
256
+ image_embs: Optional[torch.Tensor] = None,
257
+ output_labels: bool = True,
258
+ ):
259
+ if image_latent is None or image_embs is None:
260
+ image_latent, image_embs = self._encode_image(image)
261
+
262
+ if text is None:
263
+ return {"image_features": image_latent, "image_embs": image_embs}
264
+
265
+ text_latent, token_embs = self._encode_text(text)
266
+
267
+ # FIXME this isn't an ideal solution, would like to improve -RW
268
+ labels: Optional[torch.Tensor] = text[:, 1:] if output_labels else None
269
+ if output_labels:
270
+ # align text_embs and thus logits with labels for teacher-forcing caption loss
271
+ token_embs = token_embs[:, :-1]
272
+
273
+ logits = self.text_decoder(image_embs, token_embs)
274
+ out_dict = {
275
+ "image_features": image_latent,
276
+ "text_features": text_latent,
277
+ "logits": logits,
278
+ "logit_scale": self.logit_scale.exp()
279
+ }
280
+ if labels is not None:
281
+ out_dict["labels"] = labels
282
+ if self.logit_bias is not None:
283
+ out_dict["logit_bias"] = self.logit_bias
284
+ return out_dict
285
+
286
+ def generate(
287
+ self,
288
+ image,
289
+ text=None,
290
+ seq_len=30,
291
+ max_seq_len=77,
292
+ temperature=1.,
293
+ generation_type="beam_search",
294
+ top_p=0.1, # keep tokens in the 1 - top_p quantile
295
+ top_k=1, # keeps the top_k most probable tokens
296
+ pad_token_id=None,
297
+ eos_token_id=None,
298
+ sot_token_id=None,
299
+ num_beams=6,
300
+ num_beam_groups=3,
301
+ min_seq_len=5,
302
+ stopping_criteria=None,
303
+ repetition_penalty=1.0,
304
+ fixed_output_length=False # if True output.shape == (batch_size, seq_len)
305
+ ):
306
+ # taking many ideas and components from HuggingFace GenerationMixin
307
+ # https://huggingface.co/docs/transformers/main/en/main_classes/text_generation
308
+ assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`."
309
+ assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len"
310
+ device = image.device
311
+
312
+ with torch.no_grad():
313
+ sot_token_id = _token_to_tensor(49406 if sot_token_id is None else sot_token_id, device=device)
314
+ eos_token_id = _token_to_tensor(49407 if eos_token_id is None else eos_token_id, device=device)
315
+ pad_token_id = self.pad_id if pad_token_id is None else pad_token_id
316
+ logit_processor = LogitsProcessorList(
317
+ [
318
+ MinLengthLogitsProcessor(min_seq_len, eos_token_id),
319
+ RepetitionPenaltyLogitsProcessor(repetition_penalty),
320
+ ]
321
+ )
322
+
323
+ if stopping_criteria is None:
324
+ stopping_criteria = [MaxLengthCriteria(max_length=seq_len)]
325
+ stopping_criteria = StoppingCriteriaList(stopping_criteria)
326
+
327
+ if generation_type == "beam_search":
328
+ output = self._generate_beamsearch(
329
+ image_inputs=image,
330
+ pad_token_id=pad_token_id,
331
+ eos_token_id=eos_token_id,
332
+ sot_token_id=sot_token_id,
333
+ num_beams=num_beams,
334
+ num_beam_groups=num_beam_groups,
335
+ min_seq_len=min_seq_len,
336
+ stopping_criteria=stopping_criteria,
337
+ logit_processor=logit_processor,
338
+ )
339
+ if fixed_output_length and output.shape[1] < seq_len:
340
+ pad_len = seq_len - output.shape[1]
341
+ return torch.cat((
342
+ output,
343
+ torch.ones(output.shape[0], pad_len, device=device, dtype=output.dtype) * pad_token_id
344
+ ),
345
+ dim=1
346
+ )
347
+ return output
348
+
349
+ elif generation_type == "top_p":
350
+ logit_warper = GENERATION_TYPES[generation_type](top_p)
351
+ elif generation_type == "top_k":
352
+ logit_warper = GENERATION_TYPES[generation_type](top_k)
353
+ else:
354
+ raise ValueError(
355
+ f"generation_type has to be one of "
356
+ f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}."
357
+ )
358
+
359
+ image_latent, image_embs = self._encode_image(image)
360
+
361
+ if text is None:
362
+ text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id
363
+
364
+ was_training = self.training
365
+ num_dims = len(text.shape)
366
+
367
+ if num_dims == 1:
368
+ text = text[None, :]
369
+
370
+ self.eval()
371
+ out = text
372
+
373
+ while True:
374
+ x = out[:, -max_seq_len:]
375
+ cur_len = x.shape[1]
376
+ logits = self(
377
+ image,
378
+ x,
379
+ image_latent=image_latent,
380
+ image_embs=image_embs,
381
+ output_labels=False,
382
+ )["logits"][:, -1]
383
+ mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id)
384
+ sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id
385
+
386
+ if mask.all():
387
+ if not fixed_output_length:
388
+ break
389
+ else:
390
+ logits = logits[~mask, :]
391
+ filtered_logits = logit_processor(x[~mask, :], logits)
392
+ filtered_logits = logit_warper(x[~mask, :], filtered_logits)
393
+ probs = F.softmax(filtered_logits / temperature, dim=-1)
394
+
395
+ if (cur_len + 1 == seq_len):
396
+ sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id
397
+ else:
398
+ sample[~mask, :] = torch.multinomial(probs, 1)
399
+
400
+ out = torch.cat((out, sample), dim=-1)
401
+
402
+ cur_len += 1
403
+
404
+ if all(stopping_criteria(out, None)):
405
+ break
406
+
407
+ if num_dims == 1:
408
+ out = out.squeeze(0)
409
+
410
+ self.train(was_training)
411
+ return out
412
+
413
+ def _generate_beamsearch(
414
+ self,
415
+ image_inputs,
416
+ pad_token_id=None,
417
+ eos_token_id=None,
418
+ sot_token_id=None,
419
+ num_beams=6,
420
+ num_beam_groups=3,
421
+ min_seq_len=5,
422
+ stopping_criteria=None,
423
+ logit_processor=None,
424
+ logit_warper=None,
425
+ ):
426
+ device = image_inputs.device
427
+ batch_size = image_inputs.shape[0]
428
+ image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0)
429
+ image_latent, image_embs = self._encode_image(image_inputs)
430
+
431
+ input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long)
432
+ input_ids = input_ids * sot_token_id
433
+ beam_scorer = BeamSearchScorer(
434
+ batch_size=batch_size,
435
+ num_beams=num_beams,
436
+ device=device,
437
+ num_beam_groups=num_beam_groups,
438
+ )
439
+ # instantiate logits processors
440
+ logits_processor = (
441
+ LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)])
442
+ if logit_processor is None
443
+ else logit_processor
444
+ )
445
+
446
+ num_beams = beam_scorer.num_beams
447
+ num_beam_groups = beam_scorer.num_beam_groups
448
+ num_sub_beams = num_beams // num_beam_groups
449
+ batch_size = len(beam_scorer._beam_hyps) // num_beam_groups
450
+ batch_beam_size, cur_len = input_ids.shape
451
+ beam_indices = None
452
+
453
+ if num_beams * batch_size != batch_beam_size:
454
+ raise ValueError(
455
+ f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
456
+ )
457
+
458
+ beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
459
+ # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
460
+ # the same group don't produce same tokens everytime.
461
+ beam_scores[:, ::num_sub_beams] = 0
462
+ beam_scores = beam_scores.view((batch_size * num_beams,))
463
+
464
+ while True:
465
+
466
+ # predicted tokens in cur_len step
467
+ current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
468
+
469
+ # indices which will form the beams in the next time step
470
+ reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
471
+
472
+ # do one decoder step on all beams of all sentences in batch
473
+ model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs)
474
+ outputs = self(
475
+ model_inputs['images'],
476
+ model_inputs['text'],
477
+ image_latent=image_latent,
478
+ image_embs=image_embs,
479
+ output_labels=False,
480
+ )
481
+
482
+ for beam_group_idx in range(num_beam_groups):
483
+ group_start_idx = beam_group_idx * num_sub_beams
484
+ group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
485
+ group_size = group_end_idx - group_start_idx
486
+
487
+ # indices of beams of current group among all sentences in batch
488
+ batch_group_indices = []
489
+
490
+ for batch_idx in range(batch_size):
491
+ batch_group_indices.extend(
492
+ [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
493
+ )
494
+ group_input_ids = input_ids[batch_group_indices]
495
+
496
+ # select outputs of beams of currentg group only
497
+ next_token_logits = outputs['logits'][batch_group_indices, -1, :]
498
+ vocab_size = next_token_logits.shape[-1]
499
+
500
+ next_token_scores_processed = logits_processor(
501
+ group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx
502
+ )
503
+ next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
504
+ next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
505
+
506
+ # reshape for beam search
507
+ next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
508
+
509
+ next_token_scores, next_tokens = torch.topk(
510
+ next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
511
+ )
512
+
513
+ next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
514
+ next_tokens = next_tokens % vocab_size
515
+
516
+ # stateless
517
+ process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
518
+ beam_outputs = beam_scorer.process(
519
+ group_input_ids,
520
+ next_token_scores,
521
+ next_tokens,
522
+ next_indices,
523
+ pad_token_id=pad_token_id,
524
+ eos_token_id=eos_token_id,
525
+ beam_indices=process_beam_indices,
526
+ group_index=beam_group_idx,
527
+ )
528
+ beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
529
+ beam_next_tokens = beam_outputs["next_beam_tokens"]
530
+ beam_idx = beam_outputs["next_beam_indices"]
531
+
532
+ input_ids[batch_group_indices] = group_input_ids[beam_idx]
533
+ group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
534
+ current_tokens[batch_group_indices] = group_input_ids[:, -1]
535
+
536
+ # (beam_idx // group_size) -> batch_idx
537
+ # (beam_idx % group_size) -> offset of idx inside the group
538
+ reordering_indices[batch_group_indices] = (
539
+ num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size)
540
+ )
541
+
542
+ input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
543
+
544
+ # increase cur_len
545
+ cur_len = cur_len + 1
546
+ if beam_scorer.is_done or all(stopping_criteria(input_ids, None)):
547
+ break
548
+
549
+ final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
550
+ sequence_outputs = beam_scorer.finalize(
551
+ input_ids,
552
+ beam_scores,
553
+ next_tokens,
554
+ next_indices,
555
+ pad_token_id=pad_token_id,
556
+ eos_token_id=eos_token_id,
557
+ max_length=stopping_criteria.max_length,
558
+ beam_indices=final_beam_indices,
559
+ )
560
+ return sequence_outputs['sequences']
561
+
562
+
563
+ def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs):
564
+ if past:
565
+ input_ids = input_ids[:, -1].unsqueeze(-1)
566
+
567
+ attention_mask = kwargs.get("attention_mask", None)
568
+ position_ids = kwargs.get("position_ids", None)
569
+
570
+ if attention_mask is not None and position_ids is None:
571
+ # create position_ids on the fly for batch generation
572
+ position_ids = attention_mask.long().cumsum(-1) - 1
573
+ position_ids.masked_fill_(attention_mask == 0, 1)
574
+ else:
575
+ position_ids = None
576
+ return {
577
+ "text": input_ids,
578
+ "images": image_inputs,
579
+ "past_key_values": past,
580
+ "position_ids": position_ids,
581
+ "attention_mask": attention_mask,
582
+ }
src/open_clip/constants.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
2
+ OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
3
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
4
+ IMAGENET_STD = (0.229, 0.224, 0.225)
5
+ INCEPTION_MEAN = (0.5, 0.5, 0.5)
6
+ INCEPTION_STD = (0.5, 0.5, 0.5)
7
+
8
+ # Default name for a weights file hosted on the Huggingface Hub.
9
+ HF_WEIGHTS_NAME = "open_clip_pytorch_model.bin" # default pytorch pkl
10
+ HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version
11
+ HF_CONFIG_NAME = 'open_clip_config.json'
src/open_clip/convert.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Conversion functions for 3rd part state-dicts and non-torch native checkpoint formats.
2
+ """
3
+ from typing import Union
4
+
5
+ import torch
6
+ import numpy as np
7
+
8
+ from .model import CLIP, CustomTextCLIP
9
+ from .transformer import TextTransformer, Transformer
10
+
11
+
12
+ @torch.no_grad()
13
+ def load_big_vision_weights(model: CustomTextCLIP, checkpoint_path: str):
14
+ """ Load weights from .npz checkpoints for official Google big_vision image-text models
15
+
16
+ Currently, the SigLIP source models are supported and a CustomTextCLIP destination model
17
+ w/ timm image encoder.
18
+ """
19
+ from timm.layers import resample_patch_embed, resample_abs_pos_embed
20
+
21
+ def _n2p(w, t=True, idx=None):
22
+ if idx is not None:
23
+ w = w[idx]
24
+ if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
25
+ w = w.flatten()
26
+ if t:
27
+ if w.ndim == 4:
28
+ w = w.transpose([3, 2, 0, 1])
29
+ elif w.ndim == 3:
30
+ w = w.transpose([2, 0, 1])
31
+ elif w.ndim == 2:
32
+ w = w.transpose([1, 0])
33
+ return torch.from_numpy(w)
34
+
35
+ w = np.load(checkpoint_path)
36
+ interpolation = 'bilinear'
37
+ antialias = False
38
+
39
+ def _convert_timm_img(module, prefix):
40
+ embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
41
+ if embed_conv_w.shape[-2:] != module.patch_embed.proj.weight.shape[-2:]:
42
+ embed_conv_w = resample_patch_embed(
43
+ embed_conv_w,
44
+ module.patch_embed.proj.weight.shape[-2:],
45
+ interpolation=interpolation,
46
+ antialias=antialias,
47
+ verbose=True,
48
+ )
49
+ module.patch_embed.proj.weight.copy_(embed_conv_w)
50
+ module.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
51
+
52
+ if module.cls_token is not None:
53
+ module.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
54
+
55
+ pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False)
56
+ if pos_embed_w.shape != module.pos_embed.shape:
57
+ assert False, f'{pos_embed_w.shape}, {module.pos_embed.shape}'
58
+ num_prefix_tokens = 0 if getattr(module, 'no_embed_class', False) else getattr(module, 'num_prefix_tokens', 1)
59
+ pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights
60
+ pos_embed_w,
61
+ new_size=module.patch_embed.grid_size,
62
+ num_prefix_tokens=num_prefix_tokens,
63
+ interpolation=interpolation,
64
+ antialias=antialias,
65
+ verbose=True,
66
+ )
67
+ module.pos_embed.copy_(pos_embed_w)
68
+
69
+ mha_sub, b_sub, ln1_sub = (0, 0, 1)
70
+ for i, block in enumerate(module.blocks.children()):
71
+ if f'{prefix}Transformer/encoderblock/LayerNorm_0/scale' in w:
72
+ block_prefix = f'{prefix}Transformer/encoderblock/'
73
+ idx = i
74
+ else:
75
+ block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
76
+ idx = None
77
+ mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/'
78
+ block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'], idx=idx))
79
+ block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'], idx=idx))
80
+ block.attn.qkv.weight.copy_(torch.cat([
81
+ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False, idx=idx).flatten(1).T for n in ('query', 'key', 'value')]))
82
+ block.attn.qkv.bias.copy_(torch.cat([
83
+ _n2p(w[f'{mha_prefix}{n}/bias'], t=False, idx=idx).reshape(-1) for n in ('query', 'key', 'value')]))
84
+ block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel'], idx=idx).flatten(1))
85
+ block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'], idx=idx))
86
+ block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'], idx=idx))
87
+ block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'], idx=idx))
88
+ for r in range(2):
89
+ getattr(block.mlp, f'fc{r + 1}').weight.copy_(
90
+ _n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'], idx=idx))
91
+ getattr(block.mlp, f'fc{r + 1}').bias.copy_(
92
+ _n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'], idx=idx))
93
+
94
+ module.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
95
+ module.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
96
+
97
+ if module.attn_pool is not None:
98
+ block_prefix = f'{prefix}MAPHead_0/'
99
+ mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/'
100
+ module.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False))
101
+ module.attn_pool.q.weight.copy_(_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T)
102
+ module.attn_pool.q.bias.copy_(_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1))
103
+ module.attn_pool.kv.weight.copy_(torch.cat([
104
+ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('key', 'value')]))
105
+ module.attn_pool.kv.bias.copy_(torch.cat([
106
+ _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('key', 'value')]))
107
+ module.attn_pool.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
108
+ module.attn_pool.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
109
+ module.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
110
+ module.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
111
+ for r in range(2):
112
+ getattr(module.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel']))
113
+ getattr(module.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias']))
114
+
115
+ def _convert_openclip_transformer(module: Transformer, prefix):
116
+ for i, block in enumerate(module.resblocks.children()):
117
+ if f'{prefix}encoderblock/LayerNorm_0/scale' in w:
118
+ block_prefix = f'{prefix}encoderblock/'
119
+ idx = i
120
+ else:
121
+ block_prefix = f'{prefix}encoderblock_{i}/'
122
+ idx = None
123
+ mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/'
124
+ block.ln_1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'], idx=idx))
125
+ block.ln_1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'], idx=idx))
126
+ block.attn.in_proj_weight.copy_(torch.cat([
127
+ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False, idx=idx).flatten(1).T for n in ('query', 'key', 'value')]))
128
+ block.attn.in_proj_bias.copy_(torch.cat([
129
+ _n2p(w[f'{mha_prefix}{n}/bias'], t=False, idx=idx).reshape(-1) for n in ('query', 'key', 'value')]))
130
+ block.attn.out_proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel'], idx=idx).flatten(1))
131
+ block.attn.out_proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'], idx=idx))
132
+ block.ln_2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/scale'], idx=idx))
133
+ block.ln_2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/bias'], idx=idx))
134
+ block.mlp.c_fc.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/kernel'], idx=idx))
135
+ block.mlp.c_fc.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/bias'], idx=idx))
136
+ block.mlp.c_proj.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/kernel'], idx=idx))
137
+ block.mlp.c_proj.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/bias'], idx=idx))
138
+
139
+ def _convert_openclip_txt(module: TextTransformer, prefix):
140
+ module.token_embedding.weight.copy_(_n2p(w[f'{prefix}Embed_0/embedding'], t=False))
141
+ pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False).squeeze(0)
142
+ module.positional_embedding.copy_(pos_embed_w)
143
+ _convert_openclip_transformer(module.transformer, prefix=prefix + 'Encoder_0/')
144
+ module.ln_final.weight.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/scale']))
145
+ module.ln_final.bias.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/bias']))
146
+ if module.text_projection is not None:
147
+ module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
148
+ module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias']))
149
+
150
+ root_prefix = 'params/' if 'params/b' in w else ''
151
+ _convert_timm_img(model.visual.trunk, f'{root_prefix}img/')
152
+ _convert_openclip_txt(model.text, f'{root_prefix}txt/')
153
+ model.logit_bias.copy_(_n2p(w[f'{root_prefix}b'])[0])
154
+ model.logit_scale.copy_(_n2p(w[f'{root_prefix}t'])[0])
155
+
156
+
157
+ @torch.no_grad()
158
+ def convert_mobile_clip_state_dict(model: CustomTextCLIP, state_dict, fastvit = True):
159
+
160
+ def _convert_timm_img(state_dict):
161
+ if fastvit:
162
+ from timm.models.fastvit import checkpoint_filter_fn
163
+ else:
164
+ from timm.models.vision_transformer_hybrid import checkpoint_filter_fn
165
+ timm_state_dict = checkpoint_filter_fn(state_dict, model.visual.trunk)
166
+ timm_state_dict = {'visual.trunk.' + k: v for k, v in timm_state_dict.items()}
167
+ return timm_state_dict
168
+
169
+ def _convert_openclip_txt(state_dict, prefix='text_encoder.'):
170
+ text_dict = {}
171
+ for k, v in state_dict.items():
172
+ if not k.startswith(prefix):
173
+ continue
174
+ k = k.replace(prefix, '')
175
+ k = k.replace('projection_layer', 'text_projection')
176
+ k = k.replace('embedding_layer', 'token_embedding')
177
+ if k.startswith('positional_embedding.pos_embed.pos_embed'):
178
+ k = k.replace('positional_embedding.pos_embed.pos_embed', 'positional_embedding')
179
+ v = v.squeeze()
180
+ k = k.replace('final_layer_norm', 'ln_final')
181
+ k = k.replace('pre_norm_mha.0', 'ln_1')
182
+ k = k.replace('pre_norm_mha.1', 'attn')
183
+ k = k.replace('pre_norm_ffn.0', 'ln_2')
184
+ k = k.replace('pre_norm_ffn.1', 'mlp.c_fc')
185
+ k = k.replace('pre_norm_ffn.4', 'mlp.c_proj')
186
+ k = k.replace('qkv_proj.weight', 'in_proj_weight')
187
+ k = k.replace('qkv_proj.bias', 'in_proj_bias')
188
+ k = k.replace('transformer.', 'transformer.resblocks.')
189
+ text_dict['text.' + k] = v
190
+ return text_dict
191
+
192
+ image_dict = _convert_timm_img(state_dict)
193
+ text_dict = _convert_openclip_txt(state_dict)
194
+ out_dict = {**image_dict, **text_dict}
195
+ out_dict['logit_scale'] = state_dict['logit_scale']
196
+ return out_dict
197
+
198
+
199
+ def convert_state_dict(model: Union[CustomTextCLIP, CLIP], state_dict):
200
+ if 'image_encoder.model.patch_embed.0.rbr_conv.0.conv.weight' in state_dict:
201
+ # Apple MobileCLIP s1 & s2 state_dicts (s0 and b not currently supported)
202
+ state_dict = convert_mobile_clip_state_dict(model, state_dict)
203
+ if 'image_encoder.model.patch_emb.0.block.conv.weight' in state_dict:
204
+ # convert b model
205
+ state_dict = convert_mobile_clip_state_dict(model, state_dict, fastvit=False)
206
+ return state_dict
src/open_clip/factory.py ADDED
@@ -0,0 +1,586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import re
5
+ import warnings
6
+ from copy import deepcopy
7
+ from dataclasses import asdict
8
+ from pathlib import Path
9
+ from typing import Any, Dict, Optional, Tuple, Union
10
+
11
+ import torch
12
+
13
+ from .convert import convert_state_dict
14
+ from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
15
+ resize_pos_embed, get_cast_dtype, resize_text_pos_embed, set_model_preprocess_cfg
16
+ from .coca_model import CoCa
17
+ from .loss import ClipLoss, DistillClipLoss, CoCaLoss, SigLipLoss
18
+ from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\
19
+ list_pretrained_tags_by_model, download_pretrained_from_hf
20
+ from .transform import image_transform_v2, AugmentationCfg, PreprocessCfg, merge_preprocess_dict, merge_preprocess_kwargs
21
+ from .tokenizer import HFTokenizer, SimpleTokenizer, SigLipTokenizer, DEFAULT_CONTEXT_LENGTH
22
+
23
+ HF_HUB_PREFIX = 'hf-hub:'
24
+ _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
25
+ _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
26
+
27
+
28
+ def _natural_key(string_):
29
+ return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
30
+
31
+
32
+ def _rescan_model_configs():
33
+ global _MODEL_CONFIGS
34
+
35
+ config_ext = ('.json',)
36
+ config_files = []
37
+ for config_path in _MODEL_CONFIG_PATHS:
38
+ if config_path.is_file() and config_path.suffix in config_ext:
39
+ config_files.append(config_path)
40
+ elif config_path.is_dir():
41
+ for ext in config_ext:
42
+ config_files.extend(config_path.glob(f'*{ext}'))
43
+
44
+ for cf in config_files:
45
+ with open(cf, 'r') as f:
46
+ model_cfg = json.load(f)
47
+ if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
48
+ _MODEL_CONFIGS[cf.stem] = model_cfg
49
+
50
+ _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
51
+
52
+
53
+ _rescan_model_configs() # initial populate of model config registry
54
+
55
+
56
+ def list_models():
57
+ """ enumerate available model architectures based on config files """
58
+ return list(_MODEL_CONFIGS.keys())
59
+
60
+
61
+ def add_model_config(path):
62
+ """ add model config path or file and update registry """
63
+ if not isinstance(path, Path):
64
+ path = Path(path)
65
+ _MODEL_CONFIG_PATHS.append(path)
66
+ _rescan_model_configs()
67
+
68
+
69
+ def get_model_config(model_name):
70
+ """ Fetch model config from builtin (local library) configs.
71
+ """
72
+ if model_name in _MODEL_CONFIGS:
73
+ return deepcopy(_MODEL_CONFIGS[model_name])
74
+ else:
75
+ return None
76
+
77
+
78
+ def _get_hf_config(
79
+ model_id: str,
80
+ cache_dir: Optional[str] = None,
81
+ ):
82
+ """ Fetch model config from HuggingFace Hub.
83
+ """
84
+ config_path = download_pretrained_from_hf(
85
+ model_id,
86
+ filename='open_clip_config.json',
87
+ cache_dir=cache_dir,
88
+ )
89
+ with open(config_path, 'r', encoding='utf-8') as f:
90
+ config = json.load(f)
91
+ return config
92
+
93
+
94
+ def get_tokenizer(
95
+ model_name: str = '',
96
+ context_length: Optional[int] = None,
97
+ cache_dir: Optional[str] = None,
98
+ **kwargs,
99
+ ):
100
+ if model_name.startswith(HF_HUB_PREFIX):
101
+ model_name = model_name[len(HF_HUB_PREFIX):]
102
+ try:
103
+ config = _get_hf_config(model_name, cache_dir=cache_dir)['model_cfg']
104
+ except Exception:
105
+ tokenizer = HFTokenizer(
106
+ model_name,
107
+ context_length=context_length or DEFAULT_CONTEXT_LENGTH,
108
+ cache_dir=cache_dir,
109
+ **kwargs,
110
+ )
111
+ return tokenizer
112
+ else:
113
+ config = get_model_config(model_name)
114
+ assert config is not None, f"No valid model config found for {model_name}."
115
+
116
+ text_config = config.get('text_cfg', {})
117
+ if 'tokenizer_kwargs' in text_config:
118
+ tokenizer_kwargs = dict(text_config['tokenizer_kwargs'], **kwargs)
119
+ else:
120
+ tokenizer_kwargs = kwargs
121
+
122
+ if context_length is None:
123
+ context_length = text_config.get('context_length', DEFAULT_CONTEXT_LENGTH)
124
+
125
+ model_name = model_name.lower()
126
+ if text_config.get('hf_tokenizer_name', ''):
127
+ tokenizer = HFTokenizer(
128
+ text_config['hf_tokenizer_name'],
129
+ context_length=context_length,
130
+ cache_dir=cache_dir,
131
+ **tokenizer_kwargs,
132
+ )
133
+ elif 'siglip' in model_name:
134
+ tn = 'gemma' if 'siglip2' in model_name else 'mc4' if 'i18n' in model_name else 'c4-en'
135
+ tokenizer = SigLipTokenizer(
136
+ tn,
137
+ context_length=context_length,
138
+ # **tokenizer_kwargs,
139
+ )
140
+ else:
141
+ tokenizer = SimpleTokenizer(
142
+ context_length=context_length,
143
+ **tokenizer_kwargs,
144
+ )
145
+
146
+ return tokenizer
147
+
148
+
149
+ def load_state_dict(
150
+ checkpoint_path: str,
151
+ device='cpu',
152
+ weights_only=True,
153
+ ):
154
+ # Check if safetensors or not and load weights accordingly
155
+ if str(checkpoint_path).endswith(".safetensors"):
156
+ from safetensors.torch import load_file
157
+ checkpoint = load_file(checkpoint_path, device=device)
158
+ else:
159
+ try:
160
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=weights_only)
161
+ except TypeError:
162
+ checkpoint = torch.load(checkpoint_path, map_location=device)
163
+
164
+ if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
165
+ state_dict = checkpoint['state_dict']
166
+ elif isinstance(checkpoint, torch.jit.ScriptModule):
167
+ state_dict = checkpoint.state_dict()
168
+ for key in ["input_resolution", "context_length", "vocab_size"]:
169
+ state_dict.pop(key, None)
170
+ else:
171
+ state_dict = checkpoint
172
+ if next(iter(state_dict.items()))[0].startswith('module'):
173
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
174
+ return state_dict
175
+
176
+
177
+ def load_checkpoint(
178
+ model: Union[CLIP, CustomTextCLIP],
179
+ checkpoint_path: str,
180
+ strict: bool = True,
181
+ weights_only: bool = True,
182
+ device='cpu',
183
+ ):
184
+ if Path(checkpoint_path).suffix in ('.npz', '.npy'):
185
+ # Separate path loading numpy big_vision (SigLIP) weights
186
+ from open_clip.convert import load_big_vision_weights
187
+ load_big_vision_weights(model, checkpoint_path)
188
+ return {}
189
+
190
+ state_dict = load_state_dict(checkpoint_path, device=device, weights_only=weights_only)
191
+
192
+ # Detect & convert 3rd party state_dicts -> open_clip
193
+ state_dict = convert_state_dict(model, state_dict)
194
+
195
+ # Detect old format and make compatible with new format
196
+ if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
197
+ state_dict = convert_to_custom_text_state_dict(state_dict)
198
+
199
+ # correct if logit_scale differs in being scaler vs 1d param
200
+ if 'logit_scale' in state_dict and model.logit_scale.ndim != state_dict['logit_scale'].ndim:
201
+ state_dict['logit_scale'] = state_dict['logit_scale'].reshape(model.logit_scale.shape)
202
+
203
+ # correct if logit_bias differs in being scaler vs 1d param
204
+ if 'logit_bias' in state_dict and model.logit_bias.ndim != state_dict['logit_bias'].ndim:
205
+ state_dict['logit_bias'] = state_dict['logit_bias'].reshape(model.logit_bias.shape)
206
+
207
+ # If loading a non-SigLIP model for SigLIP training. See https://github.com/mlfoundations/open_clip/issues/712
208
+ if 'logit_bias' not in state_dict and model.logit_bias is not None:
209
+ state_dict["logit_bias"] = torch.zeros_like(state_dict["logit_scale"])
210
+
211
+ # Certain text transformers no longer expect position_ids after transformers==4.31
212
+ position_id_key = 'text.transformer.embeddings.position_ids'
213
+ if position_id_key in state_dict and not hasattr(model, position_id_key):
214
+ del state_dict[position_id_key]
215
+
216
+ resize_pos_embed(state_dict, model)
217
+ resize_text_pos_embed(state_dict, model)
218
+
219
+ # Finally, load the massaged state_dict into model
220
+ incompatible_keys = model.load_state_dict(state_dict, strict=strict)
221
+ return incompatible_keys
222
+
223
+
224
+ def create_model(
225
+ model_name: str,
226
+ pretrained: Optional[str] = None,
227
+ precision: str = 'fp32',
228
+ device: Union[str, torch.device] = 'cpu',
229
+ jit: bool = False,
230
+ force_quick_gelu: bool = False,
231
+ force_custom_text: bool = False,
232
+ force_patch_dropout: Optional[float] = None,
233
+ force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
234
+ force_preprocess_cfg: Optional[Dict[str, Any]] = None,
235
+ pretrained_image: bool = False,
236
+ pretrained_hf: bool = True,
237
+ cache_dir: Optional[str] = None,
238
+ output_dict: Optional[bool] = None,
239
+ require_pretrained: bool = False,
240
+ load_weights_only: bool = True,
241
+ **model_kwargs,
242
+ ):
243
+ """Creates and configures a contrastive vision-language model.
244
+
245
+ Args:
246
+ model_name: Name of the model architecture to create. Can be a local model name
247
+ or a Hugging Face model ID prefixed with 'hf-hub:'.
248
+ pretrained: Tag/path for pretrained model weights. Can be:
249
+ - A pretrained tag name (e.g., 'openai')
250
+ - A path to local weights
251
+ - None to initialize with random weights
252
+ precision: Model precision/AMP configuration. Options:
253
+ - 'fp32': 32-bit floating point
254
+ - 'fp16'/'bf16': Mixed precision with FP32 for certain layers
255
+ - 'pure_fp16'/'pure_bf16': Pure 16-bit precision
256
+ device: Device to load the model on ('cpu', 'cuda', or torch.device object)
257
+ jit: If True, JIT compile the model
258
+ force_quick_gelu: Force use of QuickGELU activation
259
+ force_custom_text: Force use of custom text encoder
260
+ force_patch_dropout: Override default patch dropout value
261
+ force_image_size: Override default image size for vision encoder
262
+ force_preprocess_cfg: Override default preprocessing configuration
263
+ pretrained_image: Load pretrained weights for timm vision models
264
+ pretrained_hf: Load pretrained weights for HF text models when not loading CLIP weights
265
+ cache_dir: Override default cache directory for downloaded model files
266
+ output_dict: If True and model supports it, return dictionary of features
267
+ require_pretrained: Raise error if pretrained weights cannot be loaded
268
+ load_weights_only: Only deserialize model weights and unpickling torch checkpoints (for safety)
269
+ **model_kwargs: Additional keyword arguments passed to model constructor
270
+
271
+ Returns:
272
+ Created and configured model instance
273
+
274
+ Raises:
275
+ RuntimeError: If model config is not found or required pretrained weights
276
+ cannot be loaded
277
+
278
+ Examples:
279
+ # Create basic CLIP model
280
+ model = create_model('ViT-B/32')
281
+
282
+ # Create CLIP model with mixed precision on GPU
283
+ model = create_model('ViT-B/32', precision='fp16', device='cuda')
284
+
285
+ # Load pretrained OpenAI weights
286
+ model = create_model('ViT-B/32', pretrained='openai')
287
+
288
+ # Load Hugging Face model
289
+ model = create_model('hf-hub:organization/model-name')
290
+ """
291
+
292
+ force_preprocess_cfg = force_preprocess_cfg or {}
293
+ preprocess_cfg = asdict(PreprocessCfg())
294
+ has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
295
+ if has_hf_hub_prefix:
296
+ model_id = model_name[len(HF_HUB_PREFIX):]
297
+ checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
298
+ config = _get_hf_config(model_id, cache_dir=cache_dir)
299
+ preprocess_cfg = merge_preprocess_dict(preprocess_cfg, config['preprocess_cfg'])
300
+ model_cfg = config['model_cfg']
301
+ pretrained_hf = False # override, no need to load original HF text weights
302
+ else:
303
+ model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
304
+ checkpoint_path = None
305
+ model_cfg = None
306
+
307
+ if isinstance(device, str):
308
+ device = torch.device(device)
309
+
310
+ model_cfg = model_cfg or get_model_config(model_name)
311
+ if model_cfg is not None:
312
+ logging.info(f'Loaded {model_name} model config.')
313
+ else:
314
+ logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
315
+ raise RuntimeError(f'Model config for {model_name} not found.')
316
+
317
+ if force_quick_gelu:
318
+ # override for use of QuickGELU on non-OpenAI transformer models
319
+ model_cfg["quick_gelu"] = True
320
+
321
+ if force_patch_dropout is not None:
322
+ # override the default patch dropout value
323
+ model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
324
+
325
+ if force_image_size is not None:
326
+ # override model config's image size
327
+ model_cfg["vision_cfg"]["image_size"] = force_image_size
328
+
329
+ is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {})
330
+ if pretrained_image:
331
+ if is_timm_model:
332
+ # pretrained weight loading for timm models set via vision_cfg
333
+ model_cfg['vision_cfg']['timm_model_pretrained'] = True
334
+ else:
335
+ assert False, 'pretrained image towers currently only supported for timm models'
336
+
337
+ # cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes
338
+ cast_dtype = get_cast_dtype(precision)
339
+ is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
340
+ if is_hf_model:
341
+ # load pretrained weights for HF text model IFF no CLIP weights being loaded
342
+ model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf and not pretrained
343
+ custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model
344
+
345
+ model_cfg = dict(model_cfg, **model_kwargs) # merge cfg dict w/ kwargs (kwargs overrides cfg)
346
+ if custom_text:
347
+ if "multimodal_cfg" in model_cfg:
348
+ model = CoCa(**model_cfg, cast_dtype=cast_dtype)
349
+ else:
350
+ model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
351
+ else:
352
+ model = CLIP(**model_cfg, cast_dtype=cast_dtype)
353
+
354
+ if precision in ("fp16", "bf16"):
355
+ dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
356
+ # manual mixed precision that matches original OpenAI behaviour
357
+ if is_timm_model:
358
+ # FIXME this is a bit janky, create timm based model in low-precision and
359
+ # then cast only LayerNormFp32 instances back to float32 so they don't break.
360
+ # Why? The convert_weights_to_lp fn only works with native models.
361
+ model.to(device=device, dtype=dtype)
362
+ from .transformer import LayerNormFp32
363
+
364
+ def _convert_ln(m):
365
+ if isinstance(m, LayerNormFp32):
366
+ m.weight.data = m.weight.data.to(torch.float32)
367
+ m.bias.data = m.bias.data.to(torch.float32)
368
+ model.apply(_convert_ln)
369
+ else:
370
+ model.to(device=device)
371
+ convert_weights_to_lp(model, dtype=dtype)
372
+ elif precision in ("pure_fp16", "pure_bf16"):
373
+ dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
374
+ model.to(device=device, dtype=dtype)
375
+ else:
376
+ model.to(device=device)
377
+
378
+ pretrained_loaded = False
379
+ if pretrained:
380
+ checkpoint_path = ''
381
+ pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
382
+ if pretrained_cfg:
383
+ checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
384
+ preprocess_cfg = merge_preprocess_dict(preprocess_cfg, pretrained_cfg)
385
+ pretrained_quick_gelu = pretrained_cfg.get('quick_gelu', False)
386
+ model_quick_gelu = model_cfg.get('quick_gelu', False)
387
+ if pretrained_quick_gelu and not model_quick_gelu:
388
+ warnings.warn(
389
+ f'These pretrained weights were trained with QuickGELU activation but the model config does '
390
+ f'not have that enabled. Consider using a model config with a "-quickgelu" suffix or enable with a flag.')
391
+ elif not pretrained_quick_gelu and model_quick_gelu:
392
+ warnings.warn(
393
+ f'The pretrained weights were not trained with QuickGELU but this activation is enabled in the '
394
+ f'model config, consider using a model config without QuickGELU or disable override flags.')
395
+ elif os.path.exists(pretrained):
396
+ checkpoint_path = pretrained
397
+
398
+ if checkpoint_path:
399
+ logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
400
+ load_checkpoint(model, checkpoint_path, weights_only=load_weights_only)
401
+ else:
402
+ error_str = (
403
+ f'Pretrained weights ({pretrained}) not found for model {model_name}.'
404
+ f' Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
405
+ logging.warning(error_str)
406
+ raise RuntimeError(error_str)
407
+ pretrained_loaded = True
408
+ elif has_hf_hub_prefix:
409
+ logging.info(f'Loading pretrained {model_name} weights ({checkpoint_path}).')
410
+ load_checkpoint(model, checkpoint_path, weights_only=load_weights_only)
411
+ pretrained_loaded = True
412
+
413
+ if require_pretrained and not pretrained_loaded:
414
+ # callers of create_model_from_pretrained always expect pretrained weights
415
+ raise RuntimeError(
416
+ f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.')
417
+
418
+ if output_dict and hasattr(model, "output_dict"):
419
+ model.output_dict = True
420
+
421
+ if jit:
422
+ model = torch.jit.script(model)
423
+
424
+ # set image preprocessing configuration in model attributes for convenience
425
+ if getattr(model.visual, 'image_size', None) is not None:
426
+ # use image_size set on model creation (via config or force_image_size arg)
427
+ force_preprocess_cfg['size'] = model.visual.image_size
428
+ set_model_preprocess_cfg(model, merge_preprocess_dict(preprocess_cfg, force_preprocess_cfg))
429
+
430
+ return model
431
+
432
+
433
+ def create_loss(args):
434
+ if args.distill:
435
+ return DistillClipLoss(
436
+ local_loss=args.local_loss,
437
+ gather_with_grad=args.gather_with_grad,
438
+ cache_labels=True,
439
+ rank=args.rank,
440
+ world_size=args.world_size,
441
+ use_horovod=args.horovod,
442
+ )
443
+ elif "coca" in args.model.lower():
444
+ return CoCaLoss(
445
+ caption_loss_weight=args.coca_caption_loss_weight,
446
+ clip_loss_weight=args.coca_contrastive_loss_weight,
447
+ local_loss=args.local_loss,
448
+ gather_with_grad=args.gather_with_grad,
449
+ cache_labels=True,
450
+ rank=args.rank,
451
+ world_size=args.world_size,
452
+ use_horovod=args.horovod,
453
+ )
454
+ elif args.siglip:
455
+ assert not args.horovod, "Horovod not currently supported for SigLip"
456
+ return SigLipLoss(
457
+ rank=args.rank,
458
+ world_size=args.world_size,
459
+ dist_impl=args.loss_dist_impl, # siglip has multiple distributed implementations to choose from
460
+ )
461
+
462
+ return ClipLoss(
463
+ local_loss=args.local_loss,
464
+ gather_with_grad=args.gather_with_grad,
465
+ cache_labels=True,
466
+ rank=args.rank,
467
+ world_size=args.world_size,
468
+ use_horovod=args.horovod,
469
+ )
470
+
471
+
472
+ def create_model_and_transforms(
473
+ model_name: str,
474
+ pretrained: Optional[str] = None,
475
+ precision: str = 'fp32',
476
+ device: Union[str, torch.device] = 'cpu',
477
+ jit: bool = False,
478
+ force_quick_gelu: bool = False,
479
+ force_custom_text: bool = False,
480
+ force_patch_dropout: Optional[float] = None,
481
+ force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
482
+ image_mean: Optional[Tuple[float, ...]] = None,
483
+ image_std: Optional[Tuple[float, ...]] = None,
484
+ image_interpolation: Optional[str] = None,
485
+ image_resize_mode: Optional[str] = None, # only effective for inference
486
+ aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
487
+ pretrained_image: bool = False,
488
+ pretrained_hf: bool = True,
489
+ cache_dir: Optional[str] = None,
490
+ output_dict: Optional[bool] = None,
491
+ load_weights_only: bool = True,
492
+ **model_kwargs,
493
+ ):
494
+ force_preprocess_cfg = merge_preprocess_kwargs(
495
+ {},
496
+ mean=image_mean,
497
+ std=image_std,
498
+ interpolation=image_interpolation,
499
+ resize_mode=image_resize_mode,
500
+ )
501
+
502
+ model = create_model(
503
+ model_name,
504
+ pretrained,
505
+ precision=precision,
506
+ device=device,
507
+ jit=jit,
508
+ force_quick_gelu=force_quick_gelu,
509
+ force_custom_text=force_custom_text,
510
+ force_patch_dropout=force_patch_dropout,
511
+ force_image_size=force_image_size,
512
+ force_preprocess_cfg=force_preprocess_cfg,
513
+ pretrained_image=pretrained_image,
514
+ pretrained_hf=pretrained_hf,
515
+ cache_dir=cache_dir,
516
+ output_dict=output_dict,
517
+ load_weights_only=load_weights_only,
518
+ **model_kwargs,
519
+ )
520
+
521
+ pp_cfg = PreprocessCfg(**model.visual.preprocess_cfg)
522
+
523
+ preprocess_train = image_transform_v2(
524
+ pp_cfg,
525
+ is_train=True,
526
+ aug_cfg=aug_cfg,
527
+ )
528
+ preprocess_val = image_transform_v2(
529
+ pp_cfg,
530
+ is_train=False,
531
+ )
532
+
533
+ return model, preprocess_train, preprocess_val
534
+
535
+
536
+ def create_model_from_pretrained(
537
+ model_name: str,
538
+ pretrained: Optional[str] = None,
539
+ precision: str = 'fp32',
540
+ device: Union[str, torch.device] = 'cpu',
541
+ jit: bool = False,
542
+ force_quick_gelu: bool = False,
543
+ force_custom_text: bool = False,
544
+ force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
545
+ image_mean: Optional[Tuple[float, ...]] = None,
546
+ image_std: Optional[Tuple[float, ...]] = None,
547
+ image_interpolation: Optional[str] = None,
548
+ image_resize_mode: Optional[str] = None, # only effective for inference
549
+ return_transform: bool = True,
550
+ cache_dir: Optional[str] = None,
551
+ load_weights_only: bool = True,
552
+ **model_kwargs,
553
+ ):
554
+ force_preprocess_cfg = merge_preprocess_kwargs(
555
+ {},
556
+ mean=image_mean,
557
+ std=image_std,
558
+ interpolation=image_interpolation,
559
+ resize_mode=image_resize_mode,
560
+ )
561
+
562
+ model = create_model(
563
+ model_name,
564
+ pretrained,
565
+ precision=precision,
566
+ device=device,
567
+ jit=jit,
568
+ force_quick_gelu=force_quick_gelu,
569
+ force_custom_text=force_custom_text,
570
+ force_image_size=force_image_size,
571
+ force_preprocess_cfg=force_preprocess_cfg,
572
+ cache_dir=cache_dir,
573
+ require_pretrained=True,
574
+ load_weights_only=load_weights_only,
575
+ **model_kwargs,
576
+ )
577
+
578
+ if not return_transform:
579
+ return model
580
+
581
+ preprocess = image_transform_v2(
582
+ PreprocessCfg(**model.visual.preprocess_cfg),
583
+ is_train=False,
584
+ )
585
+
586
+ return model, preprocess
src/open_clip/hf_configs.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HF architecture dict:
2
+ arch_dict = {
3
+ # https://huggingface.co/docs/transformers/model_doc/roberta#roberta
4
+ "roberta": {
5
+ "config_names": {
6
+ "context_length": "max_position_embeddings",
7
+ "vocab_size": "vocab_size",
8
+ "width": "hidden_size",
9
+ "heads": "num_attention_heads",
10
+ "layers": "num_hidden_layers",
11
+ "layer_attr": "layer",
12
+ "token_embeddings_attr": "embeddings"
13
+ },
14
+ "pooler": "mean_pooler",
15
+ },
16
+ # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
17
+ "xlm-roberta": {
18
+ "config_names": {
19
+ "context_length": "max_position_embeddings",
20
+ "vocab_size": "vocab_size",
21
+ "width": "hidden_size",
22
+ "heads": "num_attention_heads",
23
+ "layers": "num_hidden_layers",
24
+ "layer_attr": "layer",
25
+ "token_embeddings_attr": "embeddings"
26
+ },
27
+ "pooler": "mean_pooler",
28
+ },
29
+ # https://huggingface.co/docs/transformers/model_doc/mt5#mt5
30
+ "mt5": {
31
+ "config_names": {
32
+ # unlimited seqlen
33
+ # https://github.com/google-research/text-to-text-transfer-transformer/issues/273
34
+ # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
35
+ "context_length": "",
36
+ "vocab_size": "vocab_size",
37
+ "width": "d_model",
38
+ "heads": "num_heads",
39
+ "layers": "num_layers",
40
+ "layer_attr": "block",
41
+ "token_embeddings_attr": "embed_tokens"
42
+ },
43
+ "pooler": "mean_pooler",
44
+ },
45
+ # https://huggingface.co/docs/transformers/model_doc/bert
46
+ "bert": {
47
+ "config_names": {
48
+ "context_length": "max_position_embeddings",
49
+ "vocab_size": "vocab_size",
50
+ "width": "hidden_size",
51
+ "heads": "num_attention_heads",
52
+ "layers": "num_hidden_layers",
53
+ },
54
+ "pooler": "cls_pooler",
55
+ },
56
+ # https://huggingface.co/docs/transformers/model_doc/m2m_100
57
+ "m2m_100": {
58
+ "config_names": {
59
+ "context_length": "max_position_embeddings",
60
+ "vocab_size": "vocab_size",
61
+ "width": "d_model",
62
+ "heads": "encoder_attention_heads",
63
+ "layers": "encoder_layers",
64
+ },
65
+ "pooler": "cls_pooler",
66
+ },
67
+ }
src/open_clip/hf_model.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ huggingface model adapter
2
+
3
+ Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
4
+ """
5
+ import re
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch import TensorType
10
+
11
+ try:
12
+ import transformers
13
+ from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig
14
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
15
+ BaseModelOutputWithPoolingAndCrossAttentions
16
+ except ImportError as e:
17
+ transformers = None
18
+
19
+
20
+ class BaseModelOutput:
21
+ pass
22
+
23
+
24
+ class PretrainedConfig:
25
+ pass
26
+
27
+ from .hf_configs import arch_dict
28
+
29
+
30
+ # utils
31
+ def _camel2snake(s):
32
+ return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
33
+
34
+
35
+ # TODO: ?last - for gpt-like models
36
+ _POOLERS = {}
37
+
38
+
39
+ def register_pooler(cls):
40
+ """Decorator registering pooler class"""
41
+ _POOLERS[_camel2snake(cls.__name__)] = cls
42
+ return cls
43
+
44
+
45
+ @register_pooler
46
+ class MeanPooler(nn.Module):
47
+ """Mean pooling"""
48
+
49
+ def forward(self, x: BaseModelOutput, attention_mask: TensorType):
50
+ masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
51
+ return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
52
+
53
+
54
+ @register_pooler
55
+ class MaxPooler(nn.Module):
56
+ """Max pooling"""
57
+
58
+ def forward(self, x: BaseModelOutput, attention_mask: TensorType):
59
+ masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
60
+ return masked_output.max(1).values
61
+
62
+
63
+ @register_pooler
64
+ class ClsPooler(nn.Module):
65
+ """CLS token pooling"""
66
+
67
+ def __init__(self, use_pooler_output=True):
68
+ super().__init__()
69
+ self.cls_token_position = 0
70
+ self.use_pooler_output = use_pooler_output
71
+
72
+ def forward(self, x: BaseModelOutput, attention_mask: TensorType):
73
+ if (self.use_pooler_output and
74
+ isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and
75
+ (x.pooler_output is not None)
76
+ ):
77
+ return x.pooler_output
78
+
79
+ return x.last_hidden_state[:, self.cls_token_position, :]
80
+
81
+
82
+ @register_pooler
83
+ class ClsLastHiddenStatePooler(nn.Module):
84
+ """CLS token pooling
85
+ NOTE: this is equivalent to ClsPooler above with use_pooler_output=False
86
+ """
87
+
88
+ def __init__(self):
89
+ super().__init__()
90
+ self.cls_token_position = 0
91
+
92
+ def forward(self, x: BaseModelOutput, attention_mask: TensorType):
93
+ return x.last_hidden_state[:, self.cls_token_position, :]
94
+
95
+
96
+ class HFTextEncoder(nn.Module):
97
+ """HuggingFace model adapter"""
98
+ output_tokens: torch.jit.Final[bool]
99
+
100
+ def __init__(
101
+ self,
102
+ model_name_or_path: str,
103
+ output_dim: int,
104
+ config: PretrainedConfig = None,
105
+ pooler_type: str = None,
106
+ proj_type: str = None,
107
+ pretrained: bool = True,
108
+ output_tokens: bool = False,
109
+ ):
110
+ super().__init__()
111
+ self.output_tokens = output_tokens
112
+ self.output_dim = output_dim
113
+
114
+ # TODO: find better way to get this information
115
+ uses_transformer_pooler = (pooler_type == "cls_pooler")
116
+
117
+ if transformers is None:
118
+ raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
119
+ if config is None:
120
+ self.config = AutoConfig.from_pretrained(model_name_or_path)
121
+ create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (
122
+ AutoModel.from_config, self.config)
123
+ # TODO: do all model configs have this attribute? PretrainedConfig does so yes??
124
+ if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
125
+ self.transformer = create_func(model_args)
126
+ self.transformer = self.transformer.encoder
127
+ else:
128
+ self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
129
+ else:
130
+ self.config = config
131
+ self.transformer = AutoModel.from_config(config)
132
+ if pooler_type is None: # get default arch pooler
133
+ pooler_type = (arch_dict[self.config.model_type]["pooler"])
134
+
135
+ # FIXME downstream users of OpenCLIP models use these attr, need to verify valid across all models
136
+ self.vocab_size = getattr(self.config, 'vocab_size', 0)
137
+ self.context_length = getattr(self.config, 'max_position_embeddings', 0)
138
+
139
+ self.pooler = _POOLERS[pooler_type]()
140
+
141
+ d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
142
+ if (d_model == output_dim) and (proj_type is None): # do we always need a proj?
143
+ self.proj = nn.Identity()
144
+ elif proj_type == 'linear':
145
+ self.proj = nn.Linear(d_model, output_dim, bias=False)
146
+ elif proj_type == 'mlp':
147
+ hidden_size = (d_model + output_dim) // 2
148
+ self.proj = nn.Sequential(
149
+ nn.Linear(d_model, hidden_size, bias=False),
150
+ nn.GELU(),
151
+ nn.Linear(hidden_size, output_dim, bias=False),
152
+ )
153
+
154
+ def forward(self, x: TensorType):
155
+ attn_mask = (x != self.config.pad_token_id).long()
156
+ out = self.transformer(input_ids=x, attention_mask=attn_mask)
157
+ pooled_out = self.pooler(out, attn_mask)
158
+ projected = self.proj(pooled_out)
159
+
160
+ seq_len = out.last_hidden_state.shape[1]
161
+ tokens = (
162
+ out.last_hidden_state[:, torch.arange(seq_len) != self.pooler.cls_token_position, :]
163
+ if type(self.pooler) == ClsPooler
164
+ else out.last_hidden_state
165
+ )
166
+
167
+ if self.output_tokens:
168
+ return projected, tokens
169
+ return projected
170
+
171
+ def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
172
+ if not unlocked_layers: # full freezing
173
+ for n, p in self.transformer.named_parameters():
174
+ p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
175
+ return
176
+
177
+ encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
178
+ layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
179
+ print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
180
+ embeddings = getattr(
181
+ self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
182
+ modules = [embeddings, *layer_list][:-unlocked_layers]
183
+ # freeze layers
184
+ for module in modules:
185
+ for n, p in module.named_parameters():
186
+ p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
187
+
188
+ @torch.jit.ignore
189
+ def set_grad_checkpointing(self, enable=True):
190
+ self.transformer.gradient_checkpointing_enable()
191
+
192
+ def init_parameters(self):
193
+ pass
src/open_clip/loss.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+
7
+ try:
8
+ import torch.distributed.nn
9
+ from torch import distributed as dist
10
+
11
+ has_distributed = True
12
+ except ImportError:
13
+ has_distributed = False
14
+
15
+ try:
16
+ import horovod.torch as hvd
17
+ except ImportError:
18
+ hvd = None
19
+
20
+
21
+ def gather_features(
22
+ image_features,
23
+ text_features,
24
+ local_loss=False,
25
+ gather_with_grad=False,
26
+ rank=0,
27
+ world_size=1,
28
+ use_horovod=False
29
+ ):
30
+ assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
31
+ if use_horovod:
32
+ assert hvd is not None, 'Please install horovod'
33
+ if gather_with_grad:
34
+ all_image_features = hvd.allgather(image_features)
35
+ all_text_features = hvd.allgather(text_features)
36
+ else:
37
+ with torch.no_grad():
38
+ all_image_features = hvd.allgather(image_features)
39
+ all_text_features = hvd.allgather(text_features)
40
+ if not local_loss:
41
+ # ensure grads for local rank when all_* features don't have a gradient
42
+ gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
43
+ gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
44
+ gathered_image_features[rank] = image_features
45
+ gathered_text_features[rank] = text_features
46
+ all_image_features = torch.cat(gathered_image_features, dim=0)
47
+ all_text_features = torch.cat(gathered_text_features, dim=0)
48
+ else:
49
+ # We gather tensors from all gpus
50
+ if gather_with_grad:
51
+ all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
52
+ all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
53
+ else:
54
+ gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
55
+ gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
56
+ dist.all_gather(gathered_image_features, image_features)
57
+ dist.all_gather(gathered_text_features, text_features)
58
+ if not local_loss:
59
+ # ensure grads for local rank when all_* features don't have a gradient
60
+ gathered_image_features[rank] = image_features
61
+ gathered_text_features[rank] = text_features
62
+ all_image_features = torch.cat(gathered_image_features, dim=0)
63
+ all_text_features = torch.cat(gathered_text_features, dim=0)
64
+
65
+ return all_image_features, all_text_features
66
+
67
+
68
+ class ClipLoss(nn.Module):
69
+
70
+ def __init__(
71
+ self,
72
+ local_loss=False,
73
+ gather_with_grad=False,
74
+ cache_labels=False,
75
+ rank=0,
76
+ world_size=1,
77
+ use_horovod=False,
78
+ ):
79
+ super().__init__()
80
+ self.local_loss = local_loss
81
+ self.gather_with_grad = gather_with_grad
82
+ self.cache_labels = cache_labels
83
+ self.rank = rank
84
+ self.world_size = world_size
85
+ self.use_horovod = use_horovod
86
+
87
+ # cache state
88
+ self.prev_num_logits = 0
89
+ self.labels = {}
90
+
91
+ def get_ground_truth(self, device, num_logits) -> torch.Tensor:
92
+ # calculated ground-truth and cache if enabled
93
+ if self.prev_num_logits != num_logits or device not in self.labels:
94
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
95
+ if self.world_size > 1 and self.local_loss:
96
+ labels = labels + num_logits * self.rank
97
+ if self.cache_labels:
98
+ self.labels[device] = labels
99
+ self.prev_num_logits = num_logits
100
+ else:
101
+ labels = self.labels[device]
102
+ return labels
103
+
104
+ def get_logits(self, image_features, text_features, logit_scale):
105
+ if self.world_size > 1:
106
+ all_image_features, all_text_features = gather_features(
107
+ image_features,
108
+ text_features,
109
+ local_loss=self.local_loss,
110
+ gather_with_grad=self.gather_with_grad,
111
+ rank=self.rank,
112
+ world_size=self.world_size,
113
+ use_horovod=self.use_horovod,
114
+ )
115
+
116
+ if self.local_loss:
117
+ logits_per_image = logit_scale * image_features @ all_text_features.T
118
+ logits_per_text = logit_scale * text_features @ all_image_features.T
119
+ else:
120
+ logits_per_image = logit_scale * all_image_features @ all_text_features.T
121
+ logits_per_text = logits_per_image.T
122
+ else:
123
+ logits_per_image = logit_scale * image_features @ text_features.T
124
+ logits_per_text = logit_scale * text_features @ image_features.T
125
+
126
+ return logits_per_image, logits_per_text
127
+
128
+ def forward(self, image_features, text_features, logit_scale, output_dict=False):
129
+ device = image_features.device
130
+ logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)
131
+ labels = self.get_ground_truth(device, logits_per_image.shape[0])
132
+
133
+ total_loss = (
134
+ F.cross_entropy(logits_per_image, labels) +
135
+ F.cross_entropy(logits_per_text, labels)
136
+ ) / 2
137
+
138
+ return {"contrastive_loss": total_loss} if output_dict else total_loss
139
+
140
+
141
+ class CoCaLoss(ClipLoss):
142
+ def __init__(
143
+ self,
144
+ caption_loss_weight,
145
+ clip_loss_weight,
146
+ pad_id=0, # pad_token for open_clip custom tokenizer
147
+ local_loss=False,
148
+ gather_with_grad=False,
149
+ cache_labels=False,
150
+ rank=0,
151
+ world_size=1,
152
+ use_horovod=False,
153
+ ):
154
+ super().__init__(
155
+ local_loss=local_loss,
156
+ gather_with_grad=gather_with_grad,
157
+ cache_labels=cache_labels,
158
+ rank=rank,
159
+ world_size=world_size,
160
+ use_horovod=use_horovod
161
+ )
162
+
163
+ self.clip_loss_weight = clip_loss_weight
164
+ self.caption_loss_weight = caption_loss_weight
165
+ self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id)
166
+
167
+ def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False):
168
+ if self.clip_loss_weight:
169
+ clip_loss = super().forward(image_features, text_features, logit_scale)
170
+ clip_loss = self.clip_loss_weight * clip_loss
171
+ else:
172
+ clip_loss = torch.tensor(0, device=logits.device)
173
+
174
+ caption_loss = self.caption_loss(
175
+ logits.permute(0, 2, 1),
176
+ labels,
177
+ )
178
+ caption_loss = caption_loss * self.caption_loss_weight
179
+
180
+ if output_dict:
181
+ return {"contrastive_loss": clip_loss, "caption_loss": caption_loss}
182
+
183
+ return clip_loss, caption_loss
184
+
185
+
186
+ class DistillClipLoss(ClipLoss):
187
+
188
+ def dist_loss(self, teacher_logits, student_logits):
189
+ return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0)
190
+
191
+ def forward(
192
+ self,
193
+ image_features,
194
+ text_features,
195
+ logit_scale,
196
+ dist_image_features,
197
+ dist_text_features,
198
+ dist_logit_scale,
199
+ output_dict=False,
200
+ ):
201
+ logits_per_image, logits_per_text = \
202
+ self.get_logits(image_features, text_features, logit_scale)
203
+
204
+ dist_logits_per_image, dist_logits_per_text = \
205
+ self.get_logits(dist_image_features, dist_text_features, dist_logit_scale)
206
+
207
+ labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0])
208
+
209
+ contrastive_loss = (
210
+ F.cross_entropy(logits_per_image, labels) +
211
+ F.cross_entropy(logits_per_text, labels)
212
+ ) / 2
213
+
214
+ distill_loss = (
215
+ self.dist_loss(dist_logits_per_image, logits_per_image) +
216
+ self.dist_loss(dist_logits_per_text, logits_per_text)
217
+ ) / 2
218
+
219
+ if output_dict:
220
+ return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss}
221
+
222
+ return contrastive_loss, distill_loss
223
+
224
+
225
+ def neighbour_exchange(from_rank, to_rank, tensor, group=None):
226
+ tensor_recv = torch.zeros_like(tensor)
227
+ send_op = torch.distributed.P2POp(
228
+ torch.distributed.isend,
229
+ tensor,
230
+ to_rank,
231
+ group=group,
232
+ )
233
+ recv_op = torch.distributed.P2POp(
234
+ torch.distributed.irecv,
235
+ tensor_recv,
236
+ from_rank,
237
+ group=group,
238
+ )
239
+ reqs = torch.distributed.batch_isend_irecv([send_op, recv_op])
240
+ for req in reqs:
241
+ req.wait()
242
+ return tensor_recv
243
+
244
+
245
+ def neighbour_exchange_bidir(left_rank, right_rank, tensor_to_left, tensor_to_right, group=None):
246
+ tensor_from_left = torch.zeros_like(tensor_to_right)
247
+ tensor_from_right = torch.zeros_like(tensor_to_left)
248
+ send_op_left = torch.distributed.P2POp(
249
+ torch.distributed.isend,
250
+ tensor_to_left,
251
+ left_rank,
252
+ group=group,
253
+ )
254
+ send_op_right = torch.distributed.P2POp(
255
+ torch.distributed.isend,
256
+ tensor_to_right,
257
+ right_rank,
258
+ group=group,
259
+ )
260
+ recv_op_left = torch.distributed.P2POp(
261
+ torch.distributed.irecv,
262
+ tensor_from_left,
263
+ left_rank,
264
+ group=group,
265
+ )
266
+ recv_op_right = torch.distributed.P2POp(
267
+ torch.distributed.irecv,
268
+ tensor_from_right,
269
+ right_rank,
270
+ group=group,
271
+ )
272
+ reqs = torch.distributed.batch_isend_irecv([send_op_right, send_op_left, recv_op_right, recv_op_left])
273
+ for req in reqs:
274
+ req.wait()
275
+ return tensor_from_right, tensor_from_left
276
+
277
+
278
+ class NeighbourExchange(torch.autograd.Function):
279
+ @staticmethod
280
+ def forward(ctx, from_rank, to_rank, group, tensor):
281
+ ctx.group = group
282
+ ctx.from_rank = from_rank
283
+ ctx.to_rank = to_rank
284
+ return neighbour_exchange(from_rank, to_rank, tensor, group=group)
285
+
286
+ @staticmethod
287
+ def backward(ctx, grad_output):
288
+ return (None, None, None) + (NeighbourExchange.apply(ctx.to_rank, ctx.from_rank, ctx.group, grad_output),)
289
+
290
+
291
+ def neighbour_exchange_with_grad(from_rank, to_rank, tensor, group=None):
292
+ return NeighbourExchange.apply(from_rank, to_rank, group, tensor)
293
+
294
+
295
+ class NeighbourExchangeBidir(torch.autograd.Function):
296
+ @staticmethod
297
+ def forward(ctx, left_rank, right_rank, group, tensor_to_left, tensor_to_right):
298
+ ctx.group = group
299
+ ctx.left_rank = left_rank
300
+ ctx.right_rank = right_rank
301
+ return neighbour_exchange_bidir(left_rank, right_rank, tensor_to_left, tensor_to_right, group=group)
302
+
303
+ @staticmethod
304
+ def backward(ctx, *grad_outputs):
305
+ return (None, None, None) + \
306
+ NeighbourExchangeBidir.apply(ctx.right_rank, ctx.left_rank, ctx.group, *grad_outputs)
307
+
308
+
309
+ def neighbour_exchange_bidir_with_grad(left_rank, right_rank, tensor_to_left, tensor_to_right, group=None):
310
+ return NeighbourExchangeBidir.apply(left_rank, right_rank, group, tensor_to_left, tensor_to_right)
311
+
312
+
313
+ class SigLipLoss(nn.Module):
314
+ """ Sigmoid Loss for Language Image Pre-Training (SigLIP) - https://arxiv.org/abs/2303.15343
315
+
316
+ @article{zhai2023sigmoid,
317
+ title={Sigmoid loss for language image pre-training},
318
+ author={Zhai, Xiaohua and Mustafa, Basil and Kolesnikov, Alexander and Beyer, Lucas},
319
+ journal={arXiv preprint arXiv:2303.15343},
320
+ year={2023}
321
+ }
322
+ """
323
+ def __init__(
324
+ self,
325
+ cache_labels: bool = False,
326
+ rank: int = 0,
327
+ world_size: int = 1,
328
+ dist_impl: Optional[str] = None,
329
+ ):
330
+ super().__init__()
331
+ self.cache_labels = cache_labels
332
+ self.rank = rank
333
+ self.world_size = world_size
334
+ self.dist_impl = dist_impl or 'bidir' # default to bidir exchange for now, this will likely change
335
+ assert self.dist_impl in ('bidir', 'shift', 'reduce', 'gather')
336
+
337
+ # cache state FIXME cache not currently used, worthwhile?
338
+ self.prev_num_logits = 0
339
+ self.labels = {}
340
+
341
+ def get_ground_truth(self, device, dtype, num_logits, negative_only=False) -> torch.Tensor:
342
+ labels = -torch.ones((num_logits, num_logits), device=device, dtype=dtype)
343
+ if not negative_only:
344
+ labels = 2 * torch.eye(num_logits, device=device, dtype=dtype) + labels
345
+ return labels
346
+
347
+ def get_logits(self, image_features, text_features, logit_scale, logit_bias=None):
348
+ logits = logit_scale * image_features @ text_features.T
349
+ if logit_bias is not None:
350
+ logits += logit_bias
351
+ return logits
352
+
353
+ def _loss(self, image_features, text_features, logit_scale, logit_bias=None, negative_only=False):
354
+ logits = self.get_logits(image_features, text_features, logit_scale, logit_bias)
355
+ labels = self.get_ground_truth(
356
+ image_features.device,
357
+ image_features.dtype,
358
+ image_features.shape[0],
359
+ negative_only=negative_only,
360
+ )
361
+ loss = -F.logsigmoid(labels * logits).sum() / image_features.shape[0]
362
+ return loss
363
+
364
+ def forward(self, image_features, text_features, logit_scale, logit_bias, output_dict=False):
365
+ loss = self._loss(image_features, text_features, logit_scale, logit_bias)
366
+
367
+ if self.world_size > 1:
368
+ if self.dist_impl == 'bidir':
369
+ right_rank = (self.rank + 1) % self.world_size
370
+ left_rank = (self.rank - 1 + self.world_size) % self.world_size
371
+ text_features_to_right = text_features_to_left = text_features
372
+ num_bidir, remainder = divmod(self.world_size - 1, 2)
373
+ for i in range(num_bidir):
374
+ text_features_recv = neighbour_exchange_bidir_with_grad(
375
+ left_rank,
376
+ right_rank,
377
+ text_features_to_left,
378
+ text_features_to_right,
379
+ )
380
+ for f in text_features_recv:
381
+ loss += self._loss(
382
+ image_features,
383
+ f,
384
+ logit_scale,
385
+ logit_bias,
386
+ negative_only=True,
387
+ )
388
+ text_features_to_left, text_features_to_right = text_features_recv
389
+
390
+ if remainder:
391
+ text_features_recv = neighbour_exchange_with_grad(
392
+ left_rank,
393
+ right_rank,
394
+ text_features_to_right
395
+ )
396
+ loss += self._loss(
397
+ image_features,
398
+ text_features_recv,
399
+ logit_scale,
400
+ logit_bias,
401
+ negative_only=True,
402
+ )
403
+ elif self.dist_impl == "shift":
404
+ right_rank = (self.rank + 1) % self.world_size
405
+ left_rank = (self.rank - 1 + self.world_size) % self.world_size
406
+ text_features_to_right = text_features
407
+ for i in range(self.world_size - 1):
408
+ text_features_from_left = neighbour_exchange_with_grad(
409
+ left_rank,
410
+ right_rank,
411
+ text_features_to_right,
412
+ )
413
+ loss += self._loss(
414
+ image_features,
415
+ text_features_from_left,
416
+ logit_scale,
417
+ logit_bias,
418
+ negative_only=True,
419
+ )
420
+ text_features_to_right = text_features_from_left
421
+ elif self.dist_impl == "reduce":
422
+ for i in range(self.world_size):
423
+ text_from_other = torch.distributed.nn.all_reduce(
424
+ text_features * (self.rank == i),
425
+ torch.distributed.ReduceOp.SUM,
426
+ )
427
+ loss += float(i != self.rank) * self._loss(
428
+ image_features,
429
+ text_from_other,
430
+ logit_scale,
431
+ logit_bias,
432
+ negative_only=True,
433
+ )
434
+ elif self.dist_impl == "gather":
435
+ all_text = torch.distributed.nn.all_gather(text_features)
436
+ for i in range(self.world_size):
437
+ loss += float(i != self.rank) * self._loss(
438
+ image_features,
439
+ all_text[i],
440
+ logit_scale,
441
+ logit_bias,
442
+ negative_only=True,
443
+ )
444
+ else:
445
+ assert False
446
+
447
+ return {"contrastive_loss": loss} if output_dict else loss
src/open_clip/model.py ADDED
@@ -0,0 +1,919 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ CLIP Model
2
+
3
+ Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+ import copy
6
+ import logging
7
+ import math
8
+ from dataclasses import dataclass
9
+ from typing import Any, Dict, List, Optional, Tuple, Union
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from torch import nn
15
+ from torch.utils.checkpoint import checkpoint
16
+ from functools import partial
17
+
18
+ from .hf_model import HFTextEncoder
19
+ from .modified_resnet import ModifiedResNet
20
+ from .timm_model import TimmModel
21
+ from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer,\
22
+ text_global_pool
23
+ from .utils import to_2tuple
24
+
25
+
26
+ @dataclass
27
+ class CLIPVisionCfg:
28
+ layers: Union[Tuple[int, int, int, int], int] = 12
29
+ width: int = 768
30
+ head_width: int = 64
31
+ mlp_ratio: float = 4.0
32
+ patch_size: int = 16
33
+ image_size: Union[Tuple[int, int], int] = 224
34
+ in_chans: int = 3
35
+
36
+ ls_init_value: Optional[float] = None # layer scale initial value
37
+ patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
38
+ attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer (overrides pool_type)
39
+ attn_pooler_queries: int = 256 # n_queries for attentional pooler
40
+ attn_pooler_heads: int = 8 # n heads for attentional_pooling
41
+ no_ln_pre: bool = False # disable pre transformer LayerNorm
42
+ pos_embed_type: str = 'learnable'
43
+ final_ln_after_pool: bool = False # apply final LayerNorm after pooling
44
+ pool_type: str = 'tok'
45
+ output_tokens: bool = False
46
+ act_kwargs: Optional[dict] = None
47
+ norm_kwargs: Optional[dict] = None
48
+
49
+ timm_model_name: Optional[str] = None # a valid model name overrides layers, width, patch_size
50
+ timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
51
+ timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
52
+ timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
53
+ timm_proj_bias: bool = False # enable bias final projection
54
+ timm_drop: float = 0. # head dropout
55
+ timm_drop_path: Optional[float] = None # backbone stochastic depth
56
+
57
+
58
+ @dataclass
59
+ class CLIPTextCfg:
60
+ context_length: int = 77
61
+ vocab_size: int = 49408
62
+ hf_tokenizer_name: Optional[str] = None
63
+ tokenizer_kwargs: Optional[dict] = None
64
+
65
+ width: int = 512
66
+ heads: int = 8
67
+ layers: int = 12
68
+ mlp_ratio: float = 4.0
69
+ ls_init_value: Optional[float] = None # layer scale initial value
70
+ embed_cls: bool = False
71
+ pad_id: int = 0
72
+ no_causal_mask: bool = False # disable causal masking
73
+ final_ln_after_pool: bool = False # apply final LayerNorm after pooling
74
+ pool_type: str = 'argmax'
75
+ proj_bias: bool = False
76
+ proj_type: str = 'linear' # control final text projection, 'none' forces no projection
77
+ output_tokens: bool = False
78
+ act_kwargs: dict = None
79
+ norm_kwargs: dict = None
80
+
81
+ # HuggingFace specific text tower config
82
+ hf_model_name: Optional[str] = None
83
+ hf_model_pretrained: bool = True
84
+ hf_proj_type: str = 'mlp'
85
+ hf_pooler_type: str = 'mean_pooler' # attentional pooling for HF models
86
+
87
+
88
+ def get_cast_dtype(precision: str):
89
+ cast_dtype = None
90
+ if precision == 'bf16':
91
+ cast_dtype = torch.bfloat16
92
+ elif precision == 'fp16':
93
+ cast_dtype = torch.float16
94
+ return cast_dtype
95
+
96
+
97
+ def get_input_dtype(precision: str):
98
+ input_dtype = None
99
+ if precision in ('bf16', 'pure_bf16'):
100
+ input_dtype = torch.bfloat16
101
+ elif precision in ('fp16', 'pure_fp16'):
102
+ input_dtype = torch.float16
103
+ return input_dtype
104
+
105
+
106
+ def _build_vision_tower(
107
+ embed_dim: int,
108
+ vision_cfg: CLIPVisionCfg,
109
+ quick_gelu: bool = False,
110
+ cast_dtype: Optional[torch.dtype] = None
111
+ ):
112
+ if isinstance(vision_cfg, dict):
113
+ vision_cfg = CLIPVisionCfg(**vision_cfg)
114
+
115
+ # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
116
+ # memory efficient in recent PyTorch releases (>= 1.10).
117
+ # NOTE: timm models always use native GELU regardless of quick_gelu flag.
118
+ act_layer = QuickGELU if quick_gelu else nn.GELU
119
+
120
+ if vision_cfg.timm_model_name:
121
+ visual = TimmModel(
122
+ vision_cfg.timm_model_name,
123
+ pretrained=vision_cfg.timm_model_pretrained,
124
+ pool=vision_cfg.timm_pool,
125
+ proj=vision_cfg.timm_proj,
126
+ proj_bias=vision_cfg.timm_proj_bias,
127
+ drop=vision_cfg.timm_drop,
128
+ drop_path=vision_cfg.timm_drop_path,
129
+ patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None,
130
+ embed_dim=embed_dim,
131
+ image_size=vision_cfg.image_size,
132
+ )
133
+ elif isinstance(vision_cfg.layers, (tuple, list)):
134
+ vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
135
+ visual = ModifiedResNet(
136
+ layers=vision_cfg.layers,
137
+ output_dim=embed_dim,
138
+ heads=vision_heads,
139
+ image_size=vision_cfg.image_size,
140
+ width=vision_cfg.width,
141
+ )
142
+ else:
143
+ vision_heads = vision_cfg.width // vision_cfg.head_width
144
+ norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
145
+ if vision_cfg.norm_kwargs:
146
+ norm_layer = partial(norm_layer, **vision_cfg.norm_kwargs)
147
+ if vision_cfg.act_kwargs is not None:
148
+ act_layer = partial(act_layer, **vision_cfg.act_kwargs)
149
+
150
+ visual = VisionTransformer(
151
+ image_size=vision_cfg.image_size,
152
+ patch_size=vision_cfg.patch_size,
153
+ width=vision_cfg.width,
154
+ layers=vision_cfg.layers,
155
+ heads=vision_heads,
156
+ mlp_ratio=vision_cfg.mlp_ratio,
157
+ ls_init_value=vision_cfg.ls_init_value,
158
+ patch_dropout=vision_cfg.patch_dropout,
159
+ attentional_pool=vision_cfg.attentional_pool,
160
+ attn_pooler_queries=vision_cfg.attn_pooler_queries,
161
+ attn_pooler_heads=vision_cfg.attn_pooler_heads,
162
+ pos_embed_type=vision_cfg.pos_embed_type,
163
+ no_ln_pre=vision_cfg.no_ln_pre,
164
+ final_ln_after_pool=vision_cfg.final_ln_after_pool,
165
+ pool_type=vision_cfg.pool_type,
166
+ output_tokens=vision_cfg.output_tokens,
167
+ output_dim=embed_dim,
168
+ act_layer=act_layer,
169
+ norm_layer=norm_layer,
170
+ in_chans=vision_cfg.in_chans,
171
+ )
172
+
173
+ return visual
174
+
175
+
176
+ def _build_text_tower(
177
+ embed_dim: int,
178
+ text_cfg: CLIPTextCfg,
179
+ quick_gelu: bool = False,
180
+ cast_dtype: Optional[torch.dtype] = None,
181
+ ):
182
+ if isinstance(text_cfg, dict):
183
+ text_cfg = CLIPTextCfg(**text_cfg)
184
+
185
+ if text_cfg.hf_model_name:
186
+ text = HFTextEncoder(
187
+ text_cfg.hf_model_name,
188
+ output_dim=embed_dim,
189
+ proj_type=text_cfg.hf_proj_type,
190
+ pooler_type=text_cfg.hf_pooler_type,
191
+ pretrained=text_cfg.hf_model_pretrained,
192
+ output_tokens=text_cfg.output_tokens,
193
+ )
194
+ else:
195
+ act_layer = QuickGELU if quick_gelu else nn.GELU
196
+ norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
197
+ if text_cfg.norm_kwargs:
198
+ norm_layer = partial(norm_layer, **text_cfg.norm_kwargs)
199
+ if text_cfg.act_kwargs is not None:
200
+ act_layer = partial(act_layer, **text_cfg.act_kwargs)
201
+
202
+ text = TextTransformer(
203
+ context_length=text_cfg.context_length,
204
+ vocab_size=text_cfg.vocab_size,
205
+ width=text_cfg.width,
206
+ heads=text_cfg.heads,
207
+ layers=text_cfg.layers,
208
+ mlp_ratio=text_cfg.mlp_ratio,
209
+ ls_init_value=text_cfg.ls_init_value,
210
+ output_dim=embed_dim,
211
+ embed_cls=text_cfg.embed_cls,
212
+ no_causal_mask=text_cfg.no_causal_mask,
213
+ pad_id=text_cfg.pad_id,
214
+ pool_type=text_cfg.pool_type,
215
+ proj_type=text_cfg.proj_type,
216
+ proj_bias=text_cfg.proj_bias,
217
+ output_tokens=text_cfg.output_tokens,
218
+ act_layer=act_layer,
219
+ norm_layer=norm_layer,
220
+ )
221
+ return text
222
+
223
+
224
+
225
+ class TrunkNet(nn.Module):
226
+ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
227
+ super().__init__()
228
+ self.net = nn.Sequential(
229
+ nn.Linear(input_dim, hidden_dim),
230
+ LayerNorm(hidden_dim),
231
+ nn.GELU(),
232
+ nn.Linear(hidden_dim, hidden_dim),
233
+ LayerNorm(hidden_dim),
234
+ nn.GELU(),
235
+ nn.Linear(hidden_dim, output_dim)
236
+ )
237
+
238
+ def forward(self, x):
239
+
240
+ for i, layer in enumerate(self.net):
241
+ x = layer(x)
242
+
243
+ return x
244
+
245
+
246
+ class MultiTrunkNet(nn.Module):
247
+ def __init__(self, embed_dim: int):
248
+ super().__init__()
249
+ self.embed_dim = embed_dim
250
+
251
+ self.compound_trunk = TrunkNet(input_dim=159, hidden_dim=embed_dim, output_dim=embed_dim)
252
+ self.concentration_trunk = TrunkNet(input_dim=2, hidden_dim=embed_dim, output_dim=embed_dim)
253
+ self.time_trunk = TrunkNet(input_dim=1, hidden_dim=embed_dim, output_dim=embed_dim)
254
+
255
+ total_dim = embed_dim * 3
256
+ self.projection = nn.Linear(total_dim, embed_dim)
257
+
258
+ def forward(self, compound_embedding: torch.Tensor, concentration: torch.Tensor, time: torch.Tensor):
259
+
260
+ # Process each input through its own trunk
261
+ compound_features = self.compound_trunk(compound_embedding)
262
+
263
+ concentration_features = self.concentration_trunk(concentration)
264
+
265
+ time = time.unsqueeze(-1) if time.dim() == 1 else time
266
+ time_features = self.time_trunk(time)
267
+
268
+ # Concatenate all features
269
+ return compound_features, concentration_features, time_features
270
+
271
+
272
+ class CLIP(nn.Module):
273
+ output_dict: torch.jit.Final[bool]
274
+
275
+ def __init__(
276
+ self,
277
+ embed_dim: int,
278
+ vision_cfg: CLIPVisionCfg,
279
+ text_cfg: CLIPTextCfg,
280
+ quick_gelu: bool = False,
281
+ init_logit_scale: float = np.log(1 / 0.07),
282
+ init_logit_bias: Optional[float] = None,
283
+ nonscalar_logit_scale: bool = False,
284
+ cast_dtype: Optional[torch.dtype] = None,
285
+ output_dict: bool = False,
286
+ ):
287
+ super().__init__()
288
+ self.output_dict = output_dict
289
+
290
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
291
+
292
+ text = _build_text_tower(int(embed_dim/4), text_cfg, quick_gelu, cast_dtype)
293
+ self.transformer = text.transformer
294
+ self.context_length = text.context_length
295
+ self.vocab_size = text.vocab_size
296
+ self.token_embedding = text.token_embedding
297
+ self.positional_embedding = text.positional_embedding
298
+ self.ln_final = text.ln_final
299
+ self.text_projection = text.text_projection
300
+ self.text_pool_type = text.pool_type
301
+ self.register_buffer('attn_mask', text.attn_mask, persistent=False)
302
+
303
+ # Add multi-trunk net for additional inputs
304
+ self.multi_trunk = MultiTrunkNet(int(embed_dim/4))
305
+
306
+ # # Add projection layer for concatenated features
307
+ # self.fusion_projection = nn.Linear(embed_dim * 4, embed_dim)
308
+
309
+ lshape = [1] if nonscalar_logit_scale else []
310
+ self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale)
311
+ if init_logit_bias is not None:
312
+ self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias)
313
+ else:
314
+ self.logit_bias = None
315
+
316
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
317
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
318
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
319
+
320
+ @torch.jit.ignore
321
+ def set_grad_checkpointing(self, enable=True):
322
+ self.visual.set_grad_checkpointing(enable)
323
+ self.transformer.grad_checkpointing = enable
324
+
325
+ @torch.jit.ignore
326
+ def no_weight_decay(self):
327
+ # for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default
328
+ no_wd = {'positional_embedding'}
329
+ if hasattr(self.visual, 'no_weight_decay'):
330
+ for n in self.visual.no_weight_decay():
331
+ no_wd.add('visual.' + n)
332
+ return no_wd
333
+
334
+ def encode_image(self, image, normalize: bool = False):
335
+ features = self.visual(image)
336
+ return F.normalize(features, dim=-1) if normalize else features
337
+
338
+ def encode_text(self, text, normalize: bool = False, concentration: Optional[torch.Tensor] = None,
339
+ time: Optional[torch.Tensor] = None, compound_embedding: Optional[torch.Tensor] = None):
340
+ cast_dtype = self.transformer.get_cast_dtype()
341
+
342
+ x = self.token_embedding(text).to(cast_dtype)
343
+ x = x + self.positional_embedding.to(cast_dtype)
344
+ x = self.transformer(x, attn_mask=self.attn_mask)
345
+ x = self.ln_final(x)
346
+ x = text_global_pool(x, text, self.text_pool_type)
347
+
348
+ if self.text_projection is not None:
349
+ if isinstance(self.text_projection, nn.Linear):
350
+ x = self.text_projection(x)
351
+ else:
352
+ x = x @ self.text_projection
353
+
354
+ if compound_embedding is not None and concentration is not None and time is not None:
355
+ compound_features, concentration_features, time_features = self.multi_trunk(compound_embedding, concentration, time)
356
+ x = torch.cat([x, compound_features, concentration_features, time_features], dim=-1)
357
+
358
+ if normalize:
359
+ x = F.normalize(x, dim=-1)
360
+
361
+ return x
362
+
363
+ def get_logits(self, image, text, concentration: Optional[torch.Tensor] = None,
364
+ time: Optional[torch.Tensor] = None,
365
+ compound_embedding: Optional[torch.Tensor] = None):
366
+ image_features = self.encode_image(image, normalize=True)
367
+ text_features = self.encode_text(text, normalize=True,
368
+ concentration=concentration,
369
+ time=time,
370
+ compound_embedding=compound_embedding)
371
+ image_logits = self.logit_scale.exp() * image_features @ text_features.T
372
+ if self.logit_bias is not None:
373
+ image_logits += self.logit_bias
374
+ text_logits = image_logits.T
375
+ return image_logits, text_logits
376
+
377
+ def forward_intermediates(
378
+ self,
379
+ image: Optional[torch.Tensor] = None,
380
+ text: Optional[torch.Tensor] = None,
381
+ image_indices: Optional[Union[int, List[int]]] = None,
382
+ text_indices: Optional[Union[int, List[int]]] = None,
383
+ stop_early: bool = False,
384
+ normalize: bool = True,
385
+ normalize_intermediates: bool = False,
386
+ intermediates_only: bool = False,
387
+ image_output_fmt: str = 'NCHW',
388
+ image_output_extra_tokens: bool = False,
389
+ text_output_fmt: str = 'NLC',
390
+ text_output_extra_tokens: bool = False,
391
+ output_logits: bool = False,
392
+ output_logit_scale_bias: bool = False,
393
+ ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
394
+ """ Forward features that returns intermediates.
395
+
396
+ Args:
397
+ image: Input image tensor
398
+ text: Input text tensor
399
+ image_indices: For image tower, Take last n blocks if int, all if None, select matching indices if sequence
400
+ text_indices: Take last n blocks if int, all if None, select matching indices if sequence
401
+ stop_early: Stop iterating over blocks when last desired intermediate hit
402
+ normalize_intermediates: Apply final norm layer to all intermediates
403
+ normalize: L2 Normalize final features
404
+ intermediates_only: Only return intermediate features, do not return final features
405
+ image_output_fmt: Shape of intermediate image feature outputs
406
+ image_output_extra_tokens: Return both prefix and spatial intermediate tokens
407
+ text_output_fmt: Shape of intermediate text feature outputs (ignored for this model)
408
+ text_output_extra_tokens: Return both prefix and spatial intermediate tokens (ignored for this model)
409
+ output_logits: Include logits in output
410
+ output_logit_scale_bias: Include the logit scale bias in the output
411
+ Returns:
412
+
413
+ """
414
+ output = {}
415
+ if intermediates_only:
416
+ # intermediates only disables final feature normalization, and include logits
417
+ normalize = False
418
+ output_logits = False
419
+ if output_logits:
420
+ assert image is not None and text is not None, 'Both image and text inputs are required to compute logits'
421
+
422
+ if image is not None:
423
+ image_output = self.visual.forward_intermediates(
424
+ image,
425
+ indices=image_indices,
426
+ stop_early=stop_early,
427
+ normalize_intermediates=normalize_intermediates,
428
+ intermediates_only=intermediates_only,
429
+ output_fmt=image_output_fmt,
430
+ output_extra_tokens=image_output_extra_tokens,
431
+ )
432
+ if normalize and "image_features" in image_output:
433
+ image_output["image_features"] = F.normalize(image_output["image_features"], dim=-1)
434
+ output.update(image_output)
435
+
436
+ if text is not None:
437
+ cast_dtype = self.transformer.get_cast_dtype()
438
+ x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
439
+ x = x + self.positional_embedding.to(cast_dtype)
440
+ x, intermediates = self.transformer.forward_intermediates(
441
+ x,
442
+ attn_mask=self.attn_mask,
443
+ indices=text_indices
444
+ )
445
+ if normalize_intermediates:
446
+ intermediates = [self.ln_final(xi) for xi in intermediates]
447
+
448
+ # NOTE this model doesn't support cls embed in text transformer, no need for extra intermediate tokens
449
+ output["text_intermediates"] = intermediates
450
+
451
+ if not intermediates_only:
452
+ x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
453
+ x = text_global_pool(x, text, self.text_pool_type)
454
+ if self.text_projection is not None:
455
+ if isinstance(self.text_projection, nn.Linear):
456
+ x = self.text_projection(x)
457
+ else:
458
+ x = x @ self.text_projection
459
+ if normalize:
460
+ x = F.normalize(x, dim=-1)
461
+ output["text_features"] = x
462
+
463
+ logit_scale_exp = self.logit_scale.exp() if output_logits or output_logit_scale_bias else None
464
+
465
+ if output_logits:
466
+ image_logits = logit_scale_exp * output["image_features"] @ output["text_features"].T
467
+ if self.logit_bias is not None:
468
+ image_logits += self.logit_bias
469
+ text_logits = image_logits.T
470
+ output["image_logits"] = image_logits
471
+ output["text_logits"] = text_logits
472
+
473
+ if output_logit_scale_bias:
474
+ output["logit_scale"] = logit_scale_exp
475
+ if self.logit_bias is not None:
476
+ output['logit_bias'] = self.logit_bias
477
+
478
+ return output
479
+
480
+
481
+ def forward(
482
+ self,
483
+ image: Optional[torch.Tensor] = None,
484
+ text: Optional[torch.Tensor] = None,
485
+ concentration: Optional[torch.Tensor] = None,
486
+ time: Optional[torch.Tensor] = None,
487
+ compound_embedding: Optional[torch.Tensor] = None,
488
+ ):
489
+
490
+ image_features = self.encode_image(image, normalize=True) if image is not None else None
491
+ text_features = self.encode_text(text, normalize=True, concentration=concentration, time=time, compound_embedding=compound_embedding)
492
+ if self.output_dict:
493
+ out_dict = {
494
+ "image_features": image_features,
495
+ "text_features": text_features,
496
+ "logit_scale": self.logit_scale.exp()
497
+ }
498
+ if self.logit_bias is not None:
499
+ out_dict['logit_bias'] = self.logit_bias
500
+ return out_dict
501
+
502
+ if self.logit_bias is not None:
503
+ return image_features, text_features, self.logit_scale.exp(), self.logit_bias
504
+ return image_features, text_features, self.logit_scale.exp()
505
+
506
+
507
+ class CustomTextCLIP(nn.Module):
508
+ output_dict: torch.jit.Final[bool]
509
+
510
+ def __init__(
511
+ self,
512
+ embed_dim: int,
513
+ vision_cfg: CLIPVisionCfg,
514
+ text_cfg: CLIPTextCfg,
515
+ quick_gelu: bool = False,
516
+ init_logit_scale: float = np.log(1 / 0.07),
517
+ init_logit_bias: Optional[float] = None,
518
+ nonscalar_logit_scale: bool = False,
519
+ cast_dtype: Optional[torch.dtype] = None,
520
+ output_dict: bool = False,
521
+ ):
522
+ super().__init__()
523
+ self.output_dict = output_dict
524
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
525
+ self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
526
+ self.context_length = self.text.context_length
527
+ self.vocab_size = self.text.vocab_size
528
+
529
+ lshape = [1] if nonscalar_logit_scale else []
530
+ self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale)
531
+ if init_logit_bias is not None:
532
+ self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias)
533
+ else:
534
+ self.logit_bias = None
535
+
536
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
537
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
538
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
539
+
540
+ def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
541
+ self.text.lock(unlocked_layers, freeze_layer_norm)
542
+
543
+ @torch.jit.ignore
544
+ def set_grad_checkpointing(self, enable=True):
545
+ self.visual.set_grad_checkpointing(enable)
546
+ self.text.set_grad_checkpointing(enable)
547
+
548
+ @torch.jit.ignore
549
+ def no_weight_decay(self):
550
+ # for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default
551
+ no_wd = set()
552
+ if hasattr(self.visual, 'no_weight_decay'):
553
+ for n in self.visual.no_weight_decay():
554
+ no_wd.add('visual.' + n)
555
+ if hasattr(self.text, 'no_weight_decay'):
556
+ for n in self.visual.no_weight_decay():
557
+ no_wd.add('text.' + n)
558
+ return no_wd
559
+
560
+ def encode_image(self, image, normalize: bool = False):
561
+ features = self.visual(image)
562
+ return F.normalize(features, dim=-1) if normalize else features
563
+
564
+ def encode_text(self, text, normalize: bool = False):
565
+ features = self.text(text)
566
+ return F.normalize(features, dim=-1) if normalize else features
567
+
568
+ def get_logits(self, image, text):
569
+ image_features = self.encode_image(image, normalize=True)
570
+ text_features = self.encode_text(text, normalize=True)
571
+ image_logits = self.logit_scale.exp() * image_features @ text_features.T
572
+ if self.logit_bias is not None:
573
+ image_logits += self.logit_bias
574
+ text_logits = image_logits.T
575
+ return image_logits, text_logits
576
+
577
+ def forward_intermediates(
578
+ self,
579
+ image: Optional[torch.Tensor] = None,
580
+ text: Optional[torch.Tensor] = None,
581
+ image_indices: Optional[Union[int, List[int]]] = None,
582
+ text_indices: Optional[Union[int, List[int]]] = None,
583
+ stop_early: bool = False,
584
+ normalize: bool = True,
585
+ normalize_intermediates: bool = False,
586
+ intermediates_only: bool = False,
587
+ image_output_fmt: str = 'NCHW',
588
+ image_output_extra_tokens: bool = False,
589
+ text_output_fmt: str = 'NLC',
590
+ text_output_extra_tokens: bool = False,
591
+ output_logits: bool = False,
592
+ output_logit_scale_bias: bool = False,
593
+ ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
594
+ """ Forward features that returns intermediates.
595
+
596
+ Args:
597
+ image: Input image tensor
598
+ text: Input text tensor
599
+ image_indices: For image tower, Take last n blocks if int, all if None, select matching indices if sequence
600
+ text_indices: Take last n blocks if int, all if None, select matching indices if sequence
601
+ stop_early: Stop iterating over blocks when last desired intermediate hit
602
+ normalize: L2 Normalize final image and text features (if present)
603
+ normalize_intermediates: Apply final encoder norm layer to all intermediates (if possible)
604
+ intermediates_only: Only return intermediate features, do not return final features
605
+ image_output_fmt: Shape of intermediate image feature outputs
606
+ image_output_extra_tokens: Return both prefix and spatial intermediate tokens
607
+ text_output_fmt: Shape of intermediate text feature outputs
608
+ text_output_extra_tokens: Return both prefix and spatial intermediate tokens
609
+ output_logits: Include logits in output
610
+ output_logit_scale_bias: Include the logit scale bias in the output
611
+ Returns:
612
+
613
+ """
614
+ output = {}
615
+ if intermediates_only:
616
+ # intermediates only disables final feature normalization, and include logits
617
+ normalize = False
618
+ output_logits = False
619
+ if output_logits:
620
+ assert image is not None and text is not None, 'Both image and text inputs are required to compute logits'
621
+
622
+ if image is not None:
623
+ image_output = self.visual.forward_intermediates(
624
+ image,
625
+ indices=image_indices,
626
+ stop_early=stop_early,
627
+ normalize_intermediates=normalize_intermediates,
628
+ intermediates_only=intermediates_only,
629
+ output_fmt=image_output_fmt,
630
+ output_extra_tokens=image_output_extra_tokens,
631
+ )
632
+ if normalize and "image_features" in image_output:
633
+ image_output["image_features"] = F.normalize(image_output["image_features"], dim=-1)
634
+ output.update(image_output)
635
+
636
+ if text is not None:
637
+ text_output = self.text.forward_intermediates(
638
+ text,
639
+ indices=text_indices,
640
+ stop_early=stop_early,
641
+ normalize_intermediates=normalize_intermediates,
642
+ intermediates_only=intermediates_only,
643
+ output_fmt=text_output_fmt,
644
+ output_extra_tokens=text_output_extra_tokens,
645
+ )
646
+ if normalize and "text_features" in text_output:
647
+ text_output["text_features"] = F.normalize(text_output["text_features"], dim=-1)
648
+ output.update(text_output)
649
+
650
+ logit_scale_exp = self.logit_scale.exp() if output_logits or output_logit_scale_bias else None
651
+
652
+ if output_logits:
653
+ image_logits = logit_scale_exp * output["image_features"] @ output["text_features"].T
654
+ if self.logit_bias is not None:
655
+ image_logits += self.logit_bias
656
+ text_logits = image_logits.T
657
+ output["image_logits"] = image_logits
658
+ output["text_logits"] = text_logits
659
+
660
+ if output_logit_scale_bias:
661
+ output["logit_scale"] = logit_scale_exp
662
+ if self.logit_bias is not None:
663
+ output['logit_bias'] = self.logit_bias
664
+
665
+ return output
666
+
667
+ def forward(
668
+ self,
669
+ image: Optional[torch.Tensor] = None,
670
+ text: Optional[torch.Tensor] = None,
671
+ ):
672
+ image_features = self.encode_image(image, normalize=True) if image is not None else None
673
+ text_features = self.encode_text(text, normalize=True) if text is not None else None
674
+
675
+ if self.output_dict:
676
+ out_dict = {
677
+ "image_features": image_features,
678
+ "text_features": text_features,
679
+ "logit_scale": self.logit_scale.exp()
680
+ }
681
+ if self.logit_bias is not None:
682
+ out_dict['logit_bias'] = self.logit_bias
683
+ return out_dict
684
+
685
+ if self.logit_bias is not None:
686
+ return image_features, text_features, self.logit_scale.exp(), self.logit_bias
687
+ return image_features, text_features, self.logit_scale.exp()
688
+
689
+
690
+ def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
691
+ """Convert applicable model parameters to low-precision (bf16 or fp16)"""
692
+
693
+ def _convert_weights(l):
694
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
695
+ l.weight.data = l.weight.data.to(dtype)
696
+ if l.bias is not None:
697
+ l.bias.data = l.bias.data.to(dtype)
698
+
699
+ if isinstance(l, (nn.MultiheadAttention, Attention)):
700
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
701
+ tensor = getattr(l, attr)
702
+ if tensor is not None:
703
+ tensor.data = tensor.data.to(dtype)
704
+
705
+ if isinstance(l, (CLIP, TextTransformer)):
706
+ # convert text nn.Parameter projections
707
+ attr = getattr(l, "text_projection", None)
708
+ if attr is not None:
709
+ attr.data = attr.data.to(dtype)
710
+
711
+ if isinstance(l, VisionTransformer):
712
+ # convert vision nn.Parameter projections
713
+ attr = getattr(l, "proj", None)
714
+ if attr is not None:
715
+ attr.data = attr.data.to(dtype)
716
+
717
+ model.apply(_convert_weights)
718
+
719
+
720
+ convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
721
+
722
+
723
+ # used to maintain checkpoint compatibility
724
+ def convert_to_custom_text_state_dict(state_dict: dict):
725
+ if 'text_projection' in state_dict:
726
+ # old format state_dict, move text tower -> .text
727
+ new_state_dict = {}
728
+ for k, v in state_dict.items():
729
+ if any(k.startswith(p) for p in (
730
+ 'text_projection',
731
+ 'positional_embedding',
732
+ 'token_embedding',
733
+ 'transformer',
734
+ 'ln_final',
735
+ )):
736
+ k = 'text.' + k
737
+ new_state_dict[k] = v
738
+ return new_state_dict
739
+ return state_dict
740
+
741
+
742
+ def build_model_from_openai_state_dict(
743
+ state_dict: dict,
744
+ quick_gelu=True,
745
+ cast_dtype=torch.float16,
746
+ ):
747
+ vit = "visual.proj" in state_dict
748
+
749
+ if vit:
750
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
751
+ vision_layers = len(
752
+ [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
753
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
754
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
755
+ image_size = vision_patch_size * grid_size
756
+ else:
757
+ counts: list = [
758
+ len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
759
+ vision_layers = tuple(counts)
760
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
761
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
762
+ vision_patch_size = None
763
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
764
+ image_size = output_width * 32
765
+
766
+ embed_dim = state_dict["text_projection"].shape[1]
767
+ context_length = state_dict["positional_embedding"].shape[0]
768
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
769
+ transformer_width = state_dict["ln_final.weight"].shape[0]
770
+ transformer_heads = transformer_width // 64
771
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
772
+
773
+ vision_cfg = CLIPVisionCfg(
774
+ layers=vision_layers,
775
+ width=vision_width,
776
+ patch_size=vision_patch_size,
777
+ image_size=image_size,
778
+ )
779
+ text_cfg = CLIPTextCfg(
780
+ context_length=context_length,
781
+ vocab_size=vocab_size,
782
+ width=transformer_width,
783
+ heads=transformer_heads,
784
+ layers=transformer_layers,
785
+ )
786
+ model = CLIP(
787
+ embed_dim,
788
+ vision_cfg=vision_cfg,
789
+ text_cfg=text_cfg,
790
+ quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
791
+ cast_dtype=cast_dtype,
792
+ )
793
+
794
+ for key in ["input_resolution", "context_length", "vocab_size"]:
795
+ state_dict.pop(key, None)
796
+ convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
797
+ model.load_state_dict(state_dict)
798
+ return model.eval()
799
+
800
+
801
+ def trace_model(model, batch_size=256, device=torch.device('cpu')):
802
+ model.eval()
803
+ image_size = model.visual.image_size
804
+ example_images = torch.ones((batch_size, 2, image_size, image_size), device=device)
805
+ example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
806
+ example_concentration = torch.rand((batch_size, 2), device=device)
807
+ example_time = torch.rand((batch_size, 1), device=device)
808
+ example_compound_embedding = torch.rand((batch_size, 159), device=device)
809
+ model = torch.jit.trace_module(
810
+ model,
811
+ inputs=dict(
812
+ forward=(example_images, example_text, example_concentration, example_time, example_compound_embedding),
813
+ encode_text=(example_text, True, example_concentration, example_time, example_compound_embedding),
814
+ encode_image=(example_images,)
815
+ ))
816
+ model.visual.image_size = image_size
817
+ return model
818
+
819
+
820
+ def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):
821
+ # Rescale the grid of position embeddings when loading from state_dict
822
+ old_pos_embed = state_dict.get('visual.positional_embedding', None)
823
+ if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
824
+ return
825
+ grid_size = to_2tuple(model.visual.grid_size)
826
+ extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
827
+ new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
828
+ if new_seq_len == old_pos_embed.shape[0]:
829
+ return
830
+
831
+ if extra_tokens:
832
+ pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
833
+ else:
834
+ pos_emb_tok, pos_emb_img = None, old_pos_embed
835
+ old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
836
+
837
+ logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
838
+ pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
839
+ pos_emb_img = F.interpolate(
840
+ pos_emb_img,
841
+ size=grid_size,
842
+ mode=interpolation,
843
+ antialias=antialias,
844
+ align_corners=False,
845
+ )
846
+ pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
847
+ if pos_emb_tok is not None:
848
+ new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
849
+ else:
850
+ new_pos_embed = pos_emb_img
851
+ state_dict['visual.positional_embedding'] = new_pos_embed
852
+
853
+
854
+ def resize_text_pos_embed(state_dict, model, interpolation: str = 'linear', antialias: bool = False):
855
+ old_pos_embed = state_dict.get('positional_embedding', None)
856
+ if old_pos_embed is None:
857
+ return
858
+ # FIXME add support for text cls_token
859
+ model_pos_embed = getattr(model, 'positional_embedding', None)
860
+ if model_pos_embed is None:
861
+ model_pos_embed = getattr(model.text, 'positional_embedding', None)
862
+
863
+ old_num_pos = old_pos_embed.shape[0]
864
+ old_width = old_pos_embed.shape[1]
865
+ num_pos = model_pos_embed.shape[0]
866
+ width = model_pos_embed.shape[1]
867
+ assert old_width == width, 'text pos_embed width changed!'
868
+ if old_num_pos == num_pos:
869
+ return
870
+
871
+ logging.info('Resizing text position embedding num_pos from %s to %s', old_num_pos, num_pos)
872
+ old_pos_embed = old_pos_embed.reshape(1, old_num_pos, old_width).permute(0, 2, 1)
873
+ old_pos_embed = F.interpolate(
874
+ old_pos_embed,
875
+ size=num_pos,
876
+ mode=interpolation,
877
+ antialias=antialias,
878
+ align_corners=False,
879
+ )
880
+ old_pos_embed = old_pos_embed.permute(0, 2, 1)[0]
881
+ new_pos_embed = old_pos_embed
882
+
883
+ state_dict['positional_embedding'] = new_pos_embed
884
+
885
+
886
+ def get_model_preprocess_cfg(model):
887
+ module = getattr(model, 'visual', model)
888
+ preprocess_cfg = getattr(module, 'preprocess_cfg', {})
889
+ if not preprocess_cfg:
890
+ # use separate legacy attributes if preprocess_cfg dict not found
891
+ size = getattr(module, 'image_size')
892
+ if size is not None:
893
+ preprocess_cfg['size'] = size
894
+ mean = getattr(module, 'image_mean', None)
895
+ if mean is not None:
896
+ preprocess_cfg['mean'] = mean
897
+ std = getattr(module, 'image_std', None)
898
+ if std is not None:
899
+ preprocess_cfg['std'] = std
900
+ return preprocess_cfg
901
+
902
+
903
+ def set_model_preprocess_cfg(model, preprocess_cfg: Dict[str, Any]):
904
+ module = getattr(model, 'visual', model)
905
+ module.image_mean = preprocess_cfg['mean'] # legacy attribute, keeping for bwd compat
906
+ module.image_std = preprocess_cfg['std'] # legacy attribute, keeping for bwd compat
907
+ module.preprocess_cfg = copy.deepcopy(preprocess_cfg) # new attr, package all pp cfg as dict
908
+
909
+
910
+ def get_model_tokenize_cfg(model):
911
+ module = getattr(model, 'text', model)
912
+ cfg = {}
913
+ context_length = getattr(module, 'context_length', None)
914
+ if context_length is not None:
915
+ cfg['context_length'] = context_length
916
+ vocab_size = getattr(module, 'vocab_size', None)
917
+ if vocab_size is not None:
918
+ cfg['vocab_size'] = vocab_size
919
+ return cfg
src/open_clip/model_configs/EVA01-g-14-plus.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "timm_model_name": "eva_giant_patch14_224",
6
+ "timm_model_pretrained": false,
7
+ "timm_pool": "token",
8
+ "timm_proj": null
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 1024,
14
+ "heads": 16,
15
+ "layers": 24
16
+ },
17
+ "custom_text": true
18
+ }
src/open_clip/model_configs/EVA01-g-14.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "timm_model_name": "eva_giant_patch14_224",
6
+ "timm_model_pretrained": false,
7
+ "timm_pool": "token",
8
+ "timm_proj": null
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 768,
14
+ "heads": 12,
15
+ "layers": 12
16
+ },
17
+ "custom_text": true
18
+ }
src/open_clip/model_configs/EVA02-B-16.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "timm_model_name": "eva02_base_patch16_clip_224",
6
+ "timm_model_pretrained": false,
7
+ "timm_pool": "token",
8
+ "timm_proj": null
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 512,
14
+ "heads": 8,
15
+ "layers": 12
16
+ },
17
+ "custom_text": true
18
+ }
src/open_clip/model_configs/EVA02-E-14-plus.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "timm_model_name": "eva02_enormous_patch14_clip_224",
6
+ "timm_model_pretrained": false,
7
+ "timm_pool": "token",
8
+ "timm_proj": null
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 1280,
14
+ "heads": 20,
15
+ "layers": 32
16
+ },
17
+ "custom_text": true
18
+ }
src/open_clip/model_configs/EVA02-E-14.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "timm_model_name": "eva02_enormous_patch14_clip_224",
6
+ "timm_model_pretrained": false,
7
+ "timm_pool": "token",
8
+ "timm_proj": null
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 1024,
14
+ "heads": 16,
15
+ "layers": 24
16
+ },
17
+ "custom_text": true
18
+ }
src/open_clip/model_configs/EVA02-L-14-336.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 336,
5
+ "timm_model_name": "eva02_large_patch14_clip_336",
6
+ "timm_model_pretrained": false,
7
+ "timm_pool": "token",
8
+ "timm_proj": null
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 768,
14
+ "heads": 12,
15
+ "layers": 12
16
+ },
17
+ "custom_text": true
18
+ }
src/open_clip/model_configs/EVA02-L-14.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "timm_model_name": "eva02_large_patch14_clip_224",
6
+ "timm_model_pretrained": false,
7
+ "timm_pool": "token",
8
+ "timm_proj": null
9
+ },
10
+ "text_cfg": {
11
+ "context_length": 77,
12
+ "vocab_size": 49408,
13
+ "width": 768,
14
+ "heads": 12,
15
+ "layers": 12
16
+ },
17
+ "custom_text": true
18
+ }
src/open_clip/model_configs/MobileCLIP-B.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "timm_model_name": "vit_base_mci_224",
5
+ "timm_model_pretrained": false,
6
+ "timm_pool": "token",
7
+ "timm_proj": null,
8
+ "timm_drop": 0.0,
9
+ "timm_drop_path": 0.0,
10
+ "image_size": 224
11
+ },
12
+ "text_cfg": {
13
+ "context_length": 77,
14
+ "vocab_size": 49408,
15
+ "width": 512,
16
+ "heads": 8,
17
+ "layers": 12,
18
+ "no_causal_mask": false
19
+ },
20
+ "custom_text": true
21
+ }
src/open_clip/model_configs/MobileCLIP-S1.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "timm_model_name": "fastvit_mci1",
5
+ "timm_model_pretrained": false,
6
+ "timm_pool": "avg",
7
+ "timm_proj": null,
8
+ "timm_drop": 0.0,
9
+ "timm_drop_path": 0.0,
10
+ "image_size": 256
11
+ },
12
+ "text_cfg": {
13
+ "context_length": 77,
14
+ "vocab_size": 49408,
15
+ "width": 512,
16
+ "heads": 8,
17
+ "layers": 12,
18
+ "no_causal_mask": true
19
+ },
20
+ "custom_text": true
21
+ }
src/open_clip/model_configs/MobileCLIP-S2.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "timm_model_name": "fastvit_mci2",
5
+ "timm_model_pretrained": false,
6
+ "timm_pool": "avg",
7
+ "timm_proj": null,
8
+ "timm_drop": 0.0,
9
+ "timm_drop_path": 0.0,
10
+ "image_size": 256
11
+ },
12
+ "text_cfg": {
13
+ "context_length": 77,
14
+ "vocab_size": 49408,
15
+ "width": 512,
16
+ "heads": 8,
17
+ "layers": 12,
18
+ "no_causal_mask": true
19
+ },
20
+ "custom_text": true
21
+ }
src/open_clip/model_configs/RN101-quickgelu.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "quick_gelu": true,
4
+ "vision_cfg": {
5
+ "image_size": 224,
6
+ "layers": [
7
+ 3,
8
+ 4,
9
+ 23,
10
+ 3
11
+ ],
12
+ "width": 64,
13
+ "patch_size": null
14
+ },
15
+ "text_cfg": {
16
+ "context_length": 77,
17
+ "vocab_size": 49408,
18
+ "width": 512,
19
+ "heads": 8,
20
+ "layers": 12
21
+ }
22
+ }
src/open_clip/model_configs/RN101.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 512,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": [
6
+ 3,
7
+ 4,
8
+ 23,
9
+ 3
10
+ ],
11
+ "width": 64,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 512,
18
+ "heads": 8,
19
+ "layers": 12
20
+ }
21
+ }
src/open_clip/model_configs/RN50-quickgelu.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "quick_gelu": true,
4
+ "vision_cfg": {
5
+ "image_size": 224,
6
+ "layers": [
7
+ 3,
8
+ 4,
9
+ 6,
10
+ 3
11
+ ],
12
+ "width": 64,
13
+ "patch_size": null
14
+ },
15
+ "text_cfg": {
16
+ "context_length": 77,
17
+ "vocab_size": 49408,
18
+ "width": 512,
19
+ "heads": 8,
20
+ "layers": 12
21
+ }
22
+ }
src/open_clip/model_configs/RN50.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 224,
5
+ "layers": [
6
+ 3,
7
+ 4,
8
+ 6,
9
+ 3
10
+ ],
11
+ "width": 64,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 512,
18
+ "heads": 8,
19
+ "layers": 12
20
+ }
21
+ }
src/open_clip/model_configs/RN50x16-quickgelu.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "quick_gelu": true,
4
+ "vision_cfg": {
5
+ "image_size": 384,
6
+ "layers": [
7
+ 6,
8
+ 8,
9
+ 18,
10
+ 8
11
+ ],
12
+ "width": 96,
13
+ "patch_size": null
14
+ },
15
+ "text_cfg": {
16
+ "context_length": 77,
17
+ "vocab_size": 49408,
18
+ "width": 768,
19
+ "heads": 12,
20
+ "layers": 12
21
+ }
22
+ }
src/open_clip/model_configs/RN50x16.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "vision_cfg": {
4
+ "image_size": 384,
5
+ "layers": [
6
+ 6,
7
+ 8,
8
+ 18,
9
+ 8
10
+ ],
11
+ "width": 96,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 768,
18
+ "heads": 12,
19
+ "layers": 12
20
+ }
21
+ }
src/open_clip/model_configs/RN50x4-quickgelu.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 640,
3
+ "quick_gelu": true,
4
+ "vision_cfg": {
5
+ "image_size": 288,
6
+ "layers": [
7
+ 4,
8
+ 6,
9
+ 10,
10
+ 6
11
+ ],
12
+ "width": 80,
13
+ "patch_size": null
14
+ },
15
+ "text_cfg": {
16
+ "context_length": 77,
17
+ "vocab_size": 49408,
18
+ "width": 640,
19
+ "heads": 10,
20
+ "layers": 12
21
+ }
22
+ }
src/open_clip/model_configs/RN50x4.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 640,
3
+ "vision_cfg": {
4
+ "image_size": 288,
5
+ "layers": [
6
+ 4,
7
+ 6,
8
+ 10,
9
+ 6
10
+ ],
11
+ "width": 80,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 640,
18
+ "heads": 10,
19
+ "layers": 12
20
+ }
21
+ }
src/open_clip/model_configs/RN50x64-quickgelu.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "quick_gelu": true,
4
+ "vision_cfg": {
5
+ "image_size": 448,
6
+ "layers": [
7
+ 3,
8
+ 15,
9
+ 36,
10
+ 10
11
+ ],
12
+ "width": 128,
13
+ "patch_size": null
14
+ },
15
+ "text_cfg": {
16
+ "context_length": 77,
17
+ "vocab_size": 49408,
18
+ "width": 1024,
19
+ "heads": 16,
20
+ "layers": 12
21
+ }
22
+ }
src/open_clip/model_configs/RN50x64.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 1024,
3
+ "vision_cfg": {
4
+ "image_size": 448,
5
+ "layers": [
6
+ 3,
7
+ 15,
8
+ 36,
9
+ 10
10
+ ],
11
+ "width": 128,
12
+ "patch_size": null
13
+ },
14
+ "text_cfg": {
15
+ "context_length": 77,
16
+ "vocab_size": 49408,
17
+ "width": 1024,
18
+ "heads": 16,
19
+ "layers": 12
20
+ }
21
+ }
src/open_clip/model_configs/ViT-B-16-SigLIP-256.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "init_logit_bias": -10,
4
+ "custom_text": true,
5
+ "vision_cfg": {
6
+ "image_size": 256,
7
+ "timm_model_name": "vit_base_patch16_siglip_256",
8
+ "timm_model_pretrained": false,
9
+ "timm_pool": "map",
10
+ "timm_proj": "none"
11
+ },
12
+ "text_cfg": {
13
+ "context_length": 64,
14
+ "vocab_size": 32000,
15
+ "hf_tokenizer_name": "timm/ViT-B-16-SigLIP",
16
+ "tokenizer_kwargs": {
17
+ "clean": "canonicalize"
18
+ },
19
+ "width": 768,
20
+ "heads": 12,
21
+ "layers": 12,
22
+ "no_causal_mask": true,
23
+ "proj_bias": true,
24
+ "pool_type": "last",
25
+ "norm_kwargs":{
26
+ "eps": 1e-6
27
+ }
28
+ }
29
+ }
src/open_clip/model_configs/ViT-B-16-SigLIP-384.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "init_logit_bias": -10,
4
+ "custom_text": true,
5
+ "vision_cfg": {
6
+ "image_size": 384,
7
+ "timm_model_name": "vit_base_patch16_siglip_384",
8
+ "timm_model_pretrained": false,
9
+ "timm_pool": "map",
10
+ "timm_proj": "none"
11
+ },
12
+ "text_cfg": {
13
+ "context_length": 64,
14
+ "vocab_size": 32000,
15
+ "hf_tokenizer_name": "timm/ViT-B-16-SigLIP",
16
+ "tokenizer_kwargs": {
17
+ "clean": "canonicalize"
18
+ },
19
+ "width": 768,
20
+ "heads": 12,
21
+ "layers": 12,
22
+ "no_causal_mask": true,
23
+ "proj_bias": true,
24
+ "pool_type": "last",
25
+ "norm_kwargs":{
26
+ "eps": 1e-6
27
+ }
28
+ }
29
+ }
src/open_clip/model_configs/ViT-B-16-SigLIP-512.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "init_logit_bias": -10,
4
+ "custom_text": true,
5
+ "vision_cfg": {
6
+ "image_size": 512,
7
+ "timm_model_name": "vit_base_patch16_siglip_512",
8
+ "timm_model_pretrained": false,
9
+ "timm_pool": "map",
10
+ "timm_proj": "none"
11
+ },
12
+ "text_cfg": {
13
+ "context_length": 64,
14
+ "vocab_size": 32000,
15
+ "hf_tokenizer_name": "timm/ViT-B-16-SigLIP",
16
+ "tokenizer_kwargs": {
17
+ "clean": "canonicalize"
18
+ },
19
+ "width": 768,
20
+ "heads": 12,
21
+ "layers": 12,
22
+ "no_causal_mask": true,
23
+ "proj_bias": true,
24
+ "pool_type": "last",
25
+ "norm_kwargs":{
26
+ "eps": 1e-6
27
+ }
28
+ }
29
+ }
src/open_clip/model_configs/ViT-B-16-SigLIP-i18n-256.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "init_logit_bias": -10,
4
+ "custom_text": true,
5
+ "vision_cfg": {
6
+ "image_size": 256,
7
+ "timm_model_name": "vit_base_patch16_siglip_256",
8
+ "timm_model_pretrained": false,
9
+ "timm_pool": "map",
10
+ "timm_proj": "none"
11
+ },
12
+ "text_cfg": {
13
+ "context_length": 64,
14
+ "vocab_size": 250000,
15
+ "hf_tokenizer_name": "timm/ViT-B-16-SigLIP-i18n-256",
16
+ "tokenizer_kwargs": {
17
+ "clean": "canonicalize"
18
+ },
19
+ "width": 768,
20
+ "heads": 12,
21
+ "layers": 12,
22
+ "no_causal_mask": true,
23
+ "proj_bias": true,
24
+ "pool_type": "last",
25
+ "norm_kwargs":{
26
+ "eps": 1e-6
27
+ }
28
+ }
29
+ }
src/open_clip/model_configs/ViT-B-16-SigLIP.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "init_logit_bias": -10,
4
+ "custom_text": true,
5
+ "vision_cfg": {
6
+ "image_size": 224,
7
+ "timm_model_name": "vit_base_patch16_siglip_224",
8
+ "timm_model_pretrained": false,
9
+ "timm_pool": "map",
10
+ "timm_proj": "none"
11
+ },
12
+ "text_cfg": {
13
+ "context_length": 64,
14
+ "vocab_size": 32000,
15
+ "hf_tokenizer_name": "timm/ViT-B-16-SigLIP",
16
+ "tokenizer_kwargs": {
17
+ "clean": "canonicalize"
18
+ },
19
+ "width": 768,
20
+ "heads": 12,
21
+ "layers": 12,
22
+ "no_causal_mask": true,
23
+ "proj_bias": true,
24
+ "pool_type": "last",
25
+ "norm_kwargs":{
26
+ "eps": 1e-6
27
+ }
28
+ }
29
+ }
src/open_clip/model_configs/ViT-B-16-SigLIP2-256.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "init_logit_bias": -10,
4
+ "custom_text": true,
5
+ "vision_cfg": {
6
+ "image_size": 256,
7
+ "timm_model_name": "vit_base_patch16_siglip_256",
8
+ "timm_model_pretrained": false,
9
+ "timm_pool": "map",
10
+ "timm_proj": "none"
11
+ },
12
+ "text_cfg": {
13
+ "context_length": 64,
14
+ "vocab_size": 256000,
15
+ "hf_tokenizer_name": "timm/ViT-B-16-SigLIP2-256",
16
+ "tokenizer_kwargs": {
17
+ "clean": "canonicalize"
18
+ },
19
+ "width": 768,
20
+ "heads": 12,
21
+ "layers": 12,
22
+ "no_causal_mask": true,
23
+ "proj_bias": true,
24
+ "pool_type": "last",
25
+ "norm_kwargs":{
26
+ "eps": 1e-6
27
+ },
28
+ "act_kwargs": {
29
+ "approximate": "tanh"
30
+ }
31
+ }
32
+ }
src/open_clip/model_configs/ViT-B-16-SigLIP2-384.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "init_logit_bias": -10,
4
+ "custom_text": true,
5
+ "vision_cfg": {
6
+ "image_size": 384,
7
+ "timm_model_name": "vit_base_patch16_siglip_384",
8
+ "timm_model_pretrained": false,
9
+ "timm_pool": "map",
10
+ "timm_proj": "none"
11
+ },
12
+ "text_cfg": {
13
+ "context_length": 64,
14
+ "vocab_size": 256000,
15
+ "hf_tokenizer_name": "timm/ViT-B-16-SigLIP2-384",
16
+ "tokenizer_kwargs": {
17
+ "clean": "canonicalize"
18
+ },
19
+ "width": 768,
20
+ "heads": 12,
21
+ "layers": 12,
22
+ "no_causal_mask": true,
23
+ "proj_bias": true,
24
+ "pool_type": "last",
25
+ "norm_kwargs":{
26
+ "eps": 1e-6
27
+ },
28
+ "act_kwargs": {
29
+ "approximate": "tanh"
30
+ }
31
+ }
32
+ }
src/open_clip/model_configs/ViT-B-16-SigLIP2-512.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 768,
3
+ "init_logit_bias": -10,
4
+ "custom_text": true,
5
+ "vision_cfg": {
6
+ "image_size": 512,
7
+ "timm_model_name": "vit_base_patch16_siglip_512",
8
+ "timm_model_pretrained": false,
9
+ "timm_pool": "map",
10
+ "timm_proj": "none"
11
+ },
12
+ "text_cfg": {
13
+ "context_length": 64,
14
+ "vocab_size": 256000,
15
+ "hf_tokenizer_name": "timm/ViT-B-16-SigLIP2-512",
16
+ "tokenizer_kwargs": {
17
+ "clean": "canonicalize"
18
+ },
19
+ "width": 768,
20
+ "heads": 12,
21
+ "layers": 12,
22
+ "no_causal_mask": true,
23
+ "proj_bias": true,
24
+ "pool_type": "last",
25
+ "norm_kwargs":{
26
+ "eps": 1e-6
27
+ },
28
+ "act_kwargs": {
29
+ "approximate": "tanh"
30
+ }
31
+ }
32
+ }