Upload 90 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- environment.yml +248 -0
- exp/NeuralOM/20250309-195251/6_steps_finetune/model2/10_steps_finetune/training_checkpoints/.ipynb_checkpoints/readme-checkpoint.txt +1 -0
- exp/NeuralOM/20250309-195251/6_steps_finetune/model2/10_steps_finetune/training_checkpoints/best_ckpt.tar +3 -0
- exp/NeuralOM/20250309-195251/6_steps_finetune/model2/10_steps_finetune/training_checkpoints/readme.txt +1 -0
- exp/NeuralOM/20250309-195251/6_steps_finetune/training_checkpoints/.ipynb_checkpoints/readme-checkpoint.txt +1 -0
- exp/NeuralOM/20250309-195251/6_steps_finetune/training_checkpoints/best_ckpt.tar +3 -0
- exp/NeuralOM/20250309-195251/6_steps_finetune/training_checkpoints/readme.txt +1 -0
- exp/NeuralOM/20250309-195251/6_steps_finetune/training_checkpoints_atmos/best_ckpt.tar +3 -0
- exp/NeuralOM/20250309-195251/6_steps_finetune/training_checkpoints_atmos/readme.txt +1 -0
- exp/NeuralOM/20250309-195251/config.yaml +78 -0
- inference_forecasting.py +366 -0
- inference_forecasting.sh +13 -0
- inference_simulation.py +312 -0
- inference_simulation.sh +13 -0
- my_utils/YParams.py +55 -0
- my_utils/__pycache__/YParams.cpython-310.pyc +0 -0
- my_utils/__pycache__/YParams.cpython-37.pyc +0 -0
- my_utils/__pycache__/YParams.cpython-39.pyc +0 -0
- my_utils/__pycache__/bicubic.cpython-310.pyc +0 -0
- my_utils/__pycache__/bicubic.cpython-39.pyc +0 -0
- my_utils/__pycache__/darcy_loss.cpython-310.pyc +0 -0
- my_utils/__pycache__/darcy_loss.cpython-310.pyc.70370790180304 +0 -0
- my_utils/__pycache__/darcy_loss.cpython-310.pyc.70373230085584 +0 -0
- my_utils/__pycache__/darcy_loss.cpython-310.pyc.70384414393808 +0 -0
- my_utils/__pycache__/darcy_loss.cpython-37.pyc +0 -0
- my_utils/__pycache__/darcy_loss.cpython-39.pyc +0 -0
- my_utils/__pycache__/data_loader.cpython-310.pyc +0 -0
- my_utils/__pycache__/data_loader_multifiles.cpython-310.pyc +0 -0
- my_utils/__pycache__/data_loader_multifiles.cpython-37.pyc +0 -0
- my_utils/__pycache__/data_loader_multifiles.cpython-39.pyc +0 -0
- my_utils/__pycache__/get_date.cpython-310.pyc +0 -0
- my_utils/__pycache__/img_utils.cpython-310.pyc +0 -0
- my_utils/__pycache__/img_utils.cpython-37.pyc +0 -0
- my_utils/__pycache__/img_utils.cpython-39.pyc +0 -0
- my_utils/__pycache__/logging_utils.cpython-310.pyc +0 -0
- my_utils/__pycache__/logging_utils.cpython-37.pyc +0 -0
- my_utils/__pycache__/logging_utils.cpython-39.pyc +0 -0
- my_utils/__pycache__/norm.cpython-310.pyc +0 -0
- my_utils/__pycache__/time_utils.cpython-310.pyc +0 -0
- my_utils/__pycache__/time_utils.cpython-39.pyc +0 -0
- my_utils/__pycache__/weighted_acc_rmse.cpython-310.pyc +0 -0
- my_utils/__pycache__/weighted_acc_rmse.cpython-37.pyc +0 -0
- my_utils/__pycache__/weighted_acc_rmse.cpython-39.pyc +0 -0
- my_utils/data_loader.py +205 -0
- my_utils/logging_utils.py +26 -0
- my_utils/norm.py +114 -0
- networks/.ipynb_checkpoints/CirT1-checkpoint.py +301 -0
- networks/.ipynb_checkpoints/CirT2-checkpoint.py +301 -0
- networks/CirT1.py +301 -0
- networks/CirT2.py +301 -0
environment.yml
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: neuralom
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- dglteam/label/th24_cu118
|
5 |
+
- nvidia
|
6 |
+
- defaults
|
7 |
+
dependencies:
|
8 |
+
- _libgcc_mutex=0.1=main
|
9 |
+
- _openmp_mutex=5.1=1_gnu
|
10 |
+
- blas=1.0=mkl
|
11 |
+
- brotli-python=1.0.9=py310h6a678d5_8
|
12 |
+
- bzip2=1.0.8=h5eee18b_6
|
13 |
+
- ca-certificates=2024.9.24=h06a4308_0
|
14 |
+
- certifi=2024.8.30=py310h06a4308_0
|
15 |
+
- charset-normalizer=3.3.2=pyhd3eb1b0_0
|
16 |
+
- cuda-cudart=11.8.89=0
|
17 |
+
- cuda-cupti=11.8.87=0
|
18 |
+
- cuda-libraries=11.8.0=0
|
19 |
+
- cuda-nvrtc=11.8.89=0
|
20 |
+
- cuda-nvtx=11.8.86=0
|
21 |
+
- cuda-runtime=11.8.0=0
|
22 |
+
- dgl=2.4.0.th24.cu118=py310_0
|
23 |
+
- ffmpeg=4.3=hf484d3e_0
|
24 |
+
- filelock=3.13.1=py310h06a4308_0
|
25 |
+
- freetype=2.12.1=h4a9f257_0
|
26 |
+
- gmp=6.2.1=h295c915_3
|
27 |
+
- gmpy2=2.1.2=py310heeb90bb_0
|
28 |
+
- gnutls=3.6.15=he1e5248_0
|
29 |
+
- idna=3.7=py310h06a4308_0
|
30 |
+
- intel-openmp=2023.1.0=hdb19cb5_46306
|
31 |
+
- jinja2=3.1.4=py310h06a4308_0
|
32 |
+
- jpeg=9e=h5eee18b_3
|
33 |
+
- lame=3.100=h7b6447c_0
|
34 |
+
- lcms2=2.12=h3be6417_0
|
35 |
+
- ld_impl_linux-64=2.38=h1181459_1
|
36 |
+
- lerc=3.0=h295c915_0
|
37 |
+
- libcublas=11.11.3.6=0
|
38 |
+
- libcufft=10.9.0.58=0
|
39 |
+
- libcufile=1.9.1.3=0
|
40 |
+
- libcurand=10.3.5.147=0
|
41 |
+
- libcusolver=11.4.1.48=0
|
42 |
+
- libcusparse=11.7.5.86=0
|
43 |
+
- libdeflate=1.17=h5eee18b_1
|
44 |
+
- libffi=3.4.4=h6a678d5_1
|
45 |
+
- libgcc-ng=11.2.0=h1234567_1
|
46 |
+
- libgfortran-ng=11.2.0=h00389a5_1
|
47 |
+
- libgfortran5=11.2.0=h1234567_1
|
48 |
+
- libgomp=11.2.0=h1234567_1
|
49 |
+
- libiconv=1.16=h5eee18b_3
|
50 |
+
- libidn2=2.3.4=h5eee18b_0
|
51 |
+
- libjpeg-turbo=2.0.0=h9bf148f_0
|
52 |
+
- libnpp=11.8.0.86=0
|
53 |
+
- libnvjpeg=11.9.0.86=0
|
54 |
+
- libpng=1.6.39=h5eee18b_0
|
55 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
56 |
+
- libtasn1=4.19.0=h5eee18b_0
|
57 |
+
- libtiff=4.5.1=h6a678d5_0
|
58 |
+
- libunistring=0.9.10=h27cfd23_0
|
59 |
+
- libuuid=1.41.5=h5eee18b_0
|
60 |
+
- libwebp-base=1.3.2=h5eee18b_0
|
61 |
+
- llvm-openmp=14.0.6=h9e868ea_0
|
62 |
+
- lz4-c=1.9.4=h6a678d5_1
|
63 |
+
- markupsafe=2.1.3=py310h5eee18b_0
|
64 |
+
- mkl=2023.1.0=h213fc3f_46344
|
65 |
+
- mkl-service=2.4.0=py310h5eee18b_1
|
66 |
+
- mkl_fft=1.3.10=py310h5eee18b_0
|
67 |
+
- mkl_random=1.2.7=py310h1128e8f_0
|
68 |
+
- mpc=1.1.0=h10f8cd9_1
|
69 |
+
- mpfr=4.0.2=hb69a4c5_1
|
70 |
+
- mpmath=1.3.0=py310h06a4308_0
|
71 |
+
- ncurses=6.4=h6a678d5_0
|
72 |
+
- nettle=3.7.3=hbbd107a_1
|
73 |
+
- networkx=3.3=py310h06a4308_0
|
74 |
+
- numpy=1.26.4=py310h5f9d8c6_0
|
75 |
+
- numpy-base=1.26.4=py310hb5e798b_0
|
76 |
+
- openh264=2.1.1=h4ff587b_0
|
77 |
+
- openjpeg=2.5.2=he7f1fd0_0
|
78 |
+
- openssl=3.0.15=h5eee18b_0
|
79 |
+
- pillow=10.4.0=py310h5eee18b_0
|
80 |
+
- pip=24.2=py310h06a4308_0
|
81 |
+
- psutil=5.9.0=py310h5eee18b_0
|
82 |
+
- pybind11-abi=4=hd3eb1b0_1
|
83 |
+
- pysocks=1.7.1=py310h06a4308_0
|
84 |
+
- python=3.10.15=he870216_1
|
85 |
+
- pytorch=2.4.0=py3.10_cuda11.8_cudnn9.1.0_0
|
86 |
+
- pytorch-cuda=11.8=h7e8668a_5
|
87 |
+
- pytorch-mutex=1.0=cuda
|
88 |
+
- pyyaml=6.0.1=py310h5eee18b_0
|
89 |
+
- readline=8.2=h5eee18b_0
|
90 |
+
- requests=2.32.3=py310h06a4308_0
|
91 |
+
- scipy=1.13.1=py310h5f9d8c6_0
|
92 |
+
- setuptools=75.1.0=py310h06a4308_0
|
93 |
+
- sqlite=3.45.3=h5eee18b_0
|
94 |
+
- sympy=1.13.2=py310h06a4308_0
|
95 |
+
- tbb=2021.8.0=hdb19cb5_0
|
96 |
+
- tk=8.6.14=h39e8969_0
|
97 |
+
- torchaudio=2.4.0=py310_cu118
|
98 |
+
- torchtriton=3.0.0=py310
|
99 |
+
- torchvision=0.19.0=py310_cu118
|
100 |
+
- tqdm=4.66.5=py310h2f386ee_0
|
101 |
+
- typing_extensions=4.11.0=py310h06a4308_0
|
102 |
+
- urllib3=2.2.3=py310h06a4308_0
|
103 |
+
- wheel=0.44.0=py310h06a4308_0
|
104 |
+
- xz=5.4.6=h5eee18b_1
|
105 |
+
- yaml=0.2.5=h7b6447c_0
|
106 |
+
- zlib=1.2.13=h5eee18b_1
|
107 |
+
- zstd=1.5.5=hc292b87_2
|
108 |
+
- pip:
|
109 |
+
- aiobotocore==2.15.1
|
110 |
+
- aiohappyeyeballs==2.4.3
|
111 |
+
- aiohttp==3.10.8
|
112 |
+
- aioitertools==0.12.0
|
113 |
+
- aiosignal==1.3.1
|
114 |
+
- anyio==4.6.0
|
115 |
+
- argon2-cffi==23.1.0
|
116 |
+
- argon2-cffi-bindings==21.2.0
|
117 |
+
- arrow==1.3.0
|
118 |
+
- asttokens==2.4.1
|
119 |
+
- async-lru==2.0.4
|
120 |
+
- async-timeout==4.0.3
|
121 |
+
- attrs==24.2.0
|
122 |
+
- babel==2.16.0
|
123 |
+
- beautifulsoup4==4.12.3
|
124 |
+
- bleach==6.1.0
|
125 |
+
- blessed==1.20.0
|
126 |
+
- botocore==1.35.23
|
127 |
+
- cartopy==0.24.1
|
128 |
+
- cffi==1.17.1
|
129 |
+
- cftime==1.6.4.post1
|
130 |
+
- cmocean==4.0.3
|
131 |
+
- colorama==0.4.6
|
132 |
+
- comm==0.2.2
|
133 |
+
- contourpy==1.3.0
|
134 |
+
- cycler==0.12.1
|
135 |
+
- debugpy==1.8.6
|
136 |
+
- decorator==5.1.1
|
137 |
+
- defusedxml==0.7.1
|
138 |
+
- einops==0.8.0
|
139 |
+
- exceptiongroup==1.2.2
|
140 |
+
- executing==2.1.0
|
141 |
+
- fastjsonschema==2.20.0
|
142 |
+
- fonttools==4.54.1
|
143 |
+
- fqdn==1.5.1
|
144 |
+
- frozenlist==1.4.1
|
145 |
+
- fsspec==2024.9.0
|
146 |
+
- gpustat==1.1.1
|
147 |
+
- h11==0.14.0
|
148 |
+
- h5netcdf==1.4.0
|
149 |
+
- h5py==3.12.1
|
150 |
+
- httpcore==1.0.6
|
151 |
+
- httpx==0.27.2
|
152 |
+
- huggingface-hub==0.25.1
|
153 |
+
- icecream==2.1.3
|
154 |
+
- importlib-metadata==8.5.0
|
155 |
+
- ipykernel==6.29.5
|
156 |
+
- ipython==8.28.0
|
157 |
+
- isoduration==20.11.0
|
158 |
+
- jedi==0.19.1
|
159 |
+
- jmespath==1.0.1
|
160 |
+
- joblib==1.4.2
|
161 |
+
- json5==0.9.25
|
162 |
+
- jsonpointer==3.0.0
|
163 |
+
- jsonschema==4.23.0
|
164 |
+
- jsonschema-specifications==2024.10.1
|
165 |
+
- jupyter-client==8.6.3
|
166 |
+
- jupyter-core==5.7.2
|
167 |
+
- jupyter-events==0.10.0
|
168 |
+
- jupyter-lsp==2.2.5
|
169 |
+
- jupyter-server==2.14.2
|
170 |
+
- jupyter-server-terminals==0.5.3
|
171 |
+
- jupyterlab==4.2.5
|
172 |
+
- jupyterlab-pygments==0.3.0
|
173 |
+
- jupyterlab-server==2.27.3
|
174 |
+
- kiwisolver==1.4.7
|
175 |
+
- matplotlib==3.9.2
|
176 |
+
- matplotlib-inline==0.1.7
|
177 |
+
- mistune==3.0.2
|
178 |
+
- multidict==6.1.0
|
179 |
+
- nbclient==0.10.0
|
180 |
+
- nbconvert==7.16.4
|
181 |
+
- nbformat==5.10.4
|
182 |
+
- nest-asyncio==1.6.0
|
183 |
+
- netcdf4==1.7.2
|
184 |
+
- notebook==7.2.2
|
185 |
+
- notebook-shim==0.2.4
|
186 |
+
- nvfuser-cu118-torch24==0.2.9.dev20240808
|
187 |
+
- nvidia-cuda-cupti-cu11==11.8.87
|
188 |
+
- nvidia-cuda-nvrtc-cu11==11.8.89
|
189 |
+
- nvidia-cuda-runtime-cu11==11.8.89
|
190 |
+
- nvidia-ml-py==12.560.30
|
191 |
+
- nvidia-nvtx-cu11==11.8.86
|
192 |
+
- overrides==7.7.0
|
193 |
+
- packaging==24.1
|
194 |
+
- pandas==2.2.3
|
195 |
+
- pandocfilters==1.5.1
|
196 |
+
- parso==0.8.4
|
197 |
+
- pexpect==4.9.0
|
198 |
+
- platformdirs==4.3.6
|
199 |
+
- prometheus-client==0.21.0
|
200 |
+
- prompt-toolkit==3.0.48
|
201 |
+
- ptyprocess==0.7.0
|
202 |
+
- pure-eval==0.2.3
|
203 |
+
- pycparser==2.22
|
204 |
+
- pygments==2.18.0
|
205 |
+
- pyparsing==3.2.0
|
206 |
+
- pyproj==3.7.0
|
207 |
+
- pyshp==2.3.1
|
208 |
+
- python-dateutil==2.9.0.post0
|
209 |
+
- python-json-logger==2.0.7
|
210 |
+
- pytz==2024.2
|
211 |
+
- pyzmq==26.2.0
|
212 |
+
- referencing==0.35.1
|
213 |
+
- rfc3339-validator==0.1.4
|
214 |
+
- rfc3986-validator==0.1.1
|
215 |
+
- rpds-py==0.20.0
|
216 |
+
- ruamel-yaml==0.18.6
|
217 |
+
- ruamel-yaml-clib==0.2.8
|
218 |
+
- s3fs==2024.9.0
|
219 |
+
- safetensors==0.4.5
|
220 |
+
- scikit-learn==1.5.2
|
221 |
+
- send2trash==1.8.3
|
222 |
+
- shapely==2.0.6
|
223 |
+
- six==1.16.0
|
224 |
+
- sniffio==1.3.1
|
225 |
+
- soupsieve==2.6
|
226 |
+
- stack-data==0.6.3
|
227 |
+
- terminado==0.18.1
|
228 |
+
- thop==0.1.1-2209072238
|
229 |
+
- threadpoolctl==3.5.0
|
230 |
+
- timm==1.0.9
|
231 |
+
- tinycss2==1.3.0
|
232 |
+
- tomli==2.0.2
|
233 |
+
- torchsummary==1.5.1
|
234 |
+
- tornado==6.4.1
|
235 |
+
- traitlets==5.14.3
|
236 |
+
- treelib==1.7.0
|
237 |
+
- types-python-dateutil==2.9.0.20241003
|
238 |
+
- tzdata==2024.2
|
239 |
+
- uri-template==1.3.0
|
240 |
+
- wcwidth==0.2.13
|
241 |
+
- webcolors==24.8.0
|
242 |
+
- webencodings==0.5.1
|
243 |
+
- websocket-client==1.8.0
|
244 |
+
- wrapt==1.16.0
|
245 |
+
- xarray==2024.9.0
|
246 |
+
- yarl==1.13.1
|
247 |
+
- zipp==3.20.2
|
248 |
+
prefix: /miniconda3/envs/neuralom
|
exp/NeuralOM/20250309-195251/6_steps_finetune/model2/10_steps_finetune/training_checkpoints/.ipynb_checkpoints/readme-checkpoint.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
The intact project is available at the Hugging Face.
|
exp/NeuralOM/20250309-195251/6_steps_finetune/model2/10_steps_finetune/training_checkpoints/best_ckpt.tar
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f9fe78a12419997b9deaf0dd2ec912c1c936e96cc418bc507acdd2baecf908a2
|
3 |
+
size 661771939
|
exp/NeuralOM/20250309-195251/6_steps_finetune/model2/10_steps_finetune/training_checkpoints/readme.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
The intact project is available at the Hugging Face.
|
exp/NeuralOM/20250309-195251/6_steps_finetune/training_checkpoints/.ipynb_checkpoints/readme-checkpoint.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
The intact project is available at the Hugging Face.
|
exp/NeuralOM/20250309-195251/6_steps_finetune/training_checkpoints/best_ckpt.tar
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8fb1c1827478608c96b490662f9b59351a52b1ee423883158c0a9e7c09b7d919
|
3 |
+
size 661813411
|
exp/NeuralOM/20250309-195251/6_steps_finetune/training_checkpoints/readme.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
The intact project is available at the Hugging Face.
|
exp/NeuralOM/20250309-195251/6_steps_finetune/training_checkpoints_atmos/best_ckpt.tar
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cb94ee483750345aea106470e39366e77bd92014c62344e9dad1837138e0ead7
|
3 |
+
size 600426467
|
exp/NeuralOM/20250309-195251/6_steps_finetune/training_checkpoints_atmos/readme.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
The intact project is available at the Hugging Face.
|
exp/NeuralOM/20250309-195251/config.yaml
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### base config ###
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
full_field: &FULL_FIELD
|
4 |
+
num_data_workers: 4
|
5 |
+
dt: 1
|
6 |
+
n_history: 0
|
7 |
+
prediction_length: 41
|
8 |
+
ics_type: "default"
|
9 |
+
|
10 |
+
exp_dir: './exp'
|
11 |
+
|
12 |
+
# data
|
13 |
+
train_data_path: './data/train'
|
14 |
+
valid_data_path: './data/valid'
|
15 |
+
test_data_path: './data/test'
|
16 |
+
test_data_path_atmos: './data/test_atmos'
|
17 |
+
|
18 |
+
# land mask
|
19 |
+
land_mask: !!bool True
|
20 |
+
land_mask_path: './data/land_mask.h5'
|
21 |
+
|
22 |
+
# normalization
|
23 |
+
normalize: !!bool True
|
24 |
+
normalization: 'zscore'
|
25 |
+
global_means_path: './data/mean_s_t_ssh.npy'
|
26 |
+
global_stds_path: './data/std_s_t_ssh.npy'
|
27 |
+
|
28 |
+
global_means_path_atmos: './data/mean_atmos.npy'
|
29 |
+
global_stds_path_atmos: './data/std_atmos.npy'
|
30 |
+
|
31 |
+
# orography
|
32 |
+
orography: !!bool False
|
33 |
+
|
34 |
+
# noise
|
35 |
+
add_noise: !!bool False
|
36 |
+
noise_std: 0
|
37 |
+
|
38 |
+
# crop
|
39 |
+
crop_size_x: None
|
40 |
+
crop_size_y: None
|
41 |
+
|
42 |
+
log_to_screen: !!bool True
|
43 |
+
log_to_wandb: !!bool False
|
44 |
+
save_checkpoint: !!bool True
|
45 |
+
plot_animations: !!bool False
|
46 |
+
|
47 |
+
|
48 |
+
#############################################
|
49 |
+
NeuralOM: &NeuralOM
|
50 |
+
<<: *FULL_FIELD
|
51 |
+
nettype: 'NeuralOM'
|
52 |
+
log_to_wandb: !!bool False
|
53 |
+
|
54 |
+
# Train params
|
55 |
+
lr: 1e-3
|
56 |
+
batch_size: 32
|
57 |
+
scheduler: 'CosineAnnealingLR'
|
58 |
+
|
59 |
+
loss_channel_wise: True
|
60 |
+
loss_scale: False
|
61 |
+
use_loss_scaler_from_metnet3: True
|
62 |
+
|
63 |
+
|
64 |
+
atmos_channels: [93, 94, 95, 96]
|
65 |
+
|
66 |
+
ocean_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ,21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92]
|
67 |
+
|
68 |
+
in_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ,21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96]
|
69 |
+
|
70 |
+
out_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ,21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92]
|
71 |
+
|
72 |
+
in_channels_atmos: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ,21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68]
|
73 |
+
|
74 |
+
out_channels_atmos: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ,21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68]
|
75 |
+
|
76 |
+
|
77 |
+
out_variables: ["S0", "S2", "S5", "S7", "S11", "S15", "S21", "S29", "S40", "S55", "S77", "S92", "S109", "S130", "S155", "S186", "S222", "S266", "S318", "S380", "S453", "S541", "S643", "U0", "U2", "U5", "U7", "U11", "U15", "U21", "U29", "U40", "U55", "U77", "U92", "U109", "U130", "U155", "U186", "U222", "U266", "U318", "U380", "U453", "U541", "U643", "V0", "V2", "V5", "V7", "V11", "V15", "V21", "V29", "V40", "V55", "V77", "V92", "V109", "V130", "V155", "V186", "V222", "V266", "V318", "V380", "V453", "V541", "V643", "T0", "T2", "T5", "T7", "T11", "T15", "T21", "T29", "T40", "T55", "T77", "T92", "T109", "T130", "T155", "T186", "T222", "T266", "T318", "T380", "T453", "T541", "T643", "SSH"]
|
78 |
+
|
inference_forecasting.py
ADDED
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import time
|
4 |
+
import glob
|
5 |
+
import h5py
|
6 |
+
import logging
|
7 |
+
import argparse
|
8 |
+
import numpy as np
|
9 |
+
from icecream import ic
|
10 |
+
from datetime import datetime
|
11 |
+
from collections import OrderedDict
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.cuda.amp as amp
|
15 |
+
import torch.nn.functional as F
|
16 |
+
import torch.distributed as dist
|
17 |
+
from torch.nn.parallel import DistributedDataParallel
|
18 |
+
|
19 |
+
sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../')
|
20 |
+
from my_utils.YParams import YParams
|
21 |
+
from my_utils.data_loader import get_data_loader
|
22 |
+
from my_utils import logging_utils
|
23 |
+
logging_utils.config_logger()
|
24 |
+
|
25 |
+
|
26 |
+
def load_model(model, params, checkpoint_file):
|
27 |
+
model.zero_grad()
|
28 |
+
checkpoint_fname = checkpoint_file
|
29 |
+
checkpoint = torch.load(checkpoint_fname)
|
30 |
+
try:
|
31 |
+
new_state_dict = OrderedDict()
|
32 |
+
for key, val in checkpoint['model_state'].items():
|
33 |
+
name = key[7:]
|
34 |
+
if name != 'ged':
|
35 |
+
new_state_dict[name] = val
|
36 |
+
model.load_state_dict(new_state_dict)
|
37 |
+
except:
|
38 |
+
model.load_state_dict(checkpoint['model_state'])
|
39 |
+
model.eval()
|
40 |
+
return model
|
41 |
+
|
42 |
+
def setup(params):
|
43 |
+
device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
|
44 |
+
|
45 |
+
# get data loader
|
46 |
+
valid_data_loader, valid_dataset = get_data_loader(params, params.test_data_path, dist.is_initialized(), train=False)
|
47 |
+
|
48 |
+
img_shape_x = valid_dataset.img_shape_x
|
49 |
+
img_shape_y = valid_dataset.img_shape_y
|
50 |
+
params.img_shape_x = img_shape_x
|
51 |
+
params.img_shape_y = img_shape_y
|
52 |
+
|
53 |
+
in_channels = np.array(params.in_channels)
|
54 |
+
out_channels = np.array(params.out_channels)
|
55 |
+
n_in_channels = len(in_channels)
|
56 |
+
n_out_channels = len(out_channels)
|
57 |
+
|
58 |
+
params['N_in_channels'] = n_in_channels
|
59 |
+
params['N_out_channels'] = n_out_channels
|
60 |
+
|
61 |
+
if params.normalization == 'zscore':
|
62 |
+
params.means = np.load(params.global_means_path)
|
63 |
+
params.stds = np.load(params.global_stds_path)
|
64 |
+
|
65 |
+
params.means_atmos = np.load(params.global_means_path_atmos)
|
66 |
+
params.stds_atmos = np.load(params.global_stds_path_atmos)
|
67 |
+
|
68 |
+
if params.nettype == 'NeuralOM':
|
69 |
+
from networks.MIGNN1 import MIGraph as model
|
70 |
+
from networks.MIGNN2 import MIGraph_stage2 as model2
|
71 |
+
from networks.OneForecast import OneForecast as model_atmos
|
72 |
+
else:
|
73 |
+
raise Exception("not implemented")
|
74 |
+
|
75 |
+
checkpoint_file = params['best_checkpoint_path']
|
76 |
+
checkpoint_file2 = params['best_checkpoint_path2']
|
77 |
+
checkpoint_file_atmos = params['best_checkpoint_path_atmos']
|
78 |
+
logging.info('Loading trained model checkpoint from {}'.format(checkpoint_file))
|
79 |
+
logging.info('Loading trained model2 checkpoint from {}'.format(checkpoint_file2))
|
80 |
+
logging.info('Loading trained model_atmos checkpoint from {}'.format(checkpoint_file_atmos))
|
81 |
+
|
82 |
+
model = model(params).to(device)
|
83 |
+
model = load_model(model, params, checkpoint_file)
|
84 |
+
model = model.to(device)
|
85 |
+
|
86 |
+
print('model is ok')
|
87 |
+
|
88 |
+
model2 = model2(params).to(device)
|
89 |
+
model2 = load_model(model2, params, checkpoint_file2)
|
90 |
+
model2 = model2.to(device)
|
91 |
+
|
92 |
+
print('model2 is ok')
|
93 |
+
|
94 |
+
model_atmos = model_atmos(params).to(device)
|
95 |
+
model_atmos = load_model(model_atmos, params, checkpoint_file_atmos)
|
96 |
+
model_atmos = model_atmos.to(device)
|
97 |
+
|
98 |
+
print('model_atmos is ok')
|
99 |
+
|
100 |
+
files_paths = glob.glob(params.test_data_path + "/*.h5")
|
101 |
+
files_paths.sort()
|
102 |
+
|
103 |
+
files_paths_atmos = glob.glob(params.test_data_path_atmos + "/*.h5")
|
104 |
+
files_paths_atmos.sort()
|
105 |
+
|
106 |
+
# which year
|
107 |
+
yr = 0
|
108 |
+
logging.info('Loading inference data')
|
109 |
+
logging.info('Inference data from {}'.format(files_paths[yr]))
|
110 |
+
logging.info('Inference data_atmos from {}'.format(files_paths_atmos[yr]))
|
111 |
+
climate_mean = np.load('./data/climate_mean_s_t_ssh.npy')
|
112 |
+
valid_data_full = h5py.File(files_paths[yr], 'r')['fields'][:365, :, :, :]
|
113 |
+
valid_data_full = valid_data_full - climate_mean
|
114 |
+
|
115 |
+
valid_data_full_atmos = h5py.File(files_paths_atmos[yr], 'r')['fields'][2:1460:4, :, :, :]
|
116 |
+
|
117 |
+
return valid_data_full, valid_data_full_atmos, model, model2, model_atmos
|
118 |
+
|
119 |
+
|
120 |
+
def autoregressive_inference(params, init_condition, valid_data_full, valid_data_full_atmos, model, model2, model_atmos):
|
121 |
+
device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
|
122 |
+
|
123 |
+
icd = int(init_condition)
|
124 |
+
|
125 |
+
exp_dir = params['experiment_dir']
|
126 |
+
dt = int(params.dt)
|
127 |
+
prediction_length = int(params.prediction_length/dt)
|
128 |
+
n_history = params.n_history
|
129 |
+
img_shape_x = params.img_shape_x
|
130 |
+
img_shape_y = params.img_shape_y
|
131 |
+
in_channels = np.array(params.in_channels)
|
132 |
+
out_channels = np.array(params.out_channels)
|
133 |
+
in_channels_atmos = np.array(params.in_channels_atmos)
|
134 |
+
out_channels_atmos = np.array(params.out_channels_atmos)
|
135 |
+
atmos_channels = np.array(params.atmos_channels)
|
136 |
+
n_in_channels = len(in_channels)
|
137 |
+
n_out_channels = len(out_channels)
|
138 |
+
|
139 |
+
seq_real = torch.zeros((prediction_length, n_out_channels, img_shape_x, img_shape_y))
|
140 |
+
seq_pred = torch.zeros((prediction_length, n_out_channels, img_shape_x, img_shape_y))
|
141 |
+
|
142 |
+
|
143 |
+
valid_data = valid_data_full[icd:(icd+prediction_length*dt+n_history*dt):dt][:, params.in_channels][:,:,0:360]
|
144 |
+
valid_data_atmos = valid_data_full_atmos[icd:(icd+prediction_length*dt+n_history*dt):dt][:, params.in_channels_atmos][:,:,0:120]
|
145 |
+
logging.info(f'valid_data_full: {valid_data_full.shape}')
|
146 |
+
logging.info(f'valid_data: {valid_data.shape}')
|
147 |
+
logging.info(f'valid_data_full_atmos: {valid_data_full_atmos.shape}')
|
148 |
+
logging.info(f'valid_data_atmos: {valid_data_atmos.shape}')
|
149 |
+
|
150 |
+
# normalize
|
151 |
+
if params.normalization == 'zscore':
|
152 |
+
valid_data = (valid_data - params.means[:,params.in_channels])/params.stds[:,params.in_channels]
|
153 |
+
valid_data = np.nan_to_num(valid_data, nan=0)
|
154 |
+
|
155 |
+
valid_data_atmos = (valid_data_atmos - params.means_atmos[:,params.in_channels_atmos])/params.stds_atmos[:,params.in_channels_atmos]
|
156 |
+
valid_data_atmos = np.nan_to_num(valid_data_atmos, nan=0)
|
157 |
+
|
158 |
+
valid_data = torch.as_tensor(valid_data)
|
159 |
+
valid_data_atmos = torch.as_tensor(valid_data_atmos)
|
160 |
+
|
161 |
+
# autoregressive inference
|
162 |
+
logging.info('Begin autoregressive inference')
|
163 |
+
|
164 |
+
|
165 |
+
with torch.no_grad():
|
166 |
+
for i in range(valid_data.shape[0]):
|
167 |
+
if i==0: # start of sequence, t0 --> t0'
|
168 |
+
first = valid_data[0:n_history+1]
|
169 |
+
first_atmos = valid_data_atmos[0:n_history+1]
|
170 |
+
ic(valid_data.shape, first.shape)
|
171 |
+
ic(valid_data_atmos.shape, first_atmos.shape)
|
172 |
+
future = valid_data[n_history+1]
|
173 |
+
ic(future.shape)
|
174 |
+
|
175 |
+
for h in range(n_history+1):
|
176 |
+
|
177 |
+
seq_real[h] = first[h*n_in_channels : (h+1)*n_in_channels, :93]
|
178 |
+
|
179 |
+
seq_pred[h] = seq_real[h]
|
180 |
+
|
181 |
+
first = first.to(device, dtype=torch.float)
|
182 |
+
first_atmos = first_atmos.to(device, dtype=torch.float)
|
183 |
+
first_ocean = first[:, params.ocean_channels, :, :]
|
184 |
+
ic(first_ocean.shape)
|
185 |
+
future_force0 = first_atmos[:, [65, 66, 67, 68], :120, :240]
|
186 |
+
# future_force0 = torch.unsqueeze(future_force0, dim=0).to(device, dtype=torch.float)
|
187 |
+
future_force0 = F.interpolate(future_force0, size=(360, 720), mode='bilinear', align_corners=False)
|
188 |
+
|
189 |
+
model_input_atmos = first_atmos
|
190 |
+
ic(model_input_atmos.shape)
|
191 |
+
for k in range(4):
|
192 |
+
if k ==0:
|
193 |
+
model_atmos_future_pred = model_atmos(model_input_atmos)
|
194 |
+
else:
|
195 |
+
model_atmos_future_pred = model_atmos(model_atmos_future_pred)
|
196 |
+
|
197 |
+
future_force = model_atmos_future_pred[:, [65, 66, 67, 68], :120, :240]
|
198 |
+
# future_force = torch.unsqueeze(future_force, dim=0).to(device, dtype=torch.float)
|
199 |
+
future_force = F.interpolate(future_force, size=(360, 720), mode='bilinear', align_corners=False)
|
200 |
+
|
201 |
+
model_input = torch.cat((first_ocean, future_force0, future_force.cuda()), axis=1)
|
202 |
+
ic(model_input.shape)
|
203 |
+
model1_future_pred = model(model_input)
|
204 |
+
with h5py.File(params.land_mask_path, 'r') as _f:
|
205 |
+
mask_data = torch.as_tensor(_f['fields'][:,out_channels, :360, :720], dtype=bool).to(device, dtype=torch.bool)
|
206 |
+
model1_future_pred = torch.masked_fill(input=model1_future_pred, mask=~mask_data, value=0)
|
207 |
+
future_pred = model2(model1_future_pred) + model1_future_pred
|
208 |
+
|
209 |
+
|
210 |
+
else:
|
211 |
+
if i < prediction_length-1:
|
212 |
+
future0 = valid_data[n_history+i]
|
213 |
+
future = valid_data[n_history+i+1]
|
214 |
+
|
215 |
+
inf_one_step_start = time.time()
|
216 |
+
future_force0 = model_atmos_future_pred[:, [65, 66, 67, 68], :120, :240]
|
217 |
+
# future_force0 = torch.unsqueeze(future_force0, dim=0).to(device, dtype=torch.float)
|
218 |
+
future_force0 = F.interpolate(future_force0, size=(360, 720), mode='bilinear', align_corners=False)
|
219 |
+
|
220 |
+
for k in range(4):
|
221 |
+
model_atmos_future_pred = model_atmos(model_atmos_future_pred)
|
222 |
+
|
223 |
+
future_force = model_atmos_future_pred[:, [65, 66, 67, 68], :120, :240]
|
224 |
+
# future_force = torch.unsqueeze(future_force, dim=0).to(device, dtype=torch.float)
|
225 |
+
future_force = F.interpolate(future_force, size=(360, 720), mode='bilinear', align_corners=False)
|
226 |
+
|
227 |
+
model1_future_pred = model(torch.cat((future_pred.cuda(), future_force0, future_force), axis=1)) #autoregressive step
|
228 |
+
with h5py.File(params.land_mask_path, 'r') as _f:
|
229 |
+
mask_data = torch.as_tensor(_f['fields'][:,out_channels, :360, :720], dtype=bool).to(device, dtype=torch.bool)
|
230 |
+
model1_future_pred = torch.masked_fill(input=model1_future_pred, mask=~mask_data, value=0)
|
231 |
+
future_pred = model2(model1_future_pred) + model1_future_pred
|
232 |
+
inf_one_step_time = time.time() - inf_one_step_start
|
233 |
+
|
234 |
+
logging.info(f'inference one step time: {inf_one_step_time}')
|
235 |
+
|
236 |
+
|
237 |
+
if i < prediction_length - 1: # not on the last step
|
238 |
+
with h5py.File(params.land_mask_path, 'r') as _f:
|
239 |
+
mask_data = torch.as_tensor(_f['fields'][:,out_channels, :360, :720], dtype=bool)
|
240 |
+
seq_pred[n_history+i+1] = torch.masked_fill(input=future_pred.cpu(), mask=~mask_data, value=0)
|
241 |
+
seq_real[n_history+i+1] = future[:93]
|
242 |
+
history_stack = seq_pred[i+1:i+2+n_history]
|
243 |
+
|
244 |
+
future_pred = history_stack
|
245 |
+
|
246 |
+
pred = torch.unsqueeze(seq_pred[i], 0)
|
247 |
+
tar = torch.unsqueeze(seq_real[i], 0)
|
248 |
+
|
249 |
+
with h5py.File(params.land_mask_path, 'r') as _f:
|
250 |
+
mask_data = torch.as_tensor(_f['fields'][:,out_channels, :360, :720], dtype=bool)
|
251 |
+
ic(mask_data.shape, pred.shape, tar.shape)
|
252 |
+
pred = torch.masked_fill(input=pred, mask=~mask_data, value=0)
|
253 |
+
tar = torch.masked_fill(input=tar, mask=~mask_data, value=0)
|
254 |
+
|
255 |
+
print(torch.mean((pred-tar)**2))
|
256 |
+
|
257 |
+
|
258 |
+
seq_real = seq_real * params.stds[:,params.out_channels] + params.means[:,params.out_channels]
|
259 |
+
seq_real = seq_real.numpy()
|
260 |
+
seq_pred = seq_pred * params.stds[:,params.out_channels] + params.means[:,params.out_channels]
|
261 |
+
seq_pred = seq_pred.numpy()
|
262 |
+
|
263 |
+
|
264 |
+
return (np.expand_dims(seq_real[n_history:], 0),
|
265 |
+
np.expand_dims(seq_pred[n_history:], 0),
|
266 |
+
)
|
267 |
+
|
268 |
+
|
269 |
+
if __name__ == '__main__':
|
270 |
+
parser = argparse.ArgumentParser()
|
271 |
+
parser.add_argument("--exp_dir", default='../exp_15_levels', type=str)
|
272 |
+
parser.add_argument("--config", default='full_field', type=str)
|
273 |
+
parser.add_argument("--run_num", default='00', type=str)
|
274 |
+
parser.add_argument("--prediction_length", default=61, type=int)
|
275 |
+
parser.add_argument("--finetune_dir", default='', type=str)
|
276 |
+
parser.add_argument("--ics_type", default='default', type=str)
|
277 |
+
args = parser.parse_args()
|
278 |
+
|
279 |
+
config_path = os.path.join(args.exp_dir, args.config, args.run_num, 'config.yaml')
|
280 |
+
params = YParams(config_path, args.config)
|
281 |
+
|
282 |
+
params['resuming'] = False
|
283 |
+
params['interp'] = 0
|
284 |
+
params['world_size'] = 1
|
285 |
+
params['local_rank'] = 0
|
286 |
+
params['global_batch_size'] = params.batch_size
|
287 |
+
params['prediction_length'] = args.prediction_length
|
288 |
+
params['multi_steps_finetune'] = 1
|
289 |
+
|
290 |
+
torch.cuda.set_device(0)
|
291 |
+
torch.backends.cudnn.benchmark = True
|
292 |
+
|
293 |
+
# Set up directory
|
294 |
+
if args.finetune_dir == '':
|
295 |
+
expDir = os.path.join(params.exp_dir, args.config, str(args.run_num))
|
296 |
+
else:
|
297 |
+
expDir = os.path.join(params.exp_dir, args.config, str(args.run_num), args.finetune_dir)
|
298 |
+
logging.info(f'expDir: {expDir}')
|
299 |
+
params['experiment_dir'] = expDir
|
300 |
+
params['best_checkpoint_path'] = os.path.join(expDir, 'training_checkpoints/best_ckpt.tar')
|
301 |
+
params['best_checkpoint_path2'] = os.path.join(expDir, 'model2/10_steps_finetune/training_checkpoints/best_ckpt.tar')
|
302 |
+
|
303 |
+
params['best_checkpoint_path_atmos'] = os.path.join(expDir, 'training_checkpoints_atmos/best_ckpt.tar')
|
304 |
+
|
305 |
+
# set up logging
|
306 |
+
logging_utils.log_to_file(logger_name=None, log_filename=os.path.join(expDir, 'inference.log'))
|
307 |
+
logging_utils.log_versions()
|
308 |
+
params.log()
|
309 |
+
|
310 |
+
if params["ics_type"] == 'default':
|
311 |
+
ics = np.arange(0, 50, 1)
|
312 |
+
n_ics = len(ics)
|
313 |
+
print('init_condition:', ics)
|
314 |
+
|
315 |
+
logging.info("Inference for {} initial conditions".format(n_ics))
|
316 |
+
|
317 |
+
try:
|
318 |
+
autoregressive_inference_filetag = params["inference_file_tag"]
|
319 |
+
except:
|
320 |
+
autoregressive_inference_filetag = ""
|
321 |
+
if params.interp > 0:
|
322 |
+
autoregressive_inference_filetag = "_coarse"
|
323 |
+
|
324 |
+
valid_data_full, valid_data_full_atmos, model, model2, model_atmos = setup(params)
|
325 |
+
|
326 |
+
|
327 |
+
seq_pred = []
|
328 |
+
seq_real = []
|
329 |
+
|
330 |
+
# run autoregressive inference for multiple initial conditions
|
331 |
+
for i, ic_ in enumerate(ics):
|
332 |
+
logging.info("Initial condition {} of {}".format(i+1, n_ics))
|
333 |
+
seq_real, seq_pred = autoregressive_inference(params, ic_, valid_data_full, valid_data_full_atmos, model, model2, model_atmos)
|
334 |
+
|
335 |
+
prediction_length = seq_real[0].shape[0]
|
336 |
+
n_out_channels = seq_real[0].shape[1]
|
337 |
+
img_shape_x = seq_real[0].shape[2]
|
338 |
+
img_shape_y = seq_real[0].shape[3]
|
339 |
+
|
340 |
+
# save predictions and loss
|
341 |
+
save_path = os.path.join(params['experiment_dir'], 'results_forecasting.h5')
|
342 |
+
logging.info("Saving to {}".format(save_path))
|
343 |
+
print(f'saving to {save_path}')
|
344 |
+
if i==0:
|
345 |
+
f = h5py.File(save_path, 'w')
|
346 |
+
f.create_dataset(
|
347 |
+
"ground_truth",
|
348 |
+
data=seq_real,
|
349 |
+
maxshape=[None, prediction_length, n_out_channels, img_shape_x, img_shape_y],
|
350 |
+
dtype=np.float32)
|
351 |
+
f.create_dataset(
|
352 |
+
"predicted",
|
353 |
+
data=seq_pred,
|
354 |
+
maxshape=[None, prediction_length, n_out_channels, img_shape_x, img_shape_y],
|
355 |
+
dtype=np.float32)
|
356 |
+
f.close()
|
357 |
+
else:
|
358 |
+
f = h5py.File(save_path, 'a')
|
359 |
+
|
360 |
+
f["ground_truth"].resize((f["ground_truth"].shape[0] + 1), axis = 0)
|
361 |
+
f["ground_truth"][-1:] = seq_real
|
362 |
+
|
363 |
+
f["predicted"].resize((f["predicted"].shape[0] + 1), axis = 0)
|
364 |
+
f["predicted"][-1:] = seq_pred
|
365 |
+
f.close()
|
366 |
+
|
inference_forecasting.sh
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
prediction_length=61 # 31
|
2 |
+
|
3 |
+
exp_dir='./exp'
|
4 |
+
config='NeuralOM'
|
5 |
+
run_num='20250309-195251'
|
6 |
+
finetune_dir='6_steps_finetune'
|
7 |
+
|
8 |
+
ics_type='default'
|
9 |
+
|
10 |
+
CUDA_VISIBLE_DEVICES=2 python inference_forecasting.py --exp_dir=${exp_dir} --config=${config} --run_num=${run_num} --finetune_dir=$finetune_dir --prediction_length=${prediction_length} --ics_type=${ics_type}
|
11 |
+
|
12 |
+
|
13 |
+
|
inference_simulation.py
ADDED
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import time
|
4 |
+
import glob
|
5 |
+
import h5py
|
6 |
+
import logging
|
7 |
+
import argparse
|
8 |
+
import numpy as np
|
9 |
+
from icecream import ic
|
10 |
+
from datetime import datetime
|
11 |
+
from collections import OrderedDict
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.cuda.amp as amp
|
15 |
+
import torch.distributed as dist
|
16 |
+
from torch.nn.parallel import DistributedDataParallel
|
17 |
+
|
18 |
+
sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../')
|
19 |
+
from my_utils.YParams import YParams
|
20 |
+
from my_utils.data_loader import get_data_loader
|
21 |
+
from my_utils import logging_utils
|
22 |
+
logging_utils.config_logger()
|
23 |
+
|
24 |
+
|
25 |
+
def load_model(model, params, checkpoint_file):
|
26 |
+
model.zero_grad()
|
27 |
+
checkpoint_fname = checkpoint_file
|
28 |
+
checkpoint = torch.load(checkpoint_fname)
|
29 |
+
try:
|
30 |
+
new_state_dict = OrderedDict()
|
31 |
+
for key, val in checkpoint['model_state'].items():
|
32 |
+
name = key[7:]
|
33 |
+
if name != 'ged':
|
34 |
+
new_state_dict[name] = val
|
35 |
+
model.load_state_dict(new_state_dict)
|
36 |
+
except:
|
37 |
+
model.load_state_dict(checkpoint['model_state'])
|
38 |
+
model.eval()
|
39 |
+
return model
|
40 |
+
|
41 |
+
def setup(params):
|
42 |
+
device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
|
43 |
+
|
44 |
+
# get data loader
|
45 |
+
valid_data_loader, valid_dataset = get_data_loader(params, params.test_data_path, dist.is_initialized(), train=False)
|
46 |
+
|
47 |
+
img_shape_x = valid_dataset.img_shape_x
|
48 |
+
img_shape_y = valid_dataset.img_shape_y
|
49 |
+
params.img_shape_x = img_shape_x
|
50 |
+
params.img_shape_y = img_shape_y
|
51 |
+
|
52 |
+
in_channels = np.array(params.in_channels)
|
53 |
+
out_channels = np.array(params.out_channels)
|
54 |
+
n_in_channels = len(in_channels)
|
55 |
+
n_out_channels = len(out_channels)
|
56 |
+
|
57 |
+
params['N_in_channels'] = n_in_channels
|
58 |
+
params['N_out_channels'] = n_out_channels
|
59 |
+
|
60 |
+
if params.normalization == 'zscore':
|
61 |
+
params.means = np.load(params.global_means_path)
|
62 |
+
params.stds = np.load(params.global_stds_path)
|
63 |
+
|
64 |
+
if params.nettype == 'NeuralOM':
|
65 |
+
from networks.MIGNN1 import MIGraph as model
|
66 |
+
from networks.MIGNN2 import MIGraph_stage2 as model2
|
67 |
+
else:
|
68 |
+
raise Exception("not implemented")
|
69 |
+
|
70 |
+
checkpoint_file = params['best_checkpoint_path']
|
71 |
+
checkpoint_file2 = params['best_checkpoint_path2']
|
72 |
+
logging.info('Loading trained model checkpoint from {}'.format(checkpoint_file))
|
73 |
+
logging.info('Loading trained model2 checkpoint from {}'.format(checkpoint_file2))
|
74 |
+
|
75 |
+
model = model(params).to(device)
|
76 |
+
model = load_model(model, params, checkpoint_file)
|
77 |
+
model = model.to(device)
|
78 |
+
|
79 |
+
print('model is ok')
|
80 |
+
|
81 |
+
model2 = model2(params).to(device)
|
82 |
+
model2 = load_model(model2, params, checkpoint_file2)
|
83 |
+
model2 = model2.to(device)
|
84 |
+
|
85 |
+
print('model2 is ok')
|
86 |
+
|
87 |
+
files_paths = glob.glob(params.test_data_path + "/*.h5")
|
88 |
+
files_paths.sort()
|
89 |
+
|
90 |
+
# which year
|
91 |
+
yr = 0
|
92 |
+
logging.info('Loading inference data')
|
93 |
+
logging.info('Inference data from {}'.format(files_paths[yr]))
|
94 |
+
climate_mean = np.load('./data/climate_mean_s_t_ssh.npy')
|
95 |
+
valid_data_full = h5py.File(files_paths[yr], 'r')['fields'][:365, :, :, :]
|
96 |
+
valid_data_full = valid_data_full - climate_mean
|
97 |
+
|
98 |
+
return valid_data_full, model, model2
|
99 |
+
|
100 |
+
|
101 |
+
def autoregressive_inference(params, init_condition, valid_data_full, model, model2):
|
102 |
+
device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
|
103 |
+
|
104 |
+
icd = int(init_condition)
|
105 |
+
|
106 |
+
exp_dir = params['experiment_dir']
|
107 |
+
dt = int(params.dt)
|
108 |
+
prediction_length = int(params.prediction_length/dt)
|
109 |
+
n_history = params.n_history
|
110 |
+
img_shape_x = params.img_shape_x
|
111 |
+
img_shape_y = params.img_shape_y
|
112 |
+
in_channels = np.array(params.in_channels)
|
113 |
+
out_channels = np.array(params.out_channels)
|
114 |
+
atmos_channels = np.array(params.atmos_channels)
|
115 |
+
n_in_channels = len(in_channels)
|
116 |
+
n_out_channels = len(out_channels)
|
117 |
+
|
118 |
+
seq_real = torch.zeros((prediction_length, n_out_channels, img_shape_x, img_shape_y))
|
119 |
+
seq_pred = torch.zeros((prediction_length, n_out_channels, img_shape_x, img_shape_y))
|
120 |
+
|
121 |
+
|
122 |
+
valid_data = valid_data_full[icd:(icd+prediction_length*dt+n_history*dt):dt][:, params.in_channels][:,:,0:360]
|
123 |
+
logging.info(f'valid_data_full: {valid_data_full.shape}')
|
124 |
+
logging.info(f'valid_data: {valid_data.shape}')
|
125 |
+
|
126 |
+
# normalize
|
127 |
+
if params.normalization == 'zscore':
|
128 |
+
valid_data = (valid_data - params.means[:,params.in_channels])/params.stds[:,params.in_channels]
|
129 |
+
valid_data = np.nan_to_num(valid_data, nan=0)
|
130 |
+
|
131 |
+
valid_data = torch.as_tensor(valid_data)
|
132 |
+
|
133 |
+
# autoregressive inference
|
134 |
+
logging.info('Begin autoregressive inference')
|
135 |
+
|
136 |
+
|
137 |
+
with torch.no_grad():
|
138 |
+
for i in range(valid_data.shape[0]):
|
139 |
+
if i==0: # start of sequence, t0 --> t0'
|
140 |
+
first = valid_data[0:n_history+1]
|
141 |
+
ic(valid_data.shape, first.shape)
|
142 |
+
future = valid_data[n_history+1]
|
143 |
+
ic(future.shape)
|
144 |
+
|
145 |
+
for h in range(n_history+1):
|
146 |
+
|
147 |
+
seq_real[h] = first[h*n_in_channels : (h+1)*n_in_channels, :93]
|
148 |
+
|
149 |
+
seq_pred[h] = seq_real[h]
|
150 |
+
|
151 |
+
first = first.to(device, dtype=torch.float)
|
152 |
+
first_ocean = first[:, params.ocean_channels, :, :]
|
153 |
+
ic(first_ocean.shape)
|
154 |
+
future_force0 = first[:, params.atmos_channels, :, :]
|
155 |
+
|
156 |
+
future_force = future[params.atmos_channels, :360, :720]
|
157 |
+
future_force = torch.unsqueeze(future_force, dim=0).to(device, dtype=torch.float)
|
158 |
+
model_input = torch.cat((first_ocean, future_force0, future_force.cuda()), axis=1)
|
159 |
+
ic(model_input.shape)
|
160 |
+
model1_future_pred = model(model_input)
|
161 |
+
with h5py.File(params.land_mask_path, 'r') as _f:
|
162 |
+
mask_data = torch.as_tensor(_f['fields'][:,out_channels, :360, :720], dtype=bool).to(device, dtype=torch.bool)
|
163 |
+
model1_future_pred = torch.masked_fill(input=model1_future_pred, mask=~mask_data, value=0)
|
164 |
+
future_pred = model2(model1_future_pred) + model1_future_pred
|
165 |
+
|
166 |
+
else:
|
167 |
+
if i < prediction_length-1:
|
168 |
+
future0 = valid_data[n_history+i]
|
169 |
+
future = valid_data[n_history+i+1]
|
170 |
+
|
171 |
+
inf_one_step_start = time.time()
|
172 |
+
future_force0 = future0[params.atmos_channels, :360, :720]
|
173 |
+
future_force = future[params.atmos_channels, :360, :720]
|
174 |
+
future_force0 = torch.unsqueeze(future_force0, dim=0).to(device, dtype=torch.float)
|
175 |
+
future_force = torch.unsqueeze(future_force, dim=0).to(device, dtype=torch.float)
|
176 |
+
model1_future_pred = model(torch.cat((future_pred.cuda(), future_force0, future_force), axis=1)) #autoregressive step
|
177 |
+
with h5py.File(params.land_mask_path, 'r') as _f:
|
178 |
+
mask_data = torch.as_tensor(_f['fields'][:,out_channels, :360, :720], dtype=bool).to(device, dtype=torch.bool)
|
179 |
+
model1_future_pred = torch.masked_fill(input=model1_future_pred, mask=~mask_data, value=0)
|
180 |
+
future_pred = model2(model1_future_pred) + model1_future_pred
|
181 |
+
inf_one_step_time = time.time() - inf_one_step_start
|
182 |
+
|
183 |
+
logging.info(f'inference one step time: {inf_one_step_time}')
|
184 |
+
|
185 |
+
if i < prediction_length - 1: # not on the last step
|
186 |
+
with h5py.File(params.land_mask_path, 'r') as _f:
|
187 |
+
mask_data = torch.as_tensor(_f['fields'][:,out_channels, :360, :720], dtype=bool)
|
188 |
+
seq_pred[n_history+i+1] = torch.masked_fill(input=future_pred.cpu(), mask=~mask_data, value=0)
|
189 |
+
seq_real[n_history+i+1] = future[:93]
|
190 |
+
history_stack = seq_pred[i+1:i+2+n_history]
|
191 |
+
|
192 |
+
future_pred = history_stack
|
193 |
+
|
194 |
+
pred = torch.unsqueeze(seq_pred[i], 0)
|
195 |
+
tar = torch.unsqueeze(seq_real[i], 0)
|
196 |
+
|
197 |
+
with h5py.File(params.land_mask_path, 'r') as _f:
|
198 |
+
mask_data = torch.as_tensor(_f['fields'][:,out_channels, :360, :720], dtype=bool)
|
199 |
+
ic(mask_data.shape, pred.shape, tar.shape)
|
200 |
+
pred = torch.masked_fill(input=pred, mask=~mask_data, value=0)
|
201 |
+
tar = torch.masked_fill(input=tar, mask=~mask_data, value=0)
|
202 |
+
|
203 |
+
print(torch.mean((pred-tar)**2))
|
204 |
+
|
205 |
+
|
206 |
+
seq_real = seq_real * params.stds[:,params.out_channels] + params.means[:,params.out_channels]
|
207 |
+
seq_real = seq_real.numpy()
|
208 |
+
seq_pred = seq_pred * params.stds[:,params.out_channels] + params.means[:,params.out_channels]
|
209 |
+
seq_pred = seq_pred.numpy()
|
210 |
+
|
211 |
+
|
212 |
+
return (np.expand_dims(seq_real[n_history:], 0),
|
213 |
+
np.expand_dims(seq_pred[n_history:], 0),
|
214 |
+
)
|
215 |
+
|
216 |
+
|
217 |
+
if __name__ == '__main__':
|
218 |
+
parser = argparse.ArgumentParser()
|
219 |
+
parser.add_argument("--exp_dir", default='../exp_15_levels', type=str)
|
220 |
+
parser.add_argument("--config", default='full_field', type=str)
|
221 |
+
parser.add_argument("--run_num", default='00', type=str)
|
222 |
+
parser.add_argument("--prediction_length", default=61, type=int)
|
223 |
+
parser.add_argument("--finetune_dir", default='', type=str)
|
224 |
+
parser.add_argument("--ics_type", default='default', type=str)
|
225 |
+
args = parser.parse_args()
|
226 |
+
|
227 |
+
config_path = os.path.join(args.exp_dir, args.config, args.run_num, 'config.yaml')
|
228 |
+
params = YParams(config_path, args.config)
|
229 |
+
|
230 |
+
params['resuming'] = False
|
231 |
+
params['interp'] = 0
|
232 |
+
params['world_size'] = 1
|
233 |
+
params['local_rank'] = 0
|
234 |
+
params['global_batch_size'] = params.batch_size
|
235 |
+
params['prediction_length'] = args.prediction_length
|
236 |
+
params['multi_steps_finetune'] = 1
|
237 |
+
|
238 |
+
torch.cuda.set_device(0)
|
239 |
+
torch.backends.cudnn.benchmark = True
|
240 |
+
|
241 |
+
# Set up directory
|
242 |
+
if args.finetune_dir == '':
|
243 |
+
expDir = os.path.join(params.exp_dir, args.config, str(args.run_num))
|
244 |
+
else:
|
245 |
+
expDir = os.path.join(params.exp_dir, args.config, str(args.run_num), args.finetune_dir)
|
246 |
+
logging.info(f'expDir: {expDir}')
|
247 |
+
params['experiment_dir'] = expDir
|
248 |
+
params['best_checkpoint_path'] = os.path.join(expDir, 'training_checkpoints/best_ckpt.tar')
|
249 |
+
params['best_checkpoint_path2'] = os.path.join(expDir, 'model2/10_steps_finetune/training_checkpoints/best_ckpt.tar')
|
250 |
+
|
251 |
+
# set up logging
|
252 |
+
logging_utils.log_to_file(logger_name=None, log_filename=os.path.join(expDir, 'inference.log'))
|
253 |
+
logging_utils.log_versions()
|
254 |
+
params.log()
|
255 |
+
|
256 |
+
if params["ics_type"] == 'default':
|
257 |
+
ics = np.arange(0, 240, 1)
|
258 |
+
n_ics = len(ics)
|
259 |
+
print('init_condition:', ics)
|
260 |
+
|
261 |
+
logging.info("Inference for {} initial conditions".format(n_ics))
|
262 |
+
|
263 |
+
try:
|
264 |
+
autoregressive_inference_filetag = params["inference_file_tag"]
|
265 |
+
except:
|
266 |
+
autoregressive_inference_filetag = ""
|
267 |
+
if params.interp > 0:
|
268 |
+
autoregressive_inference_filetag = "_coarse"
|
269 |
+
|
270 |
+
valid_data_full, model, model2 = setup(params)
|
271 |
+
|
272 |
+
|
273 |
+
seq_pred = []
|
274 |
+
seq_real = []
|
275 |
+
|
276 |
+
# run autoregressive inference for multiple initial conditions
|
277 |
+
for i, ic_ in enumerate(ics):
|
278 |
+
logging.info("Initial condition {} of {}".format(i+1, n_ics))
|
279 |
+
seq_real, seq_pred = autoregressive_inference(params, ic_, valid_data_full, model, model2)
|
280 |
+
|
281 |
+
prediction_length = seq_real[0].shape[0]
|
282 |
+
n_out_channels = seq_real[0].shape[1]
|
283 |
+
img_shape_x = seq_real[0].shape[2]
|
284 |
+
img_shape_y = seq_real[0].shape[3]
|
285 |
+
|
286 |
+
# save predictions and loss
|
287 |
+
save_path = os.path.join(params['experiment_dir'], 'results_simulation.h5')
|
288 |
+
logging.info("Saving to {}".format(save_path))
|
289 |
+
print(f'saving to {save_path}')
|
290 |
+
if i==0:
|
291 |
+
f = h5py.File(save_path, 'w')
|
292 |
+
f.create_dataset(
|
293 |
+
"ground_truth",
|
294 |
+
data=seq_real,
|
295 |
+
maxshape=[None, prediction_length, n_out_channels, img_shape_x, img_shape_y],
|
296 |
+
dtype=np.float32)
|
297 |
+
f.create_dataset(
|
298 |
+
"predicted",
|
299 |
+
data=seq_pred,
|
300 |
+
maxshape=[None, prediction_length, n_out_channels, img_shape_x, img_shape_y],
|
301 |
+
dtype=np.float32)
|
302 |
+
f.close()
|
303 |
+
else:
|
304 |
+
f = h5py.File(save_path, 'a')
|
305 |
+
|
306 |
+
f["ground_truth"].resize((f["ground_truth"].shape[0] + 1), axis = 0)
|
307 |
+
f["ground_truth"][-1:] = seq_real
|
308 |
+
|
309 |
+
f["predicted"].resize((f["predicted"].shape[0] + 1), axis = 0)
|
310 |
+
f["predicted"][-1:] = seq_pred
|
311 |
+
f.close()
|
312 |
+
|
inference_simulation.sh
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
prediction_length=61 # 31
|
2 |
+
|
3 |
+
exp_dir='./exp'
|
4 |
+
config='NeuralOM'
|
5 |
+
run_num='20250309-195251'
|
6 |
+
finetune_dir='6_steps_finetune'
|
7 |
+
|
8 |
+
ics_type='default'
|
9 |
+
|
10 |
+
CUDA_VISIBLE_DEVICES=0 python inference_simulation.py --exp_dir=${exp_dir} --config=${config} --run_num=${run_num} --finetune_dir=$finetune_dir --prediction_length=${prediction_length} --ics_type=${ics_type}
|
11 |
+
|
12 |
+
|
13 |
+
|
my_utils/YParams.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
|
3 |
+
import importlib
|
4 |
+
import sys
|
5 |
+
import os
|
6 |
+
importlib.reload(sys)
|
7 |
+
|
8 |
+
from ruamel.yaml import YAML
|
9 |
+
import logging
|
10 |
+
|
11 |
+
class YParams():
|
12 |
+
""" Yaml file parser """
|
13 |
+
def __init__(self, yaml_filename, config_name, print_params=False):
|
14 |
+
self._yaml_filename = yaml_filename
|
15 |
+
self._config_name = config_name
|
16 |
+
self.params = {}
|
17 |
+
|
18 |
+
if print_params:
|
19 |
+
print(os.system('hostname'))
|
20 |
+
print("------------------ Configuration ------------------ ", yaml_filename)
|
21 |
+
|
22 |
+
with open(yaml_filename, 'rb') as _file:
|
23 |
+
yaml = YAML().load(_file)
|
24 |
+
for key, val in yaml[config_name].items():
|
25 |
+
if print_params: print(key, val)
|
26 |
+
if val =='None': val = None
|
27 |
+
|
28 |
+
self.params[key] = val
|
29 |
+
self.__setattr__(key, val)
|
30 |
+
|
31 |
+
if print_params:
|
32 |
+
print("---------------------------------------------------")
|
33 |
+
|
34 |
+
def __getitem__(self, key):
|
35 |
+
return self.params[key]
|
36 |
+
|
37 |
+
def __setitem__(self, key, val):
|
38 |
+
self.params[key] = val
|
39 |
+
self.__setattr__(key, val)
|
40 |
+
|
41 |
+
def __contains__(self, key):
|
42 |
+
return (key in self.params)
|
43 |
+
|
44 |
+
def update_params(self, config):
|
45 |
+
for key, val in config.items():
|
46 |
+
self.params[key] = val
|
47 |
+
self.__setattr__(key, val)
|
48 |
+
|
49 |
+
def log(self):
|
50 |
+
logging.info("------------------ Configuration ------------------")
|
51 |
+
logging.info("Configuration file: "+str(self._yaml_filename))
|
52 |
+
logging.info("Configuration name: "+str(self._config_name))
|
53 |
+
for key, val in self.params.items():
|
54 |
+
logging.info(str(key) + ' ' + str(val))
|
55 |
+
logging.info("---------------------------------------------------")
|
my_utils/__pycache__/YParams.cpython-310.pyc
ADDED
Binary file (2.1 kB). View file
|
|
my_utils/__pycache__/YParams.cpython-37.pyc
ADDED
Binary file (2.11 kB). View file
|
|
my_utils/__pycache__/YParams.cpython-39.pyc
ADDED
Binary file (2.08 kB). View file
|
|
my_utils/__pycache__/bicubic.cpython-310.pyc
ADDED
Binary file (9.24 kB). View file
|
|
my_utils/__pycache__/bicubic.cpython-39.pyc
ADDED
Binary file (9.2 kB). View file
|
|
my_utils/__pycache__/darcy_loss.cpython-310.pyc
ADDED
Binary file (13.7 kB). View file
|
|
my_utils/__pycache__/darcy_loss.cpython-310.pyc.70370790180304
ADDED
Binary file (13.5 kB). View file
|
|
my_utils/__pycache__/darcy_loss.cpython-310.pyc.70373230085584
ADDED
Binary file (13.5 kB). View file
|
|
my_utils/__pycache__/darcy_loss.cpython-310.pyc.70384414393808
ADDED
Binary file (13.5 kB). View file
|
|
my_utils/__pycache__/darcy_loss.cpython-37.pyc
ADDED
Binary file (9.02 kB). View file
|
|
my_utils/__pycache__/darcy_loss.cpython-39.pyc
ADDED
Binary file (14.2 kB). View file
|
|
my_utils/__pycache__/data_loader.cpython-310.pyc
ADDED
Binary file (5 kB). View file
|
|
my_utils/__pycache__/data_loader_multifiles.cpython-310.pyc
ADDED
Binary file (4.94 kB). View file
|
|
my_utils/__pycache__/data_loader_multifiles.cpython-37.pyc
ADDED
Binary file (3.69 kB). View file
|
|
my_utils/__pycache__/data_loader_multifiles.cpython-39.pyc
ADDED
Binary file (16.5 kB). View file
|
|
my_utils/__pycache__/get_date.cpython-310.pyc
ADDED
Binary file (167 Bytes). View file
|
|
my_utils/__pycache__/img_utils.cpython-310.pyc
ADDED
Binary file (3.39 kB). View file
|
|
my_utils/__pycache__/img_utils.cpython-37.pyc
ADDED
Binary file (3.81 kB). View file
|
|
my_utils/__pycache__/img_utils.cpython-39.pyc
ADDED
Binary file (4.83 kB). View file
|
|
my_utils/__pycache__/logging_utils.cpython-310.pyc
ADDED
Binary file (1 kB). View file
|
|
my_utils/__pycache__/logging_utils.cpython-37.pyc
ADDED
Binary file (994 Bytes). View file
|
|
my_utils/__pycache__/logging_utils.cpython-39.pyc
ADDED
Binary file (995 Bytes). View file
|
|
my_utils/__pycache__/norm.cpython-310.pyc
ADDED
Binary file (3.37 kB). View file
|
|
my_utils/__pycache__/time_utils.cpython-310.pyc
ADDED
Binary file (606 Bytes). View file
|
|
my_utils/__pycache__/time_utils.cpython-39.pyc
ADDED
Binary file (578 Bytes). View file
|
|
my_utils/__pycache__/weighted_acc_rmse.cpython-310.pyc
ADDED
Binary file (5.9 kB). View file
|
|
my_utils/__pycache__/weighted_acc_rmse.cpython-37.pyc
ADDED
Binary file (6.27 kB). View file
|
|
my_utils/__pycache__/weighted_acc_rmse.cpython-39.pyc
ADDED
Binary file (6 kB). View file
|
|
my_utils/data_loader.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import glob
|
3 |
+
import torch
|
4 |
+
import random
|
5 |
+
import numpy as np
|
6 |
+
from torch.utils.data import DataLoader, Dataset
|
7 |
+
from torch.utils.data.distributed import DistributedSampler
|
8 |
+
from torch import Tensor
|
9 |
+
import h5py
|
10 |
+
import math
|
11 |
+
from my_utils.norm import reshape_fields
|
12 |
+
import os
|
13 |
+
|
14 |
+
|
15 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
16 |
+
parent_dir = os.path.dirname(current_dir)
|
17 |
+
climate_mean_path = os.path.join(parent_dir, 'data/climate_mean_s_t_ssh.npy')
|
18 |
+
|
19 |
+
def get_data_loader(params, files_pattern, distributed, train):
|
20 |
+
dataset = GetDataset(params, files_pattern, train)
|
21 |
+
sampler = DistributedSampler(dataset, shuffle=train) if distributed else None
|
22 |
+
|
23 |
+
|
24 |
+
dataloader = DataLoader(dataset,
|
25 |
+
batch_size = int(params.batch_size),
|
26 |
+
num_workers = params.num_data_workers,
|
27 |
+
shuffle = False,
|
28 |
+
sampler = sampler if train else None,
|
29 |
+
drop_last = True,
|
30 |
+
pin_memory = True)
|
31 |
+
|
32 |
+
if train:
|
33 |
+
return dataloader, dataset, sampler
|
34 |
+
else:
|
35 |
+
return dataloader, dataset
|
36 |
+
|
37 |
+
|
38 |
+
class GetDataset(Dataset):
|
39 |
+
def __init__(self, params, location, train):
|
40 |
+
self.params = params
|
41 |
+
self.location = location
|
42 |
+
self.train = train
|
43 |
+
self.orography = params.orography
|
44 |
+
self.normalize = params.normalize
|
45 |
+
self.dt = params.dt
|
46 |
+
self.n_history = params.n_history
|
47 |
+
self.in_channels = np.array(params.in_channels)
|
48 |
+
self.out_channels = np.array(params.out_channels)
|
49 |
+
self.ocean_channels = np.array(params.ocean_channels)
|
50 |
+
self.atmos_channels = np.array(params.atmos_channels)
|
51 |
+
self.n_in_channels = len(self.in_channels)
|
52 |
+
self.n_out_channels = len(self.out_channels)
|
53 |
+
|
54 |
+
self._get_files_stats()
|
55 |
+
self.add_noise = params.add_noise if train else False
|
56 |
+
self.climate_mean = np.load(climate_mean_path, mmap_mode='r')
|
57 |
+
|
58 |
+
|
59 |
+
def _get_files_stats(self):
|
60 |
+
self.files_paths = glob.glob(self.location + "/*.h5")
|
61 |
+
self.files_paths.sort()
|
62 |
+
self.n_years = len(self.files_paths)
|
63 |
+
|
64 |
+
with h5py.File(self.files_paths[0], 'r') as _f:
|
65 |
+
logging.info("Getting file stats from {}".format(self.files_paths[0]))
|
66 |
+
|
67 |
+
self.n_samples_per_year = _f['fields'].shape[0] - self.params.multi_steps_finetune
|
68 |
+
|
69 |
+
self.img_shape_x = _f['fields'].shape[2] - 1
|
70 |
+
self.img_shape_y = _f['fields'].shape[3]
|
71 |
+
|
72 |
+
self.n_samples_total = self.n_years * self.n_samples_per_year
|
73 |
+
self.files = [None for _ in range(self.n_years)]
|
74 |
+
|
75 |
+
logging.info("Number of samples per year: {}".format(self.n_samples_per_year))
|
76 |
+
logging.info("Found data at path {}. Number of examples: {}. Image Shape: {} x {} x {}".format(self.location,
|
77 |
+
self.n_samples_total,
|
78 |
+
self.img_shape_x,
|
79 |
+
self.img_shape_y,
|
80 |
+
self.n_in_channels))
|
81 |
+
logging.info("Delta t: {} days".format(1 * self.dt))
|
82 |
+
logging.info("Including {} days of past history in training at a frequency of {} days".format(
|
83 |
+
1 * self.dt * self.n_history, 1 * self.dt))
|
84 |
+
|
85 |
+
def _open_file(self, year_idx):
|
86 |
+
_file = h5py.File(self.files_paths[year_idx], 'r')
|
87 |
+
self.files[year_idx] = _file['fields']
|
88 |
+
|
89 |
+
if self.orography and self.params.normalization == 'zscore':
|
90 |
+
_orog_file = h5py.File(self.params.orography_norm_zscore_path, 'r')
|
91 |
+
if self.orography and self.params.normalization == 'maxmin':
|
92 |
+
_orog_file = h5py.File(self.params.orography_norm_maxmin_path, 'r')
|
93 |
+
|
94 |
+
def __len__(self):
|
95 |
+
return self.n_samples_total
|
96 |
+
|
97 |
+
def __getitem__(self, global_idx):
|
98 |
+
year_idx = int(global_idx / self.n_samples_per_year) # which year
|
99 |
+
local_idx = int(global_idx % self.n_samples_per_year) # which sample in a year
|
100 |
+
|
101 |
+
if self.files[year_idx] is None:
|
102 |
+
self._open_file(year_idx)
|
103 |
+
|
104 |
+
if local_idx < self.dt * self.n_history:
|
105 |
+
local_idx += self.dt * self.n_history
|
106 |
+
|
107 |
+
step = 0 if local_idx >= self.n_samples_per_year - self.dt else self.dt
|
108 |
+
|
109 |
+
orog = None
|
110 |
+
|
111 |
+
|
112 |
+
if self.params.multi_steps_finetune == 1:
|
113 |
+
if local_idx == 365:
|
114 |
+
local_idx = 364
|
115 |
+
|
116 |
+
climate_mean_ocean = self.climate_mean[(local_idx-self.dt*self.n_history):(local_idx+1):self.dt, self.ocean_channels, :360, :720]
|
117 |
+
ocean = reshape_fields(
|
118 |
+
self.files[year_idx][(local_idx-self.dt*self.n_history):(local_idx+1):self.dt, self.ocean_channels, :360, :720] - climate_mean_ocean,
|
119 |
+
'ocean',
|
120 |
+
self.params,
|
121 |
+
self.train,
|
122 |
+
self.normalize,
|
123 |
+
orog,
|
124 |
+
self.add_noise
|
125 |
+
)
|
126 |
+
|
127 |
+
force_future0 = reshape_fields(
|
128 |
+
self.files[year_idx][local_idx, self.atmos_channels, :360, :720],
|
129 |
+
'force',
|
130 |
+
self.params,
|
131 |
+
self.train,
|
132 |
+
self.normalize,
|
133 |
+
orog,
|
134 |
+
self.add_noise
|
135 |
+
)
|
136 |
+
|
137 |
+
force_future1 = reshape_fields(
|
138 |
+
self.files[year_idx][local_idx+step, self.atmos_channels, :360, :720],
|
139 |
+
'force',
|
140 |
+
self.params,
|
141 |
+
self.train,
|
142 |
+
self.normalize,
|
143 |
+
orog,
|
144 |
+
self.add_noise
|
145 |
+
)
|
146 |
+
|
147 |
+
climate_mean_tar = self.climate_mean[local_idx+step, self.out_channels, :360, :720]
|
148 |
+
tar = reshape_fields(
|
149 |
+
self.files[year_idx][local_idx+step, self.out_channels, :360, :720] - climate_mean_tar,
|
150 |
+
'tar',
|
151 |
+
self.params,
|
152 |
+
self.train,
|
153 |
+
self.normalize,
|
154 |
+
orog
|
155 |
+
)
|
156 |
+
else:
|
157 |
+
climate_mean_ocean = self.climate_mean[(local_idx-self.dt*self.n_history):(local_idx+1):self.dt, self.ocean_channels, :360, :720]
|
158 |
+
ocean = reshape_fields(
|
159 |
+
self.files[year_idx][(local_idx-self.dt*self.n_history):(local_idx+1):self.dt, self.ocean_channels, :360, :720] - climate_mean_ocean,
|
160 |
+
'ocean',
|
161 |
+
self.params,
|
162 |
+
self.train,
|
163 |
+
self.normalize,
|
164 |
+
orog,
|
165 |
+
self.add_noise
|
166 |
+
)
|
167 |
+
|
168 |
+
force_future0 = reshape_fields(
|
169 |
+
self.files[year_idx][local_idx, self.atmos_channels, :360, :720],
|
170 |
+
'force',
|
171 |
+
self.params,
|
172 |
+
self.train,
|
173 |
+
self.normalize,
|
174 |
+
orog,
|
175 |
+
self.add_noise
|
176 |
+
)
|
177 |
+
|
178 |
+
force_future1 = reshape_fields(
|
179 |
+
self.files[year_idx][local_idx+step, self.atmos_channels, :360, :720],
|
180 |
+
'force',
|
181 |
+
self.params,
|
182 |
+
self.train,
|
183 |
+
self.normalize,
|
184 |
+
orog,
|
185 |
+
self.add_noise
|
186 |
+
)
|
187 |
+
|
188 |
+
climate_mean_tar = self.climate_mean[local_idx+step:local_idx+step+self.params.multi_steps_finetune, self.in_channels, :360, :720]
|
189 |
+
tar_data = self.files[year_idx][local_idx+step:local_idx+step+self.params.multi_steps_finetune, self.in_channels, :360, :720]
|
190 |
+
tar = reshape_fields(
|
191 |
+
tar_data - climate_mean_tar,
|
192 |
+
'inp',
|
193 |
+
self.params,
|
194 |
+
self.train,
|
195 |
+
self.normalize,
|
196 |
+
orog
|
197 |
+
)
|
198 |
+
|
199 |
+
ocean = np.nan_to_num(ocean, nan=0)
|
200 |
+
force_future0 = np.nan_to_num(force_future0, nan=0)
|
201 |
+
force_future1 = np.nan_to_num(force_future1, nan=0)
|
202 |
+
tar = np.nan_to_num(tar, nan=0)
|
203 |
+
|
204 |
+
|
205 |
+
return np.concatenate((ocean, force_future0, force_future1), axis=0), tar
|
my_utils/logging_utils.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
|
4 |
+
_format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
5 |
+
|
6 |
+
def config_logger(log_level=logging.INFO):
|
7 |
+
logging.basicConfig(format=_format, level=log_level)
|
8 |
+
|
9 |
+
def log_to_file(logger_name=None, log_level=logging.INFO, log_filename='tensorflow.log'):
|
10 |
+
|
11 |
+
if not os.path.exists(os.path.dirname(log_filename)):
|
12 |
+
os.makedirs(os.path.dirname(log_filename))
|
13 |
+
|
14 |
+
if logger_name is not None:
|
15 |
+
log = logging.getLogger(logger_name)
|
16 |
+
else:
|
17 |
+
log = logging.getLogger()
|
18 |
+
|
19 |
+
fh = logging.FileHandler(log_filename)
|
20 |
+
fh.setLevel(log_level)
|
21 |
+
fh.setFormatter(logging.Formatter(_format))
|
22 |
+
log.addHandler(fh)
|
23 |
+
|
24 |
+
def log_versions():
|
25 |
+
import torch
|
26 |
+
import subprocess
|
my_utils/norm.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import glob
|
3 |
+
from types import new_class
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import random
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from torch.utils.data import DataLoader, Dataset
|
11 |
+
from torch.utils.data.distributed import DistributedSampler
|
12 |
+
from torch import Tensor
|
13 |
+
import h5py
|
14 |
+
import math
|
15 |
+
import torchvision.transforms.functional as TF
|
16 |
+
# import matplotlib
|
17 |
+
# import matplotlib.pyplot as plt
|
18 |
+
|
19 |
+
class PeriodicPad2d(nn.Module):
|
20 |
+
"""
|
21 |
+
pad longitudinal (left-right) circular
|
22 |
+
and pad latitude (top-bottom) with zeros
|
23 |
+
"""
|
24 |
+
def __init__(self, pad_width):
|
25 |
+
super(PeriodicPad2d, self).__init__()
|
26 |
+
self.pad_width = pad_width
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
# pad left and right circular
|
30 |
+
out = F.pad(x, (self.pad_width, self.pad_width, 0, 0), mode="circular")
|
31 |
+
# pad top and bottom zeros
|
32 |
+
out = F.pad(out, (0, 0, self.pad_width, self.pad_width), mode="constant", value=0)
|
33 |
+
return out
|
34 |
+
|
35 |
+
def reshape_fields(img, inp_or_tar, params, train, normalize=True, orog=None, add_noise=False):
|
36 |
+
# Takes in np array of size (n_history+1, c, h, w)
|
37 |
+
# returns torch tensor of size ((n_channels*(n_history+1), crop_size_x, crop_size_y)
|
38 |
+
|
39 |
+
if len(np.shape(img)) == 3:
|
40 |
+
img = np.expand_dims(img, 0)
|
41 |
+
|
42 |
+
if np.shape(img)[2] == 721:
|
43 |
+
img = img[:,:, 0:720, :] # remove last pixel
|
44 |
+
|
45 |
+
n_history = np.shape(img)[0] - 1
|
46 |
+
img_shape_x = np.shape(img)[-2]
|
47 |
+
img_shape_y = np.shape(img)[-1]
|
48 |
+
n_channels = np.shape(img)[1] # this will either be N_in_channels or N_out_channels
|
49 |
+
|
50 |
+
if inp_or_tar == 'inp':
|
51 |
+
channels = params.in_channels
|
52 |
+
elif inp_or_tar == 'ocean':
|
53 |
+
channels = params.ocean_channels
|
54 |
+
elif inp_or_tar == 'force':
|
55 |
+
channels = params.atmos_channels
|
56 |
+
else:
|
57 |
+
channels = params.out_channels
|
58 |
+
|
59 |
+
if normalize and params.normalization == 'minmax':
|
60 |
+
maxs = np.load(params.global_maxs_path)[:, channels]
|
61 |
+
mins = np.load(params.global_mins_path)[:, channels]
|
62 |
+
img = (img - mins) / (maxs - mins)
|
63 |
+
|
64 |
+
if normalize and params.normalization == 'zscore':
|
65 |
+
means = np.load(params.global_means_path)[:, channels]
|
66 |
+
stds = np.load(params.global_stds_path)[:, channels]
|
67 |
+
img -=means
|
68 |
+
img /=stds
|
69 |
+
|
70 |
+
if normalize and params.normalization == 'zscore_lat':
|
71 |
+
means = np.load(params.global_lat_means_path)[:, channels,:720]
|
72 |
+
stds = np.load(params.global_lat_stds_path)[:, channels,:720]
|
73 |
+
img -=means
|
74 |
+
img /=stds
|
75 |
+
|
76 |
+
if params.orography and inp_or_tar == 'inp':
|
77 |
+
# print('img:', img.shape, 'orog:', orog.shape)
|
78 |
+
orog = np.expand_dims(orog, axis = (0,1))
|
79 |
+
orog = np.repeat(orog, repeats=img.shape[0], axis=0)
|
80 |
+
# print('img:', img.shape, 'orog:', orog.shape)
|
81 |
+
img = np.concatenate((img, orog), axis = 1)
|
82 |
+
n_channels += 1
|
83 |
+
|
84 |
+
img = np.squeeze(img)
|
85 |
+
# if inp_or_tar == 'inp':
|
86 |
+
# img = np.reshape(img, (n_channels*(n_history+1))) # ??
|
87 |
+
# elif inp_or_tar == 'tar':
|
88 |
+
# img = np.reshape(img, (n_channels, crop_size_x, crop_size_y)) #??
|
89 |
+
|
90 |
+
if add_noise:
|
91 |
+
img = img + np.random.normal(0, scale=params.noise_std, size=img.shape)
|
92 |
+
|
93 |
+
return torch.as_tensor(img)
|
94 |
+
|
95 |
+
def vis_precip(fields):
|
96 |
+
pred, tar = fields
|
97 |
+
fig, ax = plt.subplots(1, 2, figsize=(24,12))
|
98 |
+
ax[0].imshow(pred, cmap="coolwarm")
|
99 |
+
ax[0].set_title("tp pred")
|
100 |
+
ax[1].imshow(tar, cmap="coolwarm")
|
101 |
+
ax[1].set_title("tp tar")
|
102 |
+
fig.tight_layout()
|
103 |
+
return fig
|
104 |
+
|
105 |
+
def read_max_min_value(min_max_val_file_path):
|
106 |
+
with h5py.File(min_max_val_file_path, 'r') as f:
|
107 |
+
max_values = f['max_values']
|
108 |
+
min_values = f['min_values']
|
109 |
+
return max_values, min_values
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
+
|
114 |
+
|
networks/.ipynb_checkpoints/CirT1-checkpoint.py
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import lru_cache
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from timm.models.vision_transformer import trunc_normal_, Block
|
7 |
+
from torch.jit import Final
|
8 |
+
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from typing import Optional
|
11 |
+
from timm.layers import DropPath, use_fused_attn, Mlp
|
12 |
+
|
13 |
+
class PatchEmbed(nn.Module):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
img_size=[121, 240],
|
17 |
+
in_chans=63,
|
18 |
+
embed_dim=768,
|
19 |
+
norm_layer=None,
|
20 |
+
flatten=True,
|
21 |
+
bias=True,
|
22 |
+
):
|
23 |
+
super().__init__()
|
24 |
+
self.img_size = img_size
|
25 |
+
self.num_patches = img_size[0]
|
26 |
+
self.flatten = flatten
|
27 |
+
|
28 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=[1, img_size[1]], stride=1, bias=bias)
|
29 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
B, C, H, W = x.shape
|
33 |
+
x = self.proj(x)
|
34 |
+
if self.flatten:
|
35 |
+
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
36 |
+
x = self.norm(x)
|
37 |
+
return x
|
38 |
+
|
39 |
+
class Attention(nn.Module):
|
40 |
+
fused_attn: Final[bool]
|
41 |
+
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
dim: int,
|
45 |
+
num_heads: int = 8,
|
46 |
+
qkv_bias: bool = False,
|
47 |
+
qk_norm: bool = False,
|
48 |
+
attn_drop: float = 0.,
|
49 |
+
proj_drop: float = 0.,
|
50 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
51 |
+
) -> None:
|
52 |
+
super().__init__()
|
53 |
+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
54 |
+
self.num_heads = num_heads
|
55 |
+
self.head_dim = dim // self.num_heads
|
56 |
+
self.scale = self.head_dim ** -0.5
|
57 |
+
self.fused_attn = use_fused_attn()
|
58 |
+
self.dim = dim
|
59 |
+
self.attn_bias = nn.Parameter(torch.zeros(121, 121, 2), requires_grad=True)
|
60 |
+
|
61 |
+
# self.qkv = CLinear(dim, dim * 3, bias=qkv_bias)
|
62 |
+
self.qkv = nn.Linear((dim // 2 + 1) * 2, dim * 3, bias=qkv_bias)
|
63 |
+
self.q = nn.Linear((dim // 2 + 1) * 2, dim, bias=qkv_bias)
|
64 |
+
self.k = nn.Linear((dim // 2 + 1) * 2, dim, bias=qkv_bias)
|
65 |
+
self.v = nn.Linear((dim // 2 + 1) * 2, dim, bias=qkv_bias)
|
66 |
+
|
67 |
+
|
68 |
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
69 |
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
70 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
71 |
+
# self.proj = CLinear(dim, dim)
|
72 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
73 |
+
self.proj = nn.Linear(dim, dim)
|
74 |
+
# self.proj_drop = nn.Dropout(proj_drop)
|
75 |
+
|
76 |
+
|
77 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
78 |
+
B, N, C = x.shape
|
79 |
+
|
80 |
+
x = torch.fft.rfft(x, norm="forward")
|
81 |
+
x = torch.view_as_real(x)
|
82 |
+
x = torch.cat((x[:, :, :, 0], -x[:, :, :, 1]), dim=-1)
|
83 |
+
|
84 |
+
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
85 |
+
# q, k, v = qkv.unbind(0)
|
86 |
+
q = self.q(x).reshape(B, self.num_heads, N, self.head_dim)
|
87 |
+
k = self.k(x).reshape(B, self.num_heads, N, self.head_dim)
|
88 |
+
v = self.v(x).reshape(B, self.num_heads, N, self.head_dim)
|
89 |
+
q = q * self.scale
|
90 |
+
attn = q @ k.transpose(-2, -1)
|
91 |
+
|
92 |
+
|
93 |
+
attn = attn.softmax(dim=-1)
|
94 |
+
attn = self.attn_drop(attn)
|
95 |
+
x = attn @ v
|
96 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
97 |
+
x = self.proj(x)
|
98 |
+
x = self.proj_drop(x)
|
99 |
+
|
100 |
+
real, img = torch.split(x, x.shape[-1] // 2, dim=-1)
|
101 |
+
x = torch.stack([real,-img], dim=-1)
|
102 |
+
x = torch.view_as_complex(x)
|
103 |
+
x = torch.fft.irfft(x, self.dim, norm="forward")
|
104 |
+
|
105 |
+
return x
|
106 |
+
|
107 |
+
|
108 |
+
class LayerScale(nn.Module):
|
109 |
+
def __init__(
|
110 |
+
self,
|
111 |
+
dim: int,
|
112 |
+
init_values: float = 1e-5,
|
113 |
+
inplace: bool = False,
|
114 |
+
) -> None:
|
115 |
+
super().__init__()
|
116 |
+
self.inplace = inplace
|
117 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
118 |
+
|
119 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
120 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
121 |
+
|
122 |
+
|
123 |
+
|
124 |
+
class Block(nn.Module):
|
125 |
+
def __init__(
|
126 |
+
self,
|
127 |
+
dim: int,
|
128 |
+
num_heads: int,
|
129 |
+
mlp_ratio: float = 4.,
|
130 |
+
qkv_bias: bool = False,
|
131 |
+
qk_norm: bool = False,
|
132 |
+
proj_drop: float = 0.,
|
133 |
+
attn_drop: float = 0.,
|
134 |
+
init_values: Optional[float] = None,
|
135 |
+
drop_path: float = 0.,
|
136 |
+
act_layer: nn.Module = nn.GELU,
|
137 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
138 |
+
mlp_layer: nn.Module = Mlp,
|
139 |
+
) -> None:
|
140 |
+
super().__init__()
|
141 |
+
self.norm1 = norm_layer(dim)
|
142 |
+
self.attn = Attention(
|
143 |
+
dim,
|
144 |
+
num_heads=num_heads,
|
145 |
+
qkv_bias=qkv_bias,
|
146 |
+
qk_norm=qk_norm,
|
147 |
+
attn_drop=attn_drop,
|
148 |
+
proj_drop=proj_drop,
|
149 |
+
norm_layer=norm_layer,
|
150 |
+
)
|
151 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
152 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
153 |
+
|
154 |
+
self.norm2 = norm_layer(dim)
|
155 |
+
self.mlp = mlp_layer(
|
156 |
+
in_features=dim,
|
157 |
+
hidden_features=int(dim * mlp_ratio),
|
158 |
+
act_layer=act_layer,
|
159 |
+
drop=proj_drop,
|
160 |
+
)
|
161 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
162 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
163 |
+
|
164 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
165 |
+
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
|
166 |
+
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
|
167 |
+
return x
|
168 |
+
|
169 |
+
|
170 |
+
class CirT(nn.Module):
|
171 |
+
def __init__(
|
172 |
+
self,
|
173 |
+
params,
|
174 |
+
img_size=[360, 720],
|
175 |
+
input_size=101,
|
176 |
+
output_size=93,
|
177 |
+
patch_size=124, #124
|
178 |
+
embed_dim=256,
|
179 |
+
depth=8,
|
180 |
+
decoder_depth=2,
|
181 |
+
num_heads=16,
|
182 |
+
mlp_ratio=4.0,
|
183 |
+
drop_path=0.1,
|
184 |
+
drop_rate=0.1
|
185 |
+
):
|
186 |
+
super().__init__()
|
187 |
+
|
188 |
+
# TODO: remove time_history parameter
|
189 |
+
self.img_size = img_size
|
190 |
+
self.patch_size = img_size[1]
|
191 |
+
self.input_size = input_size
|
192 |
+
self.output_size = output_size
|
193 |
+
self.token_embeds = PatchEmbed(img_size, input_size, embed_dim)
|
194 |
+
# self.token_embeds = nn.Linear(img_size[0] * 2, embed_dim)
|
195 |
+
self.num_patches = self.token_embeds.num_patches
|
196 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim), requires_grad=True)
|
197 |
+
# self.pos_embed = PosEmbed(embed_dim=embed_dim)
|
198 |
+
|
199 |
+
|
200 |
+
# --------------------------------------------------------------------------
|
201 |
+
|
202 |
+
# ViT backbone
|
203 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
204 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
|
205 |
+
self.blocks = nn.ModuleList(
|
206 |
+
[
|
207 |
+
Block(
|
208 |
+
embed_dim,
|
209 |
+
num_heads,
|
210 |
+
mlp_ratio,
|
211 |
+
qkv_bias=True,
|
212 |
+
drop_path=dpr[i],
|
213 |
+
norm_layer=nn.LayerNorm,
|
214 |
+
# drop=drop_rate,
|
215 |
+
)
|
216 |
+
for i in range(depth)
|
217 |
+
]
|
218 |
+
)
|
219 |
+
self.norm = nn.LayerNorm(embed_dim)
|
220 |
+
|
221 |
+
# --------------------------------------------------------------------------
|
222 |
+
|
223 |
+
# prediction head
|
224 |
+
self.head = nn.ModuleList()
|
225 |
+
for _ in range(decoder_depth):
|
226 |
+
self.head.append(nn.Linear(embed_dim, embed_dim))
|
227 |
+
self.head.append(nn.GELU())
|
228 |
+
self.head.append(nn.Linear(embed_dim, output_size * self.img_size[1]))
|
229 |
+
self.head = nn.Sequential(*self.head)
|
230 |
+
|
231 |
+
# --------------------------------------------------------------------------
|
232 |
+
|
233 |
+
self.initialize_weights()
|
234 |
+
|
235 |
+
def initialize_weights(self):
|
236 |
+
# token embedding layer
|
237 |
+
w = self.token_embeds.proj.weight.data
|
238 |
+
trunc_normal_(w.view([w.shape[0], -1]), std=0.02)
|
239 |
+
|
240 |
+
# initialize nn.Linear and nn.LayerNorm
|
241 |
+
self.apply(self._init_weights)
|
242 |
+
|
243 |
+
def _init_weights(self, m):
|
244 |
+
if isinstance(m, nn.Linear):
|
245 |
+
trunc_normal_(m.weight, std=0.02)
|
246 |
+
if m.bias is not None:
|
247 |
+
nn.init.constant_(m.bias, 0)
|
248 |
+
elif isinstance(m, nn.LayerNorm):
|
249 |
+
nn.init.constant_(m.bias, 0)
|
250 |
+
nn.init.constant_(m.weight, 1.0)
|
251 |
+
|
252 |
+
|
253 |
+
def unpatchify(self, x: torch.Tensor, h=None, w=None):
|
254 |
+
"""
|
255 |
+
x: (B, L, V * patch_size)
|
256 |
+
return imgs: (B, V, H, W)
|
257 |
+
"""
|
258 |
+
p = self.patch_size
|
259 |
+
c_out = self.output_size
|
260 |
+
h = self.img_size[0] // 1
|
261 |
+
w = self.img_size[1] // p
|
262 |
+
assert h * w == x.shape[1]
|
263 |
+
|
264 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, c_out))
|
265 |
+
x = torch.einsum("nhwpc->nchpw", x)
|
266 |
+
imgs = x.reshape(shape=(x.shape[0], c_out, h, w * p))
|
267 |
+
return imgs
|
268 |
+
|
269 |
+
def forward_encoder(self, x: torch.Tensor):
|
270 |
+
# x: `[B, V, H, W]` shape.
|
271 |
+
|
272 |
+
# tokenize each variable separately
|
273 |
+
# x = torch.fft.rfft(x, norm="forward")
|
274 |
+
# x = torch.view_as_real(x)
|
275 |
+
# x = torch.cat((x[:, :, :, :, 0], -x[:, :, :, :, 1]), dim=-1)
|
276 |
+
|
277 |
+
x = self.token_embeds(x)
|
278 |
+
|
279 |
+
# pos_embed = self.pos_embed()
|
280 |
+
# add pos embedding
|
281 |
+
x = x + self.pos_embed
|
282 |
+
x = self.pos_drop(x)
|
283 |
+
|
284 |
+
# apply Transformer blocks
|
285 |
+
for blk in self.blocks:
|
286 |
+
x = blk(x)
|
287 |
+
x = self.norm(x)
|
288 |
+
|
289 |
+
return x
|
290 |
+
|
291 |
+
def forward(self, x):
|
292 |
+
B, V, H, W = x.shape
|
293 |
+
# print(x.shape)
|
294 |
+
out_transformers = self.forward_encoder(x) # B, L, D
|
295 |
+
preds = self.head(out_transformers) # B, L, V*p*p
|
296 |
+
preds = self.unpatchify(preds)
|
297 |
+
|
298 |
+
# real, img = torch.split(preds, preds.shape[-1] // 2, dim=-1)
|
299 |
+
# preds = torch.cat([real, -img], dim=-1)
|
300 |
+
# preds = torch.fft.irfft(preds, W, norm="forward")
|
301 |
+
return preds
|
networks/.ipynb_checkpoints/CirT2-checkpoint.py
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import lru_cache
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from timm.models.vision_transformer import trunc_normal_, Block
|
7 |
+
from torch.jit import Final
|
8 |
+
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from typing import Optional
|
11 |
+
from timm.layers import DropPath, use_fused_attn, Mlp
|
12 |
+
|
13 |
+
class PatchEmbed(nn.Module):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
img_size=[121, 240],
|
17 |
+
in_chans=63,
|
18 |
+
embed_dim=768,
|
19 |
+
norm_layer=None,
|
20 |
+
flatten=True,
|
21 |
+
bias=True,
|
22 |
+
):
|
23 |
+
super().__init__()
|
24 |
+
self.img_size = img_size
|
25 |
+
self.num_patches = img_size[0]
|
26 |
+
self.flatten = flatten
|
27 |
+
|
28 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=[1, img_size[1]], stride=1, bias=bias)
|
29 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
B, C, H, W = x.shape
|
33 |
+
x = self.proj(x)
|
34 |
+
if self.flatten:
|
35 |
+
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
36 |
+
x = self.norm(x)
|
37 |
+
return x
|
38 |
+
|
39 |
+
class Attention(nn.Module):
|
40 |
+
fused_attn: Final[bool]
|
41 |
+
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
dim: int,
|
45 |
+
num_heads: int = 8,
|
46 |
+
qkv_bias: bool = False,
|
47 |
+
qk_norm: bool = False,
|
48 |
+
attn_drop: float = 0.,
|
49 |
+
proj_drop: float = 0.,
|
50 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
51 |
+
) -> None:
|
52 |
+
super().__init__()
|
53 |
+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
54 |
+
self.num_heads = num_heads
|
55 |
+
self.head_dim = dim // self.num_heads
|
56 |
+
self.scale = self.head_dim ** -0.5
|
57 |
+
self.fused_attn = use_fused_attn()
|
58 |
+
self.dim = dim
|
59 |
+
self.attn_bias = nn.Parameter(torch.zeros(121, 121, 2), requires_grad=True)
|
60 |
+
|
61 |
+
# self.qkv = CLinear(dim, dim * 3, bias=qkv_bias)
|
62 |
+
self.qkv = nn.Linear((dim // 2 + 1) * 2, dim * 3, bias=qkv_bias)
|
63 |
+
self.q = nn.Linear((dim // 2 + 1) * 2, dim, bias=qkv_bias)
|
64 |
+
self.k = nn.Linear((dim // 2 + 1) * 2, dim, bias=qkv_bias)
|
65 |
+
self.v = nn.Linear((dim // 2 + 1) * 2, dim, bias=qkv_bias)
|
66 |
+
|
67 |
+
|
68 |
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
69 |
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
70 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
71 |
+
# self.proj = CLinear(dim, dim)
|
72 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
73 |
+
self.proj = nn.Linear(dim, dim)
|
74 |
+
# self.proj_drop = nn.Dropout(proj_drop)
|
75 |
+
|
76 |
+
|
77 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
78 |
+
B, N, C = x.shape
|
79 |
+
|
80 |
+
x = torch.fft.rfft(x, norm="forward")
|
81 |
+
x = torch.view_as_real(x)
|
82 |
+
x = torch.cat((x[:, :, :, 0], -x[:, :, :, 1]), dim=-1)
|
83 |
+
|
84 |
+
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
85 |
+
# q, k, v = qkv.unbind(0)
|
86 |
+
q = self.q(x).reshape(B, self.num_heads, N, self.head_dim)
|
87 |
+
k = self.k(x).reshape(B, self.num_heads, N, self.head_dim)
|
88 |
+
v = self.v(x).reshape(B, self.num_heads, N, self.head_dim)
|
89 |
+
q = q * self.scale
|
90 |
+
attn = q @ k.transpose(-2, -1)
|
91 |
+
|
92 |
+
|
93 |
+
attn = attn.softmax(dim=-1)
|
94 |
+
attn = self.attn_drop(attn)
|
95 |
+
x = attn @ v
|
96 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
97 |
+
x = self.proj(x)
|
98 |
+
x = self.proj_drop(x)
|
99 |
+
|
100 |
+
real, img = torch.split(x, x.shape[-1] // 2, dim=-1)
|
101 |
+
x = torch.stack([real,-img], dim=-1)
|
102 |
+
x = torch.view_as_complex(x)
|
103 |
+
x = torch.fft.irfft(x, self.dim, norm="forward")
|
104 |
+
|
105 |
+
return x
|
106 |
+
|
107 |
+
|
108 |
+
class LayerScale(nn.Module):
|
109 |
+
def __init__(
|
110 |
+
self,
|
111 |
+
dim: int,
|
112 |
+
init_values: float = 1e-5,
|
113 |
+
inplace: bool = False,
|
114 |
+
) -> None:
|
115 |
+
super().__init__()
|
116 |
+
self.inplace = inplace
|
117 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
118 |
+
|
119 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
120 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
121 |
+
|
122 |
+
|
123 |
+
|
124 |
+
class Block(nn.Module):
|
125 |
+
def __init__(
|
126 |
+
self,
|
127 |
+
dim: int,
|
128 |
+
num_heads: int,
|
129 |
+
mlp_ratio: float = 4.,
|
130 |
+
qkv_bias: bool = False,
|
131 |
+
qk_norm: bool = False,
|
132 |
+
proj_drop: float = 0.,
|
133 |
+
attn_drop: float = 0.,
|
134 |
+
init_values: Optional[float] = None,
|
135 |
+
drop_path: float = 0.,
|
136 |
+
act_layer: nn.Module = nn.GELU,
|
137 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
138 |
+
mlp_layer: nn.Module = Mlp,
|
139 |
+
) -> None:
|
140 |
+
super().__init__()
|
141 |
+
self.norm1 = norm_layer(dim)
|
142 |
+
self.attn = Attention(
|
143 |
+
dim,
|
144 |
+
num_heads=num_heads,
|
145 |
+
qkv_bias=qkv_bias,
|
146 |
+
qk_norm=qk_norm,
|
147 |
+
attn_drop=attn_drop,
|
148 |
+
proj_drop=proj_drop,
|
149 |
+
norm_layer=norm_layer,
|
150 |
+
)
|
151 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
152 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
153 |
+
|
154 |
+
self.norm2 = norm_layer(dim)
|
155 |
+
self.mlp = mlp_layer(
|
156 |
+
in_features=dim,
|
157 |
+
hidden_features=int(dim * mlp_ratio),
|
158 |
+
act_layer=act_layer,
|
159 |
+
drop=proj_drop,
|
160 |
+
)
|
161 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
162 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
163 |
+
|
164 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
165 |
+
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
|
166 |
+
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
|
167 |
+
return x
|
168 |
+
|
169 |
+
|
170 |
+
class CirT_stage2(nn.Module):
|
171 |
+
def __init__(
|
172 |
+
self,
|
173 |
+
params,
|
174 |
+
img_size=[360, 720],
|
175 |
+
input_size=93,
|
176 |
+
output_size=93,
|
177 |
+
patch_size=124, #124
|
178 |
+
embed_dim=256,
|
179 |
+
depth=8,
|
180 |
+
decoder_depth=2,
|
181 |
+
num_heads=16,
|
182 |
+
mlp_ratio=4.0,
|
183 |
+
drop_path=0.1,
|
184 |
+
drop_rate=0.1
|
185 |
+
):
|
186 |
+
super().__init__()
|
187 |
+
|
188 |
+
# TODO: remove time_history parameter
|
189 |
+
self.img_size = img_size
|
190 |
+
self.patch_size = img_size[1]
|
191 |
+
self.input_size = input_size
|
192 |
+
self.output_size = output_size
|
193 |
+
self.token_embeds = PatchEmbed(img_size, input_size, embed_dim)
|
194 |
+
# self.token_embeds = nn.Linear(img_size[0] * 2, embed_dim)
|
195 |
+
self.num_patches = self.token_embeds.num_patches
|
196 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim), requires_grad=True)
|
197 |
+
# self.pos_embed = PosEmbed(embed_dim=embed_dim)
|
198 |
+
|
199 |
+
|
200 |
+
# --------------------------------------------------------------------------
|
201 |
+
|
202 |
+
# ViT backbone
|
203 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
204 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
|
205 |
+
self.blocks = nn.ModuleList(
|
206 |
+
[
|
207 |
+
Block(
|
208 |
+
embed_dim,
|
209 |
+
num_heads,
|
210 |
+
mlp_ratio,
|
211 |
+
qkv_bias=True,
|
212 |
+
drop_path=dpr[i],
|
213 |
+
norm_layer=nn.LayerNorm,
|
214 |
+
# drop=drop_rate,
|
215 |
+
)
|
216 |
+
for i in range(depth)
|
217 |
+
]
|
218 |
+
)
|
219 |
+
self.norm = nn.LayerNorm(embed_dim)
|
220 |
+
|
221 |
+
# --------------------------------------------------------------------------
|
222 |
+
|
223 |
+
# prediction head
|
224 |
+
self.head = nn.ModuleList()
|
225 |
+
for _ in range(decoder_depth):
|
226 |
+
self.head.append(nn.Linear(embed_dim, embed_dim))
|
227 |
+
self.head.append(nn.GELU())
|
228 |
+
self.head.append(nn.Linear(embed_dim, output_size * self.img_size[1]))
|
229 |
+
self.head = nn.Sequential(*self.head)
|
230 |
+
|
231 |
+
# --------------------------------------------------------------------------
|
232 |
+
|
233 |
+
self.initialize_weights()
|
234 |
+
|
235 |
+
def initialize_weights(self):
|
236 |
+
# token embedding layer
|
237 |
+
w = self.token_embeds.proj.weight.data
|
238 |
+
trunc_normal_(w.view([w.shape[0], -1]), std=0.02)
|
239 |
+
|
240 |
+
# initialize nn.Linear and nn.LayerNorm
|
241 |
+
self.apply(self._init_weights)
|
242 |
+
|
243 |
+
def _init_weights(self, m):
|
244 |
+
if isinstance(m, nn.Linear):
|
245 |
+
trunc_normal_(m.weight, std=0.02)
|
246 |
+
if m.bias is not None:
|
247 |
+
nn.init.constant_(m.bias, 0)
|
248 |
+
elif isinstance(m, nn.LayerNorm):
|
249 |
+
nn.init.constant_(m.bias, 0)
|
250 |
+
nn.init.constant_(m.weight, 1.0)
|
251 |
+
|
252 |
+
|
253 |
+
def unpatchify(self, x: torch.Tensor, h=None, w=None):
|
254 |
+
"""
|
255 |
+
x: (B, L, V * patch_size)
|
256 |
+
return imgs: (B, V, H, W)
|
257 |
+
"""
|
258 |
+
p = self.patch_size
|
259 |
+
c_out = self.output_size
|
260 |
+
h = self.img_size[0] // 1
|
261 |
+
w = self.img_size[1] // p
|
262 |
+
assert h * w == x.shape[1]
|
263 |
+
|
264 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, c_out))
|
265 |
+
x = torch.einsum("nhwpc->nchpw", x)
|
266 |
+
imgs = x.reshape(shape=(x.shape[0], c_out, h, w * p))
|
267 |
+
return imgs
|
268 |
+
|
269 |
+
def forward_encoder(self, x: torch.Tensor):
|
270 |
+
# x: `[B, V, H, W]` shape.
|
271 |
+
|
272 |
+
# tokenize each variable separately
|
273 |
+
# x = torch.fft.rfft(x, norm="forward")
|
274 |
+
# x = torch.view_as_real(x)
|
275 |
+
# x = torch.cat((x[:, :, :, :, 0], -x[:, :, :, :, 1]), dim=-1)
|
276 |
+
|
277 |
+
x = self.token_embeds(x)
|
278 |
+
|
279 |
+
# pos_embed = self.pos_embed()
|
280 |
+
# add pos embedding
|
281 |
+
x = x + self.pos_embed
|
282 |
+
x = self.pos_drop(x)
|
283 |
+
|
284 |
+
# apply Transformer blocks
|
285 |
+
for blk in self.blocks:
|
286 |
+
x = blk(x)
|
287 |
+
x = self.norm(x)
|
288 |
+
|
289 |
+
return x
|
290 |
+
|
291 |
+
def forward(self, x):
|
292 |
+
B, V, H, W = x.shape
|
293 |
+
# print(x.shape)
|
294 |
+
out_transformers = self.forward_encoder(x) # B, L, D
|
295 |
+
preds = self.head(out_transformers) # B, L, V*p*p
|
296 |
+
preds = self.unpatchify(preds)
|
297 |
+
|
298 |
+
# real, img = torch.split(preds, preds.shape[-1] // 2, dim=-1)
|
299 |
+
# preds = torch.cat([real, -img], dim=-1)
|
300 |
+
# preds = torch.fft.irfft(preds, W, norm="forward")
|
301 |
+
return preds
|
networks/CirT1.py
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import lru_cache
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from timm.models.vision_transformer import trunc_normal_, Block
|
7 |
+
from torch.jit import Final
|
8 |
+
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from typing import Optional
|
11 |
+
from timm.layers import DropPath, use_fused_attn, Mlp
|
12 |
+
|
13 |
+
class PatchEmbed(nn.Module):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
img_size=[121, 240],
|
17 |
+
in_chans=63,
|
18 |
+
embed_dim=768,
|
19 |
+
norm_layer=None,
|
20 |
+
flatten=True,
|
21 |
+
bias=True,
|
22 |
+
):
|
23 |
+
super().__init__()
|
24 |
+
self.img_size = img_size
|
25 |
+
self.num_patches = img_size[0]
|
26 |
+
self.flatten = flatten
|
27 |
+
|
28 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=[1, img_size[1]], stride=1, bias=bias)
|
29 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
B, C, H, W = x.shape
|
33 |
+
x = self.proj(x)
|
34 |
+
if self.flatten:
|
35 |
+
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
36 |
+
x = self.norm(x)
|
37 |
+
return x
|
38 |
+
|
39 |
+
class Attention(nn.Module):
|
40 |
+
fused_attn: Final[bool]
|
41 |
+
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
dim: int,
|
45 |
+
num_heads: int = 8,
|
46 |
+
qkv_bias: bool = False,
|
47 |
+
qk_norm: bool = False,
|
48 |
+
attn_drop: float = 0.,
|
49 |
+
proj_drop: float = 0.,
|
50 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
51 |
+
) -> None:
|
52 |
+
super().__init__()
|
53 |
+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
54 |
+
self.num_heads = num_heads
|
55 |
+
self.head_dim = dim // self.num_heads
|
56 |
+
self.scale = self.head_dim ** -0.5
|
57 |
+
self.fused_attn = use_fused_attn()
|
58 |
+
self.dim = dim
|
59 |
+
self.attn_bias = nn.Parameter(torch.zeros(121, 121, 2), requires_grad=True)
|
60 |
+
|
61 |
+
# self.qkv = CLinear(dim, dim * 3, bias=qkv_bias)
|
62 |
+
self.qkv = nn.Linear((dim // 2 + 1) * 2, dim * 3, bias=qkv_bias)
|
63 |
+
self.q = nn.Linear((dim // 2 + 1) * 2, dim, bias=qkv_bias)
|
64 |
+
self.k = nn.Linear((dim // 2 + 1) * 2, dim, bias=qkv_bias)
|
65 |
+
self.v = nn.Linear((dim // 2 + 1) * 2, dim, bias=qkv_bias)
|
66 |
+
|
67 |
+
|
68 |
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
69 |
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
70 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
71 |
+
# self.proj = CLinear(dim, dim)
|
72 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
73 |
+
self.proj = nn.Linear(dim, dim)
|
74 |
+
# self.proj_drop = nn.Dropout(proj_drop)
|
75 |
+
|
76 |
+
|
77 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
78 |
+
B, N, C = x.shape
|
79 |
+
|
80 |
+
x = torch.fft.rfft(x, norm="forward")
|
81 |
+
x = torch.view_as_real(x)
|
82 |
+
x = torch.cat((x[:, :, :, 0], -x[:, :, :, 1]), dim=-1)
|
83 |
+
|
84 |
+
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
85 |
+
# q, k, v = qkv.unbind(0)
|
86 |
+
q = self.q(x).reshape(B, self.num_heads, N, self.head_dim)
|
87 |
+
k = self.k(x).reshape(B, self.num_heads, N, self.head_dim)
|
88 |
+
v = self.v(x).reshape(B, self.num_heads, N, self.head_dim)
|
89 |
+
q = q * self.scale
|
90 |
+
attn = q @ k.transpose(-2, -1)
|
91 |
+
|
92 |
+
|
93 |
+
attn = attn.softmax(dim=-1)
|
94 |
+
attn = self.attn_drop(attn)
|
95 |
+
x = attn @ v
|
96 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
97 |
+
x = self.proj(x)
|
98 |
+
x = self.proj_drop(x)
|
99 |
+
|
100 |
+
real, img = torch.split(x, x.shape[-1] // 2, dim=-1)
|
101 |
+
x = torch.stack([real,-img], dim=-1)
|
102 |
+
x = torch.view_as_complex(x)
|
103 |
+
x = torch.fft.irfft(x, self.dim, norm="forward")
|
104 |
+
|
105 |
+
return x
|
106 |
+
|
107 |
+
|
108 |
+
class LayerScale(nn.Module):
|
109 |
+
def __init__(
|
110 |
+
self,
|
111 |
+
dim: int,
|
112 |
+
init_values: float = 1e-5,
|
113 |
+
inplace: bool = False,
|
114 |
+
) -> None:
|
115 |
+
super().__init__()
|
116 |
+
self.inplace = inplace
|
117 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
118 |
+
|
119 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
120 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
121 |
+
|
122 |
+
|
123 |
+
|
124 |
+
class Block(nn.Module):
|
125 |
+
def __init__(
|
126 |
+
self,
|
127 |
+
dim: int,
|
128 |
+
num_heads: int,
|
129 |
+
mlp_ratio: float = 4.,
|
130 |
+
qkv_bias: bool = False,
|
131 |
+
qk_norm: bool = False,
|
132 |
+
proj_drop: float = 0.,
|
133 |
+
attn_drop: float = 0.,
|
134 |
+
init_values: Optional[float] = None,
|
135 |
+
drop_path: float = 0.,
|
136 |
+
act_layer: nn.Module = nn.GELU,
|
137 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
138 |
+
mlp_layer: nn.Module = Mlp,
|
139 |
+
) -> None:
|
140 |
+
super().__init__()
|
141 |
+
self.norm1 = norm_layer(dim)
|
142 |
+
self.attn = Attention(
|
143 |
+
dim,
|
144 |
+
num_heads=num_heads,
|
145 |
+
qkv_bias=qkv_bias,
|
146 |
+
qk_norm=qk_norm,
|
147 |
+
attn_drop=attn_drop,
|
148 |
+
proj_drop=proj_drop,
|
149 |
+
norm_layer=norm_layer,
|
150 |
+
)
|
151 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
152 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
153 |
+
|
154 |
+
self.norm2 = norm_layer(dim)
|
155 |
+
self.mlp = mlp_layer(
|
156 |
+
in_features=dim,
|
157 |
+
hidden_features=int(dim * mlp_ratio),
|
158 |
+
act_layer=act_layer,
|
159 |
+
drop=proj_drop,
|
160 |
+
)
|
161 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
162 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
163 |
+
|
164 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
165 |
+
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
|
166 |
+
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
|
167 |
+
return x
|
168 |
+
|
169 |
+
|
170 |
+
class CirT(nn.Module):
|
171 |
+
def __init__(
|
172 |
+
self,
|
173 |
+
params,
|
174 |
+
img_size=[360, 720],
|
175 |
+
input_size=101,
|
176 |
+
output_size=93,
|
177 |
+
patch_size=124, #124
|
178 |
+
embed_dim=256,
|
179 |
+
depth=8,
|
180 |
+
decoder_depth=2,
|
181 |
+
num_heads=16,
|
182 |
+
mlp_ratio=4.0,
|
183 |
+
drop_path=0.1,
|
184 |
+
drop_rate=0.1
|
185 |
+
):
|
186 |
+
super().__init__()
|
187 |
+
|
188 |
+
# TODO: remove time_history parameter
|
189 |
+
self.img_size = img_size
|
190 |
+
self.patch_size = img_size[1]
|
191 |
+
self.input_size = input_size
|
192 |
+
self.output_size = output_size
|
193 |
+
self.token_embeds = PatchEmbed(img_size, input_size, embed_dim)
|
194 |
+
# self.token_embeds = nn.Linear(img_size[0] * 2, embed_dim)
|
195 |
+
self.num_patches = self.token_embeds.num_patches
|
196 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim), requires_grad=True)
|
197 |
+
# self.pos_embed = PosEmbed(embed_dim=embed_dim)
|
198 |
+
|
199 |
+
|
200 |
+
# --------------------------------------------------------------------------
|
201 |
+
|
202 |
+
# ViT backbone
|
203 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
204 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
|
205 |
+
self.blocks = nn.ModuleList(
|
206 |
+
[
|
207 |
+
Block(
|
208 |
+
embed_dim,
|
209 |
+
num_heads,
|
210 |
+
mlp_ratio,
|
211 |
+
qkv_bias=True,
|
212 |
+
drop_path=dpr[i],
|
213 |
+
norm_layer=nn.LayerNorm,
|
214 |
+
# drop=drop_rate,
|
215 |
+
)
|
216 |
+
for i in range(depth)
|
217 |
+
]
|
218 |
+
)
|
219 |
+
self.norm = nn.LayerNorm(embed_dim)
|
220 |
+
|
221 |
+
# --------------------------------------------------------------------------
|
222 |
+
|
223 |
+
# prediction head
|
224 |
+
self.head = nn.ModuleList()
|
225 |
+
for _ in range(decoder_depth):
|
226 |
+
self.head.append(nn.Linear(embed_dim, embed_dim))
|
227 |
+
self.head.append(nn.GELU())
|
228 |
+
self.head.append(nn.Linear(embed_dim, output_size * self.img_size[1]))
|
229 |
+
self.head = nn.Sequential(*self.head)
|
230 |
+
|
231 |
+
# --------------------------------------------------------------------------
|
232 |
+
|
233 |
+
self.initialize_weights()
|
234 |
+
|
235 |
+
def initialize_weights(self):
|
236 |
+
# token embedding layer
|
237 |
+
w = self.token_embeds.proj.weight.data
|
238 |
+
trunc_normal_(w.view([w.shape[0], -1]), std=0.02)
|
239 |
+
|
240 |
+
# initialize nn.Linear and nn.LayerNorm
|
241 |
+
self.apply(self._init_weights)
|
242 |
+
|
243 |
+
def _init_weights(self, m):
|
244 |
+
if isinstance(m, nn.Linear):
|
245 |
+
trunc_normal_(m.weight, std=0.02)
|
246 |
+
if m.bias is not None:
|
247 |
+
nn.init.constant_(m.bias, 0)
|
248 |
+
elif isinstance(m, nn.LayerNorm):
|
249 |
+
nn.init.constant_(m.bias, 0)
|
250 |
+
nn.init.constant_(m.weight, 1.0)
|
251 |
+
|
252 |
+
|
253 |
+
def unpatchify(self, x: torch.Tensor, h=None, w=None):
|
254 |
+
"""
|
255 |
+
x: (B, L, V * patch_size)
|
256 |
+
return imgs: (B, V, H, W)
|
257 |
+
"""
|
258 |
+
p = self.patch_size
|
259 |
+
c_out = self.output_size
|
260 |
+
h = self.img_size[0] // 1
|
261 |
+
w = self.img_size[1] // p
|
262 |
+
assert h * w == x.shape[1]
|
263 |
+
|
264 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, c_out))
|
265 |
+
x = torch.einsum("nhwpc->nchpw", x)
|
266 |
+
imgs = x.reshape(shape=(x.shape[0], c_out, h, w * p))
|
267 |
+
return imgs
|
268 |
+
|
269 |
+
def forward_encoder(self, x: torch.Tensor):
|
270 |
+
# x: `[B, V, H, W]` shape.
|
271 |
+
|
272 |
+
# tokenize each variable separately
|
273 |
+
# x = torch.fft.rfft(x, norm="forward")
|
274 |
+
# x = torch.view_as_real(x)
|
275 |
+
# x = torch.cat((x[:, :, :, :, 0], -x[:, :, :, :, 1]), dim=-1)
|
276 |
+
|
277 |
+
x = self.token_embeds(x)
|
278 |
+
|
279 |
+
# pos_embed = self.pos_embed()
|
280 |
+
# add pos embedding
|
281 |
+
x = x + self.pos_embed
|
282 |
+
x = self.pos_drop(x)
|
283 |
+
|
284 |
+
# apply Transformer blocks
|
285 |
+
for blk in self.blocks:
|
286 |
+
x = blk(x)
|
287 |
+
x = self.norm(x)
|
288 |
+
|
289 |
+
return x
|
290 |
+
|
291 |
+
def forward(self, x):
|
292 |
+
B, V, H, W = x.shape
|
293 |
+
# print(x.shape)
|
294 |
+
out_transformers = self.forward_encoder(x) # B, L, D
|
295 |
+
preds = self.head(out_transformers) # B, L, V*p*p
|
296 |
+
preds = self.unpatchify(preds)
|
297 |
+
|
298 |
+
# real, img = torch.split(preds, preds.shape[-1] // 2, dim=-1)
|
299 |
+
# preds = torch.cat([real, -img], dim=-1)
|
300 |
+
# preds = torch.fft.irfft(preds, W, norm="forward")
|
301 |
+
return preds
|
networks/CirT2.py
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import lru_cache
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from timm.models.vision_transformer import trunc_normal_, Block
|
7 |
+
from torch.jit import Final
|
8 |
+
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from typing import Optional
|
11 |
+
from timm.layers import DropPath, use_fused_attn, Mlp
|
12 |
+
|
13 |
+
class PatchEmbed(nn.Module):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
img_size=[121, 240],
|
17 |
+
in_chans=63,
|
18 |
+
embed_dim=768,
|
19 |
+
norm_layer=None,
|
20 |
+
flatten=True,
|
21 |
+
bias=True,
|
22 |
+
):
|
23 |
+
super().__init__()
|
24 |
+
self.img_size = img_size
|
25 |
+
self.num_patches = img_size[0]
|
26 |
+
self.flatten = flatten
|
27 |
+
|
28 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=[1, img_size[1]], stride=1, bias=bias)
|
29 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
B, C, H, W = x.shape
|
33 |
+
x = self.proj(x)
|
34 |
+
if self.flatten:
|
35 |
+
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
36 |
+
x = self.norm(x)
|
37 |
+
return x
|
38 |
+
|
39 |
+
class Attention(nn.Module):
|
40 |
+
fused_attn: Final[bool]
|
41 |
+
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
dim: int,
|
45 |
+
num_heads: int = 8,
|
46 |
+
qkv_bias: bool = False,
|
47 |
+
qk_norm: bool = False,
|
48 |
+
attn_drop: float = 0.,
|
49 |
+
proj_drop: float = 0.,
|
50 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
51 |
+
) -> None:
|
52 |
+
super().__init__()
|
53 |
+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
54 |
+
self.num_heads = num_heads
|
55 |
+
self.head_dim = dim // self.num_heads
|
56 |
+
self.scale = self.head_dim ** -0.5
|
57 |
+
self.fused_attn = use_fused_attn()
|
58 |
+
self.dim = dim
|
59 |
+
self.attn_bias = nn.Parameter(torch.zeros(121, 121, 2), requires_grad=True)
|
60 |
+
|
61 |
+
# self.qkv = CLinear(dim, dim * 3, bias=qkv_bias)
|
62 |
+
self.qkv = nn.Linear((dim // 2 + 1) * 2, dim * 3, bias=qkv_bias)
|
63 |
+
self.q = nn.Linear((dim // 2 + 1) * 2, dim, bias=qkv_bias)
|
64 |
+
self.k = nn.Linear((dim // 2 + 1) * 2, dim, bias=qkv_bias)
|
65 |
+
self.v = nn.Linear((dim // 2 + 1) * 2, dim, bias=qkv_bias)
|
66 |
+
|
67 |
+
|
68 |
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
69 |
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
70 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
71 |
+
# self.proj = CLinear(dim, dim)
|
72 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
73 |
+
self.proj = nn.Linear(dim, dim)
|
74 |
+
# self.proj_drop = nn.Dropout(proj_drop)
|
75 |
+
|
76 |
+
|
77 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
78 |
+
B, N, C = x.shape
|
79 |
+
|
80 |
+
x = torch.fft.rfft(x, norm="forward")
|
81 |
+
x = torch.view_as_real(x)
|
82 |
+
x = torch.cat((x[:, :, :, 0], -x[:, :, :, 1]), dim=-1)
|
83 |
+
|
84 |
+
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
85 |
+
# q, k, v = qkv.unbind(0)
|
86 |
+
q = self.q(x).reshape(B, self.num_heads, N, self.head_dim)
|
87 |
+
k = self.k(x).reshape(B, self.num_heads, N, self.head_dim)
|
88 |
+
v = self.v(x).reshape(B, self.num_heads, N, self.head_dim)
|
89 |
+
q = q * self.scale
|
90 |
+
attn = q @ k.transpose(-2, -1)
|
91 |
+
|
92 |
+
|
93 |
+
attn = attn.softmax(dim=-1)
|
94 |
+
attn = self.attn_drop(attn)
|
95 |
+
x = attn @ v
|
96 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
97 |
+
x = self.proj(x)
|
98 |
+
x = self.proj_drop(x)
|
99 |
+
|
100 |
+
real, img = torch.split(x, x.shape[-1] // 2, dim=-1)
|
101 |
+
x = torch.stack([real,-img], dim=-1)
|
102 |
+
x = torch.view_as_complex(x)
|
103 |
+
x = torch.fft.irfft(x, self.dim, norm="forward")
|
104 |
+
|
105 |
+
return x
|
106 |
+
|
107 |
+
|
108 |
+
class LayerScale(nn.Module):
|
109 |
+
def __init__(
|
110 |
+
self,
|
111 |
+
dim: int,
|
112 |
+
init_values: float = 1e-5,
|
113 |
+
inplace: bool = False,
|
114 |
+
) -> None:
|
115 |
+
super().__init__()
|
116 |
+
self.inplace = inplace
|
117 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
118 |
+
|
119 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
120 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
121 |
+
|
122 |
+
|
123 |
+
|
124 |
+
class Block(nn.Module):
|
125 |
+
def __init__(
|
126 |
+
self,
|
127 |
+
dim: int,
|
128 |
+
num_heads: int,
|
129 |
+
mlp_ratio: float = 4.,
|
130 |
+
qkv_bias: bool = False,
|
131 |
+
qk_norm: bool = False,
|
132 |
+
proj_drop: float = 0.,
|
133 |
+
attn_drop: float = 0.,
|
134 |
+
init_values: Optional[float] = None,
|
135 |
+
drop_path: float = 0.,
|
136 |
+
act_layer: nn.Module = nn.GELU,
|
137 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
138 |
+
mlp_layer: nn.Module = Mlp,
|
139 |
+
) -> None:
|
140 |
+
super().__init__()
|
141 |
+
self.norm1 = norm_layer(dim)
|
142 |
+
self.attn = Attention(
|
143 |
+
dim,
|
144 |
+
num_heads=num_heads,
|
145 |
+
qkv_bias=qkv_bias,
|
146 |
+
qk_norm=qk_norm,
|
147 |
+
attn_drop=attn_drop,
|
148 |
+
proj_drop=proj_drop,
|
149 |
+
norm_layer=norm_layer,
|
150 |
+
)
|
151 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
152 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
153 |
+
|
154 |
+
self.norm2 = norm_layer(dim)
|
155 |
+
self.mlp = mlp_layer(
|
156 |
+
in_features=dim,
|
157 |
+
hidden_features=int(dim * mlp_ratio),
|
158 |
+
act_layer=act_layer,
|
159 |
+
drop=proj_drop,
|
160 |
+
)
|
161 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
162 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
163 |
+
|
164 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
165 |
+
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
|
166 |
+
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
|
167 |
+
return x
|
168 |
+
|
169 |
+
|
170 |
+
class CirT_stage2(nn.Module):
|
171 |
+
def __init__(
|
172 |
+
self,
|
173 |
+
params,
|
174 |
+
img_size=[360, 720],
|
175 |
+
input_size=93,
|
176 |
+
output_size=93,
|
177 |
+
patch_size=124, #124
|
178 |
+
embed_dim=256,
|
179 |
+
depth=8,
|
180 |
+
decoder_depth=2,
|
181 |
+
num_heads=16,
|
182 |
+
mlp_ratio=4.0,
|
183 |
+
drop_path=0.1,
|
184 |
+
drop_rate=0.1
|
185 |
+
):
|
186 |
+
super().__init__()
|
187 |
+
|
188 |
+
# TODO: remove time_history parameter
|
189 |
+
self.img_size = img_size
|
190 |
+
self.patch_size = img_size[1]
|
191 |
+
self.input_size = input_size
|
192 |
+
self.output_size = output_size
|
193 |
+
self.token_embeds = PatchEmbed(img_size, input_size, embed_dim)
|
194 |
+
# self.token_embeds = nn.Linear(img_size[0] * 2, embed_dim)
|
195 |
+
self.num_patches = self.token_embeds.num_patches
|
196 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim), requires_grad=True)
|
197 |
+
# self.pos_embed = PosEmbed(embed_dim=embed_dim)
|
198 |
+
|
199 |
+
|
200 |
+
# --------------------------------------------------------------------------
|
201 |
+
|
202 |
+
# ViT backbone
|
203 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
204 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
|
205 |
+
self.blocks = nn.ModuleList(
|
206 |
+
[
|
207 |
+
Block(
|
208 |
+
embed_dim,
|
209 |
+
num_heads,
|
210 |
+
mlp_ratio,
|
211 |
+
qkv_bias=True,
|
212 |
+
drop_path=dpr[i],
|
213 |
+
norm_layer=nn.LayerNorm,
|
214 |
+
# drop=drop_rate,
|
215 |
+
)
|
216 |
+
for i in range(depth)
|
217 |
+
]
|
218 |
+
)
|
219 |
+
self.norm = nn.LayerNorm(embed_dim)
|
220 |
+
|
221 |
+
# --------------------------------------------------------------------------
|
222 |
+
|
223 |
+
# prediction head
|
224 |
+
self.head = nn.ModuleList()
|
225 |
+
for _ in range(decoder_depth):
|
226 |
+
self.head.append(nn.Linear(embed_dim, embed_dim))
|
227 |
+
self.head.append(nn.GELU())
|
228 |
+
self.head.append(nn.Linear(embed_dim, output_size * self.img_size[1]))
|
229 |
+
self.head = nn.Sequential(*self.head)
|
230 |
+
|
231 |
+
# --------------------------------------------------------------------------
|
232 |
+
|
233 |
+
self.initialize_weights()
|
234 |
+
|
235 |
+
def initialize_weights(self):
|
236 |
+
# token embedding layer
|
237 |
+
w = self.token_embeds.proj.weight.data
|
238 |
+
trunc_normal_(w.view([w.shape[0], -1]), std=0.02)
|
239 |
+
|
240 |
+
# initialize nn.Linear and nn.LayerNorm
|
241 |
+
self.apply(self._init_weights)
|
242 |
+
|
243 |
+
def _init_weights(self, m):
|
244 |
+
if isinstance(m, nn.Linear):
|
245 |
+
trunc_normal_(m.weight, std=0.02)
|
246 |
+
if m.bias is not None:
|
247 |
+
nn.init.constant_(m.bias, 0)
|
248 |
+
elif isinstance(m, nn.LayerNorm):
|
249 |
+
nn.init.constant_(m.bias, 0)
|
250 |
+
nn.init.constant_(m.weight, 1.0)
|
251 |
+
|
252 |
+
|
253 |
+
def unpatchify(self, x: torch.Tensor, h=None, w=None):
|
254 |
+
"""
|
255 |
+
x: (B, L, V * patch_size)
|
256 |
+
return imgs: (B, V, H, W)
|
257 |
+
"""
|
258 |
+
p = self.patch_size
|
259 |
+
c_out = self.output_size
|
260 |
+
h = self.img_size[0] // 1
|
261 |
+
w = self.img_size[1] // p
|
262 |
+
assert h * w == x.shape[1]
|
263 |
+
|
264 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, c_out))
|
265 |
+
x = torch.einsum("nhwpc->nchpw", x)
|
266 |
+
imgs = x.reshape(shape=(x.shape[0], c_out, h, w * p))
|
267 |
+
return imgs
|
268 |
+
|
269 |
+
def forward_encoder(self, x: torch.Tensor):
|
270 |
+
# x: `[B, V, H, W]` shape.
|
271 |
+
|
272 |
+
# tokenize each variable separately
|
273 |
+
# x = torch.fft.rfft(x, norm="forward")
|
274 |
+
# x = torch.view_as_real(x)
|
275 |
+
# x = torch.cat((x[:, :, :, :, 0], -x[:, :, :, :, 1]), dim=-1)
|
276 |
+
|
277 |
+
x = self.token_embeds(x)
|
278 |
+
|
279 |
+
# pos_embed = self.pos_embed()
|
280 |
+
# add pos embedding
|
281 |
+
x = x + self.pos_embed
|
282 |
+
x = self.pos_drop(x)
|
283 |
+
|
284 |
+
# apply Transformer blocks
|
285 |
+
for blk in self.blocks:
|
286 |
+
x = blk(x)
|
287 |
+
x = self.norm(x)
|
288 |
+
|
289 |
+
return x
|
290 |
+
|
291 |
+
def forward(self, x):
|
292 |
+
B, V, H, W = x.shape
|
293 |
+
# print(x.shape)
|
294 |
+
out_transformers = self.forward_encoder(x) # B, L, D
|
295 |
+
preds = self.head(out_transformers) # B, L, V*p*p
|
296 |
+
preds = self.unpatchify(preds)
|
297 |
+
|
298 |
+
# real, img = torch.split(preds, preds.shape[-1] // 2, dim=-1)
|
299 |
+
# preds = torch.cat([real, -img], dim=-1)
|
300 |
+
# preds = torch.fft.irfft(preds, W, norm="forward")
|
301 |
+
return preds
|