Initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .girattributes +2 -0
- .github/workflows/ci.yml +121 -0
- .github/workflows/clear-cache.yml +29 -0
- .github/workflows/python-publish.yml +37 -0
- .gitignore +153 -0
- CITATION.cff +33 -0
- HISTORY.md +223 -0
- LICENSE +23 -0
- MANIFEST.in +3 -0
- README.md +618 -0
- models.txt +2 -0
- pytest.ini +3 -0
- requirements.txt +8 -0
- src/open_clip/__init__.py +18 -0
- src/open_clip/coca_model.py +582 -0
- src/open_clip/constants.py +11 -0
- src/open_clip/convert.py +206 -0
- src/open_clip/factory.py +586 -0
- src/open_clip/hf_configs.py +67 -0
- src/open_clip/hf_model.py +193 -0
- src/open_clip/loss.py +447 -0
- src/open_clip/model.py +919 -0
- src/open_clip/model_configs/EVA01-g-14-plus.json +18 -0
- src/open_clip/model_configs/EVA01-g-14.json +18 -0
- src/open_clip/model_configs/EVA02-B-16.json +18 -0
- src/open_clip/model_configs/EVA02-E-14-plus.json +18 -0
- src/open_clip/model_configs/EVA02-E-14.json +18 -0
- src/open_clip/model_configs/EVA02-L-14-336.json +18 -0
- src/open_clip/model_configs/EVA02-L-14.json +18 -0
- src/open_clip/model_configs/MobileCLIP-B.json +21 -0
- src/open_clip/model_configs/MobileCLIP-S1.json +21 -0
- src/open_clip/model_configs/MobileCLIP-S2.json +21 -0
- src/open_clip/model_configs/RN101-quickgelu.json +22 -0
- src/open_clip/model_configs/RN101.json +21 -0
- src/open_clip/model_configs/RN50-quickgelu.json +22 -0
- src/open_clip/model_configs/RN50.json +21 -0
- src/open_clip/model_configs/RN50x16-quickgelu.json +22 -0
- src/open_clip/model_configs/RN50x16.json +21 -0
- src/open_clip/model_configs/RN50x4-quickgelu.json +22 -0
- src/open_clip/model_configs/RN50x4.json +21 -0
- src/open_clip/model_configs/RN50x64-quickgelu.json +22 -0
- src/open_clip/model_configs/RN50x64.json +21 -0
- src/open_clip/model_configs/ViT-B-16-SigLIP-256.json +29 -0
- src/open_clip/model_configs/ViT-B-16-SigLIP-384.json +29 -0
- src/open_clip/model_configs/ViT-B-16-SigLIP-512.json +29 -0
- src/open_clip/model_configs/ViT-B-16-SigLIP-i18n-256.json +29 -0
- src/open_clip/model_configs/ViT-B-16-SigLIP.json +29 -0
- src/open_clip/model_configs/ViT-B-16-SigLIP2-256.json +32 -0
- src/open_clip/model_configs/ViT-B-16-SigLIP2-384.json +32 -0
- src/open_clip/model_configs/ViT-B-16-SigLIP2-512.json +32 -0
.girattributes
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*.py linguist-language=python
|
2 |
+
*.ipynb linguist-documentation
|
.github/workflows/ci.yml
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Continuous integration
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches:
|
6 |
+
- main
|
7 |
+
paths-ignore:
|
8 |
+
- '**.md'
|
9 |
+
- 'CITATION.cff'
|
10 |
+
- 'LICENSE'
|
11 |
+
- '.gitignore'
|
12 |
+
- 'docs/**'
|
13 |
+
pull_request:
|
14 |
+
branches:
|
15 |
+
- main
|
16 |
+
paths-ignore:
|
17 |
+
- '**.md'
|
18 |
+
- 'CITATION.cff'
|
19 |
+
- 'LICENSE'
|
20 |
+
- '.gitignore'
|
21 |
+
- 'docs/**'
|
22 |
+
workflow_dispatch:
|
23 |
+
inputs:
|
24 |
+
manual_revision_reference:
|
25 |
+
required: false
|
26 |
+
type: string
|
27 |
+
manual_revision_test:
|
28 |
+
required: false
|
29 |
+
type: string
|
30 |
+
|
31 |
+
env:
|
32 |
+
REVISION_REFERENCE: v2.8.2
|
33 |
+
#9d31b2ec4df6d8228f370ff20c8267ec6ba39383 earliest compatible v2.7.0 + pretrained_hf param
|
34 |
+
|
35 |
+
jobs:
|
36 |
+
Tests:
|
37 |
+
strategy:
|
38 |
+
matrix:
|
39 |
+
os: [ ubuntu-latest ] #, macos-latest ]
|
40 |
+
python: [ 3.8 ]
|
41 |
+
job_num: [ 4 ]
|
42 |
+
job: [ 1, 2, 3, 4 ]
|
43 |
+
runs-on: ${{ matrix.os }}
|
44 |
+
steps:
|
45 |
+
- uses: actions/checkout@v3
|
46 |
+
with:
|
47 |
+
fetch-depth: 0
|
48 |
+
ref: ${{ inputs.manual_revision_test }}
|
49 |
+
- name: Set up Python ${{ matrix.python }}
|
50 |
+
id: pythonsetup
|
51 |
+
uses: actions/setup-python@v4
|
52 |
+
with:
|
53 |
+
python-version: ${{ matrix.python }}
|
54 |
+
- name: Venv cache
|
55 |
+
id: venv-cache
|
56 |
+
uses: actions/cache@v3
|
57 |
+
with:
|
58 |
+
path: .env
|
59 |
+
key: venv-${{ matrix.os }}-${{ steps.pythonsetup.outputs.python-version }}-${{ hashFiles('requirements*') }}
|
60 |
+
- name: Pytest durations cache
|
61 |
+
uses: actions/cache@v3
|
62 |
+
with:
|
63 |
+
path: .test_durations
|
64 |
+
key: test_durations-${{ matrix.os }}-${{ steps.pythonsetup.outputs.python-version }}-${{ matrix.job }}-${{ github.run_id }}
|
65 |
+
restore-keys: test_durations-0-
|
66 |
+
- name: Setup
|
67 |
+
if: steps.venv-cache.outputs.cache-hit != 'true'
|
68 |
+
run: |
|
69 |
+
python3 -m venv .env
|
70 |
+
source .env/bin/activate
|
71 |
+
pip install -e .[test]
|
72 |
+
- name: Prepare test data
|
73 |
+
run: |
|
74 |
+
source .env/bin/activate
|
75 |
+
python -m pytest \
|
76 |
+
--quiet --co \
|
77 |
+
--splitting-algorithm least_duration \
|
78 |
+
--splits ${{ matrix.job_num }} \
|
79 |
+
--group ${{ matrix.job }} \
|
80 |
+
-m regression_test \
|
81 |
+
tests \
|
82 |
+
| head -n -2 | grep -Po 'test_inference_with_data\[\K[^]]*(?=-False]|-True])' \
|
83 |
+
> models_gh_runner.txt
|
84 |
+
if [ -n "${{ inputs.manual_revision_reference }}" ]; then
|
85 |
+
REVISION_REFERENCE=${{ inputs.manual_revision_reference }}
|
86 |
+
fi
|
87 |
+
python tests/util_test.py \
|
88 |
+
--save_model_list models_gh_runner.txt \
|
89 |
+
--model_list models_gh_runner.txt \
|
90 |
+
--git_revision $REVISION_REFERENCE
|
91 |
+
- name: Unit tests
|
92 |
+
run: |
|
93 |
+
source .env/bin/activate
|
94 |
+
if [[ -f .test_durations ]]
|
95 |
+
then
|
96 |
+
cp .test_durations durations_1
|
97 |
+
mv .test_durations durations_2
|
98 |
+
fi
|
99 |
+
python -m pytest \
|
100 |
+
-x -s -v \
|
101 |
+
--splitting-algorithm least_duration \
|
102 |
+
--splits ${{ matrix.job_num }} \
|
103 |
+
--group ${{ matrix.job }} \
|
104 |
+
--store-durations \
|
105 |
+
--durations-path durations_1 \
|
106 |
+
--clean-durations \
|
107 |
+
-m "not regression_test" \
|
108 |
+
tests
|
109 |
+
OPEN_CLIP_TEST_REG_MODELS=models_gh_runner.txt python -m pytest \
|
110 |
+
-x -s -v \
|
111 |
+
--store-durations \
|
112 |
+
--durations-path durations_2 \
|
113 |
+
--clean-durations \
|
114 |
+
-m "regression_test" \
|
115 |
+
tests
|
116 |
+
jq -s -S 'add' durations_* > .test_durations
|
117 |
+
- name: Collect pytest durations
|
118 |
+
uses: actions/upload-artifact@v4
|
119 |
+
with:
|
120 |
+
name: pytest_durations_${{ matrix.os }}-${{ matrix.python }}-${{ matrix.job }}
|
121 |
+
path: .test_durations
|
.github/workflows/clear-cache.yml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Clear cache
|
2 |
+
|
3 |
+
on:
|
4 |
+
workflow_dispatch:
|
5 |
+
|
6 |
+
permissions:
|
7 |
+
actions: write
|
8 |
+
|
9 |
+
jobs:
|
10 |
+
clear-cache:
|
11 |
+
runs-on: ubuntu-latest
|
12 |
+
steps:
|
13 |
+
- name: Clear cache
|
14 |
+
uses: actions/github-script@v6
|
15 |
+
with:
|
16 |
+
script: |
|
17 |
+
const caches = await github.rest.actions.getActionsCacheList({
|
18 |
+
owner: context.repo.owner,
|
19 |
+
repo: context.repo.repo,
|
20 |
+
})
|
21 |
+
for (const cache of caches.data.actions_caches) {
|
22 |
+
console.log(cache)
|
23 |
+
await github.rest.actions.deleteActionsCacheById({
|
24 |
+
owner: context.repo.owner,
|
25 |
+
repo: context.repo.repo,
|
26 |
+
cache_id: cache.id,
|
27 |
+
})
|
28 |
+
}
|
29 |
+
|
.github/workflows/python-publish.yml
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Release
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches:
|
6 |
+
- main
|
7 |
+
jobs:
|
8 |
+
deploy:
|
9 |
+
runs-on: ubuntu-latest
|
10 |
+
steps:
|
11 |
+
- uses: actions/checkout@v2
|
12 |
+
- uses: actions-ecosystem/action-regex-match@v2
|
13 |
+
id: regex-match
|
14 |
+
with:
|
15 |
+
text: ${{ github.event.head_commit.message }}
|
16 |
+
regex: '^Release ([^ ]+)'
|
17 |
+
- name: Set up Python
|
18 |
+
uses: actions/setup-python@v2
|
19 |
+
with:
|
20 |
+
python-version: '3.8'
|
21 |
+
- name: Install dependencies
|
22 |
+
run: |
|
23 |
+
python -m pip install --upgrade pip
|
24 |
+
pip install setuptools wheel twine build
|
25 |
+
- name: Release
|
26 |
+
if: ${{ steps.regex-match.outputs.match != '' }}
|
27 |
+
uses: softprops/action-gh-release@v1
|
28 |
+
with:
|
29 |
+
tag_name: v${{ steps.regex-match.outputs.group1 }}
|
30 |
+
- name: Build and publish
|
31 |
+
if: ${{ steps.regex-match.outputs.match != '' }}
|
32 |
+
env:
|
33 |
+
TWINE_USERNAME: __token__
|
34 |
+
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
|
35 |
+
run: |
|
36 |
+
python -m build
|
37 |
+
twine upload dist/*
|
.gitignore
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
**/logs/
|
2 |
+
**/wandb/
|
3 |
+
models/
|
4 |
+
features/
|
5 |
+
results/
|
6 |
+
|
7 |
+
tests/data/
|
8 |
+
*.pt
|
9 |
+
|
10 |
+
# Byte-compiled / optimized / DLL files
|
11 |
+
__pycache__/
|
12 |
+
*.py[cod]
|
13 |
+
*$py.class
|
14 |
+
|
15 |
+
# C extensions
|
16 |
+
*.so
|
17 |
+
|
18 |
+
# Distribution / packaging
|
19 |
+
.Python
|
20 |
+
build/
|
21 |
+
develop-eggs/
|
22 |
+
dist/
|
23 |
+
downloads/
|
24 |
+
eggs/
|
25 |
+
.eggs/
|
26 |
+
lib/
|
27 |
+
lib64/
|
28 |
+
parts/
|
29 |
+
sdist/
|
30 |
+
var/
|
31 |
+
wheels/
|
32 |
+
pip-wheel-metadata/
|
33 |
+
share/python-wheels/
|
34 |
+
*.egg-info/
|
35 |
+
.installed.cfg
|
36 |
+
*.egg
|
37 |
+
MANIFEST
|
38 |
+
|
39 |
+
# PyInstaller
|
40 |
+
# Usually these files are written by a python script from a template
|
41 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
42 |
+
*.manifest
|
43 |
+
*.spec
|
44 |
+
|
45 |
+
# Installer logs
|
46 |
+
pip-log.txt
|
47 |
+
pip-delete-this-directory.txt
|
48 |
+
|
49 |
+
# Unit test / coverage reports
|
50 |
+
htmlcov/
|
51 |
+
.tox/
|
52 |
+
.nox/
|
53 |
+
.coverage
|
54 |
+
.coverage.*
|
55 |
+
.cache
|
56 |
+
nosetests.xml
|
57 |
+
coverage.xml
|
58 |
+
*.cover
|
59 |
+
*.py,cover
|
60 |
+
.hypothesis/
|
61 |
+
.pytest_cache/
|
62 |
+
|
63 |
+
# Translations
|
64 |
+
*.mo
|
65 |
+
*.pot
|
66 |
+
|
67 |
+
# Django stuff:
|
68 |
+
*.log
|
69 |
+
local_settings.py
|
70 |
+
db.sqlite3
|
71 |
+
db.sqlite3-journal
|
72 |
+
|
73 |
+
# Flask stuff:
|
74 |
+
instance/
|
75 |
+
.webassets-cache
|
76 |
+
|
77 |
+
# Scrapy stuff:
|
78 |
+
.scrapy
|
79 |
+
|
80 |
+
# Sphinx documentation
|
81 |
+
docs/_build/
|
82 |
+
|
83 |
+
# PyBuilder
|
84 |
+
target/
|
85 |
+
|
86 |
+
# Jupyter Notebook
|
87 |
+
.ipynb_checkpoints
|
88 |
+
|
89 |
+
# IPython
|
90 |
+
profile_default/
|
91 |
+
ipython_config.py
|
92 |
+
|
93 |
+
# pyenv
|
94 |
+
.python-version
|
95 |
+
|
96 |
+
# pipenv
|
97 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
98 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
99 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
100 |
+
# install all needed dependencies.
|
101 |
+
#Pipfile.lock
|
102 |
+
|
103 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
104 |
+
__pypackages__/
|
105 |
+
|
106 |
+
# Celery stuff
|
107 |
+
celerybeat-schedule
|
108 |
+
celerybeat.pid
|
109 |
+
|
110 |
+
# SageMath parsed files
|
111 |
+
*.sage.py
|
112 |
+
|
113 |
+
# Environments
|
114 |
+
.env
|
115 |
+
.venv
|
116 |
+
env/
|
117 |
+
venv/
|
118 |
+
ENV/
|
119 |
+
env.bak/
|
120 |
+
venv.bak/
|
121 |
+
|
122 |
+
# Spyder project settings
|
123 |
+
.spyderproject
|
124 |
+
.spyproject
|
125 |
+
|
126 |
+
# Rope project settings
|
127 |
+
.ropeproject
|
128 |
+
|
129 |
+
# mkdocs documentation
|
130 |
+
/site
|
131 |
+
|
132 |
+
# mypy
|
133 |
+
.mypy_cache/
|
134 |
+
.dmypy.json
|
135 |
+
dmypy.json
|
136 |
+
|
137 |
+
# Pyre type checker
|
138 |
+
.pyre/
|
139 |
+
sync.sh
|
140 |
+
gpu1sync.sh
|
141 |
+
.idea
|
142 |
+
*.pdf
|
143 |
+
**/._*
|
144 |
+
**/*DS_*
|
145 |
+
**.jsonl
|
146 |
+
src/sbatch
|
147 |
+
src/misc
|
148 |
+
.vscode
|
149 |
+
src/debug
|
150 |
+
core.*
|
151 |
+
|
152 |
+
# Allow
|
153 |
+
!src/evaluation/misc/results_dbs/*
|
CITATION.cff
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
cff-version: 1.1.0
|
2 |
+
message: If you use this software, please cite it as below.
|
3 |
+
authors:
|
4 |
+
- family-names: Ilharco
|
5 |
+
given-names: Gabriel
|
6 |
+
- family-names: Wortsman
|
7 |
+
given-names: Mitchell
|
8 |
+
- family-names: Wightman
|
9 |
+
given-names: Ross
|
10 |
+
- family-names: Gordon
|
11 |
+
given-names: Cade
|
12 |
+
- family-names: Carlini
|
13 |
+
given-names: Nicholas
|
14 |
+
- family-names: Taori
|
15 |
+
given-names: Rohan
|
16 |
+
- family-names: Dave
|
17 |
+
given-names: Achal
|
18 |
+
- family-names: Shankar
|
19 |
+
given-names: Vaishaal
|
20 |
+
- family-names: Namkoong
|
21 |
+
given-names: Hongseok
|
22 |
+
- family-names: Miller
|
23 |
+
given-names: John
|
24 |
+
- family-names: Hajishirzi
|
25 |
+
given-names: Hannaneh
|
26 |
+
- family-names: Farhadi
|
27 |
+
given-names: Ali
|
28 |
+
- family-names: Schmidt
|
29 |
+
given-names: Ludwig
|
30 |
+
title: OpenCLIP
|
31 |
+
version: v0.1
|
32 |
+
doi: 10.5281/zenodo.5143773
|
33 |
+
date-released: 2021-07-28
|
HISTORY.md
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## 2.24.0
|
2 |
+
|
3 |
+
* Fix missing space in error message
|
4 |
+
* use model flag for normalizing embeddings
|
5 |
+
* init logit_bias for non siglip pretrained models
|
6 |
+
* Fix logit_bias load_checkpoint addition
|
7 |
+
* Make CoCa model match CLIP models for logit scale/bias init
|
8 |
+
* Fix missing return of "logit_bias" in CoCa.forward
|
9 |
+
* Add NLLB-CLIP with SigLIP models
|
10 |
+
* Add get_logits method and NLLB tokenizer
|
11 |
+
* Remove the empty file src/open_clip/generation_utils.py
|
12 |
+
* Update params.py: "BatchNorm" -> "LayerNorm" in the description string for "--lock-text-freeze-layer-norm"
|
13 |
+
|
14 |
+
## 2.23.0
|
15 |
+
|
16 |
+
* Add CLIPA-v2 models
|
17 |
+
* Add SigLIP models
|
18 |
+
* Add MetaCLIP models
|
19 |
+
* Add NLLB-CLIP models
|
20 |
+
* CLIPA train code
|
21 |
+
* Minor changes/fixes
|
22 |
+
* Remove protobuf version limit
|
23 |
+
* Stop checking model name when loading CoCa models
|
24 |
+
* Log native wandb step
|
25 |
+
* Use bool instead of long masks
|
26 |
+
|
27 |
+
## 2.21.0
|
28 |
+
|
29 |
+
* Add SigLIP loss + training support
|
30 |
+
* Add more DataComp models (B/16, B/32 and B/32@256)
|
31 |
+
* Update default num workers
|
32 |
+
* Update CoCa generation for `transformers>=4.31`
|
33 |
+
* PyTorch 2.0 `state_dict()` compatibility fix for compiled models
|
34 |
+
* Fix padding in `ResizeMaxSize`
|
35 |
+
* Convert JIT model on state dict load for `pretrained='filename…'`
|
36 |
+
* Other minor changes and fixes (typos, README, dependencies, CI)
|
37 |
+
|
38 |
+
## 2.20.0
|
39 |
+
|
40 |
+
* Add EVA models
|
41 |
+
* Support serial worker training
|
42 |
+
* Fix Python 3.7 compatibility
|
43 |
+
|
44 |
+
## 2.19.0
|
45 |
+
|
46 |
+
* Add DataComp models
|
47 |
+
|
48 |
+
## 2.18.0
|
49 |
+
|
50 |
+
* Enable int8 inference without `.weight` attribute
|
51 |
+
|
52 |
+
## 2.17.2
|
53 |
+
|
54 |
+
* Update push_to_hf_hub
|
55 |
+
|
56 |
+
## 2.17.0
|
57 |
+
|
58 |
+
* Add int8 support
|
59 |
+
* Update notebook demo
|
60 |
+
* Refactor zero-shot classification code
|
61 |
+
|
62 |
+
## 2.16.2
|
63 |
+
|
64 |
+
* Fixes for context_length and vocab_size attributes
|
65 |
+
|
66 |
+
## 2.16.1
|
67 |
+
|
68 |
+
* Fixes for context_length and vocab_size attributes
|
69 |
+
* Fix --train-num-samples logic
|
70 |
+
* Add HF BERT configs for PubMed CLIP model
|
71 |
+
|
72 |
+
## 2.16.0
|
73 |
+
|
74 |
+
* Add improved g-14 weights
|
75 |
+
* Update protobuf version
|
76 |
+
|
77 |
+
## 2.15.0
|
78 |
+
|
79 |
+
* Add convnext_xxlarge weights
|
80 |
+
* Fixed import in readme
|
81 |
+
* Add samples per second per gpu logging
|
82 |
+
* Fix slurm example
|
83 |
+
|
84 |
+
## 2.14.0
|
85 |
+
|
86 |
+
* Move dataset mixtures logic to shard level
|
87 |
+
* Fix CoCa accum-grad training
|
88 |
+
* Safer transformers import guard
|
89 |
+
* get_labels refactoring
|
90 |
+
|
91 |
+
## 2.13.0
|
92 |
+
|
93 |
+
* Add support for dataset mixtures with different sampling weights
|
94 |
+
* Make transformers optional again
|
95 |
+
|
96 |
+
## 2.12.0
|
97 |
+
|
98 |
+
* Updated convnext configs for consistency
|
99 |
+
* Added input_patchnorm option
|
100 |
+
* Clean and improve CoCa generation
|
101 |
+
* Support model distillation
|
102 |
+
* Add ConvNeXt-Large 320x320 fine-tune weights
|
103 |
+
|
104 |
+
## 2.11.1
|
105 |
+
|
106 |
+
* Make transformers optional
|
107 |
+
* Add MSCOCO CoCa finetunes to pretrained models
|
108 |
+
|
109 |
+
## 2.11.0
|
110 |
+
|
111 |
+
* coca support and weights
|
112 |
+
* ConvNeXt-Large weights
|
113 |
+
|
114 |
+
## 2.10.1
|
115 |
+
|
116 |
+
* `hf-hub:org/model_id` support for loading models w/ config and weights in Hugging Face Hub
|
117 |
+
|
118 |
+
## 2.10.0
|
119 |
+
|
120 |
+
* Added a ViT-bigG-14 model.
|
121 |
+
* Added an up-to-date example slurm script for large training jobs.
|
122 |
+
* Added a option to sync logs and checkpoints to S3 during training.
|
123 |
+
* New options for LR schedulers, constant and constant with cooldown
|
124 |
+
* Fix wandb autoresuming when resume is not set
|
125 |
+
* ConvNeXt `base` & `base_w` pretrained models added
|
126 |
+
* `timm-` model prefix removed from configs
|
127 |
+
* `timm` augmentation + regularization (dropout / drop-path) supported
|
128 |
+
|
129 |
+
## 2.9.3
|
130 |
+
|
131 |
+
* Fix wandb collapsing multiple parallel runs into a single one
|
132 |
+
|
133 |
+
## 2.9.2
|
134 |
+
|
135 |
+
* Fix braceexpand memory explosion for complex webdataset urls
|
136 |
+
|
137 |
+
## 2.9.1
|
138 |
+
|
139 |
+
* Fix release
|
140 |
+
|
141 |
+
## 2.9.0
|
142 |
+
|
143 |
+
* Add training feature to auto-resume from the latest checkpoint on restart via `--resume latest`
|
144 |
+
* Allow webp in webdataset
|
145 |
+
* Fix logging for number of samples when using gradient accumulation
|
146 |
+
* Add model configs for convnext xxlarge
|
147 |
+
|
148 |
+
## 2.8.2
|
149 |
+
|
150 |
+
* wrapped patchdropout in a torch.nn.Module
|
151 |
+
|
152 |
+
## 2.8.1
|
153 |
+
|
154 |
+
* relax protobuf dependency
|
155 |
+
* override the default patch dropout value in 'vision_cfg'
|
156 |
+
|
157 |
+
## 2.8.0
|
158 |
+
|
159 |
+
* better support for HF models
|
160 |
+
* add support for gradient accumulation
|
161 |
+
* CI fixes
|
162 |
+
* add support for patch dropout
|
163 |
+
* add convnext configs
|
164 |
+
|
165 |
+
|
166 |
+
## 2.7.0
|
167 |
+
|
168 |
+
* add multilingual H/14 xlm roberta large
|
169 |
+
|
170 |
+
## 2.6.1
|
171 |
+
|
172 |
+
* fix setup.py _read_reqs
|
173 |
+
|
174 |
+
## 2.6.0
|
175 |
+
|
176 |
+
* Make openclip training usable from pypi.
|
177 |
+
* Add xlm roberta large vit h 14 config.
|
178 |
+
|
179 |
+
## 2.5.0
|
180 |
+
|
181 |
+
* pretrained B/32 xlm roberta base: first multilingual clip trained on laion5B
|
182 |
+
* pretrained B/32 roberta base: first clip trained using an HF text encoder
|
183 |
+
|
184 |
+
## 2.4.1
|
185 |
+
|
186 |
+
* Add missing hf_tokenizer_name in CLIPTextCfg.
|
187 |
+
|
188 |
+
## 2.4.0
|
189 |
+
|
190 |
+
* Fix #211, missing RN50x64 config. Fix type of dropout param for ResNet models
|
191 |
+
* Bring back LayerNorm impl that casts to input for non bf16/fp16
|
192 |
+
* zero_shot.py: set correct tokenizer based on args
|
193 |
+
* training/params.py: remove hf params and get them from model config
|
194 |
+
|
195 |
+
## 2.3.1
|
196 |
+
|
197 |
+
* Implement grad checkpointing for hf model.
|
198 |
+
* custom_text: True if hf_model_name is set
|
199 |
+
* Disable hf tokenizer parallelism
|
200 |
+
|
201 |
+
## 2.3.0
|
202 |
+
|
203 |
+
* Generalizable Text Transformer with HuggingFace Models (@iejMac)
|
204 |
+
|
205 |
+
## 2.2.0
|
206 |
+
|
207 |
+
* Support for custom text tower
|
208 |
+
* Add checksum verification for pretrained model weights
|
209 |
+
|
210 |
+
## 2.1.0
|
211 |
+
|
212 |
+
* lot including sota models, bfloat16 option, better loading, better metrics
|
213 |
+
|
214 |
+
## 1.2.0
|
215 |
+
|
216 |
+
* ViT-B/32 trained on Laion2B-en
|
217 |
+
* add missing openai RN50x64 model
|
218 |
+
|
219 |
+
## 1.1.1
|
220 |
+
|
221 |
+
* ViT-B/16+
|
222 |
+
* Add grad checkpointing support
|
223 |
+
* more robust data loader
|
LICENSE
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2012-2021 Gabriel Ilharco, Mitchell Wortsman,
|
2 |
+
Nicholas Carlini, Rohan Taori, Achal Dave, Vaishaal Shankar,
|
3 |
+
John Miller, Hongseok Namkoong, Hannaneh Hajishirzi, Ali Farhadi,
|
4 |
+
Ludwig Schmidt
|
5 |
+
|
6 |
+
Permission is hereby granted, free of charge, to any person obtaining
|
7 |
+
a copy of this software and associated documentation files (the
|
8 |
+
"Software"), to deal in the Software without restriction, including
|
9 |
+
without limitation the rights to use, copy, modify, merge, publish,
|
10 |
+
distribute, sublicense, and/or sell copies of the Software, and to
|
11 |
+
permit persons to whom the Software is furnished to do so, subject to
|
12 |
+
the following conditions:
|
13 |
+
|
14 |
+
The above copyright notice and this permission notice shall be
|
15 |
+
included in all copies or substantial portions of the Software.
|
16 |
+
|
17 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
18 |
+
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
19 |
+
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
20 |
+
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
21 |
+
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
22 |
+
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
23 |
+
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
MANIFEST.in
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
include src/open_clip/bpe_simple_vocab_16e6.txt.gz
|
2 |
+
include src/open_clip/model_configs/*.json
|
3 |
+
|
README.md
ADDED
@@ -0,0 +1,618 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# OpenCLIP
|
2 |
+
|
3 |
+
[[Paper]](https://arxiv.org/abs/2212.07143) [[Citations]](#citing) [[Clip Colab]](https://colab.research.google.com/github/mlfoundations/open_clip/blob/master/docs/Interacting_with_open_clip.ipynb) [[Coca Colab]](https://colab.research.google.com/github/mlfoundations/open_clip/blob/master/docs/Interacting_with_open_coca.ipynb)
|
4 |
+
[](https://pypi.python.org/pypi/open_clip_torch)
|
5 |
+
|
6 |
+
Welcome to an open source implementation of OpenAI's [CLIP](https://arxiv.org/abs/2103.00020) (Contrastive Language-Image Pre-training).
|
7 |
+
|
8 |
+
Using this codebase, we have trained several models on a variety of data sources and compute budgets, ranging from [small-scale experiments](docs/LOW_ACC.md) to larger runs including models trained on datasets such as [LAION-400M](https://arxiv.org/abs/2111.02114), [LAION-2B](https://arxiv.org/abs/2210.08402) and [DataComp-1B](https://arxiv.org/abs/2304.14108).
|
9 |
+
Many of our models and their scaling properties are studied in detail in the paper [reproducible scaling laws for contrastive language-image learning](https://arxiv.org/abs/2212.07143).
|
10 |
+
Some of the best models we've trained and their zero-shot ImageNet-1k accuracy are shown below, along with the ViT-L model trained by OpenAI and other state-of-the-art open source alternatives (all can be loaded via OpenCLIP).
|
11 |
+
We provide more details about our full collection of pretrained models [here](docs/PRETRAINED.md), and zero-shot results for 38 datasets [here](docs/openclip_results.csv).
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
| Model | Training data | Resolution | # of samples seen | ImageNet zero-shot acc. |
|
16 |
+
| -------- | ------- | ------- | ------- | ------- |
|
17 |
+
| ConvNext-Base | LAION-2B | 256px | 13B | 71.5% |
|
18 |
+
| ConvNext-Large | LAION-2B | 320px | 29B | 76.9% |
|
19 |
+
| ConvNext-XXLarge | LAION-2B | 256px | 34B | 79.5% |
|
20 |
+
| ViT-B/32 | DataComp-1B | 256px | 34B | 72.8% |
|
21 |
+
| ViT-B/16 | DataComp-1B | 224px | 13B | 73.5% |
|
22 |
+
| ViT-L/14 | LAION-2B | 224px | 32B | 75.3% |
|
23 |
+
| ViT-H/14 | LAION-2B | 224px | 32B | 78.0% |
|
24 |
+
| ViT-L/14 | DataComp-1B | 224px | 13B | 79.2% |
|
25 |
+
| ViT-G/14 | LAION-2B | 224px | 34B | 80.1% |
|
26 |
+
| | | | | |
|
27 |
+
| ViT-L/14-quickgelu [(Original CLIP)](https://arxiv.org/abs/2103.00020) | WIT | 224px | 13B | 75.5% |
|
28 |
+
| ViT-SO400M/14 [(SigLIP)](https://arxiv.org/abs/2303.15343) | WebLI | 224px | 45B | 82.0% |
|
29 |
+
| ViT-L/14 [(DFN)](https://arxiv.org/abs/2309.17425) | DFN-2B | 224px | 39B | 82.2% |
|
30 |
+
| ViT-SO400M-14-SigLIP-384 [(SigLIP)](https://arxiv.org/abs/2303.15343) | WebLI | 384px | 45B | 83.1% |
|
31 |
+
| ViT-H/14-quickgelu [(DFN)](https://arxiv.org/abs/2309.17425) | DFN-5B | 224px | 39B | 83.4% |
|
32 |
+
| ViT-H-14-378-quickgelu [(DFN)](https://arxiv.org/abs/2309.17425) | DFN-5B | 378px | 44B | 84.4% |
|
33 |
+
|
34 |
+
Model cards with additional model specific details can be found on the Hugging Face Hub under the OpenCLIP library tag: https://huggingface.co/models?library=open_clip.
|
35 |
+
|
36 |
+
If you found this repository useful, please consider [citing](#citing).
|
37 |
+
We welcome anyone to submit an issue or send an email if you have any other requests or suggestions.
|
38 |
+
|
39 |
+
Note that portions of `src/open_clip/` modelling and tokenizer code are adaptations of OpenAI's official [repository](https://github.com/openai/CLIP).
|
40 |
+
|
41 |
+
## Approach
|
42 |
+
|
43 |
+
|  |
|
44 |
+
|:--:|
|
45 |
+
| Image Credit: https://github.com/openai/CLIP |
|
46 |
+
|
47 |
+
## Usage
|
48 |
+
|
49 |
+
```
|
50 |
+
pip install open_clip_torch
|
51 |
+
```
|
52 |
+
|
53 |
+
```python
|
54 |
+
import torch
|
55 |
+
from PIL import Image
|
56 |
+
import open_clip
|
57 |
+
|
58 |
+
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
|
59 |
+
model.eval() # model in train mode by default, impacts some models with BatchNorm or stochastic depth active
|
60 |
+
tokenizer = open_clip.get_tokenizer('ViT-B-32')
|
61 |
+
|
62 |
+
image = preprocess(Image.open("docs/CLIP.png")).unsqueeze(0)
|
63 |
+
text = tokenizer(["a diagram", "a dog", "a cat"])
|
64 |
+
|
65 |
+
with torch.no_grad(), torch.autocast("cuda"):
|
66 |
+
image_features = model.encode_image(image)
|
67 |
+
text_features = model.encode_text(text)
|
68 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
69 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
70 |
+
|
71 |
+
text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
|
72 |
+
|
73 |
+
print("Label probs:", text_probs) # prints: [[1., 0., 0.]]
|
74 |
+
```
|
75 |
+
|
76 |
+
If model uses `timm` image encoders (convnext, siglip, eva, etc) ensure the latest timm is installed. Upgrade `timm` if you see 'Unknown model' errors for the image encoder.
|
77 |
+
|
78 |
+
If model uses transformers tokenizers, ensure `transformers` is installed.
|
79 |
+
|
80 |
+
See also this [[Clip Colab]](https://colab.research.google.com/github/mlfoundations/open_clip/blob/master/docs/Interacting_with_open_clip.ipynb).
|
81 |
+
|
82 |
+
To compute billions of embeddings efficiently, you can use [clip-retrieval](https://github.com/rom1504/clip-retrieval) which has openclip support.
|
83 |
+
|
84 |
+
### Pretrained models
|
85 |
+
|
86 |
+
We offer a simple model interface to instantiate both pre-trained and untrained models.
|
87 |
+
To see which pretrained models are available, use the following code snippet.
|
88 |
+
More details about our pretrained models are available [here](docs/PRETRAINED.md).
|
89 |
+
|
90 |
+
```python
|
91 |
+
>>> import open_clip
|
92 |
+
>>> open_clip.list_pretrained()
|
93 |
+
```
|
94 |
+
|
95 |
+
You can find more about the models we support (e.g. number of parameters, FLOPs) in [this table](docs/model_profile.csv).
|
96 |
+
|
97 |
+
NOTE: Many existing checkpoints use the QuickGELU activation from the original OpenAI models. This activation is actually less efficient than native torch.nn.GELU in recent versions of PyTorch. The model defaults are now nn.GELU, so one should use model definitions with `-quickgelu` postfix for the OpenCLIP pretrained weights. All OpenAI pretrained weights will always default to QuickGELU. One can also use the non `-quickgelu` model definitions with pretrained weights using QuickGELU but there will be an accuracy drop, for fine-tune that will likely vanish for longer runs.
|
98 |
+
Future trained models will use nn.GELU.
|
99 |
+
|
100 |
+
### Loading models
|
101 |
+
|
102 |
+
Models can be loaded with `open_clip.create_model_and_transforms`, as shown in the example below. The model name and corresponding `pretrained` keys are compatible with the outputs of `open_clip.list_pretrained()`.
|
103 |
+
|
104 |
+
The `pretrained` argument also accepts local paths, for example `/path/to/my/b32.pt`.
|
105 |
+
You can also load checkpoints from huggingface this way. To do so, download the `open_clip_pytorch_model.bin` file (for example, [https://huggingface.co/laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K/tree/main](https://huggingface.co/laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K/blob/main/open_clip_pytorch_model.bin)), and use `pretrained=/path/to/open_clip_pytorch_model.bin`.
|
106 |
+
|
107 |
+
```python
|
108 |
+
# pretrained also accepts local paths
|
109 |
+
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
|
110 |
+
```
|
111 |
+
|
112 |
+
## Fine-tuning on classification tasks
|
113 |
+
|
114 |
+
This repository is focused on training CLIP models. To fine-tune a *trained* zero-shot model on a downstream classification task such as ImageNet, please see [our other repository: WiSE-FT](https://github.com/mlfoundations/wise-ft). The [WiSE-FT repository](https://github.com/mlfoundations/wise-ft) contains code for our paper on [Robust Fine-tuning of Zero-shot Models](https://arxiv.org/abs/2109.01903), in which we introduce a technique for fine-tuning zero-shot models while preserving robustness under distribution shift.
|
115 |
+
|
116 |
+
## Data
|
117 |
+
|
118 |
+
To download datasets as webdataset, we recommend [img2dataset](https://github.com/rom1504/img2dataset).
|
119 |
+
|
120 |
+
### Conceptual Captions
|
121 |
+
|
122 |
+
See [cc3m img2dataset example](https://github.com/rom1504/img2dataset/blob/main/dataset_examples/cc3m.md).
|
123 |
+
|
124 |
+
### YFCC and other datasets
|
125 |
+
|
126 |
+
In addition to specifying the training data via CSV files as mentioned above, our codebase also supports [webdataset](https://github.com/webdataset/webdataset), which is recommended for larger scale datasets. The expected format is a series of `.tar` files. Each of these `.tar` files should contain two files for each training example, one for the image and one for the corresponding text. Both files should have the same name but different extensions. For instance, `shard_001.tar` could contain files such as `abc.jpg` and `abc.txt`. You can learn more about `webdataset` at [https://github.com/webdataset/webdataset](https://github.com/webdataset/webdataset). We use `.tar` files with 1,000 data points each, which we create using [tarp](https://github.com/webdataset/tarp).
|
127 |
+
|
128 |
+
You can download the YFCC dataset from [Multimedia Commons](http://mmcommons.org/).
|
129 |
+
Similar to OpenAI, we used a subset of YFCC to reach the aforementioned accuracy numbers.
|
130 |
+
The indices of images in this subset are in [OpenAI's CLIP repository](https://github.com/openai/CLIP/blob/main/data/yfcc100m.md).
|
131 |
+
|
132 |
+
|
133 |
+
## Training CLIP
|
134 |
+
|
135 |
+
### Install
|
136 |
+
|
137 |
+
We advise you first create a virtual environment with:
|
138 |
+
|
139 |
+
```
|
140 |
+
python3 -m venv .env
|
141 |
+
source .env/bin/activate
|
142 |
+
pip install -U pip
|
143 |
+
```
|
144 |
+
|
145 |
+
You can then install openclip for training with `pip install 'open_clip_torch[training]'`.
|
146 |
+
|
147 |
+
#### Development
|
148 |
+
|
149 |
+
If you want to make changes to contribute code, you can clone openclip then run `make install` in openclip folder (after creating a virtualenv)
|
150 |
+
|
151 |
+
Install pip PyTorch as per https://pytorch.org/get-started/locally/
|
152 |
+
|
153 |
+
You may run `make install-training` to install training deps
|
154 |
+
|
155 |
+
#### Testing
|
156 |
+
|
157 |
+
Test can be run with `make install-test` then `make test`
|
158 |
+
|
159 |
+
`python -m pytest -x -s -v tests -k "training"` to run a specific test
|
160 |
+
|
161 |
+
Running regression tests against a specific git revision or tag:
|
162 |
+
1. Generate testing data
|
163 |
+
```sh
|
164 |
+
python tests/util_test.py --model RN50 RN101 --save_model_list models.txt --git_revision 9d31b2ec4df6d8228f370ff20c8267ec6ba39383
|
165 |
+
```
|
166 |
+
**_WARNING_: This will invoke git and modify your working tree, but will reset it to the current state after data has been generated! \
|
167 |
+
Don't modify your working tree while test data is being generated this way.**
|
168 |
+
|
169 |
+
2. Run regression tests
|
170 |
+
```sh
|
171 |
+
OPEN_CLIP_TEST_REG_MODELS=models.txt python -m pytest -x -s -v -m regression_test
|
172 |
+
```
|
173 |
+
|
174 |
+
### Sample single-process running code:
|
175 |
+
|
176 |
+
```bash
|
177 |
+
python -m open_clip_train.main \
|
178 |
+
--save-frequency 1 \
|
179 |
+
--zeroshot-frequency 1 \
|
180 |
+
--report-to tensorboard \
|
181 |
+
--train-data="/path/to/train_data.csv" \
|
182 |
+
--val-data="/path/to/validation_data.csv" \
|
183 |
+
--csv-img-key filepath \
|
184 |
+
--csv-caption-key title \
|
185 |
+
--imagenet-val=/path/to/imagenet/root/val/ \
|
186 |
+
--warmup 10000 \
|
187 |
+
--batch-size=128 \
|
188 |
+
--lr=1e-3 \
|
189 |
+
--wd=0.1 \
|
190 |
+
--epochs=30 \
|
191 |
+
--workers=8 \
|
192 |
+
--model RN50
|
193 |
+
```
|
194 |
+
|
195 |
+
Note: `imagenet-val` is the path to the *validation* set of ImageNet for zero-shot evaluation, not the training set!
|
196 |
+
You can remove this argument if you do not want to perform zero-shot evaluation on ImageNet throughout training. Note that the `val` folder should contain subfolders. If it does not, please use [this script](https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh).
|
197 |
+
|
198 |
+
### Multi-GPU and Beyond
|
199 |
+
|
200 |
+
This code has been battle tested up to 1024 A100s and offers a variety of solutions
|
201 |
+
for distributed training. We include native support for SLURM clusters.
|
202 |
+
|
203 |
+
As the number of devices used to train increases, so does the space complexity of
|
204 |
+
the the logit matrix. Using a naïve all-gather scheme, space complexity will be
|
205 |
+
`O(n^2)`. Instead, complexity may become effectively linear if the flags
|
206 |
+
`--gather-with-grad` and `--local-loss` are used. This alteration results in one-to-one
|
207 |
+
numerical results as the naïve method.
|
208 |
+
|
209 |
+
#### Epochs
|
210 |
+
|
211 |
+
For larger datasets (eg Laion2B), we recommend setting `--train-num-samples` to a lower value than the full epoch, for example `--train-num-samples 135646078` to 1/16 of an epoch in conjunction with `--dataset-resampled` to do sampling with replacement. This allows having frequent checkpoints to evaluate more often.
|
212 |
+
|
213 |
+
#### Patch Dropout
|
214 |
+
|
215 |
+
<a href="https://arxiv.org/abs/2212.00794">Recent research</a> has shown that one can dropout half to three-quarters of the visual tokens, leading to up to 2-3x training speeds without loss of accuracy.
|
216 |
+
|
217 |
+
You can set this on your visual transformer config with the key `patch_dropout`.
|
218 |
+
|
219 |
+
In the paper, they also finetuned without the patch dropout at the end. You can do this with the command-line argument `--force-patch-dropout 0.`
|
220 |
+
|
221 |
+
#### Multiple data sources
|
222 |
+
|
223 |
+
OpenCLIP supports using multiple data sources, by separating different data paths with `::`.
|
224 |
+
For instance, to train on CC12M and on LAION, one might use `--train-data "/data/cc12m/cc12m-train-{0000..2175}.tar::/data/LAION-400M/{00000..41455}.tar"`.
|
225 |
+
Using `--dataset-resampled` is recommended for these cases.
|
226 |
+
|
227 |
+
By default, on expectation the amount of times the model will see a sample from each source is proportional to the size of the source.
|
228 |
+
For instance, when training on one data source with size 400M and one with size 10M, samples from the first source are 40x more likely to be seen in expectation.
|
229 |
+
|
230 |
+
We also support different weighting of the data sources, by using the `--train-data-upsampling-factors` flag.
|
231 |
+
For instance, using `--train-data-upsampling-factors=1::1` in the above scenario is equivalent to not using the flag, and `--train-data-upsampling-factors=1::2` is equivalent to upsampling the second data source twice.
|
232 |
+
If you want to sample from data sources with the same frequency, the upsampling factors should be inversely proportional to the sizes of the data sources.
|
233 |
+
For instance, if dataset `A` has 1000 samples and dataset `B` has 100 samples, you can use `--train-data-upsampling-factors=0.001::0.01` (or analogously, `--train-data-upsampling-factors=1::10`).
|
234 |
+
|
235 |
+
#### Single-Node
|
236 |
+
|
237 |
+
We make use of `torchrun` to launch distributed jobs. The following launches a
|
238 |
+
a job on a node of 4 GPUs:
|
239 |
+
|
240 |
+
```bash
|
241 |
+
cd open_clip/src
|
242 |
+
torchrun --nproc_per_node 4 -m open_clip_train.main \
|
243 |
+
--train-data '/data/cc12m/cc12m-train-{0000..2175}.tar' \
|
244 |
+
--train-num-samples 10968539 \
|
245 |
+
--dataset-type webdataset \
|
246 |
+
--batch-size 320 \
|
247 |
+
--precision amp \
|
248 |
+
--workers 4 \
|
249 |
+
--imagenet-val /data/imagenet/validation/
|
250 |
+
```
|
251 |
+
|
252 |
+
#### Multi-Node
|
253 |
+
|
254 |
+
The same script above works, so long as users include information about the number
|
255 |
+
of nodes and host node.
|
256 |
+
|
257 |
+
```bash
|
258 |
+
cd open_clip/src
|
259 |
+
torchrun --nproc_per_node=4 \
|
260 |
+
--rdzv_endpoint=$HOSTE_NODE_ADDR \
|
261 |
+
-m open_clip_train.main \
|
262 |
+
--train-data '/data/cc12m/cc12m-train-{0000..2175}.tar' \
|
263 |
+
--train-num-samples 10968539 \
|
264 |
+
--dataset-type webdataset \
|
265 |
+
--batch-size 320 \
|
266 |
+
--precision amp \
|
267 |
+
--workers 4 \
|
268 |
+
--imagenet-val /data/imagenet/validation/
|
269 |
+
```
|
270 |
+
|
271 |
+
#### SLURM
|
272 |
+
|
273 |
+
This is likely the easiest solution to utilize. The following script was used to
|
274 |
+
train our largest models:
|
275 |
+
|
276 |
+
```bash
|
277 |
+
#!/bin/bash -x
|
278 |
+
#SBATCH --nodes=32
|
279 |
+
#SBATCH --gres=gpu:4
|
280 |
+
#SBATCH --ntasks-per-node=4
|
281 |
+
#SBATCH --cpus-per-task=6
|
282 |
+
#SBATCH --wait-all-nodes=1
|
283 |
+
#SBATCH --job-name=open_clip
|
284 |
+
#SBATCH --account=ACCOUNT_NAME
|
285 |
+
#SBATCH --partition PARTITION_NAME
|
286 |
+
|
287 |
+
eval "$(/path/to/conda/bin/conda shell.bash hook)" # init conda
|
288 |
+
conda activate open_clip
|
289 |
+
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
290 |
+
export MASTER_PORT=12802
|
291 |
+
|
292 |
+
master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
|
293 |
+
export MASTER_ADDR=$master_addr
|
294 |
+
|
295 |
+
cd /shared/open_clip
|
296 |
+
export PYTHONPATH="$PYTHONPATH:$PWD/src"
|
297 |
+
srun --cpu_bind=v --accel-bind=gn python -u src/open_clip_train/main.py \
|
298 |
+
--save-frequency 1 \
|
299 |
+
--report-to tensorboard \
|
300 |
+
--train-data="/data/LAION-400M/{00000..41455}.tar" \
|
301 |
+
--warmup 2000 \
|
302 |
+
--batch-size=256 \
|
303 |
+
--epochs=32 \
|
304 |
+
--workers=8 \
|
305 |
+
--model ViT-B-32 \
|
306 |
+
--name "ViT-B-32-Vanilla" \
|
307 |
+
--seed 0 \
|
308 |
+
--local-loss \
|
309 |
+
--gather-with-grad
|
310 |
+
```
|
311 |
+
|
312 |
+
### Resuming from a checkpoint:
|
313 |
+
|
314 |
+
```bash
|
315 |
+
python -m open_clip_train.main \
|
316 |
+
--train-data="/path/to/train_data.csv" \
|
317 |
+
--val-data="/path/to/validation_data.csv" \
|
318 |
+
--resume /path/to/checkpoints/epoch_K.pt
|
319 |
+
```
|
320 |
+
|
321 |
+
### Training CoCa:
|
322 |
+
Training [CoCa](https://arxiv.org/abs/2205.01917) models is enabled through specifying a CoCa config using the ```--model``` parameter of the training script. Currently available configs are "coca_base", "coca_ViT-B-32", and "coca_roberta-ViT-B-32" (which uses RoBERTa as the text encoder). CoCa configs are different from CLIP configs because they have an additional "multimodal_cfg" component which specifies parameters for the multimodal text decoder. Here's an example from the coca_ViT-B-32 config:
|
323 |
+
```json
|
324 |
+
"multimodal_cfg": {
|
325 |
+
"context_length": 76,
|
326 |
+
"vocab_size": 49408,
|
327 |
+
"width": 512,
|
328 |
+
"heads": 8,
|
329 |
+
"layers": 12,
|
330 |
+
"latent_dim": 512,
|
331 |
+
"attn_pooler_heads": 8
|
332 |
+
}
|
333 |
+
```
|
334 |
+
Credit to [lucidrains](https://github.com/lucidrains) for [initial code](https://github.com/lucidrains/CoCa-pytorch), [gpucce](https://github.com/gpucce) for adapting the code to open_clip, and [iejMac](https://github.com/iejMac) for training the models.
|
335 |
+
|
336 |
+
### Generating text with CoCa
|
337 |
+
|
338 |
+
```python
|
339 |
+
import open_clip
|
340 |
+
import torch
|
341 |
+
from PIL import Image
|
342 |
+
|
343 |
+
model, _, transform = open_clip.create_model_and_transforms(
|
344 |
+
model_name="coca_ViT-L-14",
|
345 |
+
pretrained="mscoco_finetuned_laion2B-s13B-b90k"
|
346 |
+
)
|
347 |
+
|
348 |
+
im = Image.open("cat.jpg").convert("RGB")
|
349 |
+
im = transform(im).unsqueeze(0)
|
350 |
+
|
351 |
+
with torch.no_grad(), torch.cuda.amp.autocast():
|
352 |
+
generated = model.generate(im)
|
353 |
+
|
354 |
+
print(open_clip.decode(generated[0]).split("<end_of_text>")[0].replace("<start_of_text>", ""))
|
355 |
+
```
|
356 |
+
|
357 |
+
See also this [[Coca Colab]](https://colab.research.google.com/github/mlfoundations/open_clip/blob/master/docs/Interacting_with_open_coca.ipynb)
|
358 |
+
|
359 |
+
### Fine Tuning CoCa
|
360 |
+
|
361 |
+
To fine-tune coca on mscoco, first create the dataset, one way is using a csvdataset and perhaps the simplest way to do it is using [CLIP_benchmark](https://github.com/LAION-AI/CLIP_benchmark) which in turn uses [pycocotools](https://github.com/cocodataset/cocoapi) (that can be used also by itself).
|
362 |
+
|
363 |
+
```python
|
364 |
+
from clip_benchmark.datasets.builder import build_dataset
|
365 |
+
import pandas as pd
|
366 |
+
import os
|
367 |
+
|
368 |
+
root_path = "path/to/data/dir" # set this to smth meaningful
|
369 |
+
ds = build_dataset("mscoco_captions", root=root_path, split="train", task="captioning") # this downloads the dataset if it is not there already
|
370 |
+
coco = ds.coco
|
371 |
+
imgs = coco.loadImgs(coco.getImgIds())
|
372 |
+
future_df = {"filepath":[], "title":[]}
|
373 |
+
for img in imgs:
|
374 |
+
caps = coco.imgToAnns[img["id"]]
|
375 |
+
for cap in caps:
|
376 |
+
future_df["filepath"].append(img["file_name"])
|
377 |
+
future_df["title"].append(cap["caption"])
|
378 |
+
pd.DataFrame.from_dict(future_df).to_csv(
|
379 |
+
os.path.join(root_path, "train2014.csv"), index=False, sep="\t"
|
380 |
+
)
|
381 |
+
```
|
382 |
+
This should create a csv dataset that one can use to fine-tune coca with open_clip
|
383 |
+
```bash
|
384 |
+
python -m open_clip_train.main \
|
385 |
+
--dataset-type "csv" \
|
386 |
+
--train-data "path/to/data/dir/train2014.csv" \
|
387 |
+
--warmup 1000 \
|
388 |
+
--batch-size 128 \
|
389 |
+
--lr 1e-5 \
|
390 |
+
--wd 0.1 \
|
391 |
+
--epochs 1 \
|
392 |
+
--workers 3 \
|
393 |
+
--model "coca_ViT-L-14" \
|
394 |
+
--report-to "wandb" \
|
395 |
+
--coca-contrastive-loss-weight 0 \
|
396 |
+
--coca-caption-loss-weight 1 \
|
397 |
+
--log-every-n-steps 100
|
398 |
+
```
|
399 |
+
|
400 |
+
This is a general setting, open_clip has very parameters that can be set, ```python -m open_clip_train.main --help``` should show them. The only relevant change compared to pre-training are the two arguments
|
401 |
+
|
402 |
+
```bash
|
403 |
+
--coca-contrastive-loss-weight 0
|
404 |
+
--coca-caption-loss-weight 1
|
405 |
+
```
|
406 |
+
which make the model only train the generative side.
|
407 |
+
|
408 |
+
### Training with pre-trained language models as text encoder:
|
409 |
+
|
410 |
+
If you wish to use different language models as the text encoder for CLIP you can do so by using one of the Hugging Face model configs in ```src/open_clip/model_configs``` and passing in it's tokenizer as the ```--model``` and ```--hf-tokenizer-name``` parameters respectively. Currently we only support RoBERTa ("test-roberta" config), however adding new models should be trivial. You can also determine how many layers, from the end, to leave unfrozen with the ```--lock-text-unlocked-layers``` parameter. Here's an example command to train CLIP with the RoBERTa LM that has it's last 10 layers unfrozen:
|
411 |
+
```bash
|
412 |
+
python -m open_clip_train.main \
|
413 |
+
--train-data="pipe:aws s3 cp s3://s-mas/cc3m/{00000..00329}.tar -" \
|
414 |
+
--train-num-samples 3000000 \
|
415 |
+
--val-data="pipe:aws s3 cp s3://s-mas/cc3m/{00330..00331}.tar -" \
|
416 |
+
--val-num-samples 10000 \
|
417 |
+
--dataset-type webdataset \
|
418 |
+
--batch-size 256 \
|
419 |
+
--warmup 2000 \
|
420 |
+
--epochs 10 \
|
421 |
+
--lr 5e-4 \
|
422 |
+
--precision amp \
|
423 |
+
--workers 6 \
|
424 |
+
--model "roberta-ViT-B-32" \
|
425 |
+
--lock-text \
|
426 |
+
--lock-text-unlocked-layers 10 \
|
427 |
+
--name "10_unfrozen" \
|
428 |
+
--report-to "tensorboard" \
|
429 |
+
```
|
430 |
+
|
431 |
+
### Loss Curves
|
432 |
+
|
433 |
+
When run on a machine with 8 GPUs the command should produce the following training curve for Conceptual Captions:
|
434 |
+
|
435 |
+

|
436 |
+
|
437 |
+
More detailed curves for Conceptual Captions are given at [/docs/clip_conceptual_captions.md](/docs/clip_conceptual_captions.md).
|
438 |
+
|
439 |
+
When training a RN50 on YFCC the same hyperparameters as above are used, with the exception of `lr=5e-4` and `epochs=32`.
|
440 |
+
|
441 |
+
Note that to use another model, like `ViT-B/32` or `RN50x4` or `RN50x16` or `ViT-B/16`, specify with `--model RN50x4`.
|
442 |
+
|
443 |
+
### Logging
|
444 |
+
|
445 |
+
For tensorboard logging, run:
|
446 |
+
```bash
|
447 |
+
tensorboard --logdir=logs/tensorboard/ --port=7777
|
448 |
+
```
|
449 |
+
|
450 |
+
For wandb logging, we recommend looking at the `step` variable instead of `Step`, since the later was not properly set in earlier versions of this codebase.
|
451 |
+
For older runs with models trained before https://github.com/mlfoundations/open_clip/pull/613, the `Step` variable should be ignored.
|
452 |
+
For newer runs, after that PR, the two variables are the same.
|
453 |
+
|
454 |
+
## Evaluation / Zero-Shot
|
455 |
+
|
456 |
+
We recommend https://github.com/LAION-AI/CLIP_benchmark#how-to-use for systematic evaluation on 40 datasets.
|
457 |
+
|
458 |
+
### Evaluating local checkpoint:
|
459 |
+
|
460 |
+
```bash
|
461 |
+
python -m open_clip_train.main \
|
462 |
+
--val-data="/path/to/validation_data.csv" \
|
463 |
+
--model RN101 \
|
464 |
+
--pretrained /path/to/checkpoints/epoch_K.pt
|
465 |
+
```
|
466 |
+
|
467 |
+
### Evaluating hosted pretrained checkpoint on ImageNet zero-shot prediction:
|
468 |
+
|
469 |
+
```bash
|
470 |
+
python -m open_clip_train.main \
|
471 |
+
--imagenet-val /path/to/imagenet/validation \
|
472 |
+
--model ViT-B-32-quickgelu \
|
473 |
+
--pretrained laion400m_e32
|
474 |
+
```
|
475 |
+
|
476 |
+
### Model distillation
|
477 |
+
|
478 |
+
You can distill from a pre-trained by using `--distill-model` and `--distill-pretrained` to specify the model you'd like to distill from.
|
479 |
+
For instance, to distill from OpenAI ViT-L/14 use `--distill-model ViT-L-14 --distill-pretrained openai`.
|
480 |
+
|
481 |
+
### Gradient accumulation
|
482 |
+
|
483 |
+
To simulate larger batches use `--accum-freq k`. If per gpu batch size, `--batch-size`, is `m`, then the effective batch size will be `k * m * num_gpus`.
|
484 |
+
|
485 |
+
When increasing `--accum-freq` from its default of 1, samples/s will remain approximately constant (batch size will double, as will time-per-batch). It is recommended to use other features to reduce batch size such as `--grad-checkpointing --local-loss --gather-with-grad` before increasing `--accum-freq`. `--accum-freq` can be used in addition to these features.
|
486 |
+
|
487 |
+
Instead of 1 forward pass per example, there are now 2 forward passes per-example. However, the first is done with `torch.no_grad`.
|
488 |
+
|
489 |
+
There is some additional GPU memory required --- the features and data from all `m` batches are stored in memory.
|
490 |
+
|
491 |
+
There are also `m` loss computations instead of the usual 1.
|
492 |
+
|
493 |
+
For more information see Cui et al. (https://arxiv.org/abs/2112.09331) or Pham et al. (https://arxiv.org/abs/2111.10050).
|
494 |
+
|
495 |
+
### Int8 Support
|
496 |
+
|
497 |
+
We have beta support for int8 training and inference.
|
498 |
+
You can enable int8 training with `--use-bnb-linear SwitchBackLinearGlobal` or `--use-bnb-linear SwitchBackLinearGlobalMemEfficient`.
|
499 |
+
Please see the bitsandbytes library for definitions for these layers.
|
500 |
+
For CLIP VIT-Huge this should currently correspond to a 10% training speedup with no accuracy loss.
|
501 |
+
More speedups comin when the attention layer is refactored so that linear layers man be replaced there, too.
|
502 |
+
|
503 |
+
See the tutorial https://github.com/mlfoundations/open_clip/blob/main/tutorials/int8_tutorial.ipynb or [paper](https://arxiv.org/abs/2304.13013).
|
504 |
+
|
505 |
+
### Support for remote loading/training
|
506 |
+
|
507 |
+
It is always possible to resume directly from a remote file, e.g., a file in an s3 bucket. Just set `--resume s3://<path-to-checkpoint> `.
|
508 |
+
This will work with any filesystem supported by `fsspec`.
|
509 |
+
|
510 |
+
It is also possible to train `open_clip` models while continuously backing up to s3. This can help to avoid slow local file systems.
|
511 |
+
|
512 |
+
Say that your node has a local ssd `/scratch`, an s3 bucket `s3://<path-to-bucket>`.
|
513 |
+
|
514 |
+
In that case, set `--logs /scratch` and `--remote-sync s3://<path-to-bucket>`. Then, a background process will sync `/scratch/<run-name>` to `s3://<path-to-bucket>/<run-name>`. After syncing, the background process will sleep for `--remote-sync-frequency` seconds, which defaults to 5 minutes.
|
515 |
+
|
516 |
+
There is also experimental support for syncing to other remote file systems, not just s3. To do so, specify `--remote-sync-protocol fsspec`. However, this is currently very slow and not recommended.
|
517 |
+
|
518 |
+
Also, to optionally avoid saving too many checkpoints locally when using these features, you can use `--delete-previous-checkpoint` which deletes the previous checkpoint after saving a new one.
|
519 |
+
|
520 |
+
Note: if you are using this feature with `--resume latest`, there are a few warnings. First, use with `--save-most-recent` is not supported. Second, only `s3` is supported. Finally, since the sync happens in the background, it is possible that the most recent checkpoint may not be finished syncing to the remote.
|
521 |
+
|
522 |
+
### Pushing Models to Hugging Face Hub
|
523 |
+
|
524 |
+
The module `open_clip.push_to_hf_hub` includes helpers for pushing models /w weights and config to the HF Hub.
|
525 |
+
|
526 |
+
The tool can be run from command line, ex:
|
527 |
+
`python -m open_clip.push_to_hf_hub --model convnext_large_d_320 --pretrained /train/checkpoints/epoch_12.pt --repo-id laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft`
|
528 |
+
|
529 |
+
|
530 |
+
|
531 |
+
## Acknowledgments
|
532 |
+
|
533 |
+
We gratefully acknowledge the Gauss Centre for Supercomputing e.V. (www.gauss-centre.eu) for funding this part of work by providing computing time through the John von Neumann Institute for Computing (NIC) on the GCS Supercomputer JUWELS Booster at Jülich Supercomputing Centre (JSC).
|
534 |
+
|
535 |
+
## The Team
|
536 |
+
|
537 |
+
Current development of this repository is led by [Ross Wightman](https://rwightman.com/), [Romain Beaumont](https://github.com/rom1504), [Cade Gordon](http://cadegordon.io/), and [Vaishaal Shankar](http://vaishaal.com/).
|
538 |
+
|
539 |
+
The original version of this repository is from a group of researchers at UW, Google, Stanford, Amazon, Columbia, and Berkeley.
|
540 |
+
|
541 |
+
[Gabriel Ilharco*](http://gabrielilharco.com/), [Mitchell Wortsman*](https://mitchellnw.github.io/), [Nicholas Carlini](https://nicholas.carlini.com/), [Rohan Taori](https://www.rohantaori.com/), [Achal Dave](http://www.achaldave.com/), [Vaishaal Shankar](http://vaishaal.com/), [John Miller](https://people.eecs.berkeley.edu/~miller_john/), [Hongseok Namkoong](https://hsnamkoong.github.io/), [Hannaneh Hajishirzi](https://homes.cs.washington.edu/~hannaneh/), [Ali Farhadi](https://homes.cs.washington.edu/~ali/), [Ludwig Schmidt](https://people.csail.mit.edu/ludwigs/)
|
542 |
+
|
543 |
+
Special thanks to [Jong Wook Kim](https://jongwook.kim/) and [Alec Radford](https://github.com/Newmu) for help with reproducing CLIP!
|
544 |
+
|
545 |
+
## Citing
|
546 |
+
|
547 |
+
If you found this repository useful, please consider citing:
|
548 |
+
```bibtex
|
549 |
+
@software{ilharco_gabriel_2021_5143773,
|
550 |
+
author = {Ilharco, Gabriel and
|
551 |
+
Wortsman, Mitchell and
|
552 |
+
Wightman, Ross and
|
553 |
+
Gordon, Cade and
|
554 |
+
Carlini, Nicholas and
|
555 |
+
Taori, Rohan and
|
556 |
+
Dave, Achal and
|
557 |
+
Shankar, Vaishaal and
|
558 |
+
Namkoong, Hongseok and
|
559 |
+
Miller, John and
|
560 |
+
Hajishirzi, Hannaneh and
|
561 |
+
Farhadi, Ali and
|
562 |
+
Schmidt, Ludwig},
|
563 |
+
title = {OpenCLIP},
|
564 |
+
month = jul,
|
565 |
+
year = 2021,
|
566 |
+
note = {If you use this software, please cite it as below.},
|
567 |
+
publisher = {Zenodo},
|
568 |
+
version = {0.1},
|
569 |
+
doi = {10.5281/zenodo.5143773},
|
570 |
+
url = {https://doi.org/10.5281/zenodo.5143773}
|
571 |
+
}
|
572 |
+
```
|
573 |
+
|
574 |
+
```bibtex
|
575 |
+
@inproceedings{cherti2023reproducible,
|
576 |
+
title={Reproducible scaling laws for contrastive language-image learning},
|
577 |
+
author={Cherti, Mehdi and Beaumont, Romain and Wightman, Ross and Wortsman, Mitchell and Ilharco, Gabriel and Gordon, Cade and Schuhmann, Christoph and Schmidt, Ludwig and Jitsev, Jenia},
|
578 |
+
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
|
579 |
+
pages={2818--2829},
|
580 |
+
year={2023}
|
581 |
+
}
|
582 |
+
```
|
583 |
+
|
584 |
+
```bibtex
|
585 |
+
@inproceedings{Radford2021LearningTV,
|
586 |
+
title={Learning Transferable Visual Models From Natural Language Supervision},
|
587 |
+
author={Alec Radford and Jong Wook Kim and Chris Hallacy and A. Ramesh and Gabriel Goh and Sandhini Agarwal and Girish Sastry and Amanda Askell and Pamela Mishkin and Jack Clark and Gretchen Krueger and Ilya Sutskever},
|
588 |
+
booktitle={ICML},
|
589 |
+
year={2021}
|
590 |
+
}
|
591 |
+
```
|
592 |
+
|
593 |
+
```bibtex
|
594 |
+
@inproceedings{schuhmann2022laionb,
|
595 |
+
title={{LAION}-5B: An open large-scale dataset for training next generation image-text models},
|
596 |
+
author={Christoph Schuhmann and
|
597 |
+
Romain Beaumont and
|
598 |
+
Richard Vencu and
|
599 |
+
Cade W Gordon and
|
600 |
+
Ross Wightman and
|
601 |
+
Mehdi Cherti and
|
602 |
+
Theo Coombes and
|
603 |
+
Aarush Katta and
|
604 |
+
Clayton Mullis and
|
605 |
+
Mitchell Wortsman and
|
606 |
+
Patrick Schramowski and
|
607 |
+
Srivatsa R Kundurthy and
|
608 |
+
Katherine Crowson and
|
609 |
+
Ludwig Schmidt and
|
610 |
+
Robert Kaczmarczyk and
|
611 |
+
Jenia Jitsev},
|
612 |
+
booktitle={Thirty-sixth Conference on Neural Information Processing Systems Datasets and Benchmarks Track},
|
613 |
+
year={2022},
|
614 |
+
url={https://openreview.net/forum?id=M3Y74vmsMcY}
|
615 |
+
}
|
616 |
+
```
|
617 |
+
|
618 |
+
[](https://zenodo.org/badge/latestdoi/390536799)
|
models.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
RN101
|
2 |
+
RN50
|
pytest.ini
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[pytest]
|
2 |
+
markers =
|
3 |
+
regression_test
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=1.9.0
|
2 |
+
torchvision
|
3 |
+
regex
|
4 |
+
ftfy
|
5 |
+
tqdm
|
6 |
+
huggingface_hub
|
7 |
+
safetensors
|
8 |
+
timm
|
src/open_clip/__init__.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .version import __version__
|
2 |
+
|
3 |
+
from .coca_model import CoCa
|
4 |
+
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
5 |
+
from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss
|
6 |
+
from .factory import list_models, add_model_config, get_model_config, load_checkpoint
|
7 |
+
from .loss import ClipLoss, DistillClipLoss, CoCaLoss
|
8 |
+
from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \
|
9 |
+
convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype, \
|
10 |
+
get_model_tokenize_cfg, get_model_preprocess_cfg, set_model_preprocess_cfg
|
11 |
+
from .openai import load_openai_model, list_openai_models
|
12 |
+
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
|
13 |
+
get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
|
14 |
+
from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub
|
15 |
+
from .tokenizer import SimpleTokenizer, tokenize, decode
|
16 |
+
from .transform import image_transform, AugmentationCfg
|
17 |
+
from .zero_shot_classifier import build_zero_shot_classifier, build_zero_shot_classifier_legacy
|
18 |
+
from .zero_shot_metadata import OPENAI_IMAGENET_TEMPLATES, SIMPLE_IMAGENET_TEMPLATES, IMAGENET_CLASSNAMES
|
src/open_clip/coca_model.py
ADDED
@@ -0,0 +1,582 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Optional, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
import numpy as np
|
7 |
+
from dataclasses import dataclass
|
8 |
+
|
9 |
+
from .transformer import (
|
10 |
+
LayerNormFp32,
|
11 |
+
LayerNorm,
|
12 |
+
QuickGELU,
|
13 |
+
MultimodalTransformer,
|
14 |
+
)
|
15 |
+
from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower
|
16 |
+
|
17 |
+
try:
|
18 |
+
from transformers import (
|
19 |
+
BeamSearchScorer,
|
20 |
+
LogitsProcessorList,
|
21 |
+
TopPLogitsWarper,
|
22 |
+
TopKLogitsWarper,
|
23 |
+
RepetitionPenaltyLogitsProcessor,
|
24 |
+
MinLengthLogitsProcessor,
|
25 |
+
MaxLengthCriteria,
|
26 |
+
StopStringCriteria,
|
27 |
+
EosTokenCriteria,
|
28 |
+
StoppingCriteriaList
|
29 |
+
)
|
30 |
+
|
31 |
+
GENERATION_TYPES = {
|
32 |
+
"top_k": TopKLogitsWarper,
|
33 |
+
"top_p": TopPLogitsWarper,
|
34 |
+
"beam_search": "beam_search"
|
35 |
+
}
|
36 |
+
_has_transformers = True
|
37 |
+
except ImportError as e:
|
38 |
+
GENERATION_TYPES = {
|
39 |
+
"top_k": None,
|
40 |
+
"top_p": None,
|
41 |
+
"beam_search": "beam_search"
|
42 |
+
}
|
43 |
+
_has_transformers = False
|
44 |
+
|
45 |
+
|
46 |
+
@dataclass
|
47 |
+
class MultimodalCfg(CLIPTextCfg):
|
48 |
+
mlp_ratio: int = 4
|
49 |
+
dim_head: int = 64
|
50 |
+
heads: int = 8
|
51 |
+
n_queries: int = 256
|
52 |
+
attn_pooler_heads: int = 8
|
53 |
+
|
54 |
+
|
55 |
+
def _build_text_decoder_tower(
|
56 |
+
embed_dim,
|
57 |
+
multimodal_cfg,
|
58 |
+
quick_gelu: bool = False,
|
59 |
+
cast_dtype: Optional[torch.dtype] = None,
|
60 |
+
):
|
61 |
+
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
|
62 |
+
act_layer = QuickGELU if quick_gelu else nn.GELU
|
63 |
+
norm_layer = (
|
64 |
+
LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
65 |
+
)
|
66 |
+
|
67 |
+
decoder = MultimodalTransformer(
|
68 |
+
context_length=multimodal_cfg.context_length,
|
69 |
+
width=multimodal_cfg.width,
|
70 |
+
heads=multimodal_cfg.heads,
|
71 |
+
layers=multimodal_cfg.layers,
|
72 |
+
ls_init_value=multimodal_cfg.ls_init_value,
|
73 |
+
output_dim=embed_dim,
|
74 |
+
act_layer=act_layer,
|
75 |
+
norm_layer=norm_layer,
|
76 |
+
)
|
77 |
+
|
78 |
+
return decoder
|
79 |
+
|
80 |
+
|
81 |
+
def _token_to_tensor(token_id, device: str = "cpu") -> torch.Tensor:
|
82 |
+
if not isinstance(token_id, torch.Tensor):
|
83 |
+
if isinstance(token_id, int):
|
84 |
+
token_id = [token_id]
|
85 |
+
token_id = torch.tensor(token_id, device=device)
|
86 |
+
return token_id
|
87 |
+
|
88 |
+
|
89 |
+
class CoCa(nn.Module):
|
90 |
+
def __init__(
|
91 |
+
self,
|
92 |
+
embed_dim,
|
93 |
+
multimodal_cfg: MultimodalCfg,
|
94 |
+
text_cfg: CLIPTextCfg,
|
95 |
+
vision_cfg: CLIPVisionCfg,
|
96 |
+
quick_gelu: bool = False,
|
97 |
+
init_logit_scale: float = np.log(1 / 0.07),
|
98 |
+
init_logit_bias: Optional[float] = None,
|
99 |
+
nonscalar_logit_scale: bool = False,
|
100 |
+
cast_dtype: Optional[torch.dtype] = None,
|
101 |
+
pad_id: int = 0,
|
102 |
+
):
|
103 |
+
super().__init__()
|
104 |
+
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
|
105 |
+
text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg
|
106 |
+
vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg
|
107 |
+
|
108 |
+
self.text = _build_text_tower(
|
109 |
+
embed_dim=embed_dim,
|
110 |
+
text_cfg=text_cfg,
|
111 |
+
quick_gelu=quick_gelu,
|
112 |
+
cast_dtype=cast_dtype,
|
113 |
+
)
|
114 |
+
|
115 |
+
vocab_size = (
|
116 |
+
text_cfg.vocab_size # for hf models
|
117 |
+
if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None
|
118 |
+
else text_cfg.vocab_size
|
119 |
+
)
|
120 |
+
|
121 |
+
self.visual = _build_vision_tower(
|
122 |
+
embed_dim=embed_dim,
|
123 |
+
vision_cfg=vision_cfg,
|
124 |
+
quick_gelu=quick_gelu,
|
125 |
+
cast_dtype=cast_dtype,
|
126 |
+
)
|
127 |
+
|
128 |
+
self.text_decoder = _build_text_decoder_tower(
|
129 |
+
vocab_size,
|
130 |
+
multimodal_cfg=multimodal_cfg,
|
131 |
+
quick_gelu=quick_gelu,
|
132 |
+
cast_dtype=cast_dtype,
|
133 |
+
)
|
134 |
+
|
135 |
+
lshape = [1] if nonscalar_logit_scale else []
|
136 |
+
self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale)
|
137 |
+
if init_logit_bias is not None:
|
138 |
+
self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias)
|
139 |
+
else:
|
140 |
+
self.logit_bias = None
|
141 |
+
self.pad_id = pad_id
|
142 |
+
|
143 |
+
self.context_length = multimodal_cfg.context_length
|
144 |
+
|
145 |
+
@torch.jit.ignore
|
146 |
+
def set_grad_checkpointing(self, enable: bool = True):
|
147 |
+
self.visual.set_grad_checkpointing(enable)
|
148 |
+
self.text.set_grad_checkpointing(enable)
|
149 |
+
self.text_decoder.set_grad_checkpointing(enable)
|
150 |
+
|
151 |
+
def _encode_image(self, images, normalize: bool = True):
|
152 |
+
image_latent, tokens_embs = self.visual(images)
|
153 |
+
image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent
|
154 |
+
return image_latent, tokens_embs
|
155 |
+
|
156 |
+
def _encode_text(self, text, normalize: bool = True):
|
157 |
+
text_latent, token_emb = self.text(text)
|
158 |
+
text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent
|
159 |
+
return text_latent, token_emb
|
160 |
+
|
161 |
+
def encode_image(self, images, normalize: bool = True):
|
162 |
+
image_latent, _ = self._encode_image(images, normalize=normalize)
|
163 |
+
return image_latent
|
164 |
+
|
165 |
+
def encode_text(self, text, normalize: bool = True):
|
166 |
+
text_latent, _ = self._encode_text(text, normalize=normalize)
|
167 |
+
return text_latent
|
168 |
+
|
169 |
+
def forward_intermediates(
|
170 |
+
self,
|
171 |
+
image: Optional[torch.Tensor] = None,
|
172 |
+
text: Optional[torch.Tensor] = None,
|
173 |
+
image_indices: Optional[Union[int, List[int]]] = None,
|
174 |
+
text_indices: Optional[Union[int, List[int]]] = None,
|
175 |
+
stop_early: bool = False,
|
176 |
+
normalize: bool = True,
|
177 |
+
normalize_intermediates: bool = False,
|
178 |
+
intermediates_only: bool = False,
|
179 |
+
image_output_fmt: str = 'NCHW',
|
180 |
+
image_output_extra_tokens: bool = False,
|
181 |
+
text_output_fmt: str = 'NLC',
|
182 |
+
text_output_extra_tokens: bool = False,
|
183 |
+
output_logits: bool = False,
|
184 |
+
output_logit_scale_bias: bool = False,
|
185 |
+
) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
|
186 |
+
""" Forward features that returns intermediates.
|
187 |
+
|
188 |
+
Args:
|
189 |
+
image: Input image tensor
|
190 |
+
text: Input text tensor
|
191 |
+
image_indices: For image tower, Take last n blocks if int, all if None, select matching indices if sequence
|
192 |
+
text_indices: Take last n blocks if int, all if None, select matching indices if sequence
|
193 |
+
stop_early: Stop iterating over blocks when last desired intermediate hit
|
194 |
+
normalize: L2 Normalize final image and text features (if present)
|
195 |
+
normalize_intermediates: Apply final encoder norm layer to all intermediates (if possible)
|
196 |
+
intermediates_only: Only return intermediate features, do not return final features
|
197 |
+
image_output_fmt: Shape of intermediate image feature outputs
|
198 |
+
image_output_extra_tokens: Return both prefix and spatial intermediate tokens
|
199 |
+
text_output_fmt: Shape of intermediate text feature outputs
|
200 |
+
text_output_extra_tokens: Return both prefix and spatial intermediate tokens
|
201 |
+
output_logits: Include logits in output
|
202 |
+
output_logit_scale_bias: Include the logit scale bias in the output
|
203 |
+
Returns:
|
204 |
+
|
205 |
+
"""
|
206 |
+
output = {}
|
207 |
+
if intermediates_only:
|
208 |
+
# intermediates only disables final feature normalization, and include logits
|
209 |
+
normalize = False
|
210 |
+
output_logits = False
|
211 |
+
if output_logits:
|
212 |
+
assert False, 'FIXME, needs implementing'
|
213 |
+
|
214 |
+
if image is not None:
|
215 |
+
image_output = self.visual.forward_intermediates(
|
216 |
+
image,
|
217 |
+
indices=image_indices,
|
218 |
+
stop_early=stop_early,
|
219 |
+
normalize_intermediates=normalize_intermediates,
|
220 |
+
intermediates_only=intermediates_only,
|
221 |
+
output_fmt=image_output_fmt,
|
222 |
+
output_extra_tokens=image_output_extra_tokens,
|
223 |
+
)
|
224 |
+
if normalize and "image_features" in image_output:
|
225 |
+
image_output["image_features"] = F.normalize(image_output["image_features"], dim=-1)
|
226 |
+
output.update(image_output)
|
227 |
+
|
228 |
+
if text is not None:
|
229 |
+
text_output = self.text.forward_intermediates(
|
230 |
+
text,
|
231 |
+
indices=text_indices,
|
232 |
+
stop_early=stop_early,
|
233 |
+
normalize_intermediates=normalize_intermediates,
|
234 |
+
intermediates_only=intermediates_only,
|
235 |
+
output_fmt=text_output_fmt,
|
236 |
+
output_extra_tokens=text_output_extra_tokens,
|
237 |
+
)
|
238 |
+
if normalize and "text_features" in text_output:
|
239 |
+
text_output["text_features"] = F.normalize(text_output["text_features"], dim=-1)
|
240 |
+
output.update(text_output)
|
241 |
+
|
242 |
+
# FIXME text decoder
|
243 |
+
logit_scale_exp = self.logit_scale.exp() if output_logits or output_logit_scale_bias else None
|
244 |
+
if output_logit_scale_bias:
|
245 |
+
output["logit_scale"] = logit_scale_exp
|
246 |
+
if self.logit_bias is not None:
|
247 |
+
output['logit_bias'] = self.logit_bias
|
248 |
+
|
249 |
+
return output
|
250 |
+
|
251 |
+
def forward(
|
252 |
+
self,
|
253 |
+
image,
|
254 |
+
text: Optional[torch.Tensor] = None,
|
255 |
+
image_latent: Optional[torch.Tensor] = None,
|
256 |
+
image_embs: Optional[torch.Tensor] = None,
|
257 |
+
output_labels: bool = True,
|
258 |
+
):
|
259 |
+
if image_latent is None or image_embs is None:
|
260 |
+
image_latent, image_embs = self._encode_image(image)
|
261 |
+
|
262 |
+
if text is None:
|
263 |
+
return {"image_features": image_latent, "image_embs": image_embs}
|
264 |
+
|
265 |
+
text_latent, token_embs = self._encode_text(text)
|
266 |
+
|
267 |
+
# FIXME this isn't an ideal solution, would like to improve -RW
|
268 |
+
labels: Optional[torch.Tensor] = text[:, 1:] if output_labels else None
|
269 |
+
if output_labels:
|
270 |
+
# align text_embs and thus logits with labels for teacher-forcing caption loss
|
271 |
+
token_embs = token_embs[:, :-1]
|
272 |
+
|
273 |
+
logits = self.text_decoder(image_embs, token_embs)
|
274 |
+
out_dict = {
|
275 |
+
"image_features": image_latent,
|
276 |
+
"text_features": text_latent,
|
277 |
+
"logits": logits,
|
278 |
+
"logit_scale": self.logit_scale.exp()
|
279 |
+
}
|
280 |
+
if labels is not None:
|
281 |
+
out_dict["labels"] = labels
|
282 |
+
if self.logit_bias is not None:
|
283 |
+
out_dict["logit_bias"] = self.logit_bias
|
284 |
+
return out_dict
|
285 |
+
|
286 |
+
def generate(
|
287 |
+
self,
|
288 |
+
image,
|
289 |
+
text=None,
|
290 |
+
seq_len=30,
|
291 |
+
max_seq_len=77,
|
292 |
+
temperature=1.,
|
293 |
+
generation_type="beam_search",
|
294 |
+
top_p=0.1, # keep tokens in the 1 - top_p quantile
|
295 |
+
top_k=1, # keeps the top_k most probable tokens
|
296 |
+
pad_token_id=None,
|
297 |
+
eos_token_id=None,
|
298 |
+
sot_token_id=None,
|
299 |
+
num_beams=6,
|
300 |
+
num_beam_groups=3,
|
301 |
+
min_seq_len=5,
|
302 |
+
stopping_criteria=None,
|
303 |
+
repetition_penalty=1.0,
|
304 |
+
fixed_output_length=False # if True output.shape == (batch_size, seq_len)
|
305 |
+
):
|
306 |
+
# taking many ideas and components from HuggingFace GenerationMixin
|
307 |
+
# https://huggingface.co/docs/transformers/main/en/main_classes/text_generation
|
308 |
+
assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`."
|
309 |
+
assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len"
|
310 |
+
device = image.device
|
311 |
+
|
312 |
+
with torch.no_grad():
|
313 |
+
sot_token_id = _token_to_tensor(49406 if sot_token_id is None else sot_token_id, device=device)
|
314 |
+
eos_token_id = _token_to_tensor(49407 if eos_token_id is None else eos_token_id, device=device)
|
315 |
+
pad_token_id = self.pad_id if pad_token_id is None else pad_token_id
|
316 |
+
logit_processor = LogitsProcessorList(
|
317 |
+
[
|
318 |
+
MinLengthLogitsProcessor(min_seq_len, eos_token_id),
|
319 |
+
RepetitionPenaltyLogitsProcessor(repetition_penalty),
|
320 |
+
]
|
321 |
+
)
|
322 |
+
|
323 |
+
if stopping_criteria is None:
|
324 |
+
stopping_criteria = [MaxLengthCriteria(max_length=seq_len)]
|
325 |
+
stopping_criteria = StoppingCriteriaList(stopping_criteria)
|
326 |
+
|
327 |
+
if generation_type == "beam_search":
|
328 |
+
output = self._generate_beamsearch(
|
329 |
+
image_inputs=image,
|
330 |
+
pad_token_id=pad_token_id,
|
331 |
+
eos_token_id=eos_token_id,
|
332 |
+
sot_token_id=sot_token_id,
|
333 |
+
num_beams=num_beams,
|
334 |
+
num_beam_groups=num_beam_groups,
|
335 |
+
min_seq_len=min_seq_len,
|
336 |
+
stopping_criteria=stopping_criteria,
|
337 |
+
logit_processor=logit_processor,
|
338 |
+
)
|
339 |
+
if fixed_output_length and output.shape[1] < seq_len:
|
340 |
+
pad_len = seq_len - output.shape[1]
|
341 |
+
return torch.cat((
|
342 |
+
output,
|
343 |
+
torch.ones(output.shape[0], pad_len, device=device, dtype=output.dtype) * pad_token_id
|
344 |
+
),
|
345 |
+
dim=1
|
346 |
+
)
|
347 |
+
return output
|
348 |
+
|
349 |
+
elif generation_type == "top_p":
|
350 |
+
logit_warper = GENERATION_TYPES[generation_type](top_p)
|
351 |
+
elif generation_type == "top_k":
|
352 |
+
logit_warper = GENERATION_TYPES[generation_type](top_k)
|
353 |
+
else:
|
354 |
+
raise ValueError(
|
355 |
+
f"generation_type has to be one of "
|
356 |
+
f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}."
|
357 |
+
)
|
358 |
+
|
359 |
+
image_latent, image_embs = self._encode_image(image)
|
360 |
+
|
361 |
+
if text is None:
|
362 |
+
text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id
|
363 |
+
|
364 |
+
was_training = self.training
|
365 |
+
num_dims = len(text.shape)
|
366 |
+
|
367 |
+
if num_dims == 1:
|
368 |
+
text = text[None, :]
|
369 |
+
|
370 |
+
self.eval()
|
371 |
+
out = text
|
372 |
+
|
373 |
+
while True:
|
374 |
+
x = out[:, -max_seq_len:]
|
375 |
+
cur_len = x.shape[1]
|
376 |
+
logits = self(
|
377 |
+
image,
|
378 |
+
x,
|
379 |
+
image_latent=image_latent,
|
380 |
+
image_embs=image_embs,
|
381 |
+
output_labels=False,
|
382 |
+
)["logits"][:, -1]
|
383 |
+
mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id)
|
384 |
+
sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id
|
385 |
+
|
386 |
+
if mask.all():
|
387 |
+
if not fixed_output_length:
|
388 |
+
break
|
389 |
+
else:
|
390 |
+
logits = logits[~mask, :]
|
391 |
+
filtered_logits = logit_processor(x[~mask, :], logits)
|
392 |
+
filtered_logits = logit_warper(x[~mask, :], filtered_logits)
|
393 |
+
probs = F.softmax(filtered_logits / temperature, dim=-1)
|
394 |
+
|
395 |
+
if (cur_len + 1 == seq_len):
|
396 |
+
sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id
|
397 |
+
else:
|
398 |
+
sample[~mask, :] = torch.multinomial(probs, 1)
|
399 |
+
|
400 |
+
out = torch.cat((out, sample), dim=-1)
|
401 |
+
|
402 |
+
cur_len += 1
|
403 |
+
|
404 |
+
if all(stopping_criteria(out, None)):
|
405 |
+
break
|
406 |
+
|
407 |
+
if num_dims == 1:
|
408 |
+
out = out.squeeze(0)
|
409 |
+
|
410 |
+
self.train(was_training)
|
411 |
+
return out
|
412 |
+
|
413 |
+
def _generate_beamsearch(
|
414 |
+
self,
|
415 |
+
image_inputs,
|
416 |
+
pad_token_id=None,
|
417 |
+
eos_token_id=None,
|
418 |
+
sot_token_id=None,
|
419 |
+
num_beams=6,
|
420 |
+
num_beam_groups=3,
|
421 |
+
min_seq_len=5,
|
422 |
+
stopping_criteria=None,
|
423 |
+
logit_processor=None,
|
424 |
+
logit_warper=None,
|
425 |
+
):
|
426 |
+
device = image_inputs.device
|
427 |
+
batch_size = image_inputs.shape[0]
|
428 |
+
image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0)
|
429 |
+
image_latent, image_embs = self._encode_image(image_inputs)
|
430 |
+
|
431 |
+
input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long)
|
432 |
+
input_ids = input_ids * sot_token_id
|
433 |
+
beam_scorer = BeamSearchScorer(
|
434 |
+
batch_size=batch_size,
|
435 |
+
num_beams=num_beams,
|
436 |
+
device=device,
|
437 |
+
num_beam_groups=num_beam_groups,
|
438 |
+
)
|
439 |
+
# instantiate logits processors
|
440 |
+
logits_processor = (
|
441 |
+
LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)])
|
442 |
+
if logit_processor is None
|
443 |
+
else logit_processor
|
444 |
+
)
|
445 |
+
|
446 |
+
num_beams = beam_scorer.num_beams
|
447 |
+
num_beam_groups = beam_scorer.num_beam_groups
|
448 |
+
num_sub_beams = num_beams // num_beam_groups
|
449 |
+
batch_size = len(beam_scorer._beam_hyps) // num_beam_groups
|
450 |
+
batch_beam_size, cur_len = input_ids.shape
|
451 |
+
beam_indices = None
|
452 |
+
|
453 |
+
if num_beams * batch_size != batch_beam_size:
|
454 |
+
raise ValueError(
|
455 |
+
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
|
456 |
+
)
|
457 |
+
|
458 |
+
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
|
459 |
+
# initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
|
460 |
+
# the same group don't produce same tokens everytime.
|
461 |
+
beam_scores[:, ::num_sub_beams] = 0
|
462 |
+
beam_scores = beam_scores.view((batch_size * num_beams,))
|
463 |
+
|
464 |
+
while True:
|
465 |
+
|
466 |
+
# predicted tokens in cur_len step
|
467 |
+
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
|
468 |
+
|
469 |
+
# indices which will form the beams in the next time step
|
470 |
+
reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
|
471 |
+
|
472 |
+
# do one decoder step on all beams of all sentences in batch
|
473 |
+
model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs)
|
474 |
+
outputs = self(
|
475 |
+
model_inputs['images'],
|
476 |
+
model_inputs['text'],
|
477 |
+
image_latent=image_latent,
|
478 |
+
image_embs=image_embs,
|
479 |
+
output_labels=False,
|
480 |
+
)
|
481 |
+
|
482 |
+
for beam_group_idx in range(num_beam_groups):
|
483 |
+
group_start_idx = beam_group_idx * num_sub_beams
|
484 |
+
group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
|
485 |
+
group_size = group_end_idx - group_start_idx
|
486 |
+
|
487 |
+
# indices of beams of current group among all sentences in batch
|
488 |
+
batch_group_indices = []
|
489 |
+
|
490 |
+
for batch_idx in range(batch_size):
|
491 |
+
batch_group_indices.extend(
|
492 |
+
[batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
|
493 |
+
)
|
494 |
+
group_input_ids = input_ids[batch_group_indices]
|
495 |
+
|
496 |
+
# select outputs of beams of currentg group only
|
497 |
+
next_token_logits = outputs['logits'][batch_group_indices, -1, :]
|
498 |
+
vocab_size = next_token_logits.shape[-1]
|
499 |
+
|
500 |
+
next_token_scores_processed = logits_processor(
|
501 |
+
group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx
|
502 |
+
)
|
503 |
+
next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
|
504 |
+
next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
|
505 |
+
|
506 |
+
# reshape for beam search
|
507 |
+
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
|
508 |
+
|
509 |
+
next_token_scores, next_tokens = torch.topk(
|
510 |
+
next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
|
511 |
+
)
|
512 |
+
|
513 |
+
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
|
514 |
+
next_tokens = next_tokens % vocab_size
|
515 |
+
|
516 |
+
# stateless
|
517 |
+
process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
|
518 |
+
beam_outputs = beam_scorer.process(
|
519 |
+
group_input_ids,
|
520 |
+
next_token_scores,
|
521 |
+
next_tokens,
|
522 |
+
next_indices,
|
523 |
+
pad_token_id=pad_token_id,
|
524 |
+
eos_token_id=eos_token_id,
|
525 |
+
beam_indices=process_beam_indices,
|
526 |
+
group_index=beam_group_idx,
|
527 |
+
)
|
528 |
+
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
|
529 |
+
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
530 |
+
beam_idx = beam_outputs["next_beam_indices"]
|
531 |
+
|
532 |
+
input_ids[batch_group_indices] = group_input_ids[beam_idx]
|
533 |
+
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
534 |
+
current_tokens[batch_group_indices] = group_input_ids[:, -1]
|
535 |
+
|
536 |
+
# (beam_idx // group_size) -> batch_idx
|
537 |
+
# (beam_idx % group_size) -> offset of idx inside the group
|
538 |
+
reordering_indices[batch_group_indices] = (
|
539 |
+
num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size)
|
540 |
+
)
|
541 |
+
|
542 |
+
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
|
543 |
+
|
544 |
+
# increase cur_len
|
545 |
+
cur_len = cur_len + 1
|
546 |
+
if beam_scorer.is_done or all(stopping_criteria(input_ids, None)):
|
547 |
+
break
|
548 |
+
|
549 |
+
final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
|
550 |
+
sequence_outputs = beam_scorer.finalize(
|
551 |
+
input_ids,
|
552 |
+
beam_scores,
|
553 |
+
next_tokens,
|
554 |
+
next_indices,
|
555 |
+
pad_token_id=pad_token_id,
|
556 |
+
eos_token_id=eos_token_id,
|
557 |
+
max_length=stopping_criteria.max_length,
|
558 |
+
beam_indices=final_beam_indices,
|
559 |
+
)
|
560 |
+
return sequence_outputs['sequences']
|
561 |
+
|
562 |
+
|
563 |
+
def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs):
|
564 |
+
if past:
|
565 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
566 |
+
|
567 |
+
attention_mask = kwargs.get("attention_mask", None)
|
568 |
+
position_ids = kwargs.get("position_ids", None)
|
569 |
+
|
570 |
+
if attention_mask is not None and position_ids is None:
|
571 |
+
# create position_ids on the fly for batch generation
|
572 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
573 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
574 |
+
else:
|
575 |
+
position_ids = None
|
576 |
+
return {
|
577 |
+
"text": input_ids,
|
578 |
+
"images": image_inputs,
|
579 |
+
"past_key_values": past,
|
580 |
+
"position_ids": position_ids,
|
581 |
+
"attention_mask": attention_mask,
|
582 |
+
}
|
src/open_clip/constants.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
2 |
+
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
|
3 |
+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
4 |
+
IMAGENET_STD = (0.229, 0.224, 0.225)
|
5 |
+
INCEPTION_MEAN = (0.5, 0.5, 0.5)
|
6 |
+
INCEPTION_STD = (0.5, 0.5, 0.5)
|
7 |
+
|
8 |
+
# Default name for a weights file hosted on the Huggingface Hub.
|
9 |
+
HF_WEIGHTS_NAME = "open_clip_pytorch_model.bin" # default pytorch pkl
|
10 |
+
HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version
|
11 |
+
HF_CONFIG_NAME = 'open_clip_config.json'
|
src/open_clip/convert.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Conversion functions for 3rd part state-dicts and non-torch native checkpoint formats.
|
2 |
+
"""
|
3 |
+
from typing import Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from .model import CLIP, CustomTextCLIP
|
9 |
+
from .transformer import TextTransformer, Transformer
|
10 |
+
|
11 |
+
|
12 |
+
@torch.no_grad()
|
13 |
+
def load_big_vision_weights(model: CustomTextCLIP, checkpoint_path: str):
|
14 |
+
""" Load weights from .npz checkpoints for official Google big_vision image-text models
|
15 |
+
|
16 |
+
Currently, the SigLIP source models are supported and a CustomTextCLIP destination model
|
17 |
+
w/ timm image encoder.
|
18 |
+
"""
|
19 |
+
from timm.layers import resample_patch_embed, resample_abs_pos_embed
|
20 |
+
|
21 |
+
def _n2p(w, t=True, idx=None):
|
22 |
+
if idx is not None:
|
23 |
+
w = w[idx]
|
24 |
+
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
|
25 |
+
w = w.flatten()
|
26 |
+
if t:
|
27 |
+
if w.ndim == 4:
|
28 |
+
w = w.transpose([3, 2, 0, 1])
|
29 |
+
elif w.ndim == 3:
|
30 |
+
w = w.transpose([2, 0, 1])
|
31 |
+
elif w.ndim == 2:
|
32 |
+
w = w.transpose([1, 0])
|
33 |
+
return torch.from_numpy(w)
|
34 |
+
|
35 |
+
w = np.load(checkpoint_path)
|
36 |
+
interpolation = 'bilinear'
|
37 |
+
antialias = False
|
38 |
+
|
39 |
+
def _convert_timm_img(module, prefix):
|
40 |
+
embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
|
41 |
+
if embed_conv_w.shape[-2:] != module.patch_embed.proj.weight.shape[-2:]:
|
42 |
+
embed_conv_w = resample_patch_embed(
|
43 |
+
embed_conv_w,
|
44 |
+
module.patch_embed.proj.weight.shape[-2:],
|
45 |
+
interpolation=interpolation,
|
46 |
+
antialias=antialias,
|
47 |
+
verbose=True,
|
48 |
+
)
|
49 |
+
module.patch_embed.proj.weight.copy_(embed_conv_w)
|
50 |
+
module.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
|
51 |
+
|
52 |
+
if module.cls_token is not None:
|
53 |
+
module.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
|
54 |
+
|
55 |
+
pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False)
|
56 |
+
if pos_embed_w.shape != module.pos_embed.shape:
|
57 |
+
assert False, f'{pos_embed_w.shape}, {module.pos_embed.shape}'
|
58 |
+
num_prefix_tokens = 0 if getattr(module, 'no_embed_class', False) else getattr(module, 'num_prefix_tokens', 1)
|
59 |
+
pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights
|
60 |
+
pos_embed_w,
|
61 |
+
new_size=module.patch_embed.grid_size,
|
62 |
+
num_prefix_tokens=num_prefix_tokens,
|
63 |
+
interpolation=interpolation,
|
64 |
+
antialias=antialias,
|
65 |
+
verbose=True,
|
66 |
+
)
|
67 |
+
module.pos_embed.copy_(pos_embed_w)
|
68 |
+
|
69 |
+
mha_sub, b_sub, ln1_sub = (0, 0, 1)
|
70 |
+
for i, block in enumerate(module.blocks.children()):
|
71 |
+
if f'{prefix}Transformer/encoderblock/LayerNorm_0/scale' in w:
|
72 |
+
block_prefix = f'{prefix}Transformer/encoderblock/'
|
73 |
+
idx = i
|
74 |
+
else:
|
75 |
+
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
|
76 |
+
idx = None
|
77 |
+
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/'
|
78 |
+
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'], idx=idx))
|
79 |
+
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'], idx=idx))
|
80 |
+
block.attn.qkv.weight.copy_(torch.cat([
|
81 |
+
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False, idx=idx).flatten(1).T for n in ('query', 'key', 'value')]))
|
82 |
+
block.attn.qkv.bias.copy_(torch.cat([
|
83 |
+
_n2p(w[f'{mha_prefix}{n}/bias'], t=False, idx=idx).reshape(-1) for n in ('query', 'key', 'value')]))
|
84 |
+
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel'], idx=idx).flatten(1))
|
85 |
+
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'], idx=idx))
|
86 |
+
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'], idx=idx))
|
87 |
+
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'], idx=idx))
|
88 |
+
for r in range(2):
|
89 |
+
getattr(block.mlp, f'fc{r + 1}').weight.copy_(
|
90 |
+
_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'], idx=idx))
|
91 |
+
getattr(block.mlp, f'fc{r + 1}').bias.copy_(
|
92 |
+
_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'], idx=idx))
|
93 |
+
|
94 |
+
module.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
|
95 |
+
module.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
|
96 |
+
|
97 |
+
if module.attn_pool is not None:
|
98 |
+
block_prefix = f'{prefix}MAPHead_0/'
|
99 |
+
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/'
|
100 |
+
module.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False))
|
101 |
+
module.attn_pool.q.weight.copy_(_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T)
|
102 |
+
module.attn_pool.q.bias.copy_(_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1))
|
103 |
+
module.attn_pool.kv.weight.copy_(torch.cat([
|
104 |
+
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('key', 'value')]))
|
105 |
+
module.attn_pool.kv.bias.copy_(torch.cat([
|
106 |
+
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('key', 'value')]))
|
107 |
+
module.attn_pool.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
|
108 |
+
module.attn_pool.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
|
109 |
+
module.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
|
110 |
+
module.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
|
111 |
+
for r in range(2):
|
112 |
+
getattr(module.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel']))
|
113 |
+
getattr(module.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias']))
|
114 |
+
|
115 |
+
def _convert_openclip_transformer(module: Transformer, prefix):
|
116 |
+
for i, block in enumerate(module.resblocks.children()):
|
117 |
+
if f'{prefix}encoderblock/LayerNorm_0/scale' in w:
|
118 |
+
block_prefix = f'{prefix}encoderblock/'
|
119 |
+
idx = i
|
120 |
+
else:
|
121 |
+
block_prefix = f'{prefix}encoderblock_{i}/'
|
122 |
+
idx = None
|
123 |
+
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/'
|
124 |
+
block.ln_1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'], idx=idx))
|
125 |
+
block.ln_1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'], idx=idx))
|
126 |
+
block.attn.in_proj_weight.copy_(torch.cat([
|
127 |
+
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False, idx=idx).flatten(1).T for n in ('query', 'key', 'value')]))
|
128 |
+
block.attn.in_proj_bias.copy_(torch.cat([
|
129 |
+
_n2p(w[f'{mha_prefix}{n}/bias'], t=False, idx=idx).reshape(-1) for n in ('query', 'key', 'value')]))
|
130 |
+
block.attn.out_proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel'], idx=idx).flatten(1))
|
131 |
+
block.attn.out_proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'], idx=idx))
|
132 |
+
block.ln_2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/scale'], idx=idx))
|
133 |
+
block.ln_2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/bias'], idx=idx))
|
134 |
+
block.mlp.c_fc.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/kernel'], idx=idx))
|
135 |
+
block.mlp.c_fc.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/bias'], idx=idx))
|
136 |
+
block.mlp.c_proj.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/kernel'], idx=idx))
|
137 |
+
block.mlp.c_proj.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/bias'], idx=idx))
|
138 |
+
|
139 |
+
def _convert_openclip_txt(module: TextTransformer, prefix):
|
140 |
+
module.token_embedding.weight.copy_(_n2p(w[f'{prefix}Embed_0/embedding'], t=False))
|
141 |
+
pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False).squeeze(0)
|
142 |
+
module.positional_embedding.copy_(pos_embed_w)
|
143 |
+
_convert_openclip_transformer(module.transformer, prefix=prefix + 'Encoder_0/')
|
144 |
+
module.ln_final.weight.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/scale']))
|
145 |
+
module.ln_final.bias.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/bias']))
|
146 |
+
if module.text_projection is not None:
|
147 |
+
module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
|
148 |
+
module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias']))
|
149 |
+
|
150 |
+
root_prefix = 'params/' if 'params/b' in w else ''
|
151 |
+
_convert_timm_img(model.visual.trunk, f'{root_prefix}img/')
|
152 |
+
_convert_openclip_txt(model.text, f'{root_prefix}txt/')
|
153 |
+
model.logit_bias.copy_(_n2p(w[f'{root_prefix}b'])[0])
|
154 |
+
model.logit_scale.copy_(_n2p(w[f'{root_prefix}t'])[0])
|
155 |
+
|
156 |
+
|
157 |
+
@torch.no_grad()
|
158 |
+
def convert_mobile_clip_state_dict(model: CustomTextCLIP, state_dict, fastvit = True):
|
159 |
+
|
160 |
+
def _convert_timm_img(state_dict):
|
161 |
+
if fastvit:
|
162 |
+
from timm.models.fastvit import checkpoint_filter_fn
|
163 |
+
else:
|
164 |
+
from timm.models.vision_transformer_hybrid import checkpoint_filter_fn
|
165 |
+
timm_state_dict = checkpoint_filter_fn(state_dict, model.visual.trunk)
|
166 |
+
timm_state_dict = {'visual.trunk.' + k: v for k, v in timm_state_dict.items()}
|
167 |
+
return timm_state_dict
|
168 |
+
|
169 |
+
def _convert_openclip_txt(state_dict, prefix='text_encoder.'):
|
170 |
+
text_dict = {}
|
171 |
+
for k, v in state_dict.items():
|
172 |
+
if not k.startswith(prefix):
|
173 |
+
continue
|
174 |
+
k = k.replace(prefix, '')
|
175 |
+
k = k.replace('projection_layer', 'text_projection')
|
176 |
+
k = k.replace('embedding_layer', 'token_embedding')
|
177 |
+
if k.startswith('positional_embedding.pos_embed.pos_embed'):
|
178 |
+
k = k.replace('positional_embedding.pos_embed.pos_embed', 'positional_embedding')
|
179 |
+
v = v.squeeze()
|
180 |
+
k = k.replace('final_layer_norm', 'ln_final')
|
181 |
+
k = k.replace('pre_norm_mha.0', 'ln_1')
|
182 |
+
k = k.replace('pre_norm_mha.1', 'attn')
|
183 |
+
k = k.replace('pre_norm_ffn.0', 'ln_2')
|
184 |
+
k = k.replace('pre_norm_ffn.1', 'mlp.c_fc')
|
185 |
+
k = k.replace('pre_norm_ffn.4', 'mlp.c_proj')
|
186 |
+
k = k.replace('qkv_proj.weight', 'in_proj_weight')
|
187 |
+
k = k.replace('qkv_proj.bias', 'in_proj_bias')
|
188 |
+
k = k.replace('transformer.', 'transformer.resblocks.')
|
189 |
+
text_dict['text.' + k] = v
|
190 |
+
return text_dict
|
191 |
+
|
192 |
+
image_dict = _convert_timm_img(state_dict)
|
193 |
+
text_dict = _convert_openclip_txt(state_dict)
|
194 |
+
out_dict = {**image_dict, **text_dict}
|
195 |
+
out_dict['logit_scale'] = state_dict['logit_scale']
|
196 |
+
return out_dict
|
197 |
+
|
198 |
+
|
199 |
+
def convert_state_dict(model: Union[CustomTextCLIP, CLIP], state_dict):
|
200 |
+
if 'image_encoder.model.patch_embed.0.rbr_conv.0.conv.weight' in state_dict:
|
201 |
+
# Apple MobileCLIP s1 & s2 state_dicts (s0 and b not currently supported)
|
202 |
+
state_dict = convert_mobile_clip_state_dict(model, state_dict)
|
203 |
+
if 'image_encoder.model.patch_emb.0.block.conv.weight' in state_dict:
|
204 |
+
# convert b model
|
205 |
+
state_dict = convert_mobile_clip_state_dict(model, state_dict, fastvit=False)
|
206 |
+
return state_dict
|
src/open_clip/factory.py
ADDED
@@ -0,0 +1,586 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import warnings
|
6 |
+
from copy import deepcopy
|
7 |
+
from dataclasses import asdict
|
8 |
+
from pathlib import Path
|
9 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
10 |
+
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from .convert import convert_state_dict
|
14 |
+
from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
|
15 |
+
resize_pos_embed, get_cast_dtype, resize_text_pos_embed, set_model_preprocess_cfg
|
16 |
+
from .coca_model import CoCa
|
17 |
+
from .loss import ClipLoss, DistillClipLoss, CoCaLoss, SigLipLoss
|
18 |
+
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\
|
19 |
+
list_pretrained_tags_by_model, download_pretrained_from_hf
|
20 |
+
from .transform import image_transform_v2, AugmentationCfg, PreprocessCfg, merge_preprocess_dict, merge_preprocess_kwargs
|
21 |
+
from .tokenizer import HFTokenizer, SimpleTokenizer, SigLipTokenizer, DEFAULT_CONTEXT_LENGTH
|
22 |
+
|
23 |
+
HF_HUB_PREFIX = 'hf-hub:'
|
24 |
+
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
|
25 |
+
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
|
26 |
+
|
27 |
+
|
28 |
+
def _natural_key(string_):
|
29 |
+
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
|
30 |
+
|
31 |
+
|
32 |
+
def _rescan_model_configs():
|
33 |
+
global _MODEL_CONFIGS
|
34 |
+
|
35 |
+
config_ext = ('.json',)
|
36 |
+
config_files = []
|
37 |
+
for config_path in _MODEL_CONFIG_PATHS:
|
38 |
+
if config_path.is_file() and config_path.suffix in config_ext:
|
39 |
+
config_files.append(config_path)
|
40 |
+
elif config_path.is_dir():
|
41 |
+
for ext in config_ext:
|
42 |
+
config_files.extend(config_path.glob(f'*{ext}'))
|
43 |
+
|
44 |
+
for cf in config_files:
|
45 |
+
with open(cf, 'r') as f:
|
46 |
+
model_cfg = json.load(f)
|
47 |
+
if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
|
48 |
+
_MODEL_CONFIGS[cf.stem] = model_cfg
|
49 |
+
|
50 |
+
_MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
|
51 |
+
|
52 |
+
|
53 |
+
_rescan_model_configs() # initial populate of model config registry
|
54 |
+
|
55 |
+
|
56 |
+
def list_models():
|
57 |
+
""" enumerate available model architectures based on config files """
|
58 |
+
return list(_MODEL_CONFIGS.keys())
|
59 |
+
|
60 |
+
|
61 |
+
def add_model_config(path):
|
62 |
+
""" add model config path or file and update registry """
|
63 |
+
if not isinstance(path, Path):
|
64 |
+
path = Path(path)
|
65 |
+
_MODEL_CONFIG_PATHS.append(path)
|
66 |
+
_rescan_model_configs()
|
67 |
+
|
68 |
+
|
69 |
+
def get_model_config(model_name):
|
70 |
+
""" Fetch model config from builtin (local library) configs.
|
71 |
+
"""
|
72 |
+
if model_name in _MODEL_CONFIGS:
|
73 |
+
return deepcopy(_MODEL_CONFIGS[model_name])
|
74 |
+
else:
|
75 |
+
return None
|
76 |
+
|
77 |
+
|
78 |
+
def _get_hf_config(
|
79 |
+
model_id: str,
|
80 |
+
cache_dir: Optional[str] = None,
|
81 |
+
):
|
82 |
+
""" Fetch model config from HuggingFace Hub.
|
83 |
+
"""
|
84 |
+
config_path = download_pretrained_from_hf(
|
85 |
+
model_id,
|
86 |
+
filename='open_clip_config.json',
|
87 |
+
cache_dir=cache_dir,
|
88 |
+
)
|
89 |
+
with open(config_path, 'r', encoding='utf-8') as f:
|
90 |
+
config = json.load(f)
|
91 |
+
return config
|
92 |
+
|
93 |
+
|
94 |
+
def get_tokenizer(
|
95 |
+
model_name: str = '',
|
96 |
+
context_length: Optional[int] = None,
|
97 |
+
cache_dir: Optional[str] = None,
|
98 |
+
**kwargs,
|
99 |
+
):
|
100 |
+
if model_name.startswith(HF_HUB_PREFIX):
|
101 |
+
model_name = model_name[len(HF_HUB_PREFIX):]
|
102 |
+
try:
|
103 |
+
config = _get_hf_config(model_name, cache_dir=cache_dir)['model_cfg']
|
104 |
+
except Exception:
|
105 |
+
tokenizer = HFTokenizer(
|
106 |
+
model_name,
|
107 |
+
context_length=context_length or DEFAULT_CONTEXT_LENGTH,
|
108 |
+
cache_dir=cache_dir,
|
109 |
+
**kwargs,
|
110 |
+
)
|
111 |
+
return tokenizer
|
112 |
+
else:
|
113 |
+
config = get_model_config(model_name)
|
114 |
+
assert config is not None, f"No valid model config found for {model_name}."
|
115 |
+
|
116 |
+
text_config = config.get('text_cfg', {})
|
117 |
+
if 'tokenizer_kwargs' in text_config:
|
118 |
+
tokenizer_kwargs = dict(text_config['tokenizer_kwargs'], **kwargs)
|
119 |
+
else:
|
120 |
+
tokenizer_kwargs = kwargs
|
121 |
+
|
122 |
+
if context_length is None:
|
123 |
+
context_length = text_config.get('context_length', DEFAULT_CONTEXT_LENGTH)
|
124 |
+
|
125 |
+
model_name = model_name.lower()
|
126 |
+
if text_config.get('hf_tokenizer_name', ''):
|
127 |
+
tokenizer = HFTokenizer(
|
128 |
+
text_config['hf_tokenizer_name'],
|
129 |
+
context_length=context_length,
|
130 |
+
cache_dir=cache_dir,
|
131 |
+
**tokenizer_kwargs,
|
132 |
+
)
|
133 |
+
elif 'siglip' in model_name:
|
134 |
+
tn = 'gemma' if 'siglip2' in model_name else 'mc4' if 'i18n' in model_name else 'c4-en'
|
135 |
+
tokenizer = SigLipTokenizer(
|
136 |
+
tn,
|
137 |
+
context_length=context_length,
|
138 |
+
# **tokenizer_kwargs,
|
139 |
+
)
|
140 |
+
else:
|
141 |
+
tokenizer = SimpleTokenizer(
|
142 |
+
context_length=context_length,
|
143 |
+
**tokenizer_kwargs,
|
144 |
+
)
|
145 |
+
|
146 |
+
return tokenizer
|
147 |
+
|
148 |
+
|
149 |
+
def load_state_dict(
|
150 |
+
checkpoint_path: str,
|
151 |
+
device='cpu',
|
152 |
+
weights_only=True,
|
153 |
+
):
|
154 |
+
# Check if safetensors or not and load weights accordingly
|
155 |
+
if str(checkpoint_path).endswith(".safetensors"):
|
156 |
+
from safetensors.torch import load_file
|
157 |
+
checkpoint = load_file(checkpoint_path, device=device)
|
158 |
+
else:
|
159 |
+
try:
|
160 |
+
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=weights_only)
|
161 |
+
except TypeError:
|
162 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
163 |
+
|
164 |
+
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
165 |
+
state_dict = checkpoint['state_dict']
|
166 |
+
elif isinstance(checkpoint, torch.jit.ScriptModule):
|
167 |
+
state_dict = checkpoint.state_dict()
|
168 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
169 |
+
state_dict.pop(key, None)
|
170 |
+
else:
|
171 |
+
state_dict = checkpoint
|
172 |
+
if next(iter(state_dict.items()))[0].startswith('module'):
|
173 |
+
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
174 |
+
return state_dict
|
175 |
+
|
176 |
+
|
177 |
+
def load_checkpoint(
|
178 |
+
model: Union[CLIP, CustomTextCLIP],
|
179 |
+
checkpoint_path: str,
|
180 |
+
strict: bool = True,
|
181 |
+
weights_only: bool = True,
|
182 |
+
device='cpu',
|
183 |
+
):
|
184 |
+
if Path(checkpoint_path).suffix in ('.npz', '.npy'):
|
185 |
+
# Separate path loading numpy big_vision (SigLIP) weights
|
186 |
+
from open_clip.convert import load_big_vision_weights
|
187 |
+
load_big_vision_weights(model, checkpoint_path)
|
188 |
+
return {}
|
189 |
+
|
190 |
+
state_dict = load_state_dict(checkpoint_path, device=device, weights_only=weights_only)
|
191 |
+
|
192 |
+
# Detect & convert 3rd party state_dicts -> open_clip
|
193 |
+
state_dict = convert_state_dict(model, state_dict)
|
194 |
+
|
195 |
+
# Detect old format and make compatible with new format
|
196 |
+
if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
|
197 |
+
state_dict = convert_to_custom_text_state_dict(state_dict)
|
198 |
+
|
199 |
+
# correct if logit_scale differs in being scaler vs 1d param
|
200 |
+
if 'logit_scale' in state_dict and model.logit_scale.ndim != state_dict['logit_scale'].ndim:
|
201 |
+
state_dict['logit_scale'] = state_dict['logit_scale'].reshape(model.logit_scale.shape)
|
202 |
+
|
203 |
+
# correct if logit_bias differs in being scaler vs 1d param
|
204 |
+
if 'logit_bias' in state_dict and model.logit_bias.ndim != state_dict['logit_bias'].ndim:
|
205 |
+
state_dict['logit_bias'] = state_dict['logit_bias'].reshape(model.logit_bias.shape)
|
206 |
+
|
207 |
+
# If loading a non-SigLIP model for SigLIP training. See https://github.com/mlfoundations/open_clip/issues/712
|
208 |
+
if 'logit_bias' not in state_dict and model.logit_bias is not None:
|
209 |
+
state_dict["logit_bias"] = torch.zeros_like(state_dict["logit_scale"])
|
210 |
+
|
211 |
+
# Certain text transformers no longer expect position_ids after transformers==4.31
|
212 |
+
position_id_key = 'text.transformer.embeddings.position_ids'
|
213 |
+
if position_id_key in state_dict and not hasattr(model, position_id_key):
|
214 |
+
del state_dict[position_id_key]
|
215 |
+
|
216 |
+
resize_pos_embed(state_dict, model)
|
217 |
+
resize_text_pos_embed(state_dict, model)
|
218 |
+
|
219 |
+
# Finally, load the massaged state_dict into model
|
220 |
+
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
|
221 |
+
return incompatible_keys
|
222 |
+
|
223 |
+
|
224 |
+
def create_model(
|
225 |
+
model_name: str,
|
226 |
+
pretrained: Optional[str] = None,
|
227 |
+
precision: str = 'fp32',
|
228 |
+
device: Union[str, torch.device] = 'cpu',
|
229 |
+
jit: bool = False,
|
230 |
+
force_quick_gelu: bool = False,
|
231 |
+
force_custom_text: bool = False,
|
232 |
+
force_patch_dropout: Optional[float] = None,
|
233 |
+
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
234 |
+
force_preprocess_cfg: Optional[Dict[str, Any]] = None,
|
235 |
+
pretrained_image: bool = False,
|
236 |
+
pretrained_hf: bool = True,
|
237 |
+
cache_dir: Optional[str] = None,
|
238 |
+
output_dict: Optional[bool] = None,
|
239 |
+
require_pretrained: bool = False,
|
240 |
+
load_weights_only: bool = True,
|
241 |
+
**model_kwargs,
|
242 |
+
):
|
243 |
+
"""Creates and configures a contrastive vision-language model.
|
244 |
+
|
245 |
+
Args:
|
246 |
+
model_name: Name of the model architecture to create. Can be a local model name
|
247 |
+
or a Hugging Face model ID prefixed with 'hf-hub:'.
|
248 |
+
pretrained: Tag/path for pretrained model weights. Can be:
|
249 |
+
- A pretrained tag name (e.g., 'openai')
|
250 |
+
- A path to local weights
|
251 |
+
- None to initialize with random weights
|
252 |
+
precision: Model precision/AMP configuration. Options:
|
253 |
+
- 'fp32': 32-bit floating point
|
254 |
+
- 'fp16'/'bf16': Mixed precision with FP32 for certain layers
|
255 |
+
- 'pure_fp16'/'pure_bf16': Pure 16-bit precision
|
256 |
+
device: Device to load the model on ('cpu', 'cuda', or torch.device object)
|
257 |
+
jit: If True, JIT compile the model
|
258 |
+
force_quick_gelu: Force use of QuickGELU activation
|
259 |
+
force_custom_text: Force use of custom text encoder
|
260 |
+
force_patch_dropout: Override default patch dropout value
|
261 |
+
force_image_size: Override default image size for vision encoder
|
262 |
+
force_preprocess_cfg: Override default preprocessing configuration
|
263 |
+
pretrained_image: Load pretrained weights for timm vision models
|
264 |
+
pretrained_hf: Load pretrained weights for HF text models when not loading CLIP weights
|
265 |
+
cache_dir: Override default cache directory for downloaded model files
|
266 |
+
output_dict: If True and model supports it, return dictionary of features
|
267 |
+
require_pretrained: Raise error if pretrained weights cannot be loaded
|
268 |
+
load_weights_only: Only deserialize model weights and unpickling torch checkpoints (for safety)
|
269 |
+
**model_kwargs: Additional keyword arguments passed to model constructor
|
270 |
+
|
271 |
+
Returns:
|
272 |
+
Created and configured model instance
|
273 |
+
|
274 |
+
Raises:
|
275 |
+
RuntimeError: If model config is not found or required pretrained weights
|
276 |
+
cannot be loaded
|
277 |
+
|
278 |
+
Examples:
|
279 |
+
# Create basic CLIP model
|
280 |
+
model = create_model('ViT-B/32')
|
281 |
+
|
282 |
+
# Create CLIP model with mixed precision on GPU
|
283 |
+
model = create_model('ViT-B/32', precision='fp16', device='cuda')
|
284 |
+
|
285 |
+
# Load pretrained OpenAI weights
|
286 |
+
model = create_model('ViT-B/32', pretrained='openai')
|
287 |
+
|
288 |
+
# Load Hugging Face model
|
289 |
+
model = create_model('hf-hub:organization/model-name')
|
290 |
+
"""
|
291 |
+
|
292 |
+
force_preprocess_cfg = force_preprocess_cfg or {}
|
293 |
+
preprocess_cfg = asdict(PreprocessCfg())
|
294 |
+
has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
|
295 |
+
if has_hf_hub_prefix:
|
296 |
+
model_id = model_name[len(HF_HUB_PREFIX):]
|
297 |
+
checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
|
298 |
+
config = _get_hf_config(model_id, cache_dir=cache_dir)
|
299 |
+
preprocess_cfg = merge_preprocess_dict(preprocess_cfg, config['preprocess_cfg'])
|
300 |
+
model_cfg = config['model_cfg']
|
301 |
+
pretrained_hf = False # override, no need to load original HF text weights
|
302 |
+
else:
|
303 |
+
model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
|
304 |
+
checkpoint_path = None
|
305 |
+
model_cfg = None
|
306 |
+
|
307 |
+
if isinstance(device, str):
|
308 |
+
device = torch.device(device)
|
309 |
+
|
310 |
+
model_cfg = model_cfg or get_model_config(model_name)
|
311 |
+
if model_cfg is not None:
|
312 |
+
logging.info(f'Loaded {model_name} model config.')
|
313 |
+
else:
|
314 |
+
logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
|
315 |
+
raise RuntimeError(f'Model config for {model_name} not found.')
|
316 |
+
|
317 |
+
if force_quick_gelu:
|
318 |
+
# override for use of QuickGELU on non-OpenAI transformer models
|
319 |
+
model_cfg["quick_gelu"] = True
|
320 |
+
|
321 |
+
if force_patch_dropout is not None:
|
322 |
+
# override the default patch dropout value
|
323 |
+
model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
|
324 |
+
|
325 |
+
if force_image_size is not None:
|
326 |
+
# override model config's image size
|
327 |
+
model_cfg["vision_cfg"]["image_size"] = force_image_size
|
328 |
+
|
329 |
+
is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {})
|
330 |
+
if pretrained_image:
|
331 |
+
if is_timm_model:
|
332 |
+
# pretrained weight loading for timm models set via vision_cfg
|
333 |
+
model_cfg['vision_cfg']['timm_model_pretrained'] = True
|
334 |
+
else:
|
335 |
+
assert False, 'pretrained image towers currently only supported for timm models'
|
336 |
+
|
337 |
+
# cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes
|
338 |
+
cast_dtype = get_cast_dtype(precision)
|
339 |
+
is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
|
340 |
+
if is_hf_model:
|
341 |
+
# load pretrained weights for HF text model IFF no CLIP weights being loaded
|
342 |
+
model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf and not pretrained
|
343 |
+
custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model
|
344 |
+
|
345 |
+
model_cfg = dict(model_cfg, **model_kwargs) # merge cfg dict w/ kwargs (kwargs overrides cfg)
|
346 |
+
if custom_text:
|
347 |
+
if "multimodal_cfg" in model_cfg:
|
348 |
+
model = CoCa(**model_cfg, cast_dtype=cast_dtype)
|
349 |
+
else:
|
350 |
+
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
|
351 |
+
else:
|
352 |
+
model = CLIP(**model_cfg, cast_dtype=cast_dtype)
|
353 |
+
|
354 |
+
if precision in ("fp16", "bf16"):
|
355 |
+
dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
|
356 |
+
# manual mixed precision that matches original OpenAI behaviour
|
357 |
+
if is_timm_model:
|
358 |
+
# FIXME this is a bit janky, create timm based model in low-precision and
|
359 |
+
# then cast only LayerNormFp32 instances back to float32 so they don't break.
|
360 |
+
# Why? The convert_weights_to_lp fn only works with native models.
|
361 |
+
model.to(device=device, dtype=dtype)
|
362 |
+
from .transformer import LayerNormFp32
|
363 |
+
|
364 |
+
def _convert_ln(m):
|
365 |
+
if isinstance(m, LayerNormFp32):
|
366 |
+
m.weight.data = m.weight.data.to(torch.float32)
|
367 |
+
m.bias.data = m.bias.data.to(torch.float32)
|
368 |
+
model.apply(_convert_ln)
|
369 |
+
else:
|
370 |
+
model.to(device=device)
|
371 |
+
convert_weights_to_lp(model, dtype=dtype)
|
372 |
+
elif precision in ("pure_fp16", "pure_bf16"):
|
373 |
+
dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
|
374 |
+
model.to(device=device, dtype=dtype)
|
375 |
+
else:
|
376 |
+
model.to(device=device)
|
377 |
+
|
378 |
+
pretrained_loaded = False
|
379 |
+
if pretrained:
|
380 |
+
checkpoint_path = ''
|
381 |
+
pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
|
382 |
+
if pretrained_cfg:
|
383 |
+
checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
|
384 |
+
preprocess_cfg = merge_preprocess_dict(preprocess_cfg, pretrained_cfg)
|
385 |
+
pretrained_quick_gelu = pretrained_cfg.get('quick_gelu', False)
|
386 |
+
model_quick_gelu = model_cfg.get('quick_gelu', False)
|
387 |
+
if pretrained_quick_gelu and not model_quick_gelu:
|
388 |
+
warnings.warn(
|
389 |
+
f'These pretrained weights were trained with QuickGELU activation but the model config does '
|
390 |
+
f'not have that enabled. Consider using a model config with a "-quickgelu" suffix or enable with a flag.')
|
391 |
+
elif not pretrained_quick_gelu and model_quick_gelu:
|
392 |
+
warnings.warn(
|
393 |
+
f'The pretrained weights were not trained with QuickGELU but this activation is enabled in the '
|
394 |
+
f'model config, consider using a model config without QuickGELU or disable override flags.')
|
395 |
+
elif os.path.exists(pretrained):
|
396 |
+
checkpoint_path = pretrained
|
397 |
+
|
398 |
+
if checkpoint_path:
|
399 |
+
logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
|
400 |
+
load_checkpoint(model, checkpoint_path, weights_only=load_weights_only)
|
401 |
+
else:
|
402 |
+
error_str = (
|
403 |
+
f'Pretrained weights ({pretrained}) not found for model {model_name}.'
|
404 |
+
f' Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
|
405 |
+
logging.warning(error_str)
|
406 |
+
raise RuntimeError(error_str)
|
407 |
+
pretrained_loaded = True
|
408 |
+
elif has_hf_hub_prefix:
|
409 |
+
logging.info(f'Loading pretrained {model_name} weights ({checkpoint_path}).')
|
410 |
+
load_checkpoint(model, checkpoint_path, weights_only=load_weights_only)
|
411 |
+
pretrained_loaded = True
|
412 |
+
|
413 |
+
if require_pretrained and not pretrained_loaded:
|
414 |
+
# callers of create_model_from_pretrained always expect pretrained weights
|
415 |
+
raise RuntimeError(
|
416 |
+
f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.')
|
417 |
+
|
418 |
+
if output_dict and hasattr(model, "output_dict"):
|
419 |
+
model.output_dict = True
|
420 |
+
|
421 |
+
if jit:
|
422 |
+
model = torch.jit.script(model)
|
423 |
+
|
424 |
+
# set image preprocessing configuration in model attributes for convenience
|
425 |
+
if getattr(model.visual, 'image_size', None) is not None:
|
426 |
+
# use image_size set on model creation (via config or force_image_size arg)
|
427 |
+
force_preprocess_cfg['size'] = model.visual.image_size
|
428 |
+
set_model_preprocess_cfg(model, merge_preprocess_dict(preprocess_cfg, force_preprocess_cfg))
|
429 |
+
|
430 |
+
return model
|
431 |
+
|
432 |
+
|
433 |
+
def create_loss(args):
|
434 |
+
if args.distill:
|
435 |
+
return DistillClipLoss(
|
436 |
+
local_loss=args.local_loss,
|
437 |
+
gather_with_grad=args.gather_with_grad,
|
438 |
+
cache_labels=True,
|
439 |
+
rank=args.rank,
|
440 |
+
world_size=args.world_size,
|
441 |
+
use_horovod=args.horovod,
|
442 |
+
)
|
443 |
+
elif "coca" in args.model.lower():
|
444 |
+
return CoCaLoss(
|
445 |
+
caption_loss_weight=args.coca_caption_loss_weight,
|
446 |
+
clip_loss_weight=args.coca_contrastive_loss_weight,
|
447 |
+
local_loss=args.local_loss,
|
448 |
+
gather_with_grad=args.gather_with_grad,
|
449 |
+
cache_labels=True,
|
450 |
+
rank=args.rank,
|
451 |
+
world_size=args.world_size,
|
452 |
+
use_horovod=args.horovod,
|
453 |
+
)
|
454 |
+
elif args.siglip:
|
455 |
+
assert not args.horovod, "Horovod not currently supported for SigLip"
|
456 |
+
return SigLipLoss(
|
457 |
+
rank=args.rank,
|
458 |
+
world_size=args.world_size,
|
459 |
+
dist_impl=args.loss_dist_impl, # siglip has multiple distributed implementations to choose from
|
460 |
+
)
|
461 |
+
|
462 |
+
return ClipLoss(
|
463 |
+
local_loss=args.local_loss,
|
464 |
+
gather_with_grad=args.gather_with_grad,
|
465 |
+
cache_labels=True,
|
466 |
+
rank=args.rank,
|
467 |
+
world_size=args.world_size,
|
468 |
+
use_horovod=args.horovod,
|
469 |
+
)
|
470 |
+
|
471 |
+
|
472 |
+
def create_model_and_transforms(
|
473 |
+
model_name: str,
|
474 |
+
pretrained: Optional[str] = None,
|
475 |
+
precision: str = 'fp32',
|
476 |
+
device: Union[str, torch.device] = 'cpu',
|
477 |
+
jit: bool = False,
|
478 |
+
force_quick_gelu: bool = False,
|
479 |
+
force_custom_text: bool = False,
|
480 |
+
force_patch_dropout: Optional[float] = None,
|
481 |
+
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
482 |
+
image_mean: Optional[Tuple[float, ...]] = None,
|
483 |
+
image_std: Optional[Tuple[float, ...]] = None,
|
484 |
+
image_interpolation: Optional[str] = None,
|
485 |
+
image_resize_mode: Optional[str] = None, # only effective for inference
|
486 |
+
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
|
487 |
+
pretrained_image: bool = False,
|
488 |
+
pretrained_hf: bool = True,
|
489 |
+
cache_dir: Optional[str] = None,
|
490 |
+
output_dict: Optional[bool] = None,
|
491 |
+
load_weights_only: bool = True,
|
492 |
+
**model_kwargs,
|
493 |
+
):
|
494 |
+
force_preprocess_cfg = merge_preprocess_kwargs(
|
495 |
+
{},
|
496 |
+
mean=image_mean,
|
497 |
+
std=image_std,
|
498 |
+
interpolation=image_interpolation,
|
499 |
+
resize_mode=image_resize_mode,
|
500 |
+
)
|
501 |
+
|
502 |
+
model = create_model(
|
503 |
+
model_name,
|
504 |
+
pretrained,
|
505 |
+
precision=precision,
|
506 |
+
device=device,
|
507 |
+
jit=jit,
|
508 |
+
force_quick_gelu=force_quick_gelu,
|
509 |
+
force_custom_text=force_custom_text,
|
510 |
+
force_patch_dropout=force_patch_dropout,
|
511 |
+
force_image_size=force_image_size,
|
512 |
+
force_preprocess_cfg=force_preprocess_cfg,
|
513 |
+
pretrained_image=pretrained_image,
|
514 |
+
pretrained_hf=pretrained_hf,
|
515 |
+
cache_dir=cache_dir,
|
516 |
+
output_dict=output_dict,
|
517 |
+
load_weights_only=load_weights_only,
|
518 |
+
**model_kwargs,
|
519 |
+
)
|
520 |
+
|
521 |
+
pp_cfg = PreprocessCfg(**model.visual.preprocess_cfg)
|
522 |
+
|
523 |
+
preprocess_train = image_transform_v2(
|
524 |
+
pp_cfg,
|
525 |
+
is_train=True,
|
526 |
+
aug_cfg=aug_cfg,
|
527 |
+
)
|
528 |
+
preprocess_val = image_transform_v2(
|
529 |
+
pp_cfg,
|
530 |
+
is_train=False,
|
531 |
+
)
|
532 |
+
|
533 |
+
return model, preprocess_train, preprocess_val
|
534 |
+
|
535 |
+
|
536 |
+
def create_model_from_pretrained(
|
537 |
+
model_name: str,
|
538 |
+
pretrained: Optional[str] = None,
|
539 |
+
precision: str = 'fp32',
|
540 |
+
device: Union[str, torch.device] = 'cpu',
|
541 |
+
jit: bool = False,
|
542 |
+
force_quick_gelu: bool = False,
|
543 |
+
force_custom_text: bool = False,
|
544 |
+
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
545 |
+
image_mean: Optional[Tuple[float, ...]] = None,
|
546 |
+
image_std: Optional[Tuple[float, ...]] = None,
|
547 |
+
image_interpolation: Optional[str] = None,
|
548 |
+
image_resize_mode: Optional[str] = None, # only effective for inference
|
549 |
+
return_transform: bool = True,
|
550 |
+
cache_dir: Optional[str] = None,
|
551 |
+
load_weights_only: bool = True,
|
552 |
+
**model_kwargs,
|
553 |
+
):
|
554 |
+
force_preprocess_cfg = merge_preprocess_kwargs(
|
555 |
+
{},
|
556 |
+
mean=image_mean,
|
557 |
+
std=image_std,
|
558 |
+
interpolation=image_interpolation,
|
559 |
+
resize_mode=image_resize_mode,
|
560 |
+
)
|
561 |
+
|
562 |
+
model = create_model(
|
563 |
+
model_name,
|
564 |
+
pretrained,
|
565 |
+
precision=precision,
|
566 |
+
device=device,
|
567 |
+
jit=jit,
|
568 |
+
force_quick_gelu=force_quick_gelu,
|
569 |
+
force_custom_text=force_custom_text,
|
570 |
+
force_image_size=force_image_size,
|
571 |
+
force_preprocess_cfg=force_preprocess_cfg,
|
572 |
+
cache_dir=cache_dir,
|
573 |
+
require_pretrained=True,
|
574 |
+
load_weights_only=load_weights_only,
|
575 |
+
**model_kwargs,
|
576 |
+
)
|
577 |
+
|
578 |
+
if not return_transform:
|
579 |
+
return model
|
580 |
+
|
581 |
+
preprocess = image_transform_v2(
|
582 |
+
PreprocessCfg(**model.visual.preprocess_cfg),
|
583 |
+
is_train=False,
|
584 |
+
)
|
585 |
+
|
586 |
+
return model, preprocess
|
src/open_clip/hf_configs.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# HF architecture dict:
|
2 |
+
arch_dict = {
|
3 |
+
# https://huggingface.co/docs/transformers/model_doc/roberta#roberta
|
4 |
+
"roberta": {
|
5 |
+
"config_names": {
|
6 |
+
"context_length": "max_position_embeddings",
|
7 |
+
"vocab_size": "vocab_size",
|
8 |
+
"width": "hidden_size",
|
9 |
+
"heads": "num_attention_heads",
|
10 |
+
"layers": "num_hidden_layers",
|
11 |
+
"layer_attr": "layer",
|
12 |
+
"token_embeddings_attr": "embeddings"
|
13 |
+
},
|
14 |
+
"pooler": "mean_pooler",
|
15 |
+
},
|
16 |
+
# https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
|
17 |
+
"xlm-roberta": {
|
18 |
+
"config_names": {
|
19 |
+
"context_length": "max_position_embeddings",
|
20 |
+
"vocab_size": "vocab_size",
|
21 |
+
"width": "hidden_size",
|
22 |
+
"heads": "num_attention_heads",
|
23 |
+
"layers": "num_hidden_layers",
|
24 |
+
"layer_attr": "layer",
|
25 |
+
"token_embeddings_attr": "embeddings"
|
26 |
+
},
|
27 |
+
"pooler": "mean_pooler",
|
28 |
+
},
|
29 |
+
# https://huggingface.co/docs/transformers/model_doc/mt5#mt5
|
30 |
+
"mt5": {
|
31 |
+
"config_names": {
|
32 |
+
# unlimited seqlen
|
33 |
+
# https://github.com/google-research/text-to-text-transfer-transformer/issues/273
|
34 |
+
# https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
|
35 |
+
"context_length": "",
|
36 |
+
"vocab_size": "vocab_size",
|
37 |
+
"width": "d_model",
|
38 |
+
"heads": "num_heads",
|
39 |
+
"layers": "num_layers",
|
40 |
+
"layer_attr": "block",
|
41 |
+
"token_embeddings_attr": "embed_tokens"
|
42 |
+
},
|
43 |
+
"pooler": "mean_pooler",
|
44 |
+
},
|
45 |
+
# https://huggingface.co/docs/transformers/model_doc/bert
|
46 |
+
"bert": {
|
47 |
+
"config_names": {
|
48 |
+
"context_length": "max_position_embeddings",
|
49 |
+
"vocab_size": "vocab_size",
|
50 |
+
"width": "hidden_size",
|
51 |
+
"heads": "num_attention_heads",
|
52 |
+
"layers": "num_hidden_layers",
|
53 |
+
},
|
54 |
+
"pooler": "cls_pooler",
|
55 |
+
},
|
56 |
+
# https://huggingface.co/docs/transformers/model_doc/m2m_100
|
57 |
+
"m2m_100": {
|
58 |
+
"config_names": {
|
59 |
+
"context_length": "max_position_embeddings",
|
60 |
+
"vocab_size": "vocab_size",
|
61 |
+
"width": "d_model",
|
62 |
+
"heads": "encoder_attention_heads",
|
63 |
+
"layers": "encoder_layers",
|
64 |
+
},
|
65 |
+
"pooler": "cls_pooler",
|
66 |
+
},
|
67 |
+
}
|
src/open_clip/hf_model.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" huggingface model adapter
|
2 |
+
|
3 |
+
Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
|
4 |
+
"""
|
5 |
+
import re
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from torch import TensorType
|
10 |
+
|
11 |
+
try:
|
12 |
+
import transformers
|
13 |
+
from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig
|
14 |
+
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
|
15 |
+
BaseModelOutputWithPoolingAndCrossAttentions
|
16 |
+
except ImportError as e:
|
17 |
+
transformers = None
|
18 |
+
|
19 |
+
|
20 |
+
class BaseModelOutput:
|
21 |
+
pass
|
22 |
+
|
23 |
+
|
24 |
+
class PretrainedConfig:
|
25 |
+
pass
|
26 |
+
|
27 |
+
from .hf_configs import arch_dict
|
28 |
+
|
29 |
+
|
30 |
+
# utils
|
31 |
+
def _camel2snake(s):
|
32 |
+
return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
|
33 |
+
|
34 |
+
|
35 |
+
# TODO: ?last - for gpt-like models
|
36 |
+
_POOLERS = {}
|
37 |
+
|
38 |
+
|
39 |
+
def register_pooler(cls):
|
40 |
+
"""Decorator registering pooler class"""
|
41 |
+
_POOLERS[_camel2snake(cls.__name__)] = cls
|
42 |
+
return cls
|
43 |
+
|
44 |
+
|
45 |
+
@register_pooler
|
46 |
+
class MeanPooler(nn.Module):
|
47 |
+
"""Mean pooling"""
|
48 |
+
|
49 |
+
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
|
50 |
+
masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
|
51 |
+
return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
|
52 |
+
|
53 |
+
|
54 |
+
@register_pooler
|
55 |
+
class MaxPooler(nn.Module):
|
56 |
+
"""Max pooling"""
|
57 |
+
|
58 |
+
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
|
59 |
+
masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
|
60 |
+
return masked_output.max(1).values
|
61 |
+
|
62 |
+
|
63 |
+
@register_pooler
|
64 |
+
class ClsPooler(nn.Module):
|
65 |
+
"""CLS token pooling"""
|
66 |
+
|
67 |
+
def __init__(self, use_pooler_output=True):
|
68 |
+
super().__init__()
|
69 |
+
self.cls_token_position = 0
|
70 |
+
self.use_pooler_output = use_pooler_output
|
71 |
+
|
72 |
+
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
|
73 |
+
if (self.use_pooler_output and
|
74 |
+
isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and
|
75 |
+
(x.pooler_output is not None)
|
76 |
+
):
|
77 |
+
return x.pooler_output
|
78 |
+
|
79 |
+
return x.last_hidden_state[:, self.cls_token_position, :]
|
80 |
+
|
81 |
+
|
82 |
+
@register_pooler
|
83 |
+
class ClsLastHiddenStatePooler(nn.Module):
|
84 |
+
"""CLS token pooling
|
85 |
+
NOTE: this is equivalent to ClsPooler above with use_pooler_output=False
|
86 |
+
"""
|
87 |
+
|
88 |
+
def __init__(self):
|
89 |
+
super().__init__()
|
90 |
+
self.cls_token_position = 0
|
91 |
+
|
92 |
+
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
|
93 |
+
return x.last_hidden_state[:, self.cls_token_position, :]
|
94 |
+
|
95 |
+
|
96 |
+
class HFTextEncoder(nn.Module):
|
97 |
+
"""HuggingFace model adapter"""
|
98 |
+
output_tokens: torch.jit.Final[bool]
|
99 |
+
|
100 |
+
def __init__(
|
101 |
+
self,
|
102 |
+
model_name_or_path: str,
|
103 |
+
output_dim: int,
|
104 |
+
config: PretrainedConfig = None,
|
105 |
+
pooler_type: str = None,
|
106 |
+
proj_type: str = None,
|
107 |
+
pretrained: bool = True,
|
108 |
+
output_tokens: bool = False,
|
109 |
+
):
|
110 |
+
super().__init__()
|
111 |
+
self.output_tokens = output_tokens
|
112 |
+
self.output_dim = output_dim
|
113 |
+
|
114 |
+
# TODO: find better way to get this information
|
115 |
+
uses_transformer_pooler = (pooler_type == "cls_pooler")
|
116 |
+
|
117 |
+
if transformers is None:
|
118 |
+
raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
|
119 |
+
if config is None:
|
120 |
+
self.config = AutoConfig.from_pretrained(model_name_or_path)
|
121 |
+
create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (
|
122 |
+
AutoModel.from_config, self.config)
|
123 |
+
# TODO: do all model configs have this attribute? PretrainedConfig does so yes??
|
124 |
+
if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
|
125 |
+
self.transformer = create_func(model_args)
|
126 |
+
self.transformer = self.transformer.encoder
|
127 |
+
else:
|
128 |
+
self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
|
129 |
+
else:
|
130 |
+
self.config = config
|
131 |
+
self.transformer = AutoModel.from_config(config)
|
132 |
+
if pooler_type is None: # get default arch pooler
|
133 |
+
pooler_type = (arch_dict[self.config.model_type]["pooler"])
|
134 |
+
|
135 |
+
# FIXME downstream users of OpenCLIP models use these attr, need to verify valid across all models
|
136 |
+
self.vocab_size = getattr(self.config, 'vocab_size', 0)
|
137 |
+
self.context_length = getattr(self.config, 'max_position_embeddings', 0)
|
138 |
+
|
139 |
+
self.pooler = _POOLERS[pooler_type]()
|
140 |
+
|
141 |
+
d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
|
142 |
+
if (d_model == output_dim) and (proj_type is None): # do we always need a proj?
|
143 |
+
self.proj = nn.Identity()
|
144 |
+
elif proj_type == 'linear':
|
145 |
+
self.proj = nn.Linear(d_model, output_dim, bias=False)
|
146 |
+
elif proj_type == 'mlp':
|
147 |
+
hidden_size = (d_model + output_dim) // 2
|
148 |
+
self.proj = nn.Sequential(
|
149 |
+
nn.Linear(d_model, hidden_size, bias=False),
|
150 |
+
nn.GELU(),
|
151 |
+
nn.Linear(hidden_size, output_dim, bias=False),
|
152 |
+
)
|
153 |
+
|
154 |
+
def forward(self, x: TensorType):
|
155 |
+
attn_mask = (x != self.config.pad_token_id).long()
|
156 |
+
out = self.transformer(input_ids=x, attention_mask=attn_mask)
|
157 |
+
pooled_out = self.pooler(out, attn_mask)
|
158 |
+
projected = self.proj(pooled_out)
|
159 |
+
|
160 |
+
seq_len = out.last_hidden_state.shape[1]
|
161 |
+
tokens = (
|
162 |
+
out.last_hidden_state[:, torch.arange(seq_len) != self.pooler.cls_token_position, :]
|
163 |
+
if type(self.pooler) == ClsPooler
|
164 |
+
else out.last_hidden_state
|
165 |
+
)
|
166 |
+
|
167 |
+
if self.output_tokens:
|
168 |
+
return projected, tokens
|
169 |
+
return projected
|
170 |
+
|
171 |
+
def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
|
172 |
+
if not unlocked_layers: # full freezing
|
173 |
+
for n, p in self.transformer.named_parameters():
|
174 |
+
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
|
175 |
+
return
|
176 |
+
|
177 |
+
encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
|
178 |
+
layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
|
179 |
+
print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
|
180 |
+
embeddings = getattr(
|
181 |
+
self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
|
182 |
+
modules = [embeddings, *layer_list][:-unlocked_layers]
|
183 |
+
# freeze layers
|
184 |
+
for module in modules:
|
185 |
+
for n, p in module.named_parameters():
|
186 |
+
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
|
187 |
+
|
188 |
+
@torch.jit.ignore
|
189 |
+
def set_grad_checkpointing(self, enable=True):
|
190 |
+
self.transformer.gradient_checkpointing_enable()
|
191 |
+
|
192 |
+
def init_parameters(self):
|
193 |
+
pass
|
src/open_clip/loss.py
ADDED
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
try:
|
8 |
+
import torch.distributed.nn
|
9 |
+
from torch import distributed as dist
|
10 |
+
|
11 |
+
has_distributed = True
|
12 |
+
except ImportError:
|
13 |
+
has_distributed = False
|
14 |
+
|
15 |
+
try:
|
16 |
+
import horovod.torch as hvd
|
17 |
+
except ImportError:
|
18 |
+
hvd = None
|
19 |
+
|
20 |
+
|
21 |
+
def gather_features(
|
22 |
+
image_features,
|
23 |
+
text_features,
|
24 |
+
local_loss=False,
|
25 |
+
gather_with_grad=False,
|
26 |
+
rank=0,
|
27 |
+
world_size=1,
|
28 |
+
use_horovod=False
|
29 |
+
):
|
30 |
+
assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
|
31 |
+
if use_horovod:
|
32 |
+
assert hvd is not None, 'Please install horovod'
|
33 |
+
if gather_with_grad:
|
34 |
+
all_image_features = hvd.allgather(image_features)
|
35 |
+
all_text_features = hvd.allgather(text_features)
|
36 |
+
else:
|
37 |
+
with torch.no_grad():
|
38 |
+
all_image_features = hvd.allgather(image_features)
|
39 |
+
all_text_features = hvd.allgather(text_features)
|
40 |
+
if not local_loss:
|
41 |
+
# ensure grads for local rank when all_* features don't have a gradient
|
42 |
+
gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
|
43 |
+
gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
|
44 |
+
gathered_image_features[rank] = image_features
|
45 |
+
gathered_text_features[rank] = text_features
|
46 |
+
all_image_features = torch.cat(gathered_image_features, dim=0)
|
47 |
+
all_text_features = torch.cat(gathered_text_features, dim=0)
|
48 |
+
else:
|
49 |
+
# We gather tensors from all gpus
|
50 |
+
if gather_with_grad:
|
51 |
+
all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
|
52 |
+
all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
|
53 |
+
else:
|
54 |
+
gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
|
55 |
+
gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
|
56 |
+
dist.all_gather(gathered_image_features, image_features)
|
57 |
+
dist.all_gather(gathered_text_features, text_features)
|
58 |
+
if not local_loss:
|
59 |
+
# ensure grads for local rank when all_* features don't have a gradient
|
60 |
+
gathered_image_features[rank] = image_features
|
61 |
+
gathered_text_features[rank] = text_features
|
62 |
+
all_image_features = torch.cat(gathered_image_features, dim=0)
|
63 |
+
all_text_features = torch.cat(gathered_text_features, dim=0)
|
64 |
+
|
65 |
+
return all_image_features, all_text_features
|
66 |
+
|
67 |
+
|
68 |
+
class ClipLoss(nn.Module):
|
69 |
+
|
70 |
+
def __init__(
|
71 |
+
self,
|
72 |
+
local_loss=False,
|
73 |
+
gather_with_grad=False,
|
74 |
+
cache_labels=False,
|
75 |
+
rank=0,
|
76 |
+
world_size=1,
|
77 |
+
use_horovod=False,
|
78 |
+
):
|
79 |
+
super().__init__()
|
80 |
+
self.local_loss = local_loss
|
81 |
+
self.gather_with_grad = gather_with_grad
|
82 |
+
self.cache_labels = cache_labels
|
83 |
+
self.rank = rank
|
84 |
+
self.world_size = world_size
|
85 |
+
self.use_horovod = use_horovod
|
86 |
+
|
87 |
+
# cache state
|
88 |
+
self.prev_num_logits = 0
|
89 |
+
self.labels = {}
|
90 |
+
|
91 |
+
def get_ground_truth(self, device, num_logits) -> torch.Tensor:
|
92 |
+
# calculated ground-truth and cache if enabled
|
93 |
+
if self.prev_num_logits != num_logits or device not in self.labels:
|
94 |
+
labels = torch.arange(num_logits, device=device, dtype=torch.long)
|
95 |
+
if self.world_size > 1 and self.local_loss:
|
96 |
+
labels = labels + num_logits * self.rank
|
97 |
+
if self.cache_labels:
|
98 |
+
self.labels[device] = labels
|
99 |
+
self.prev_num_logits = num_logits
|
100 |
+
else:
|
101 |
+
labels = self.labels[device]
|
102 |
+
return labels
|
103 |
+
|
104 |
+
def get_logits(self, image_features, text_features, logit_scale):
|
105 |
+
if self.world_size > 1:
|
106 |
+
all_image_features, all_text_features = gather_features(
|
107 |
+
image_features,
|
108 |
+
text_features,
|
109 |
+
local_loss=self.local_loss,
|
110 |
+
gather_with_grad=self.gather_with_grad,
|
111 |
+
rank=self.rank,
|
112 |
+
world_size=self.world_size,
|
113 |
+
use_horovod=self.use_horovod,
|
114 |
+
)
|
115 |
+
|
116 |
+
if self.local_loss:
|
117 |
+
logits_per_image = logit_scale * image_features @ all_text_features.T
|
118 |
+
logits_per_text = logit_scale * text_features @ all_image_features.T
|
119 |
+
else:
|
120 |
+
logits_per_image = logit_scale * all_image_features @ all_text_features.T
|
121 |
+
logits_per_text = logits_per_image.T
|
122 |
+
else:
|
123 |
+
logits_per_image = logit_scale * image_features @ text_features.T
|
124 |
+
logits_per_text = logit_scale * text_features @ image_features.T
|
125 |
+
|
126 |
+
return logits_per_image, logits_per_text
|
127 |
+
|
128 |
+
def forward(self, image_features, text_features, logit_scale, output_dict=False):
|
129 |
+
device = image_features.device
|
130 |
+
logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)
|
131 |
+
labels = self.get_ground_truth(device, logits_per_image.shape[0])
|
132 |
+
|
133 |
+
total_loss = (
|
134 |
+
F.cross_entropy(logits_per_image, labels) +
|
135 |
+
F.cross_entropy(logits_per_text, labels)
|
136 |
+
) / 2
|
137 |
+
|
138 |
+
return {"contrastive_loss": total_loss} if output_dict else total_loss
|
139 |
+
|
140 |
+
|
141 |
+
class CoCaLoss(ClipLoss):
|
142 |
+
def __init__(
|
143 |
+
self,
|
144 |
+
caption_loss_weight,
|
145 |
+
clip_loss_weight,
|
146 |
+
pad_id=0, # pad_token for open_clip custom tokenizer
|
147 |
+
local_loss=False,
|
148 |
+
gather_with_grad=False,
|
149 |
+
cache_labels=False,
|
150 |
+
rank=0,
|
151 |
+
world_size=1,
|
152 |
+
use_horovod=False,
|
153 |
+
):
|
154 |
+
super().__init__(
|
155 |
+
local_loss=local_loss,
|
156 |
+
gather_with_grad=gather_with_grad,
|
157 |
+
cache_labels=cache_labels,
|
158 |
+
rank=rank,
|
159 |
+
world_size=world_size,
|
160 |
+
use_horovod=use_horovod
|
161 |
+
)
|
162 |
+
|
163 |
+
self.clip_loss_weight = clip_loss_weight
|
164 |
+
self.caption_loss_weight = caption_loss_weight
|
165 |
+
self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id)
|
166 |
+
|
167 |
+
def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False):
|
168 |
+
if self.clip_loss_weight:
|
169 |
+
clip_loss = super().forward(image_features, text_features, logit_scale)
|
170 |
+
clip_loss = self.clip_loss_weight * clip_loss
|
171 |
+
else:
|
172 |
+
clip_loss = torch.tensor(0, device=logits.device)
|
173 |
+
|
174 |
+
caption_loss = self.caption_loss(
|
175 |
+
logits.permute(0, 2, 1),
|
176 |
+
labels,
|
177 |
+
)
|
178 |
+
caption_loss = caption_loss * self.caption_loss_weight
|
179 |
+
|
180 |
+
if output_dict:
|
181 |
+
return {"contrastive_loss": clip_loss, "caption_loss": caption_loss}
|
182 |
+
|
183 |
+
return clip_loss, caption_loss
|
184 |
+
|
185 |
+
|
186 |
+
class DistillClipLoss(ClipLoss):
|
187 |
+
|
188 |
+
def dist_loss(self, teacher_logits, student_logits):
|
189 |
+
return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0)
|
190 |
+
|
191 |
+
def forward(
|
192 |
+
self,
|
193 |
+
image_features,
|
194 |
+
text_features,
|
195 |
+
logit_scale,
|
196 |
+
dist_image_features,
|
197 |
+
dist_text_features,
|
198 |
+
dist_logit_scale,
|
199 |
+
output_dict=False,
|
200 |
+
):
|
201 |
+
logits_per_image, logits_per_text = \
|
202 |
+
self.get_logits(image_features, text_features, logit_scale)
|
203 |
+
|
204 |
+
dist_logits_per_image, dist_logits_per_text = \
|
205 |
+
self.get_logits(dist_image_features, dist_text_features, dist_logit_scale)
|
206 |
+
|
207 |
+
labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0])
|
208 |
+
|
209 |
+
contrastive_loss = (
|
210 |
+
F.cross_entropy(logits_per_image, labels) +
|
211 |
+
F.cross_entropy(logits_per_text, labels)
|
212 |
+
) / 2
|
213 |
+
|
214 |
+
distill_loss = (
|
215 |
+
self.dist_loss(dist_logits_per_image, logits_per_image) +
|
216 |
+
self.dist_loss(dist_logits_per_text, logits_per_text)
|
217 |
+
) / 2
|
218 |
+
|
219 |
+
if output_dict:
|
220 |
+
return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss}
|
221 |
+
|
222 |
+
return contrastive_loss, distill_loss
|
223 |
+
|
224 |
+
|
225 |
+
def neighbour_exchange(from_rank, to_rank, tensor, group=None):
|
226 |
+
tensor_recv = torch.zeros_like(tensor)
|
227 |
+
send_op = torch.distributed.P2POp(
|
228 |
+
torch.distributed.isend,
|
229 |
+
tensor,
|
230 |
+
to_rank,
|
231 |
+
group=group,
|
232 |
+
)
|
233 |
+
recv_op = torch.distributed.P2POp(
|
234 |
+
torch.distributed.irecv,
|
235 |
+
tensor_recv,
|
236 |
+
from_rank,
|
237 |
+
group=group,
|
238 |
+
)
|
239 |
+
reqs = torch.distributed.batch_isend_irecv([send_op, recv_op])
|
240 |
+
for req in reqs:
|
241 |
+
req.wait()
|
242 |
+
return tensor_recv
|
243 |
+
|
244 |
+
|
245 |
+
def neighbour_exchange_bidir(left_rank, right_rank, tensor_to_left, tensor_to_right, group=None):
|
246 |
+
tensor_from_left = torch.zeros_like(tensor_to_right)
|
247 |
+
tensor_from_right = torch.zeros_like(tensor_to_left)
|
248 |
+
send_op_left = torch.distributed.P2POp(
|
249 |
+
torch.distributed.isend,
|
250 |
+
tensor_to_left,
|
251 |
+
left_rank,
|
252 |
+
group=group,
|
253 |
+
)
|
254 |
+
send_op_right = torch.distributed.P2POp(
|
255 |
+
torch.distributed.isend,
|
256 |
+
tensor_to_right,
|
257 |
+
right_rank,
|
258 |
+
group=group,
|
259 |
+
)
|
260 |
+
recv_op_left = torch.distributed.P2POp(
|
261 |
+
torch.distributed.irecv,
|
262 |
+
tensor_from_left,
|
263 |
+
left_rank,
|
264 |
+
group=group,
|
265 |
+
)
|
266 |
+
recv_op_right = torch.distributed.P2POp(
|
267 |
+
torch.distributed.irecv,
|
268 |
+
tensor_from_right,
|
269 |
+
right_rank,
|
270 |
+
group=group,
|
271 |
+
)
|
272 |
+
reqs = torch.distributed.batch_isend_irecv([send_op_right, send_op_left, recv_op_right, recv_op_left])
|
273 |
+
for req in reqs:
|
274 |
+
req.wait()
|
275 |
+
return tensor_from_right, tensor_from_left
|
276 |
+
|
277 |
+
|
278 |
+
class NeighbourExchange(torch.autograd.Function):
|
279 |
+
@staticmethod
|
280 |
+
def forward(ctx, from_rank, to_rank, group, tensor):
|
281 |
+
ctx.group = group
|
282 |
+
ctx.from_rank = from_rank
|
283 |
+
ctx.to_rank = to_rank
|
284 |
+
return neighbour_exchange(from_rank, to_rank, tensor, group=group)
|
285 |
+
|
286 |
+
@staticmethod
|
287 |
+
def backward(ctx, grad_output):
|
288 |
+
return (None, None, None) + (NeighbourExchange.apply(ctx.to_rank, ctx.from_rank, ctx.group, grad_output),)
|
289 |
+
|
290 |
+
|
291 |
+
def neighbour_exchange_with_grad(from_rank, to_rank, tensor, group=None):
|
292 |
+
return NeighbourExchange.apply(from_rank, to_rank, group, tensor)
|
293 |
+
|
294 |
+
|
295 |
+
class NeighbourExchangeBidir(torch.autograd.Function):
|
296 |
+
@staticmethod
|
297 |
+
def forward(ctx, left_rank, right_rank, group, tensor_to_left, tensor_to_right):
|
298 |
+
ctx.group = group
|
299 |
+
ctx.left_rank = left_rank
|
300 |
+
ctx.right_rank = right_rank
|
301 |
+
return neighbour_exchange_bidir(left_rank, right_rank, tensor_to_left, tensor_to_right, group=group)
|
302 |
+
|
303 |
+
@staticmethod
|
304 |
+
def backward(ctx, *grad_outputs):
|
305 |
+
return (None, None, None) + \
|
306 |
+
NeighbourExchangeBidir.apply(ctx.right_rank, ctx.left_rank, ctx.group, *grad_outputs)
|
307 |
+
|
308 |
+
|
309 |
+
def neighbour_exchange_bidir_with_grad(left_rank, right_rank, tensor_to_left, tensor_to_right, group=None):
|
310 |
+
return NeighbourExchangeBidir.apply(left_rank, right_rank, group, tensor_to_left, tensor_to_right)
|
311 |
+
|
312 |
+
|
313 |
+
class SigLipLoss(nn.Module):
|
314 |
+
""" Sigmoid Loss for Language Image Pre-Training (SigLIP) - https://arxiv.org/abs/2303.15343
|
315 |
+
|
316 |
+
@article{zhai2023sigmoid,
|
317 |
+
title={Sigmoid loss for language image pre-training},
|
318 |
+
author={Zhai, Xiaohua and Mustafa, Basil and Kolesnikov, Alexander and Beyer, Lucas},
|
319 |
+
journal={arXiv preprint arXiv:2303.15343},
|
320 |
+
year={2023}
|
321 |
+
}
|
322 |
+
"""
|
323 |
+
def __init__(
|
324 |
+
self,
|
325 |
+
cache_labels: bool = False,
|
326 |
+
rank: int = 0,
|
327 |
+
world_size: int = 1,
|
328 |
+
dist_impl: Optional[str] = None,
|
329 |
+
):
|
330 |
+
super().__init__()
|
331 |
+
self.cache_labels = cache_labels
|
332 |
+
self.rank = rank
|
333 |
+
self.world_size = world_size
|
334 |
+
self.dist_impl = dist_impl or 'bidir' # default to bidir exchange for now, this will likely change
|
335 |
+
assert self.dist_impl in ('bidir', 'shift', 'reduce', 'gather')
|
336 |
+
|
337 |
+
# cache state FIXME cache not currently used, worthwhile?
|
338 |
+
self.prev_num_logits = 0
|
339 |
+
self.labels = {}
|
340 |
+
|
341 |
+
def get_ground_truth(self, device, dtype, num_logits, negative_only=False) -> torch.Tensor:
|
342 |
+
labels = -torch.ones((num_logits, num_logits), device=device, dtype=dtype)
|
343 |
+
if not negative_only:
|
344 |
+
labels = 2 * torch.eye(num_logits, device=device, dtype=dtype) + labels
|
345 |
+
return labels
|
346 |
+
|
347 |
+
def get_logits(self, image_features, text_features, logit_scale, logit_bias=None):
|
348 |
+
logits = logit_scale * image_features @ text_features.T
|
349 |
+
if logit_bias is not None:
|
350 |
+
logits += logit_bias
|
351 |
+
return logits
|
352 |
+
|
353 |
+
def _loss(self, image_features, text_features, logit_scale, logit_bias=None, negative_only=False):
|
354 |
+
logits = self.get_logits(image_features, text_features, logit_scale, logit_bias)
|
355 |
+
labels = self.get_ground_truth(
|
356 |
+
image_features.device,
|
357 |
+
image_features.dtype,
|
358 |
+
image_features.shape[0],
|
359 |
+
negative_only=negative_only,
|
360 |
+
)
|
361 |
+
loss = -F.logsigmoid(labels * logits).sum() / image_features.shape[0]
|
362 |
+
return loss
|
363 |
+
|
364 |
+
def forward(self, image_features, text_features, logit_scale, logit_bias, output_dict=False):
|
365 |
+
loss = self._loss(image_features, text_features, logit_scale, logit_bias)
|
366 |
+
|
367 |
+
if self.world_size > 1:
|
368 |
+
if self.dist_impl == 'bidir':
|
369 |
+
right_rank = (self.rank + 1) % self.world_size
|
370 |
+
left_rank = (self.rank - 1 + self.world_size) % self.world_size
|
371 |
+
text_features_to_right = text_features_to_left = text_features
|
372 |
+
num_bidir, remainder = divmod(self.world_size - 1, 2)
|
373 |
+
for i in range(num_bidir):
|
374 |
+
text_features_recv = neighbour_exchange_bidir_with_grad(
|
375 |
+
left_rank,
|
376 |
+
right_rank,
|
377 |
+
text_features_to_left,
|
378 |
+
text_features_to_right,
|
379 |
+
)
|
380 |
+
for f in text_features_recv:
|
381 |
+
loss += self._loss(
|
382 |
+
image_features,
|
383 |
+
f,
|
384 |
+
logit_scale,
|
385 |
+
logit_bias,
|
386 |
+
negative_only=True,
|
387 |
+
)
|
388 |
+
text_features_to_left, text_features_to_right = text_features_recv
|
389 |
+
|
390 |
+
if remainder:
|
391 |
+
text_features_recv = neighbour_exchange_with_grad(
|
392 |
+
left_rank,
|
393 |
+
right_rank,
|
394 |
+
text_features_to_right
|
395 |
+
)
|
396 |
+
loss += self._loss(
|
397 |
+
image_features,
|
398 |
+
text_features_recv,
|
399 |
+
logit_scale,
|
400 |
+
logit_bias,
|
401 |
+
negative_only=True,
|
402 |
+
)
|
403 |
+
elif self.dist_impl == "shift":
|
404 |
+
right_rank = (self.rank + 1) % self.world_size
|
405 |
+
left_rank = (self.rank - 1 + self.world_size) % self.world_size
|
406 |
+
text_features_to_right = text_features
|
407 |
+
for i in range(self.world_size - 1):
|
408 |
+
text_features_from_left = neighbour_exchange_with_grad(
|
409 |
+
left_rank,
|
410 |
+
right_rank,
|
411 |
+
text_features_to_right,
|
412 |
+
)
|
413 |
+
loss += self._loss(
|
414 |
+
image_features,
|
415 |
+
text_features_from_left,
|
416 |
+
logit_scale,
|
417 |
+
logit_bias,
|
418 |
+
negative_only=True,
|
419 |
+
)
|
420 |
+
text_features_to_right = text_features_from_left
|
421 |
+
elif self.dist_impl == "reduce":
|
422 |
+
for i in range(self.world_size):
|
423 |
+
text_from_other = torch.distributed.nn.all_reduce(
|
424 |
+
text_features * (self.rank == i),
|
425 |
+
torch.distributed.ReduceOp.SUM,
|
426 |
+
)
|
427 |
+
loss += float(i != self.rank) * self._loss(
|
428 |
+
image_features,
|
429 |
+
text_from_other,
|
430 |
+
logit_scale,
|
431 |
+
logit_bias,
|
432 |
+
negative_only=True,
|
433 |
+
)
|
434 |
+
elif self.dist_impl == "gather":
|
435 |
+
all_text = torch.distributed.nn.all_gather(text_features)
|
436 |
+
for i in range(self.world_size):
|
437 |
+
loss += float(i != self.rank) * self._loss(
|
438 |
+
image_features,
|
439 |
+
all_text[i],
|
440 |
+
logit_scale,
|
441 |
+
logit_bias,
|
442 |
+
negative_only=True,
|
443 |
+
)
|
444 |
+
else:
|
445 |
+
assert False
|
446 |
+
|
447 |
+
return {"contrastive_loss": loss} if output_dict else loss
|
src/open_clip/model.py
ADDED
@@ -0,0 +1,919 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" CLIP Model
|
2 |
+
|
3 |
+
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
|
4 |
+
"""
|
5 |
+
import copy
|
6 |
+
import logging
|
7 |
+
import math
|
8 |
+
from dataclasses import dataclass
|
9 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from torch import nn
|
15 |
+
from torch.utils.checkpoint import checkpoint
|
16 |
+
from functools import partial
|
17 |
+
|
18 |
+
from .hf_model import HFTextEncoder
|
19 |
+
from .modified_resnet import ModifiedResNet
|
20 |
+
from .timm_model import TimmModel
|
21 |
+
from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer,\
|
22 |
+
text_global_pool
|
23 |
+
from .utils import to_2tuple
|
24 |
+
|
25 |
+
|
26 |
+
@dataclass
|
27 |
+
class CLIPVisionCfg:
|
28 |
+
layers: Union[Tuple[int, int, int, int], int] = 12
|
29 |
+
width: int = 768
|
30 |
+
head_width: int = 64
|
31 |
+
mlp_ratio: float = 4.0
|
32 |
+
patch_size: int = 16
|
33 |
+
image_size: Union[Tuple[int, int], int] = 224
|
34 |
+
in_chans: int = 3
|
35 |
+
|
36 |
+
ls_init_value: Optional[float] = None # layer scale initial value
|
37 |
+
patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
|
38 |
+
attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer (overrides pool_type)
|
39 |
+
attn_pooler_queries: int = 256 # n_queries for attentional pooler
|
40 |
+
attn_pooler_heads: int = 8 # n heads for attentional_pooling
|
41 |
+
no_ln_pre: bool = False # disable pre transformer LayerNorm
|
42 |
+
pos_embed_type: str = 'learnable'
|
43 |
+
final_ln_after_pool: bool = False # apply final LayerNorm after pooling
|
44 |
+
pool_type: str = 'tok'
|
45 |
+
output_tokens: bool = False
|
46 |
+
act_kwargs: Optional[dict] = None
|
47 |
+
norm_kwargs: Optional[dict] = None
|
48 |
+
|
49 |
+
timm_model_name: Optional[str] = None # a valid model name overrides layers, width, patch_size
|
50 |
+
timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
|
51 |
+
timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
|
52 |
+
timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
|
53 |
+
timm_proj_bias: bool = False # enable bias final projection
|
54 |
+
timm_drop: float = 0. # head dropout
|
55 |
+
timm_drop_path: Optional[float] = None # backbone stochastic depth
|
56 |
+
|
57 |
+
|
58 |
+
@dataclass
|
59 |
+
class CLIPTextCfg:
|
60 |
+
context_length: int = 77
|
61 |
+
vocab_size: int = 49408
|
62 |
+
hf_tokenizer_name: Optional[str] = None
|
63 |
+
tokenizer_kwargs: Optional[dict] = None
|
64 |
+
|
65 |
+
width: int = 512
|
66 |
+
heads: int = 8
|
67 |
+
layers: int = 12
|
68 |
+
mlp_ratio: float = 4.0
|
69 |
+
ls_init_value: Optional[float] = None # layer scale initial value
|
70 |
+
embed_cls: bool = False
|
71 |
+
pad_id: int = 0
|
72 |
+
no_causal_mask: bool = False # disable causal masking
|
73 |
+
final_ln_after_pool: bool = False # apply final LayerNorm after pooling
|
74 |
+
pool_type: str = 'argmax'
|
75 |
+
proj_bias: bool = False
|
76 |
+
proj_type: str = 'linear' # control final text projection, 'none' forces no projection
|
77 |
+
output_tokens: bool = False
|
78 |
+
act_kwargs: dict = None
|
79 |
+
norm_kwargs: dict = None
|
80 |
+
|
81 |
+
# HuggingFace specific text tower config
|
82 |
+
hf_model_name: Optional[str] = None
|
83 |
+
hf_model_pretrained: bool = True
|
84 |
+
hf_proj_type: str = 'mlp'
|
85 |
+
hf_pooler_type: str = 'mean_pooler' # attentional pooling for HF models
|
86 |
+
|
87 |
+
|
88 |
+
def get_cast_dtype(precision: str):
|
89 |
+
cast_dtype = None
|
90 |
+
if precision == 'bf16':
|
91 |
+
cast_dtype = torch.bfloat16
|
92 |
+
elif precision == 'fp16':
|
93 |
+
cast_dtype = torch.float16
|
94 |
+
return cast_dtype
|
95 |
+
|
96 |
+
|
97 |
+
def get_input_dtype(precision: str):
|
98 |
+
input_dtype = None
|
99 |
+
if precision in ('bf16', 'pure_bf16'):
|
100 |
+
input_dtype = torch.bfloat16
|
101 |
+
elif precision in ('fp16', 'pure_fp16'):
|
102 |
+
input_dtype = torch.float16
|
103 |
+
return input_dtype
|
104 |
+
|
105 |
+
|
106 |
+
def _build_vision_tower(
|
107 |
+
embed_dim: int,
|
108 |
+
vision_cfg: CLIPVisionCfg,
|
109 |
+
quick_gelu: bool = False,
|
110 |
+
cast_dtype: Optional[torch.dtype] = None
|
111 |
+
):
|
112 |
+
if isinstance(vision_cfg, dict):
|
113 |
+
vision_cfg = CLIPVisionCfg(**vision_cfg)
|
114 |
+
|
115 |
+
# OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
|
116 |
+
# memory efficient in recent PyTorch releases (>= 1.10).
|
117 |
+
# NOTE: timm models always use native GELU regardless of quick_gelu flag.
|
118 |
+
act_layer = QuickGELU if quick_gelu else nn.GELU
|
119 |
+
|
120 |
+
if vision_cfg.timm_model_name:
|
121 |
+
visual = TimmModel(
|
122 |
+
vision_cfg.timm_model_name,
|
123 |
+
pretrained=vision_cfg.timm_model_pretrained,
|
124 |
+
pool=vision_cfg.timm_pool,
|
125 |
+
proj=vision_cfg.timm_proj,
|
126 |
+
proj_bias=vision_cfg.timm_proj_bias,
|
127 |
+
drop=vision_cfg.timm_drop,
|
128 |
+
drop_path=vision_cfg.timm_drop_path,
|
129 |
+
patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None,
|
130 |
+
embed_dim=embed_dim,
|
131 |
+
image_size=vision_cfg.image_size,
|
132 |
+
)
|
133 |
+
elif isinstance(vision_cfg.layers, (tuple, list)):
|
134 |
+
vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
|
135 |
+
visual = ModifiedResNet(
|
136 |
+
layers=vision_cfg.layers,
|
137 |
+
output_dim=embed_dim,
|
138 |
+
heads=vision_heads,
|
139 |
+
image_size=vision_cfg.image_size,
|
140 |
+
width=vision_cfg.width,
|
141 |
+
)
|
142 |
+
else:
|
143 |
+
vision_heads = vision_cfg.width // vision_cfg.head_width
|
144 |
+
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
145 |
+
if vision_cfg.norm_kwargs:
|
146 |
+
norm_layer = partial(norm_layer, **vision_cfg.norm_kwargs)
|
147 |
+
if vision_cfg.act_kwargs is not None:
|
148 |
+
act_layer = partial(act_layer, **vision_cfg.act_kwargs)
|
149 |
+
|
150 |
+
visual = VisionTransformer(
|
151 |
+
image_size=vision_cfg.image_size,
|
152 |
+
patch_size=vision_cfg.patch_size,
|
153 |
+
width=vision_cfg.width,
|
154 |
+
layers=vision_cfg.layers,
|
155 |
+
heads=vision_heads,
|
156 |
+
mlp_ratio=vision_cfg.mlp_ratio,
|
157 |
+
ls_init_value=vision_cfg.ls_init_value,
|
158 |
+
patch_dropout=vision_cfg.patch_dropout,
|
159 |
+
attentional_pool=vision_cfg.attentional_pool,
|
160 |
+
attn_pooler_queries=vision_cfg.attn_pooler_queries,
|
161 |
+
attn_pooler_heads=vision_cfg.attn_pooler_heads,
|
162 |
+
pos_embed_type=vision_cfg.pos_embed_type,
|
163 |
+
no_ln_pre=vision_cfg.no_ln_pre,
|
164 |
+
final_ln_after_pool=vision_cfg.final_ln_after_pool,
|
165 |
+
pool_type=vision_cfg.pool_type,
|
166 |
+
output_tokens=vision_cfg.output_tokens,
|
167 |
+
output_dim=embed_dim,
|
168 |
+
act_layer=act_layer,
|
169 |
+
norm_layer=norm_layer,
|
170 |
+
in_chans=vision_cfg.in_chans,
|
171 |
+
)
|
172 |
+
|
173 |
+
return visual
|
174 |
+
|
175 |
+
|
176 |
+
def _build_text_tower(
|
177 |
+
embed_dim: int,
|
178 |
+
text_cfg: CLIPTextCfg,
|
179 |
+
quick_gelu: bool = False,
|
180 |
+
cast_dtype: Optional[torch.dtype] = None,
|
181 |
+
):
|
182 |
+
if isinstance(text_cfg, dict):
|
183 |
+
text_cfg = CLIPTextCfg(**text_cfg)
|
184 |
+
|
185 |
+
if text_cfg.hf_model_name:
|
186 |
+
text = HFTextEncoder(
|
187 |
+
text_cfg.hf_model_name,
|
188 |
+
output_dim=embed_dim,
|
189 |
+
proj_type=text_cfg.hf_proj_type,
|
190 |
+
pooler_type=text_cfg.hf_pooler_type,
|
191 |
+
pretrained=text_cfg.hf_model_pretrained,
|
192 |
+
output_tokens=text_cfg.output_tokens,
|
193 |
+
)
|
194 |
+
else:
|
195 |
+
act_layer = QuickGELU if quick_gelu else nn.GELU
|
196 |
+
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
197 |
+
if text_cfg.norm_kwargs:
|
198 |
+
norm_layer = partial(norm_layer, **text_cfg.norm_kwargs)
|
199 |
+
if text_cfg.act_kwargs is not None:
|
200 |
+
act_layer = partial(act_layer, **text_cfg.act_kwargs)
|
201 |
+
|
202 |
+
text = TextTransformer(
|
203 |
+
context_length=text_cfg.context_length,
|
204 |
+
vocab_size=text_cfg.vocab_size,
|
205 |
+
width=text_cfg.width,
|
206 |
+
heads=text_cfg.heads,
|
207 |
+
layers=text_cfg.layers,
|
208 |
+
mlp_ratio=text_cfg.mlp_ratio,
|
209 |
+
ls_init_value=text_cfg.ls_init_value,
|
210 |
+
output_dim=embed_dim,
|
211 |
+
embed_cls=text_cfg.embed_cls,
|
212 |
+
no_causal_mask=text_cfg.no_causal_mask,
|
213 |
+
pad_id=text_cfg.pad_id,
|
214 |
+
pool_type=text_cfg.pool_type,
|
215 |
+
proj_type=text_cfg.proj_type,
|
216 |
+
proj_bias=text_cfg.proj_bias,
|
217 |
+
output_tokens=text_cfg.output_tokens,
|
218 |
+
act_layer=act_layer,
|
219 |
+
norm_layer=norm_layer,
|
220 |
+
)
|
221 |
+
return text
|
222 |
+
|
223 |
+
|
224 |
+
|
225 |
+
class TrunkNet(nn.Module):
|
226 |
+
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
|
227 |
+
super().__init__()
|
228 |
+
self.net = nn.Sequential(
|
229 |
+
nn.Linear(input_dim, hidden_dim),
|
230 |
+
LayerNorm(hidden_dim),
|
231 |
+
nn.GELU(),
|
232 |
+
nn.Linear(hidden_dim, hidden_dim),
|
233 |
+
LayerNorm(hidden_dim),
|
234 |
+
nn.GELU(),
|
235 |
+
nn.Linear(hidden_dim, output_dim)
|
236 |
+
)
|
237 |
+
|
238 |
+
def forward(self, x):
|
239 |
+
|
240 |
+
for i, layer in enumerate(self.net):
|
241 |
+
x = layer(x)
|
242 |
+
|
243 |
+
return x
|
244 |
+
|
245 |
+
|
246 |
+
class MultiTrunkNet(nn.Module):
|
247 |
+
def __init__(self, embed_dim: int):
|
248 |
+
super().__init__()
|
249 |
+
self.embed_dim = embed_dim
|
250 |
+
|
251 |
+
self.compound_trunk = TrunkNet(input_dim=159, hidden_dim=embed_dim, output_dim=embed_dim)
|
252 |
+
self.concentration_trunk = TrunkNet(input_dim=2, hidden_dim=embed_dim, output_dim=embed_dim)
|
253 |
+
self.time_trunk = TrunkNet(input_dim=1, hidden_dim=embed_dim, output_dim=embed_dim)
|
254 |
+
|
255 |
+
total_dim = embed_dim * 3
|
256 |
+
self.projection = nn.Linear(total_dim, embed_dim)
|
257 |
+
|
258 |
+
def forward(self, compound_embedding: torch.Tensor, concentration: torch.Tensor, time: torch.Tensor):
|
259 |
+
|
260 |
+
# Process each input through its own trunk
|
261 |
+
compound_features = self.compound_trunk(compound_embedding)
|
262 |
+
|
263 |
+
concentration_features = self.concentration_trunk(concentration)
|
264 |
+
|
265 |
+
time = time.unsqueeze(-1) if time.dim() == 1 else time
|
266 |
+
time_features = self.time_trunk(time)
|
267 |
+
|
268 |
+
# Concatenate all features
|
269 |
+
return compound_features, concentration_features, time_features
|
270 |
+
|
271 |
+
|
272 |
+
class CLIP(nn.Module):
|
273 |
+
output_dict: torch.jit.Final[bool]
|
274 |
+
|
275 |
+
def __init__(
|
276 |
+
self,
|
277 |
+
embed_dim: int,
|
278 |
+
vision_cfg: CLIPVisionCfg,
|
279 |
+
text_cfg: CLIPTextCfg,
|
280 |
+
quick_gelu: bool = False,
|
281 |
+
init_logit_scale: float = np.log(1 / 0.07),
|
282 |
+
init_logit_bias: Optional[float] = None,
|
283 |
+
nonscalar_logit_scale: bool = False,
|
284 |
+
cast_dtype: Optional[torch.dtype] = None,
|
285 |
+
output_dict: bool = False,
|
286 |
+
):
|
287 |
+
super().__init__()
|
288 |
+
self.output_dict = output_dict
|
289 |
+
|
290 |
+
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
|
291 |
+
|
292 |
+
text = _build_text_tower(int(embed_dim/4), text_cfg, quick_gelu, cast_dtype)
|
293 |
+
self.transformer = text.transformer
|
294 |
+
self.context_length = text.context_length
|
295 |
+
self.vocab_size = text.vocab_size
|
296 |
+
self.token_embedding = text.token_embedding
|
297 |
+
self.positional_embedding = text.positional_embedding
|
298 |
+
self.ln_final = text.ln_final
|
299 |
+
self.text_projection = text.text_projection
|
300 |
+
self.text_pool_type = text.pool_type
|
301 |
+
self.register_buffer('attn_mask', text.attn_mask, persistent=False)
|
302 |
+
|
303 |
+
# Add multi-trunk net for additional inputs
|
304 |
+
self.multi_trunk = MultiTrunkNet(int(embed_dim/4))
|
305 |
+
|
306 |
+
# # Add projection layer for concatenated features
|
307 |
+
# self.fusion_projection = nn.Linear(embed_dim * 4, embed_dim)
|
308 |
+
|
309 |
+
lshape = [1] if nonscalar_logit_scale else []
|
310 |
+
self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale)
|
311 |
+
if init_logit_bias is not None:
|
312 |
+
self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias)
|
313 |
+
else:
|
314 |
+
self.logit_bias = None
|
315 |
+
|
316 |
+
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
|
317 |
+
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
|
318 |
+
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
|
319 |
+
|
320 |
+
@torch.jit.ignore
|
321 |
+
def set_grad_checkpointing(self, enable=True):
|
322 |
+
self.visual.set_grad_checkpointing(enable)
|
323 |
+
self.transformer.grad_checkpointing = enable
|
324 |
+
|
325 |
+
@torch.jit.ignore
|
326 |
+
def no_weight_decay(self):
|
327 |
+
# for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default
|
328 |
+
no_wd = {'positional_embedding'}
|
329 |
+
if hasattr(self.visual, 'no_weight_decay'):
|
330 |
+
for n in self.visual.no_weight_decay():
|
331 |
+
no_wd.add('visual.' + n)
|
332 |
+
return no_wd
|
333 |
+
|
334 |
+
def encode_image(self, image, normalize: bool = False):
|
335 |
+
features = self.visual(image)
|
336 |
+
return F.normalize(features, dim=-1) if normalize else features
|
337 |
+
|
338 |
+
def encode_text(self, text, normalize: bool = False, concentration: Optional[torch.Tensor] = None,
|
339 |
+
time: Optional[torch.Tensor] = None, compound_embedding: Optional[torch.Tensor] = None):
|
340 |
+
cast_dtype = self.transformer.get_cast_dtype()
|
341 |
+
|
342 |
+
x = self.token_embedding(text).to(cast_dtype)
|
343 |
+
x = x + self.positional_embedding.to(cast_dtype)
|
344 |
+
x = self.transformer(x, attn_mask=self.attn_mask)
|
345 |
+
x = self.ln_final(x)
|
346 |
+
x = text_global_pool(x, text, self.text_pool_type)
|
347 |
+
|
348 |
+
if self.text_projection is not None:
|
349 |
+
if isinstance(self.text_projection, nn.Linear):
|
350 |
+
x = self.text_projection(x)
|
351 |
+
else:
|
352 |
+
x = x @ self.text_projection
|
353 |
+
|
354 |
+
if compound_embedding is not None and concentration is not None and time is not None:
|
355 |
+
compound_features, concentration_features, time_features = self.multi_trunk(compound_embedding, concentration, time)
|
356 |
+
x = torch.cat([x, compound_features, concentration_features, time_features], dim=-1)
|
357 |
+
|
358 |
+
if normalize:
|
359 |
+
x = F.normalize(x, dim=-1)
|
360 |
+
|
361 |
+
return x
|
362 |
+
|
363 |
+
def get_logits(self, image, text, concentration: Optional[torch.Tensor] = None,
|
364 |
+
time: Optional[torch.Tensor] = None,
|
365 |
+
compound_embedding: Optional[torch.Tensor] = None):
|
366 |
+
image_features = self.encode_image(image, normalize=True)
|
367 |
+
text_features = self.encode_text(text, normalize=True,
|
368 |
+
concentration=concentration,
|
369 |
+
time=time,
|
370 |
+
compound_embedding=compound_embedding)
|
371 |
+
image_logits = self.logit_scale.exp() * image_features @ text_features.T
|
372 |
+
if self.logit_bias is not None:
|
373 |
+
image_logits += self.logit_bias
|
374 |
+
text_logits = image_logits.T
|
375 |
+
return image_logits, text_logits
|
376 |
+
|
377 |
+
def forward_intermediates(
|
378 |
+
self,
|
379 |
+
image: Optional[torch.Tensor] = None,
|
380 |
+
text: Optional[torch.Tensor] = None,
|
381 |
+
image_indices: Optional[Union[int, List[int]]] = None,
|
382 |
+
text_indices: Optional[Union[int, List[int]]] = None,
|
383 |
+
stop_early: bool = False,
|
384 |
+
normalize: bool = True,
|
385 |
+
normalize_intermediates: bool = False,
|
386 |
+
intermediates_only: bool = False,
|
387 |
+
image_output_fmt: str = 'NCHW',
|
388 |
+
image_output_extra_tokens: bool = False,
|
389 |
+
text_output_fmt: str = 'NLC',
|
390 |
+
text_output_extra_tokens: bool = False,
|
391 |
+
output_logits: bool = False,
|
392 |
+
output_logit_scale_bias: bool = False,
|
393 |
+
) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
|
394 |
+
""" Forward features that returns intermediates.
|
395 |
+
|
396 |
+
Args:
|
397 |
+
image: Input image tensor
|
398 |
+
text: Input text tensor
|
399 |
+
image_indices: For image tower, Take last n blocks if int, all if None, select matching indices if sequence
|
400 |
+
text_indices: Take last n blocks if int, all if None, select matching indices if sequence
|
401 |
+
stop_early: Stop iterating over blocks when last desired intermediate hit
|
402 |
+
normalize_intermediates: Apply final norm layer to all intermediates
|
403 |
+
normalize: L2 Normalize final features
|
404 |
+
intermediates_only: Only return intermediate features, do not return final features
|
405 |
+
image_output_fmt: Shape of intermediate image feature outputs
|
406 |
+
image_output_extra_tokens: Return both prefix and spatial intermediate tokens
|
407 |
+
text_output_fmt: Shape of intermediate text feature outputs (ignored for this model)
|
408 |
+
text_output_extra_tokens: Return both prefix and spatial intermediate tokens (ignored for this model)
|
409 |
+
output_logits: Include logits in output
|
410 |
+
output_logit_scale_bias: Include the logit scale bias in the output
|
411 |
+
Returns:
|
412 |
+
|
413 |
+
"""
|
414 |
+
output = {}
|
415 |
+
if intermediates_only:
|
416 |
+
# intermediates only disables final feature normalization, and include logits
|
417 |
+
normalize = False
|
418 |
+
output_logits = False
|
419 |
+
if output_logits:
|
420 |
+
assert image is not None and text is not None, 'Both image and text inputs are required to compute logits'
|
421 |
+
|
422 |
+
if image is not None:
|
423 |
+
image_output = self.visual.forward_intermediates(
|
424 |
+
image,
|
425 |
+
indices=image_indices,
|
426 |
+
stop_early=stop_early,
|
427 |
+
normalize_intermediates=normalize_intermediates,
|
428 |
+
intermediates_only=intermediates_only,
|
429 |
+
output_fmt=image_output_fmt,
|
430 |
+
output_extra_tokens=image_output_extra_tokens,
|
431 |
+
)
|
432 |
+
if normalize and "image_features" in image_output:
|
433 |
+
image_output["image_features"] = F.normalize(image_output["image_features"], dim=-1)
|
434 |
+
output.update(image_output)
|
435 |
+
|
436 |
+
if text is not None:
|
437 |
+
cast_dtype = self.transformer.get_cast_dtype()
|
438 |
+
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
|
439 |
+
x = x + self.positional_embedding.to(cast_dtype)
|
440 |
+
x, intermediates = self.transformer.forward_intermediates(
|
441 |
+
x,
|
442 |
+
attn_mask=self.attn_mask,
|
443 |
+
indices=text_indices
|
444 |
+
)
|
445 |
+
if normalize_intermediates:
|
446 |
+
intermediates = [self.ln_final(xi) for xi in intermediates]
|
447 |
+
|
448 |
+
# NOTE this model doesn't support cls embed in text transformer, no need for extra intermediate tokens
|
449 |
+
output["text_intermediates"] = intermediates
|
450 |
+
|
451 |
+
if not intermediates_only:
|
452 |
+
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
|
453 |
+
x = text_global_pool(x, text, self.text_pool_type)
|
454 |
+
if self.text_projection is not None:
|
455 |
+
if isinstance(self.text_projection, nn.Linear):
|
456 |
+
x = self.text_projection(x)
|
457 |
+
else:
|
458 |
+
x = x @ self.text_projection
|
459 |
+
if normalize:
|
460 |
+
x = F.normalize(x, dim=-1)
|
461 |
+
output["text_features"] = x
|
462 |
+
|
463 |
+
logit_scale_exp = self.logit_scale.exp() if output_logits or output_logit_scale_bias else None
|
464 |
+
|
465 |
+
if output_logits:
|
466 |
+
image_logits = logit_scale_exp * output["image_features"] @ output["text_features"].T
|
467 |
+
if self.logit_bias is not None:
|
468 |
+
image_logits += self.logit_bias
|
469 |
+
text_logits = image_logits.T
|
470 |
+
output["image_logits"] = image_logits
|
471 |
+
output["text_logits"] = text_logits
|
472 |
+
|
473 |
+
if output_logit_scale_bias:
|
474 |
+
output["logit_scale"] = logit_scale_exp
|
475 |
+
if self.logit_bias is not None:
|
476 |
+
output['logit_bias'] = self.logit_bias
|
477 |
+
|
478 |
+
return output
|
479 |
+
|
480 |
+
|
481 |
+
def forward(
|
482 |
+
self,
|
483 |
+
image: Optional[torch.Tensor] = None,
|
484 |
+
text: Optional[torch.Tensor] = None,
|
485 |
+
concentration: Optional[torch.Tensor] = None,
|
486 |
+
time: Optional[torch.Tensor] = None,
|
487 |
+
compound_embedding: Optional[torch.Tensor] = None,
|
488 |
+
):
|
489 |
+
|
490 |
+
image_features = self.encode_image(image, normalize=True) if image is not None else None
|
491 |
+
text_features = self.encode_text(text, normalize=True, concentration=concentration, time=time, compound_embedding=compound_embedding)
|
492 |
+
if self.output_dict:
|
493 |
+
out_dict = {
|
494 |
+
"image_features": image_features,
|
495 |
+
"text_features": text_features,
|
496 |
+
"logit_scale": self.logit_scale.exp()
|
497 |
+
}
|
498 |
+
if self.logit_bias is not None:
|
499 |
+
out_dict['logit_bias'] = self.logit_bias
|
500 |
+
return out_dict
|
501 |
+
|
502 |
+
if self.logit_bias is not None:
|
503 |
+
return image_features, text_features, self.logit_scale.exp(), self.logit_bias
|
504 |
+
return image_features, text_features, self.logit_scale.exp()
|
505 |
+
|
506 |
+
|
507 |
+
class CustomTextCLIP(nn.Module):
|
508 |
+
output_dict: torch.jit.Final[bool]
|
509 |
+
|
510 |
+
def __init__(
|
511 |
+
self,
|
512 |
+
embed_dim: int,
|
513 |
+
vision_cfg: CLIPVisionCfg,
|
514 |
+
text_cfg: CLIPTextCfg,
|
515 |
+
quick_gelu: bool = False,
|
516 |
+
init_logit_scale: float = np.log(1 / 0.07),
|
517 |
+
init_logit_bias: Optional[float] = None,
|
518 |
+
nonscalar_logit_scale: bool = False,
|
519 |
+
cast_dtype: Optional[torch.dtype] = None,
|
520 |
+
output_dict: bool = False,
|
521 |
+
):
|
522 |
+
super().__init__()
|
523 |
+
self.output_dict = output_dict
|
524 |
+
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
|
525 |
+
self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
|
526 |
+
self.context_length = self.text.context_length
|
527 |
+
self.vocab_size = self.text.vocab_size
|
528 |
+
|
529 |
+
lshape = [1] if nonscalar_logit_scale else []
|
530 |
+
self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale)
|
531 |
+
if init_logit_bias is not None:
|
532 |
+
self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias)
|
533 |
+
else:
|
534 |
+
self.logit_bias = None
|
535 |
+
|
536 |
+
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
|
537 |
+
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
|
538 |
+
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
|
539 |
+
|
540 |
+
def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
|
541 |
+
self.text.lock(unlocked_layers, freeze_layer_norm)
|
542 |
+
|
543 |
+
@torch.jit.ignore
|
544 |
+
def set_grad_checkpointing(self, enable=True):
|
545 |
+
self.visual.set_grad_checkpointing(enable)
|
546 |
+
self.text.set_grad_checkpointing(enable)
|
547 |
+
|
548 |
+
@torch.jit.ignore
|
549 |
+
def no_weight_decay(self):
|
550 |
+
# for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default
|
551 |
+
no_wd = set()
|
552 |
+
if hasattr(self.visual, 'no_weight_decay'):
|
553 |
+
for n in self.visual.no_weight_decay():
|
554 |
+
no_wd.add('visual.' + n)
|
555 |
+
if hasattr(self.text, 'no_weight_decay'):
|
556 |
+
for n in self.visual.no_weight_decay():
|
557 |
+
no_wd.add('text.' + n)
|
558 |
+
return no_wd
|
559 |
+
|
560 |
+
def encode_image(self, image, normalize: bool = False):
|
561 |
+
features = self.visual(image)
|
562 |
+
return F.normalize(features, dim=-1) if normalize else features
|
563 |
+
|
564 |
+
def encode_text(self, text, normalize: bool = False):
|
565 |
+
features = self.text(text)
|
566 |
+
return F.normalize(features, dim=-1) if normalize else features
|
567 |
+
|
568 |
+
def get_logits(self, image, text):
|
569 |
+
image_features = self.encode_image(image, normalize=True)
|
570 |
+
text_features = self.encode_text(text, normalize=True)
|
571 |
+
image_logits = self.logit_scale.exp() * image_features @ text_features.T
|
572 |
+
if self.logit_bias is not None:
|
573 |
+
image_logits += self.logit_bias
|
574 |
+
text_logits = image_logits.T
|
575 |
+
return image_logits, text_logits
|
576 |
+
|
577 |
+
def forward_intermediates(
|
578 |
+
self,
|
579 |
+
image: Optional[torch.Tensor] = None,
|
580 |
+
text: Optional[torch.Tensor] = None,
|
581 |
+
image_indices: Optional[Union[int, List[int]]] = None,
|
582 |
+
text_indices: Optional[Union[int, List[int]]] = None,
|
583 |
+
stop_early: bool = False,
|
584 |
+
normalize: bool = True,
|
585 |
+
normalize_intermediates: bool = False,
|
586 |
+
intermediates_only: bool = False,
|
587 |
+
image_output_fmt: str = 'NCHW',
|
588 |
+
image_output_extra_tokens: bool = False,
|
589 |
+
text_output_fmt: str = 'NLC',
|
590 |
+
text_output_extra_tokens: bool = False,
|
591 |
+
output_logits: bool = False,
|
592 |
+
output_logit_scale_bias: bool = False,
|
593 |
+
) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
|
594 |
+
""" Forward features that returns intermediates.
|
595 |
+
|
596 |
+
Args:
|
597 |
+
image: Input image tensor
|
598 |
+
text: Input text tensor
|
599 |
+
image_indices: For image tower, Take last n blocks if int, all if None, select matching indices if sequence
|
600 |
+
text_indices: Take last n blocks if int, all if None, select matching indices if sequence
|
601 |
+
stop_early: Stop iterating over blocks when last desired intermediate hit
|
602 |
+
normalize: L2 Normalize final image and text features (if present)
|
603 |
+
normalize_intermediates: Apply final encoder norm layer to all intermediates (if possible)
|
604 |
+
intermediates_only: Only return intermediate features, do not return final features
|
605 |
+
image_output_fmt: Shape of intermediate image feature outputs
|
606 |
+
image_output_extra_tokens: Return both prefix and spatial intermediate tokens
|
607 |
+
text_output_fmt: Shape of intermediate text feature outputs
|
608 |
+
text_output_extra_tokens: Return both prefix and spatial intermediate tokens
|
609 |
+
output_logits: Include logits in output
|
610 |
+
output_logit_scale_bias: Include the logit scale bias in the output
|
611 |
+
Returns:
|
612 |
+
|
613 |
+
"""
|
614 |
+
output = {}
|
615 |
+
if intermediates_only:
|
616 |
+
# intermediates only disables final feature normalization, and include logits
|
617 |
+
normalize = False
|
618 |
+
output_logits = False
|
619 |
+
if output_logits:
|
620 |
+
assert image is not None and text is not None, 'Both image and text inputs are required to compute logits'
|
621 |
+
|
622 |
+
if image is not None:
|
623 |
+
image_output = self.visual.forward_intermediates(
|
624 |
+
image,
|
625 |
+
indices=image_indices,
|
626 |
+
stop_early=stop_early,
|
627 |
+
normalize_intermediates=normalize_intermediates,
|
628 |
+
intermediates_only=intermediates_only,
|
629 |
+
output_fmt=image_output_fmt,
|
630 |
+
output_extra_tokens=image_output_extra_tokens,
|
631 |
+
)
|
632 |
+
if normalize and "image_features" in image_output:
|
633 |
+
image_output["image_features"] = F.normalize(image_output["image_features"], dim=-1)
|
634 |
+
output.update(image_output)
|
635 |
+
|
636 |
+
if text is not None:
|
637 |
+
text_output = self.text.forward_intermediates(
|
638 |
+
text,
|
639 |
+
indices=text_indices,
|
640 |
+
stop_early=stop_early,
|
641 |
+
normalize_intermediates=normalize_intermediates,
|
642 |
+
intermediates_only=intermediates_only,
|
643 |
+
output_fmt=text_output_fmt,
|
644 |
+
output_extra_tokens=text_output_extra_tokens,
|
645 |
+
)
|
646 |
+
if normalize and "text_features" in text_output:
|
647 |
+
text_output["text_features"] = F.normalize(text_output["text_features"], dim=-1)
|
648 |
+
output.update(text_output)
|
649 |
+
|
650 |
+
logit_scale_exp = self.logit_scale.exp() if output_logits or output_logit_scale_bias else None
|
651 |
+
|
652 |
+
if output_logits:
|
653 |
+
image_logits = logit_scale_exp * output["image_features"] @ output["text_features"].T
|
654 |
+
if self.logit_bias is not None:
|
655 |
+
image_logits += self.logit_bias
|
656 |
+
text_logits = image_logits.T
|
657 |
+
output["image_logits"] = image_logits
|
658 |
+
output["text_logits"] = text_logits
|
659 |
+
|
660 |
+
if output_logit_scale_bias:
|
661 |
+
output["logit_scale"] = logit_scale_exp
|
662 |
+
if self.logit_bias is not None:
|
663 |
+
output['logit_bias'] = self.logit_bias
|
664 |
+
|
665 |
+
return output
|
666 |
+
|
667 |
+
def forward(
|
668 |
+
self,
|
669 |
+
image: Optional[torch.Tensor] = None,
|
670 |
+
text: Optional[torch.Tensor] = None,
|
671 |
+
):
|
672 |
+
image_features = self.encode_image(image, normalize=True) if image is not None else None
|
673 |
+
text_features = self.encode_text(text, normalize=True) if text is not None else None
|
674 |
+
|
675 |
+
if self.output_dict:
|
676 |
+
out_dict = {
|
677 |
+
"image_features": image_features,
|
678 |
+
"text_features": text_features,
|
679 |
+
"logit_scale": self.logit_scale.exp()
|
680 |
+
}
|
681 |
+
if self.logit_bias is not None:
|
682 |
+
out_dict['logit_bias'] = self.logit_bias
|
683 |
+
return out_dict
|
684 |
+
|
685 |
+
if self.logit_bias is not None:
|
686 |
+
return image_features, text_features, self.logit_scale.exp(), self.logit_bias
|
687 |
+
return image_features, text_features, self.logit_scale.exp()
|
688 |
+
|
689 |
+
|
690 |
+
def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
|
691 |
+
"""Convert applicable model parameters to low-precision (bf16 or fp16)"""
|
692 |
+
|
693 |
+
def _convert_weights(l):
|
694 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
695 |
+
l.weight.data = l.weight.data.to(dtype)
|
696 |
+
if l.bias is not None:
|
697 |
+
l.bias.data = l.bias.data.to(dtype)
|
698 |
+
|
699 |
+
if isinstance(l, (nn.MultiheadAttention, Attention)):
|
700 |
+
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
701 |
+
tensor = getattr(l, attr)
|
702 |
+
if tensor is not None:
|
703 |
+
tensor.data = tensor.data.to(dtype)
|
704 |
+
|
705 |
+
if isinstance(l, (CLIP, TextTransformer)):
|
706 |
+
# convert text nn.Parameter projections
|
707 |
+
attr = getattr(l, "text_projection", None)
|
708 |
+
if attr is not None:
|
709 |
+
attr.data = attr.data.to(dtype)
|
710 |
+
|
711 |
+
if isinstance(l, VisionTransformer):
|
712 |
+
# convert vision nn.Parameter projections
|
713 |
+
attr = getattr(l, "proj", None)
|
714 |
+
if attr is not None:
|
715 |
+
attr.data = attr.data.to(dtype)
|
716 |
+
|
717 |
+
model.apply(_convert_weights)
|
718 |
+
|
719 |
+
|
720 |
+
convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
|
721 |
+
|
722 |
+
|
723 |
+
# used to maintain checkpoint compatibility
|
724 |
+
def convert_to_custom_text_state_dict(state_dict: dict):
|
725 |
+
if 'text_projection' in state_dict:
|
726 |
+
# old format state_dict, move text tower -> .text
|
727 |
+
new_state_dict = {}
|
728 |
+
for k, v in state_dict.items():
|
729 |
+
if any(k.startswith(p) for p in (
|
730 |
+
'text_projection',
|
731 |
+
'positional_embedding',
|
732 |
+
'token_embedding',
|
733 |
+
'transformer',
|
734 |
+
'ln_final',
|
735 |
+
)):
|
736 |
+
k = 'text.' + k
|
737 |
+
new_state_dict[k] = v
|
738 |
+
return new_state_dict
|
739 |
+
return state_dict
|
740 |
+
|
741 |
+
|
742 |
+
def build_model_from_openai_state_dict(
|
743 |
+
state_dict: dict,
|
744 |
+
quick_gelu=True,
|
745 |
+
cast_dtype=torch.float16,
|
746 |
+
):
|
747 |
+
vit = "visual.proj" in state_dict
|
748 |
+
|
749 |
+
if vit:
|
750 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
751 |
+
vision_layers = len(
|
752 |
+
[k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
753 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
754 |
+
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
755 |
+
image_size = vision_patch_size * grid_size
|
756 |
+
else:
|
757 |
+
counts: list = [
|
758 |
+
len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
759 |
+
vision_layers = tuple(counts)
|
760 |
+
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
761 |
+
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
762 |
+
vision_patch_size = None
|
763 |
+
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
764 |
+
image_size = output_width * 32
|
765 |
+
|
766 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
767 |
+
context_length = state_dict["positional_embedding"].shape[0]
|
768 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
769 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
770 |
+
transformer_heads = transformer_width // 64
|
771 |
+
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
|
772 |
+
|
773 |
+
vision_cfg = CLIPVisionCfg(
|
774 |
+
layers=vision_layers,
|
775 |
+
width=vision_width,
|
776 |
+
patch_size=vision_patch_size,
|
777 |
+
image_size=image_size,
|
778 |
+
)
|
779 |
+
text_cfg = CLIPTextCfg(
|
780 |
+
context_length=context_length,
|
781 |
+
vocab_size=vocab_size,
|
782 |
+
width=transformer_width,
|
783 |
+
heads=transformer_heads,
|
784 |
+
layers=transformer_layers,
|
785 |
+
)
|
786 |
+
model = CLIP(
|
787 |
+
embed_dim,
|
788 |
+
vision_cfg=vision_cfg,
|
789 |
+
text_cfg=text_cfg,
|
790 |
+
quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
|
791 |
+
cast_dtype=cast_dtype,
|
792 |
+
)
|
793 |
+
|
794 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
795 |
+
state_dict.pop(key, None)
|
796 |
+
convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
|
797 |
+
model.load_state_dict(state_dict)
|
798 |
+
return model.eval()
|
799 |
+
|
800 |
+
|
801 |
+
def trace_model(model, batch_size=256, device=torch.device('cpu')):
|
802 |
+
model.eval()
|
803 |
+
image_size = model.visual.image_size
|
804 |
+
example_images = torch.ones((batch_size, 2, image_size, image_size), device=device)
|
805 |
+
example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
|
806 |
+
example_concentration = torch.rand((batch_size, 2), device=device)
|
807 |
+
example_time = torch.rand((batch_size, 1), device=device)
|
808 |
+
example_compound_embedding = torch.rand((batch_size, 159), device=device)
|
809 |
+
model = torch.jit.trace_module(
|
810 |
+
model,
|
811 |
+
inputs=dict(
|
812 |
+
forward=(example_images, example_text, example_concentration, example_time, example_compound_embedding),
|
813 |
+
encode_text=(example_text, True, example_concentration, example_time, example_compound_embedding),
|
814 |
+
encode_image=(example_images,)
|
815 |
+
))
|
816 |
+
model.visual.image_size = image_size
|
817 |
+
return model
|
818 |
+
|
819 |
+
|
820 |
+
def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):
|
821 |
+
# Rescale the grid of position embeddings when loading from state_dict
|
822 |
+
old_pos_embed = state_dict.get('visual.positional_embedding', None)
|
823 |
+
if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
|
824 |
+
return
|
825 |
+
grid_size = to_2tuple(model.visual.grid_size)
|
826 |
+
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
|
827 |
+
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
|
828 |
+
if new_seq_len == old_pos_embed.shape[0]:
|
829 |
+
return
|
830 |
+
|
831 |
+
if extra_tokens:
|
832 |
+
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
|
833 |
+
else:
|
834 |
+
pos_emb_tok, pos_emb_img = None, old_pos_embed
|
835 |
+
old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
|
836 |
+
|
837 |
+
logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
|
838 |
+
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
|
839 |
+
pos_emb_img = F.interpolate(
|
840 |
+
pos_emb_img,
|
841 |
+
size=grid_size,
|
842 |
+
mode=interpolation,
|
843 |
+
antialias=antialias,
|
844 |
+
align_corners=False,
|
845 |
+
)
|
846 |
+
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
|
847 |
+
if pos_emb_tok is not None:
|
848 |
+
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
|
849 |
+
else:
|
850 |
+
new_pos_embed = pos_emb_img
|
851 |
+
state_dict['visual.positional_embedding'] = new_pos_embed
|
852 |
+
|
853 |
+
|
854 |
+
def resize_text_pos_embed(state_dict, model, interpolation: str = 'linear', antialias: bool = False):
|
855 |
+
old_pos_embed = state_dict.get('positional_embedding', None)
|
856 |
+
if old_pos_embed is None:
|
857 |
+
return
|
858 |
+
# FIXME add support for text cls_token
|
859 |
+
model_pos_embed = getattr(model, 'positional_embedding', None)
|
860 |
+
if model_pos_embed is None:
|
861 |
+
model_pos_embed = getattr(model.text, 'positional_embedding', None)
|
862 |
+
|
863 |
+
old_num_pos = old_pos_embed.shape[0]
|
864 |
+
old_width = old_pos_embed.shape[1]
|
865 |
+
num_pos = model_pos_embed.shape[0]
|
866 |
+
width = model_pos_embed.shape[1]
|
867 |
+
assert old_width == width, 'text pos_embed width changed!'
|
868 |
+
if old_num_pos == num_pos:
|
869 |
+
return
|
870 |
+
|
871 |
+
logging.info('Resizing text position embedding num_pos from %s to %s', old_num_pos, num_pos)
|
872 |
+
old_pos_embed = old_pos_embed.reshape(1, old_num_pos, old_width).permute(0, 2, 1)
|
873 |
+
old_pos_embed = F.interpolate(
|
874 |
+
old_pos_embed,
|
875 |
+
size=num_pos,
|
876 |
+
mode=interpolation,
|
877 |
+
antialias=antialias,
|
878 |
+
align_corners=False,
|
879 |
+
)
|
880 |
+
old_pos_embed = old_pos_embed.permute(0, 2, 1)[0]
|
881 |
+
new_pos_embed = old_pos_embed
|
882 |
+
|
883 |
+
state_dict['positional_embedding'] = new_pos_embed
|
884 |
+
|
885 |
+
|
886 |
+
def get_model_preprocess_cfg(model):
|
887 |
+
module = getattr(model, 'visual', model)
|
888 |
+
preprocess_cfg = getattr(module, 'preprocess_cfg', {})
|
889 |
+
if not preprocess_cfg:
|
890 |
+
# use separate legacy attributes if preprocess_cfg dict not found
|
891 |
+
size = getattr(module, 'image_size')
|
892 |
+
if size is not None:
|
893 |
+
preprocess_cfg['size'] = size
|
894 |
+
mean = getattr(module, 'image_mean', None)
|
895 |
+
if mean is not None:
|
896 |
+
preprocess_cfg['mean'] = mean
|
897 |
+
std = getattr(module, 'image_std', None)
|
898 |
+
if std is not None:
|
899 |
+
preprocess_cfg['std'] = std
|
900 |
+
return preprocess_cfg
|
901 |
+
|
902 |
+
|
903 |
+
def set_model_preprocess_cfg(model, preprocess_cfg: Dict[str, Any]):
|
904 |
+
module = getattr(model, 'visual', model)
|
905 |
+
module.image_mean = preprocess_cfg['mean'] # legacy attribute, keeping for bwd compat
|
906 |
+
module.image_std = preprocess_cfg['std'] # legacy attribute, keeping for bwd compat
|
907 |
+
module.preprocess_cfg = copy.deepcopy(preprocess_cfg) # new attr, package all pp cfg as dict
|
908 |
+
|
909 |
+
|
910 |
+
def get_model_tokenize_cfg(model):
|
911 |
+
module = getattr(model, 'text', model)
|
912 |
+
cfg = {}
|
913 |
+
context_length = getattr(module, 'context_length', None)
|
914 |
+
if context_length is not None:
|
915 |
+
cfg['context_length'] = context_length
|
916 |
+
vocab_size = getattr(module, 'vocab_size', None)
|
917 |
+
if vocab_size is not None:
|
918 |
+
cfg['vocab_size'] = vocab_size
|
919 |
+
return cfg
|
src/open_clip/model_configs/EVA01-g-14-plus.json
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"timm_model_name": "eva_giant_patch14_224",
|
6 |
+
"timm_model_pretrained": false,
|
7 |
+
"timm_pool": "token",
|
8 |
+
"timm_proj": null
|
9 |
+
},
|
10 |
+
"text_cfg": {
|
11 |
+
"context_length": 77,
|
12 |
+
"vocab_size": 49408,
|
13 |
+
"width": 1024,
|
14 |
+
"heads": 16,
|
15 |
+
"layers": 24
|
16 |
+
},
|
17 |
+
"custom_text": true
|
18 |
+
}
|
src/open_clip/model_configs/EVA01-g-14.json
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"timm_model_name": "eva_giant_patch14_224",
|
6 |
+
"timm_model_pretrained": false,
|
7 |
+
"timm_pool": "token",
|
8 |
+
"timm_proj": null
|
9 |
+
},
|
10 |
+
"text_cfg": {
|
11 |
+
"context_length": 77,
|
12 |
+
"vocab_size": 49408,
|
13 |
+
"width": 768,
|
14 |
+
"heads": 12,
|
15 |
+
"layers": 12
|
16 |
+
},
|
17 |
+
"custom_text": true
|
18 |
+
}
|
src/open_clip/model_configs/EVA02-B-16.json
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"timm_model_name": "eva02_base_patch16_clip_224",
|
6 |
+
"timm_model_pretrained": false,
|
7 |
+
"timm_pool": "token",
|
8 |
+
"timm_proj": null
|
9 |
+
},
|
10 |
+
"text_cfg": {
|
11 |
+
"context_length": 77,
|
12 |
+
"vocab_size": 49408,
|
13 |
+
"width": 512,
|
14 |
+
"heads": 8,
|
15 |
+
"layers": 12
|
16 |
+
},
|
17 |
+
"custom_text": true
|
18 |
+
}
|
src/open_clip/model_configs/EVA02-E-14-plus.json
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"timm_model_name": "eva02_enormous_patch14_clip_224",
|
6 |
+
"timm_model_pretrained": false,
|
7 |
+
"timm_pool": "token",
|
8 |
+
"timm_proj": null
|
9 |
+
},
|
10 |
+
"text_cfg": {
|
11 |
+
"context_length": 77,
|
12 |
+
"vocab_size": 49408,
|
13 |
+
"width": 1280,
|
14 |
+
"heads": 20,
|
15 |
+
"layers": 32
|
16 |
+
},
|
17 |
+
"custom_text": true
|
18 |
+
}
|
src/open_clip/model_configs/EVA02-E-14.json
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"timm_model_name": "eva02_enormous_patch14_clip_224",
|
6 |
+
"timm_model_pretrained": false,
|
7 |
+
"timm_pool": "token",
|
8 |
+
"timm_proj": null
|
9 |
+
},
|
10 |
+
"text_cfg": {
|
11 |
+
"context_length": 77,
|
12 |
+
"vocab_size": 49408,
|
13 |
+
"width": 1024,
|
14 |
+
"heads": 16,
|
15 |
+
"layers": 24
|
16 |
+
},
|
17 |
+
"custom_text": true
|
18 |
+
}
|
src/open_clip/model_configs/EVA02-L-14-336.json
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 336,
|
5 |
+
"timm_model_name": "eva02_large_patch14_clip_336",
|
6 |
+
"timm_model_pretrained": false,
|
7 |
+
"timm_pool": "token",
|
8 |
+
"timm_proj": null
|
9 |
+
},
|
10 |
+
"text_cfg": {
|
11 |
+
"context_length": 77,
|
12 |
+
"vocab_size": 49408,
|
13 |
+
"width": 768,
|
14 |
+
"heads": 12,
|
15 |
+
"layers": 12
|
16 |
+
},
|
17 |
+
"custom_text": true
|
18 |
+
}
|
src/open_clip/model_configs/EVA02-L-14.json
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"timm_model_name": "eva02_large_patch14_clip_224",
|
6 |
+
"timm_model_pretrained": false,
|
7 |
+
"timm_pool": "token",
|
8 |
+
"timm_proj": null
|
9 |
+
},
|
10 |
+
"text_cfg": {
|
11 |
+
"context_length": 77,
|
12 |
+
"vocab_size": 49408,
|
13 |
+
"width": 768,
|
14 |
+
"heads": 12,
|
15 |
+
"layers": 12
|
16 |
+
},
|
17 |
+
"custom_text": true
|
18 |
+
}
|
src/open_clip/model_configs/MobileCLIP-B.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"vision_cfg": {
|
4 |
+
"timm_model_name": "vit_base_mci_224",
|
5 |
+
"timm_model_pretrained": false,
|
6 |
+
"timm_pool": "token",
|
7 |
+
"timm_proj": null,
|
8 |
+
"timm_drop": 0.0,
|
9 |
+
"timm_drop_path": 0.0,
|
10 |
+
"image_size": 224
|
11 |
+
},
|
12 |
+
"text_cfg": {
|
13 |
+
"context_length": 77,
|
14 |
+
"vocab_size": 49408,
|
15 |
+
"width": 512,
|
16 |
+
"heads": 8,
|
17 |
+
"layers": 12,
|
18 |
+
"no_causal_mask": false
|
19 |
+
},
|
20 |
+
"custom_text": true
|
21 |
+
}
|
src/open_clip/model_configs/MobileCLIP-S1.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"vision_cfg": {
|
4 |
+
"timm_model_name": "fastvit_mci1",
|
5 |
+
"timm_model_pretrained": false,
|
6 |
+
"timm_pool": "avg",
|
7 |
+
"timm_proj": null,
|
8 |
+
"timm_drop": 0.0,
|
9 |
+
"timm_drop_path": 0.0,
|
10 |
+
"image_size": 256
|
11 |
+
},
|
12 |
+
"text_cfg": {
|
13 |
+
"context_length": 77,
|
14 |
+
"vocab_size": 49408,
|
15 |
+
"width": 512,
|
16 |
+
"heads": 8,
|
17 |
+
"layers": 12,
|
18 |
+
"no_causal_mask": true
|
19 |
+
},
|
20 |
+
"custom_text": true
|
21 |
+
}
|
src/open_clip/model_configs/MobileCLIP-S2.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"vision_cfg": {
|
4 |
+
"timm_model_name": "fastvit_mci2",
|
5 |
+
"timm_model_pretrained": false,
|
6 |
+
"timm_pool": "avg",
|
7 |
+
"timm_proj": null,
|
8 |
+
"timm_drop": 0.0,
|
9 |
+
"timm_drop_path": 0.0,
|
10 |
+
"image_size": 256
|
11 |
+
},
|
12 |
+
"text_cfg": {
|
13 |
+
"context_length": 77,
|
14 |
+
"vocab_size": 49408,
|
15 |
+
"width": 512,
|
16 |
+
"heads": 8,
|
17 |
+
"layers": 12,
|
18 |
+
"no_causal_mask": true
|
19 |
+
},
|
20 |
+
"custom_text": true
|
21 |
+
}
|
src/open_clip/model_configs/RN101-quickgelu.json
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"quick_gelu": true,
|
4 |
+
"vision_cfg": {
|
5 |
+
"image_size": 224,
|
6 |
+
"layers": [
|
7 |
+
3,
|
8 |
+
4,
|
9 |
+
23,
|
10 |
+
3
|
11 |
+
],
|
12 |
+
"width": 64,
|
13 |
+
"patch_size": null
|
14 |
+
},
|
15 |
+
"text_cfg": {
|
16 |
+
"context_length": 77,
|
17 |
+
"vocab_size": 49408,
|
18 |
+
"width": 512,
|
19 |
+
"heads": 8,
|
20 |
+
"layers": 12
|
21 |
+
}
|
22 |
+
}
|
src/open_clip/model_configs/RN101.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": [
|
6 |
+
3,
|
7 |
+
4,
|
8 |
+
23,
|
9 |
+
3
|
10 |
+
],
|
11 |
+
"width": 64,
|
12 |
+
"patch_size": null
|
13 |
+
},
|
14 |
+
"text_cfg": {
|
15 |
+
"context_length": 77,
|
16 |
+
"vocab_size": 49408,
|
17 |
+
"width": 512,
|
18 |
+
"heads": 8,
|
19 |
+
"layers": 12
|
20 |
+
}
|
21 |
+
}
|
src/open_clip/model_configs/RN50-quickgelu.json
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"quick_gelu": true,
|
4 |
+
"vision_cfg": {
|
5 |
+
"image_size": 224,
|
6 |
+
"layers": [
|
7 |
+
3,
|
8 |
+
4,
|
9 |
+
6,
|
10 |
+
3
|
11 |
+
],
|
12 |
+
"width": 64,
|
13 |
+
"patch_size": null
|
14 |
+
},
|
15 |
+
"text_cfg": {
|
16 |
+
"context_length": 77,
|
17 |
+
"vocab_size": 49408,
|
18 |
+
"width": 512,
|
19 |
+
"heads": 8,
|
20 |
+
"layers": 12
|
21 |
+
}
|
22 |
+
}
|
src/open_clip/model_configs/RN50.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": [
|
6 |
+
3,
|
7 |
+
4,
|
8 |
+
6,
|
9 |
+
3
|
10 |
+
],
|
11 |
+
"width": 64,
|
12 |
+
"patch_size": null
|
13 |
+
},
|
14 |
+
"text_cfg": {
|
15 |
+
"context_length": 77,
|
16 |
+
"vocab_size": 49408,
|
17 |
+
"width": 512,
|
18 |
+
"heads": 8,
|
19 |
+
"layers": 12
|
20 |
+
}
|
21 |
+
}
|
src/open_clip/model_configs/RN50x16-quickgelu.json
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"quick_gelu": true,
|
4 |
+
"vision_cfg": {
|
5 |
+
"image_size": 384,
|
6 |
+
"layers": [
|
7 |
+
6,
|
8 |
+
8,
|
9 |
+
18,
|
10 |
+
8
|
11 |
+
],
|
12 |
+
"width": 96,
|
13 |
+
"patch_size": null
|
14 |
+
},
|
15 |
+
"text_cfg": {
|
16 |
+
"context_length": 77,
|
17 |
+
"vocab_size": 49408,
|
18 |
+
"width": 768,
|
19 |
+
"heads": 12,
|
20 |
+
"layers": 12
|
21 |
+
}
|
22 |
+
}
|
src/open_clip/model_configs/RN50x16.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 384,
|
5 |
+
"layers": [
|
6 |
+
6,
|
7 |
+
8,
|
8 |
+
18,
|
9 |
+
8
|
10 |
+
],
|
11 |
+
"width": 96,
|
12 |
+
"patch_size": null
|
13 |
+
},
|
14 |
+
"text_cfg": {
|
15 |
+
"context_length": 77,
|
16 |
+
"vocab_size": 49408,
|
17 |
+
"width": 768,
|
18 |
+
"heads": 12,
|
19 |
+
"layers": 12
|
20 |
+
}
|
21 |
+
}
|
src/open_clip/model_configs/RN50x4-quickgelu.json
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 640,
|
3 |
+
"quick_gelu": true,
|
4 |
+
"vision_cfg": {
|
5 |
+
"image_size": 288,
|
6 |
+
"layers": [
|
7 |
+
4,
|
8 |
+
6,
|
9 |
+
10,
|
10 |
+
6
|
11 |
+
],
|
12 |
+
"width": 80,
|
13 |
+
"patch_size": null
|
14 |
+
},
|
15 |
+
"text_cfg": {
|
16 |
+
"context_length": 77,
|
17 |
+
"vocab_size": 49408,
|
18 |
+
"width": 640,
|
19 |
+
"heads": 10,
|
20 |
+
"layers": 12
|
21 |
+
}
|
22 |
+
}
|
src/open_clip/model_configs/RN50x4.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 640,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 288,
|
5 |
+
"layers": [
|
6 |
+
4,
|
7 |
+
6,
|
8 |
+
10,
|
9 |
+
6
|
10 |
+
],
|
11 |
+
"width": 80,
|
12 |
+
"patch_size": null
|
13 |
+
},
|
14 |
+
"text_cfg": {
|
15 |
+
"context_length": 77,
|
16 |
+
"vocab_size": 49408,
|
17 |
+
"width": 640,
|
18 |
+
"heads": 10,
|
19 |
+
"layers": 12
|
20 |
+
}
|
21 |
+
}
|
src/open_clip/model_configs/RN50x64-quickgelu.json
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"quick_gelu": true,
|
4 |
+
"vision_cfg": {
|
5 |
+
"image_size": 448,
|
6 |
+
"layers": [
|
7 |
+
3,
|
8 |
+
15,
|
9 |
+
36,
|
10 |
+
10
|
11 |
+
],
|
12 |
+
"width": 128,
|
13 |
+
"patch_size": null
|
14 |
+
},
|
15 |
+
"text_cfg": {
|
16 |
+
"context_length": 77,
|
17 |
+
"vocab_size": 49408,
|
18 |
+
"width": 1024,
|
19 |
+
"heads": 16,
|
20 |
+
"layers": 12
|
21 |
+
}
|
22 |
+
}
|
src/open_clip/model_configs/RN50x64.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 448,
|
5 |
+
"layers": [
|
6 |
+
3,
|
7 |
+
15,
|
8 |
+
36,
|
9 |
+
10
|
10 |
+
],
|
11 |
+
"width": 128,
|
12 |
+
"patch_size": null
|
13 |
+
},
|
14 |
+
"text_cfg": {
|
15 |
+
"context_length": 77,
|
16 |
+
"vocab_size": 49408,
|
17 |
+
"width": 1024,
|
18 |
+
"heads": 16,
|
19 |
+
"layers": 12
|
20 |
+
}
|
21 |
+
}
|
src/open_clip/model_configs/ViT-B-16-SigLIP-256.json
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"init_logit_bias": -10,
|
4 |
+
"custom_text": true,
|
5 |
+
"vision_cfg": {
|
6 |
+
"image_size": 256,
|
7 |
+
"timm_model_name": "vit_base_patch16_siglip_256",
|
8 |
+
"timm_model_pretrained": false,
|
9 |
+
"timm_pool": "map",
|
10 |
+
"timm_proj": "none"
|
11 |
+
},
|
12 |
+
"text_cfg": {
|
13 |
+
"context_length": 64,
|
14 |
+
"vocab_size": 32000,
|
15 |
+
"hf_tokenizer_name": "timm/ViT-B-16-SigLIP",
|
16 |
+
"tokenizer_kwargs": {
|
17 |
+
"clean": "canonicalize"
|
18 |
+
},
|
19 |
+
"width": 768,
|
20 |
+
"heads": 12,
|
21 |
+
"layers": 12,
|
22 |
+
"no_causal_mask": true,
|
23 |
+
"proj_bias": true,
|
24 |
+
"pool_type": "last",
|
25 |
+
"norm_kwargs":{
|
26 |
+
"eps": 1e-6
|
27 |
+
}
|
28 |
+
}
|
29 |
+
}
|
src/open_clip/model_configs/ViT-B-16-SigLIP-384.json
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"init_logit_bias": -10,
|
4 |
+
"custom_text": true,
|
5 |
+
"vision_cfg": {
|
6 |
+
"image_size": 384,
|
7 |
+
"timm_model_name": "vit_base_patch16_siglip_384",
|
8 |
+
"timm_model_pretrained": false,
|
9 |
+
"timm_pool": "map",
|
10 |
+
"timm_proj": "none"
|
11 |
+
},
|
12 |
+
"text_cfg": {
|
13 |
+
"context_length": 64,
|
14 |
+
"vocab_size": 32000,
|
15 |
+
"hf_tokenizer_name": "timm/ViT-B-16-SigLIP",
|
16 |
+
"tokenizer_kwargs": {
|
17 |
+
"clean": "canonicalize"
|
18 |
+
},
|
19 |
+
"width": 768,
|
20 |
+
"heads": 12,
|
21 |
+
"layers": 12,
|
22 |
+
"no_causal_mask": true,
|
23 |
+
"proj_bias": true,
|
24 |
+
"pool_type": "last",
|
25 |
+
"norm_kwargs":{
|
26 |
+
"eps": 1e-6
|
27 |
+
}
|
28 |
+
}
|
29 |
+
}
|
src/open_clip/model_configs/ViT-B-16-SigLIP-512.json
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"init_logit_bias": -10,
|
4 |
+
"custom_text": true,
|
5 |
+
"vision_cfg": {
|
6 |
+
"image_size": 512,
|
7 |
+
"timm_model_name": "vit_base_patch16_siglip_512",
|
8 |
+
"timm_model_pretrained": false,
|
9 |
+
"timm_pool": "map",
|
10 |
+
"timm_proj": "none"
|
11 |
+
},
|
12 |
+
"text_cfg": {
|
13 |
+
"context_length": 64,
|
14 |
+
"vocab_size": 32000,
|
15 |
+
"hf_tokenizer_name": "timm/ViT-B-16-SigLIP",
|
16 |
+
"tokenizer_kwargs": {
|
17 |
+
"clean": "canonicalize"
|
18 |
+
},
|
19 |
+
"width": 768,
|
20 |
+
"heads": 12,
|
21 |
+
"layers": 12,
|
22 |
+
"no_causal_mask": true,
|
23 |
+
"proj_bias": true,
|
24 |
+
"pool_type": "last",
|
25 |
+
"norm_kwargs":{
|
26 |
+
"eps": 1e-6
|
27 |
+
}
|
28 |
+
}
|
29 |
+
}
|
src/open_clip/model_configs/ViT-B-16-SigLIP-i18n-256.json
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"init_logit_bias": -10,
|
4 |
+
"custom_text": true,
|
5 |
+
"vision_cfg": {
|
6 |
+
"image_size": 256,
|
7 |
+
"timm_model_name": "vit_base_patch16_siglip_256",
|
8 |
+
"timm_model_pretrained": false,
|
9 |
+
"timm_pool": "map",
|
10 |
+
"timm_proj": "none"
|
11 |
+
},
|
12 |
+
"text_cfg": {
|
13 |
+
"context_length": 64,
|
14 |
+
"vocab_size": 250000,
|
15 |
+
"hf_tokenizer_name": "timm/ViT-B-16-SigLIP-i18n-256",
|
16 |
+
"tokenizer_kwargs": {
|
17 |
+
"clean": "canonicalize"
|
18 |
+
},
|
19 |
+
"width": 768,
|
20 |
+
"heads": 12,
|
21 |
+
"layers": 12,
|
22 |
+
"no_causal_mask": true,
|
23 |
+
"proj_bias": true,
|
24 |
+
"pool_type": "last",
|
25 |
+
"norm_kwargs":{
|
26 |
+
"eps": 1e-6
|
27 |
+
}
|
28 |
+
}
|
29 |
+
}
|
src/open_clip/model_configs/ViT-B-16-SigLIP.json
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"init_logit_bias": -10,
|
4 |
+
"custom_text": true,
|
5 |
+
"vision_cfg": {
|
6 |
+
"image_size": 224,
|
7 |
+
"timm_model_name": "vit_base_patch16_siglip_224",
|
8 |
+
"timm_model_pretrained": false,
|
9 |
+
"timm_pool": "map",
|
10 |
+
"timm_proj": "none"
|
11 |
+
},
|
12 |
+
"text_cfg": {
|
13 |
+
"context_length": 64,
|
14 |
+
"vocab_size": 32000,
|
15 |
+
"hf_tokenizer_name": "timm/ViT-B-16-SigLIP",
|
16 |
+
"tokenizer_kwargs": {
|
17 |
+
"clean": "canonicalize"
|
18 |
+
},
|
19 |
+
"width": 768,
|
20 |
+
"heads": 12,
|
21 |
+
"layers": 12,
|
22 |
+
"no_causal_mask": true,
|
23 |
+
"proj_bias": true,
|
24 |
+
"pool_type": "last",
|
25 |
+
"norm_kwargs":{
|
26 |
+
"eps": 1e-6
|
27 |
+
}
|
28 |
+
}
|
29 |
+
}
|
src/open_clip/model_configs/ViT-B-16-SigLIP2-256.json
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"init_logit_bias": -10,
|
4 |
+
"custom_text": true,
|
5 |
+
"vision_cfg": {
|
6 |
+
"image_size": 256,
|
7 |
+
"timm_model_name": "vit_base_patch16_siglip_256",
|
8 |
+
"timm_model_pretrained": false,
|
9 |
+
"timm_pool": "map",
|
10 |
+
"timm_proj": "none"
|
11 |
+
},
|
12 |
+
"text_cfg": {
|
13 |
+
"context_length": 64,
|
14 |
+
"vocab_size": 256000,
|
15 |
+
"hf_tokenizer_name": "timm/ViT-B-16-SigLIP2-256",
|
16 |
+
"tokenizer_kwargs": {
|
17 |
+
"clean": "canonicalize"
|
18 |
+
},
|
19 |
+
"width": 768,
|
20 |
+
"heads": 12,
|
21 |
+
"layers": 12,
|
22 |
+
"no_causal_mask": true,
|
23 |
+
"proj_bias": true,
|
24 |
+
"pool_type": "last",
|
25 |
+
"norm_kwargs":{
|
26 |
+
"eps": 1e-6
|
27 |
+
},
|
28 |
+
"act_kwargs": {
|
29 |
+
"approximate": "tanh"
|
30 |
+
}
|
31 |
+
}
|
32 |
+
}
|
src/open_clip/model_configs/ViT-B-16-SigLIP2-384.json
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"init_logit_bias": -10,
|
4 |
+
"custom_text": true,
|
5 |
+
"vision_cfg": {
|
6 |
+
"image_size": 384,
|
7 |
+
"timm_model_name": "vit_base_patch16_siglip_384",
|
8 |
+
"timm_model_pretrained": false,
|
9 |
+
"timm_pool": "map",
|
10 |
+
"timm_proj": "none"
|
11 |
+
},
|
12 |
+
"text_cfg": {
|
13 |
+
"context_length": 64,
|
14 |
+
"vocab_size": 256000,
|
15 |
+
"hf_tokenizer_name": "timm/ViT-B-16-SigLIP2-384",
|
16 |
+
"tokenizer_kwargs": {
|
17 |
+
"clean": "canonicalize"
|
18 |
+
},
|
19 |
+
"width": 768,
|
20 |
+
"heads": 12,
|
21 |
+
"layers": 12,
|
22 |
+
"no_causal_mask": true,
|
23 |
+
"proj_bias": true,
|
24 |
+
"pool_type": "last",
|
25 |
+
"norm_kwargs":{
|
26 |
+
"eps": 1e-6
|
27 |
+
},
|
28 |
+
"act_kwargs": {
|
29 |
+
"approximate": "tanh"
|
30 |
+
}
|
31 |
+
}
|
32 |
+
}
|
src/open_clip/model_configs/ViT-B-16-SigLIP2-512.json
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"init_logit_bias": -10,
|
4 |
+
"custom_text": true,
|
5 |
+
"vision_cfg": {
|
6 |
+
"image_size": 512,
|
7 |
+
"timm_model_name": "vit_base_patch16_siglip_512",
|
8 |
+
"timm_model_pretrained": false,
|
9 |
+
"timm_pool": "map",
|
10 |
+
"timm_proj": "none"
|
11 |
+
},
|
12 |
+
"text_cfg": {
|
13 |
+
"context_length": 64,
|
14 |
+
"vocab_size": 256000,
|
15 |
+
"hf_tokenizer_name": "timm/ViT-B-16-SigLIP2-512",
|
16 |
+
"tokenizer_kwargs": {
|
17 |
+
"clean": "canonicalize"
|
18 |
+
},
|
19 |
+
"width": 768,
|
20 |
+
"heads": 12,
|
21 |
+
"layers": 12,
|
22 |
+
"no_causal_mask": true,
|
23 |
+
"proj_bias": true,
|
24 |
+
"pool_type": "last",
|
25 |
+
"norm_kwargs":{
|
26 |
+
"eps": 1e-6
|
27 |
+
},
|
28 |
+
"act_kwargs": {
|
29 |
+
"approximate": "tanh"
|
30 |
+
}
|
31 |
+
}
|
32 |
+
}
|