YuanGao-YG commited on
Commit
9117894
·
verified ·
1 Parent(s): 71d2b66

Upload 90 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. environment.yml +248 -0
  2. exp/NeuralOM/20250309-195251/6_steps_finetune/model2/10_steps_finetune/training_checkpoints/.ipynb_checkpoints/readme-checkpoint.txt +1 -0
  3. exp/NeuralOM/20250309-195251/6_steps_finetune/model2/10_steps_finetune/training_checkpoints/best_ckpt.tar +3 -0
  4. exp/NeuralOM/20250309-195251/6_steps_finetune/model2/10_steps_finetune/training_checkpoints/readme.txt +1 -0
  5. exp/NeuralOM/20250309-195251/6_steps_finetune/training_checkpoints/.ipynb_checkpoints/readme-checkpoint.txt +1 -0
  6. exp/NeuralOM/20250309-195251/6_steps_finetune/training_checkpoints/best_ckpt.tar +3 -0
  7. exp/NeuralOM/20250309-195251/6_steps_finetune/training_checkpoints/readme.txt +1 -0
  8. exp/NeuralOM/20250309-195251/6_steps_finetune/training_checkpoints_atmos/best_ckpt.tar +3 -0
  9. exp/NeuralOM/20250309-195251/6_steps_finetune/training_checkpoints_atmos/readme.txt +1 -0
  10. exp/NeuralOM/20250309-195251/config.yaml +78 -0
  11. inference_forecasting.py +366 -0
  12. inference_forecasting.sh +13 -0
  13. inference_simulation.py +312 -0
  14. inference_simulation.sh +13 -0
  15. my_utils/YParams.py +55 -0
  16. my_utils/__pycache__/YParams.cpython-310.pyc +0 -0
  17. my_utils/__pycache__/YParams.cpython-37.pyc +0 -0
  18. my_utils/__pycache__/YParams.cpython-39.pyc +0 -0
  19. my_utils/__pycache__/bicubic.cpython-310.pyc +0 -0
  20. my_utils/__pycache__/bicubic.cpython-39.pyc +0 -0
  21. my_utils/__pycache__/darcy_loss.cpython-310.pyc +0 -0
  22. my_utils/__pycache__/darcy_loss.cpython-310.pyc.70370790180304 +0 -0
  23. my_utils/__pycache__/darcy_loss.cpython-310.pyc.70373230085584 +0 -0
  24. my_utils/__pycache__/darcy_loss.cpython-310.pyc.70384414393808 +0 -0
  25. my_utils/__pycache__/darcy_loss.cpython-37.pyc +0 -0
  26. my_utils/__pycache__/darcy_loss.cpython-39.pyc +0 -0
  27. my_utils/__pycache__/data_loader.cpython-310.pyc +0 -0
  28. my_utils/__pycache__/data_loader_multifiles.cpython-310.pyc +0 -0
  29. my_utils/__pycache__/data_loader_multifiles.cpython-37.pyc +0 -0
  30. my_utils/__pycache__/data_loader_multifiles.cpython-39.pyc +0 -0
  31. my_utils/__pycache__/get_date.cpython-310.pyc +0 -0
  32. my_utils/__pycache__/img_utils.cpython-310.pyc +0 -0
  33. my_utils/__pycache__/img_utils.cpython-37.pyc +0 -0
  34. my_utils/__pycache__/img_utils.cpython-39.pyc +0 -0
  35. my_utils/__pycache__/logging_utils.cpython-310.pyc +0 -0
  36. my_utils/__pycache__/logging_utils.cpython-37.pyc +0 -0
  37. my_utils/__pycache__/logging_utils.cpython-39.pyc +0 -0
  38. my_utils/__pycache__/norm.cpython-310.pyc +0 -0
  39. my_utils/__pycache__/time_utils.cpython-310.pyc +0 -0
  40. my_utils/__pycache__/time_utils.cpython-39.pyc +0 -0
  41. my_utils/__pycache__/weighted_acc_rmse.cpython-310.pyc +0 -0
  42. my_utils/__pycache__/weighted_acc_rmse.cpython-37.pyc +0 -0
  43. my_utils/__pycache__/weighted_acc_rmse.cpython-39.pyc +0 -0
  44. my_utils/data_loader.py +205 -0
  45. my_utils/logging_utils.py +26 -0
  46. my_utils/norm.py +114 -0
  47. networks/.ipynb_checkpoints/CirT1-checkpoint.py +301 -0
  48. networks/.ipynb_checkpoints/CirT2-checkpoint.py +301 -0
  49. networks/CirT1.py +301 -0
  50. networks/CirT2.py +301 -0
environment.yml ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: neuralom
2
+ channels:
3
+ - pytorch
4
+ - dglteam/label/th24_cu118
5
+ - nvidia
6
+ - defaults
7
+ dependencies:
8
+ - _libgcc_mutex=0.1=main
9
+ - _openmp_mutex=5.1=1_gnu
10
+ - blas=1.0=mkl
11
+ - brotli-python=1.0.9=py310h6a678d5_8
12
+ - bzip2=1.0.8=h5eee18b_6
13
+ - ca-certificates=2024.9.24=h06a4308_0
14
+ - certifi=2024.8.30=py310h06a4308_0
15
+ - charset-normalizer=3.3.2=pyhd3eb1b0_0
16
+ - cuda-cudart=11.8.89=0
17
+ - cuda-cupti=11.8.87=0
18
+ - cuda-libraries=11.8.0=0
19
+ - cuda-nvrtc=11.8.89=0
20
+ - cuda-nvtx=11.8.86=0
21
+ - cuda-runtime=11.8.0=0
22
+ - dgl=2.4.0.th24.cu118=py310_0
23
+ - ffmpeg=4.3=hf484d3e_0
24
+ - filelock=3.13.1=py310h06a4308_0
25
+ - freetype=2.12.1=h4a9f257_0
26
+ - gmp=6.2.1=h295c915_3
27
+ - gmpy2=2.1.2=py310heeb90bb_0
28
+ - gnutls=3.6.15=he1e5248_0
29
+ - idna=3.7=py310h06a4308_0
30
+ - intel-openmp=2023.1.0=hdb19cb5_46306
31
+ - jinja2=3.1.4=py310h06a4308_0
32
+ - jpeg=9e=h5eee18b_3
33
+ - lame=3.100=h7b6447c_0
34
+ - lcms2=2.12=h3be6417_0
35
+ - ld_impl_linux-64=2.38=h1181459_1
36
+ - lerc=3.0=h295c915_0
37
+ - libcublas=11.11.3.6=0
38
+ - libcufft=10.9.0.58=0
39
+ - libcufile=1.9.1.3=0
40
+ - libcurand=10.3.5.147=0
41
+ - libcusolver=11.4.1.48=0
42
+ - libcusparse=11.7.5.86=0
43
+ - libdeflate=1.17=h5eee18b_1
44
+ - libffi=3.4.4=h6a678d5_1
45
+ - libgcc-ng=11.2.0=h1234567_1
46
+ - libgfortran-ng=11.2.0=h00389a5_1
47
+ - libgfortran5=11.2.0=h1234567_1
48
+ - libgomp=11.2.0=h1234567_1
49
+ - libiconv=1.16=h5eee18b_3
50
+ - libidn2=2.3.4=h5eee18b_0
51
+ - libjpeg-turbo=2.0.0=h9bf148f_0
52
+ - libnpp=11.8.0.86=0
53
+ - libnvjpeg=11.9.0.86=0
54
+ - libpng=1.6.39=h5eee18b_0
55
+ - libstdcxx-ng=11.2.0=h1234567_1
56
+ - libtasn1=4.19.0=h5eee18b_0
57
+ - libtiff=4.5.1=h6a678d5_0
58
+ - libunistring=0.9.10=h27cfd23_0
59
+ - libuuid=1.41.5=h5eee18b_0
60
+ - libwebp-base=1.3.2=h5eee18b_0
61
+ - llvm-openmp=14.0.6=h9e868ea_0
62
+ - lz4-c=1.9.4=h6a678d5_1
63
+ - markupsafe=2.1.3=py310h5eee18b_0
64
+ - mkl=2023.1.0=h213fc3f_46344
65
+ - mkl-service=2.4.0=py310h5eee18b_1
66
+ - mkl_fft=1.3.10=py310h5eee18b_0
67
+ - mkl_random=1.2.7=py310h1128e8f_0
68
+ - mpc=1.1.0=h10f8cd9_1
69
+ - mpfr=4.0.2=hb69a4c5_1
70
+ - mpmath=1.3.0=py310h06a4308_0
71
+ - ncurses=6.4=h6a678d5_0
72
+ - nettle=3.7.3=hbbd107a_1
73
+ - networkx=3.3=py310h06a4308_0
74
+ - numpy=1.26.4=py310h5f9d8c6_0
75
+ - numpy-base=1.26.4=py310hb5e798b_0
76
+ - openh264=2.1.1=h4ff587b_0
77
+ - openjpeg=2.5.2=he7f1fd0_0
78
+ - openssl=3.0.15=h5eee18b_0
79
+ - pillow=10.4.0=py310h5eee18b_0
80
+ - pip=24.2=py310h06a4308_0
81
+ - psutil=5.9.0=py310h5eee18b_0
82
+ - pybind11-abi=4=hd3eb1b0_1
83
+ - pysocks=1.7.1=py310h06a4308_0
84
+ - python=3.10.15=he870216_1
85
+ - pytorch=2.4.0=py3.10_cuda11.8_cudnn9.1.0_0
86
+ - pytorch-cuda=11.8=h7e8668a_5
87
+ - pytorch-mutex=1.0=cuda
88
+ - pyyaml=6.0.1=py310h5eee18b_0
89
+ - readline=8.2=h5eee18b_0
90
+ - requests=2.32.3=py310h06a4308_0
91
+ - scipy=1.13.1=py310h5f9d8c6_0
92
+ - setuptools=75.1.0=py310h06a4308_0
93
+ - sqlite=3.45.3=h5eee18b_0
94
+ - sympy=1.13.2=py310h06a4308_0
95
+ - tbb=2021.8.0=hdb19cb5_0
96
+ - tk=8.6.14=h39e8969_0
97
+ - torchaudio=2.4.0=py310_cu118
98
+ - torchtriton=3.0.0=py310
99
+ - torchvision=0.19.0=py310_cu118
100
+ - tqdm=4.66.5=py310h2f386ee_0
101
+ - typing_extensions=4.11.0=py310h06a4308_0
102
+ - urllib3=2.2.3=py310h06a4308_0
103
+ - wheel=0.44.0=py310h06a4308_0
104
+ - xz=5.4.6=h5eee18b_1
105
+ - yaml=0.2.5=h7b6447c_0
106
+ - zlib=1.2.13=h5eee18b_1
107
+ - zstd=1.5.5=hc292b87_2
108
+ - pip:
109
+ - aiobotocore==2.15.1
110
+ - aiohappyeyeballs==2.4.3
111
+ - aiohttp==3.10.8
112
+ - aioitertools==0.12.0
113
+ - aiosignal==1.3.1
114
+ - anyio==4.6.0
115
+ - argon2-cffi==23.1.0
116
+ - argon2-cffi-bindings==21.2.0
117
+ - arrow==1.3.0
118
+ - asttokens==2.4.1
119
+ - async-lru==2.0.4
120
+ - async-timeout==4.0.3
121
+ - attrs==24.2.0
122
+ - babel==2.16.0
123
+ - beautifulsoup4==4.12.3
124
+ - bleach==6.1.0
125
+ - blessed==1.20.0
126
+ - botocore==1.35.23
127
+ - cartopy==0.24.1
128
+ - cffi==1.17.1
129
+ - cftime==1.6.4.post1
130
+ - cmocean==4.0.3
131
+ - colorama==0.4.6
132
+ - comm==0.2.2
133
+ - contourpy==1.3.0
134
+ - cycler==0.12.1
135
+ - debugpy==1.8.6
136
+ - decorator==5.1.1
137
+ - defusedxml==0.7.1
138
+ - einops==0.8.0
139
+ - exceptiongroup==1.2.2
140
+ - executing==2.1.0
141
+ - fastjsonschema==2.20.0
142
+ - fonttools==4.54.1
143
+ - fqdn==1.5.1
144
+ - frozenlist==1.4.1
145
+ - fsspec==2024.9.0
146
+ - gpustat==1.1.1
147
+ - h11==0.14.0
148
+ - h5netcdf==1.4.0
149
+ - h5py==3.12.1
150
+ - httpcore==1.0.6
151
+ - httpx==0.27.2
152
+ - huggingface-hub==0.25.1
153
+ - icecream==2.1.3
154
+ - importlib-metadata==8.5.0
155
+ - ipykernel==6.29.5
156
+ - ipython==8.28.0
157
+ - isoduration==20.11.0
158
+ - jedi==0.19.1
159
+ - jmespath==1.0.1
160
+ - joblib==1.4.2
161
+ - json5==0.9.25
162
+ - jsonpointer==3.0.0
163
+ - jsonschema==4.23.0
164
+ - jsonschema-specifications==2024.10.1
165
+ - jupyter-client==8.6.3
166
+ - jupyter-core==5.7.2
167
+ - jupyter-events==0.10.0
168
+ - jupyter-lsp==2.2.5
169
+ - jupyter-server==2.14.2
170
+ - jupyter-server-terminals==0.5.3
171
+ - jupyterlab==4.2.5
172
+ - jupyterlab-pygments==0.3.0
173
+ - jupyterlab-server==2.27.3
174
+ - kiwisolver==1.4.7
175
+ - matplotlib==3.9.2
176
+ - matplotlib-inline==0.1.7
177
+ - mistune==3.0.2
178
+ - multidict==6.1.0
179
+ - nbclient==0.10.0
180
+ - nbconvert==7.16.4
181
+ - nbformat==5.10.4
182
+ - nest-asyncio==1.6.0
183
+ - netcdf4==1.7.2
184
+ - notebook==7.2.2
185
+ - notebook-shim==0.2.4
186
+ - nvfuser-cu118-torch24==0.2.9.dev20240808
187
+ - nvidia-cuda-cupti-cu11==11.8.87
188
+ - nvidia-cuda-nvrtc-cu11==11.8.89
189
+ - nvidia-cuda-runtime-cu11==11.8.89
190
+ - nvidia-ml-py==12.560.30
191
+ - nvidia-nvtx-cu11==11.8.86
192
+ - overrides==7.7.0
193
+ - packaging==24.1
194
+ - pandas==2.2.3
195
+ - pandocfilters==1.5.1
196
+ - parso==0.8.4
197
+ - pexpect==4.9.0
198
+ - platformdirs==4.3.6
199
+ - prometheus-client==0.21.0
200
+ - prompt-toolkit==3.0.48
201
+ - ptyprocess==0.7.0
202
+ - pure-eval==0.2.3
203
+ - pycparser==2.22
204
+ - pygments==2.18.0
205
+ - pyparsing==3.2.0
206
+ - pyproj==3.7.0
207
+ - pyshp==2.3.1
208
+ - python-dateutil==2.9.0.post0
209
+ - python-json-logger==2.0.7
210
+ - pytz==2024.2
211
+ - pyzmq==26.2.0
212
+ - referencing==0.35.1
213
+ - rfc3339-validator==0.1.4
214
+ - rfc3986-validator==0.1.1
215
+ - rpds-py==0.20.0
216
+ - ruamel-yaml==0.18.6
217
+ - ruamel-yaml-clib==0.2.8
218
+ - s3fs==2024.9.0
219
+ - safetensors==0.4.5
220
+ - scikit-learn==1.5.2
221
+ - send2trash==1.8.3
222
+ - shapely==2.0.6
223
+ - six==1.16.0
224
+ - sniffio==1.3.1
225
+ - soupsieve==2.6
226
+ - stack-data==0.6.3
227
+ - terminado==0.18.1
228
+ - thop==0.1.1-2209072238
229
+ - threadpoolctl==3.5.0
230
+ - timm==1.0.9
231
+ - tinycss2==1.3.0
232
+ - tomli==2.0.2
233
+ - torchsummary==1.5.1
234
+ - tornado==6.4.1
235
+ - traitlets==5.14.3
236
+ - treelib==1.7.0
237
+ - types-python-dateutil==2.9.0.20241003
238
+ - tzdata==2024.2
239
+ - uri-template==1.3.0
240
+ - wcwidth==0.2.13
241
+ - webcolors==24.8.0
242
+ - webencodings==0.5.1
243
+ - websocket-client==1.8.0
244
+ - wrapt==1.16.0
245
+ - xarray==2024.9.0
246
+ - yarl==1.13.1
247
+ - zipp==3.20.2
248
+ prefix: /miniconda3/envs/neuralom
exp/NeuralOM/20250309-195251/6_steps_finetune/model2/10_steps_finetune/training_checkpoints/.ipynb_checkpoints/readme-checkpoint.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ The intact project is available at the Hugging Face.
exp/NeuralOM/20250309-195251/6_steps_finetune/model2/10_steps_finetune/training_checkpoints/best_ckpt.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f9fe78a12419997b9deaf0dd2ec912c1c936e96cc418bc507acdd2baecf908a2
3
+ size 661771939
exp/NeuralOM/20250309-195251/6_steps_finetune/model2/10_steps_finetune/training_checkpoints/readme.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ The intact project is available at the Hugging Face.
exp/NeuralOM/20250309-195251/6_steps_finetune/training_checkpoints/.ipynb_checkpoints/readme-checkpoint.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ The intact project is available at the Hugging Face.
exp/NeuralOM/20250309-195251/6_steps_finetune/training_checkpoints/best_ckpt.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8fb1c1827478608c96b490662f9b59351a52b1ee423883158c0a9e7c09b7d919
3
+ size 661813411
exp/NeuralOM/20250309-195251/6_steps_finetune/training_checkpoints/readme.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ The intact project is available at the Hugging Face.
exp/NeuralOM/20250309-195251/6_steps_finetune/training_checkpoints_atmos/best_ckpt.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb94ee483750345aea106470e39366e77bd92014c62344e9dad1837138e0ead7
3
+ size 600426467
exp/NeuralOM/20250309-195251/6_steps_finetune/training_checkpoints_atmos/readme.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ The intact project is available at the Hugging Face.
exp/NeuralOM/20250309-195251/config.yaml ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### base config ###
2
+ # -*- coding: utf-8 -*-
3
+ full_field: &FULL_FIELD
4
+ num_data_workers: 4
5
+ dt: 1
6
+ n_history: 0
7
+ prediction_length: 41
8
+ ics_type: "default"
9
+
10
+ exp_dir: './exp'
11
+
12
+ # data
13
+ train_data_path: './data/train'
14
+ valid_data_path: './data/valid'
15
+ test_data_path: './data/test'
16
+ test_data_path_atmos: './data/test_atmos'
17
+
18
+ # land mask
19
+ land_mask: !!bool True
20
+ land_mask_path: './data/land_mask.h5'
21
+
22
+ # normalization
23
+ normalize: !!bool True
24
+ normalization: 'zscore'
25
+ global_means_path: './data/mean_s_t_ssh.npy'
26
+ global_stds_path: './data/std_s_t_ssh.npy'
27
+
28
+ global_means_path_atmos: './data/mean_atmos.npy'
29
+ global_stds_path_atmos: './data/std_atmos.npy'
30
+
31
+ # orography
32
+ orography: !!bool False
33
+
34
+ # noise
35
+ add_noise: !!bool False
36
+ noise_std: 0
37
+
38
+ # crop
39
+ crop_size_x: None
40
+ crop_size_y: None
41
+
42
+ log_to_screen: !!bool True
43
+ log_to_wandb: !!bool False
44
+ save_checkpoint: !!bool True
45
+ plot_animations: !!bool False
46
+
47
+
48
+ #############################################
49
+ NeuralOM: &NeuralOM
50
+ <<: *FULL_FIELD
51
+ nettype: 'NeuralOM'
52
+ log_to_wandb: !!bool False
53
+
54
+ # Train params
55
+ lr: 1e-3
56
+ batch_size: 32
57
+ scheduler: 'CosineAnnealingLR'
58
+
59
+ loss_channel_wise: True
60
+ loss_scale: False
61
+ use_loss_scaler_from_metnet3: True
62
+
63
+
64
+ atmos_channels: [93, 94, 95, 96]
65
+
66
+ ocean_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ,21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92]
67
+
68
+ in_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ,21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96]
69
+
70
+ out_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ,21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92]
71
+
72
+ in_channels_atmos: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ,21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68]
73
+
74
+ out_channels_atmos: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 ,21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68]
75
+
76
+
77
+ out_variables: ["S0", "S2", "S5", "S7", "S11", "S15", "S21", "S29", "S40", "S55", "S77", "S92", "S109", "S130", "S155", "S186", "S222", "S266", "S318", "S380", "S453", "S541", "S643", "U0", "U2", "U5", "U7", "U11", "U15", "U21", "U29", "U40", "U55", "U77", "U92", "U109", "U130", "U155", "U186", "U222", "U266", "U318", "U380", "U453", "U541", "U643", "V0", "V2", "V5", "V7", "V11", "V15", "V21", "V29", "V40", "V55", "V77", "V92", "V109", "V130", "V155", "V186", "V222", "V266", "V318", "V380", "V453", "V541", "V643", "T0", "T2", "T5", "T7", "T11", "T15", "T21", "T29", "T40", "T55", "T77", "T92", "T109", "T130", "T155", "T186", "T222", "T266", "T318", "T380", "T453", "T541", "T643", "SSH"]
78
+
inference_forecasting.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+ import glob
5
+ import h5py
6
+ import logging
7
+ import argparse
8
+ import numpy as np
9
+ from icecream import ic
10
+ from datetime import datetime
11
+ from collections import OrderedDict
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.cuda.amp as amp
15
+ import torch.nn.functional as F
16
+ import torch.distributed as dist
17
+ from torch.nn.parallel import DistributedDataParallel
18
+
19
+ sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../')
20
+ from my_utils.YParams import YParams
21
+ from my_utils.data_loader import get_data_loader
22
+ from my_utils import logging_utils
23
+ logging_utils.config_logger()
24
+
25
+
26
+ def load_model(model, params, checkpoint_file):
27
+ model.zero_grad()
28
+ checkpoint_fname = checkpoint_file
29
+ checkpoint = torch.load(checkpoint_fname)
30
+ try:
31
+ new_state_dict = OrderedDict()
32
+ for key, val in checkpoint['model_state'].items():
33
+ name = key[7:]
34
+ if name != 'ged':
35
+ new_state_dict[name] = val
36
+ model.load_state_dict(new_state_dict)
37
+ except:
38
+ model.load_state_dict(checkpoint['model_state'])
39
+ model.eval()
40
+ return model
41
+
42
+ def setup(params):
43
+ device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
44
+
45
+ # get data loader
46
+ valid_data_loader, valid_dataset = get_data_loader(params, params.test_data_path, dist.is_initialized(), train=False)
47
+
48
+ img_shape_x = valid_dataset.img_shape_x
49
+ img_shape_y = valid_dataset.img_shape_y
50
+ params.img_shape_x = img_shape_x
51
+ params.img_shape_y = img_shape_y
52
+
53
+ in_channels = np.array(params.in_channels)
54
+ out_channels = np.array(params.out_channels)
55
+ n_in_channels = len(in_channels)
56
+ n_out_channels = len(out_channels)
57
+
58
+ params['N_in_channels'] = n_in_channels
59
+ params['N_out_channels'] = n_out_channels
60
+
61
+ if params.normalization == 'zscore':
62
+ params.means = np.load(params.global_means_path)
63
+ params.stds = np.load(params.global_stds_path)
64
+
65
+ params.means_atmos = np.load(params.global_means_path_atmos)
66
+ params.stds_atmos = np.load(params.global_stds_path_atmos)
67
+
68
+ if params.nettype == 'NeuralOM':
69
+ from networks.MIGNN1 import MIGraph as model
70
+ from networks.MIGNN2 import MIGraph_stage2 as model2
71
+ from networks.OneForecast import OneForecast as model_atmos
72
+ else:
73
+ raise Exception("not implemented")
74
+
75
+ checkpoint_file = params['best_checkpoint_path']
76
+ checkpoint_file2 = params['best_checkpoint_path2']
77
+ checkpoint_file_atmos = params['best_checkpoint_path_atmos']
78
+ logging.info('Loading trained model checkpoint from {}'.format(checkpoint_file))
79
+ logging.info('Loading trained model2 checkpoint from {}'.format(checkpoint_file2))
80
+ logging.info('Loading trained model_atmos checkpoint from {}'.format(checkpoint_file_atmos))
81
+
82
+ model = model(params).to(device)
83
+ model = load_model(model, params, checkpoint_file)
84
+ model = model.to(device)
85
+
86
+ print('model is ok')
87
+
88
+ model2 = model2(params).to(device)
89
+ model2 = load_model(model2, params, checkpoint_file2)
90
+ model2 = model2.to(device)
91
+
92
+ print('model2 is ok')
93
+
94
+ model_atmos = model_atmos(params).to(device)
95
+ model_atmos = load_model(model_atmos, params, checkpoint_file_atmos)
96
+ model_atmos = model_atmos.to(device)
97
+
98
+ print('model_atmos is ok')
99
+
100
+ files_paths = glob.glob(params.test_data_path + "/*.h5")
101
+ files_paths.sort()
102
+
103
+ files_paths_atmos = glob.glob(params.test_data_path_atmos + "/*.h5")
104
+ files_paths_atmos.sort()
105
+
106
+ # which year
107
+ yr = 0
108
+ logging.info('Loading inference data')
109
+ logging.info('Inference data from {}'.format(files_paths[yr]))
110
+ logging.info('Inference data_atmos from {}'.format(files_paths_atmos[yr]))
111
+ climate_mean = np.load('./data/climate_mean_s_t_ssh.npy')
112
+ valid_data_full = h5py.File(files_paths[yr], 'r')['fields'][:365, :, :, :]
113
+ valid_data_full = valid_data_full - climate_mean
114
+
115
+ valid_data_full_atmos = h5py.File(files_paths_atmos[yr], 'r')['fields'][2:1460:4, :, :, :]
116
+
117
+ return valid_data_full, valid_data_full_atmos, model, model2, model_atmos
118
+
119
+
120
+ def autoregressive_inference(params, init_condition, valid_data_full, valid_data_full_atmos, model, model2, model_atmos):
121
+ device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
122
+
123
+ icd = int(init_condition)
124
+
125
+ exp_dir = params['experiment_dir']
126
+ dt = int(params.dt)
127
+ prediction_length = int(params.prediction_length/dt)
128
+ n_history = params.n_history
129
+ img_shape_x = params.img_shape_x
130
+ img_shape_y = params.img_shape_y
131
+ in_channels = np.array(params.in_channels)
132
+ out_channels = np.array(params.out_channels)
133
+ in_channels_atmos = np.array(params.in_channels_atmos)
134
+ out_channels_atmos = np.array(params.out_channels_atmos)
135
+ atmos_channels = np.array(params.atmos_channels)
136
+ n_in_channels = len(in_channels)
137
+ n_out_channels = len(out_channels)
138
+
139
+ seq_real = torch.zeros((prediction_length, n_out_channels, img_shape_x, img_shape_y))
140
+ seq_pred = torch.zeros((prediction_length, n_out_channels, img_shape_x, img_shape_y))
141
+
142
+
143
+ valid_data = valid_data_full[icd:(icd+prediction_length*dt+n_history*dt):dt][:, params.in_channels][:,:,0:360]
144
+ valid_data_atmos = valid_data_full_atmos[icd:(icd+prediction_length*dt+n_history*dt):dt][:, params.in_channels_atmos][:,:,0:120]
145
+ logging.info(f'valid_data_full: {valid_data_full.shape}')
146
+ logging.info(f'valid_data: {valid_data.shape}')
147
+ logging.info(f'valid_data_full_atmos: {valid_data_full_atmos.shape}')
148
+ logging.info(f'valid_data_atmos: {valid_data_atmos.shape}')
149
+
150
+ # normalize
151
+ if params.normalization == 'zscore':
152
+ valid_data = (valid_data - params.means[:,params.in_channels])/params.stds[:,params.in_channels]
153
+ valid_data = np.nan_to_num(valid_data, nan=0)
154
+
155
+ valid_data_atmos = (valid_data_atmos - params.means_atmos[:,params.in_channels_atmos])/params.stds_atmos[:,params.in_channels_atmos]
156
+ valid_data_atmos = np.nan_to_num(valid_data_atmos, nan=0)
157
+
158
+ valid_data = torch.as_tensor(valid_data)
159
+ valid_data_atmos = torch.as_tensor(valid_data_atmos)
160
+
161
+ # autoregressive inference
162
+ logging.info('Begin autoregressive inference')
163
+
164
+
165
+ with torch.no_grad():
166
+ for i in range(valid_data.shape[0]):
167
+ if i==0: # start of sequence, t0 --> t0'
168
+ first = valid_data[0:n_history+1]
169
+ first_atmos = valid_data_atmos[0:n_history+1]
170
+ ic(valid_data.shape, first.shape)
171
+ ic(valid_data_atmos.shape, first_atmos.shape)
172
+ future = valid_data[n_history+1]
173
+ ic(future.shape)
174
+
175
+ for h in range(n_history+1):
176
+
177
+ seq_real[h] = first[h*n_in_channels : (h+1)*n_in_channels, :93]
178
+
179
+ seq_pred[h] = seq_real[h]
180
+
181
+ first = first.to(device, dtype=torch.float)
182
+ first_atmos = first_atmos.to(device, dtype=torch.float)
183
+ first_ocean = first[:, params.ocean_channels, :, :]
184
+ ic(first_ocean.shape)
185
+ future_force0 = first_atmos[:, [65, 66, 67, 68], :120, :240]
186
+ # future_force0 = torch.unsqueeze(future_force0, dim=0).to(device, dtype=torch.float)
187
+ future_force0 = F.interpolate(future_force0, size=(360, 720), mode='bilinear', align_corners=False)
188
+
189
+ model_input_atmos = first_atmos
190
+ ic(model_input_atmos.shape)
191
+ for k in range(4):
192
+ if k ==0:
193
+ model_atmos_future_pred = model_atmos(model_input_atmos)
194
+ else:
195
+ model_atmos_future_pred = model_atmos(model_atmos_future_pred)
196
+
197
+ future_force = model_atmos_future_pred[:, [65, 66, 67, 68], :120, :240]
198
+ # future_force = torch.unsqueeze(future_force, dim=0).to(device, dtype=torch.float)
199
+ future_force = F.interpolate(future_force, size=(360, 720), mode='bilinear', align_corners=False)
200
+
201
+ model_input = torch.cat((first_ocean, future_force0, future_force.cuda()), axis=1)
202
+ ic(model_input.shape)
203
+ model1_future_pred = model(model_input)
204
+ with h5py.File(params.land_mask_path, 'r') as _f:
205
+ mask_data = torch.as_tensor(_f['fields'][:,out_channels, :360, :720], dtype=bool).to(device, dtype=torch.bool)
206
+ model1_future_pred = torch.masked_fill(input=model1_future_pred, mask=~mask_data, value=0)
207
+ future_pred = model2(model1_future_pred) + model1_future_pred
208
+
209
+
210
+ else:
211
+ if i < prediction_length-1:
212
+ future0 = valid_data[n_history+i]
213
+ future = valid_data[n_history+i+1]
214
+
215
+ inf_one_step_start = time.time()
216
+ future_force0 = model_atmos_future_pred[:, [65, 66, 67, 68], :120, :240]
217
+ # future_force0 = torch.unsqueeze(future_force0, dim=0).to(device, dtype=torch.float)
218
+ future_force0 = F.interpolate(future_force0, size=(360, 720), mode='bilinear', align_corners=False)
219
+
220
+ for k in range(4):
221
+ model_atmos_future_pred = model_atmos(model_atmos_future_pred)
222
+
223
+ future_force = model_atmos_future_pred[:, [65, 66, 67, 68], :120, :240]
224
+ # future_force = torch.unsqueeze(future_force, dim=0).to(device, dtype=torch.float)
225
+ future_force = F.interpolate(future_force, size=(360, 720), mode='bilinear', align_corners=False)
226
+
227
+ model1_future_pred = model(torch.cat((future_pred.cuda(), future_force0, future_force), axis=1)) #autoregressive step
228
+ with h5py.File(params.land_mask_path, 'r') as _f:
229
+ mask_data = torch.as_tensor(_f['fields'][:,out_channels, :360, :720], dtype=bool).to(device, dtype=torch.bool)
230
+ model1_future_pred = torch.masked_fill(input=model1_future_pred, mask=~mask_data, value=0)
231
+ future_pred = model2(model1_future_pred) + model1_future_pred
232
+ inf_one_step_time = time.time() - inf_one_step_start
233
+
234
+ logging.info(f'inference one step time: {inf_one_step_time}')
235
+
236
+
237
+ if i < prediction_length - 1: # not on the last step
238
+ with h5py.File(params.land_mask_path, 'r') as _f:
239
+ mask_data = torch.as_tensor(_f['fields'][:,out_channels, :360, :720], dtype=bool)
240
+ seq_pred[n_history+i+1] = torch.masked_fill(input=future_pred.cpu(), mask=~mask_data, value=0)
241
+ seq_real[n_history+i+1] = future[:93]
242
+ history_stack = seq_pred[i+1:i+2+n_history]
243
+
244
+ future_pred = history_stack
245
+
246
+ pred = torch.unsqueeze(seq_pred[i], 0)
247
+ tar = torch.unsqueeze(seq_real[i], 0)
248
+
249
+ with h5py.File(params.land_mask_path, 'r') as _f:
250
+ mask_data = torch.as_tensor(_f['fields'][:,out_channels, :360, :720], dtype=bool)
251
+ ic(mask_data.shape, pred.shape, tar.shape)
252
+ pred = torch.masked_fill(input=pred, mask=~mask_data, value=0)
253
+ tar = torch.masked_fill(input=tar, mask=~mask_data, value=0)
254
+
255
+ print(torch.mean((pred-tar)**2))
256
+
257
+
258
+ seq_real = seq_real * params.stds[:,params.out_channels] + params.means[:,params.out_channels]
259
+ seq_real = seq_real.numpy()
260
+ seq_pred = seq_pred * params.stds[:,params.out_channels] + params.means[:,params.out_channels]
261
+ seq_pred = seq_pred.numpy()
262
+
263
+
264
+ return (np.expand_dims(seq_real[n_history:], 0),
265
+ np.expand_dims(seq_pred[n_history:], 0),
266
+ )
267
+
268
+
269
+ if __name__ == '__main__':
270
+ parser = argparse.ArgumentParser()
271
+ parser.add_argument("--exp_dir", default='../exp_15_levels', type=str)
272
+ parser.add_argument("--config", default='full_field', type=str)
273
+ parser.add_argument("--run_num", default='00', type=str)
274
+ parser.add_argument("--prediction_length", default=61, type=int)
275
+ parser.add_argument("--finetune_dir", default='', type=str)
276
+ parser.add_argument("--ics_type", default='default', type=str)
277
+ args = parser.parse_args()
278
+
279
+ config_path = os.path.join(args.exp_dir, args.config, args.run_num, 'config.yaml')
280
+ params = YParams(config_path, args.config)
281
+
282
+ params['resuming'] = False
283
+ params['interp'] = 0
284
+ params['world_size'] = 1
285
+ params['local_rank'] = 0
286
+ params['global_batch_size'] = params.batch_size
287
+ params['prediction_length'] = args.prediction_length
288
+ params['multi_steps_finetune'] = 1
289
+
290
+ torch.cuda.set_device(0)
291
+ torch.backends.cudnn.benchmark = True
292
+
293
+ # Set up directory
294
+ if args.finetune_dir == '':
295
+ expDir = os.path.join(params.exp_dir, args.config, str(args.run_num))
296
+ else:
297
+ expDir = os.path.join(params.exp_dir, args.config, str(args.run_num), args.finetune_dir)
298
+ logging.info(f'expDir: {expDir}')
299
+ params['experiment_dir'] = expDir
300
+ params['best_checkpoint_path'] = os.path.join(expDir, 'training_checkpoints/best_ckpt.tar')
301
+ params['best_checkpoint_path2'] = os.path.join(expDir, 'model2/10_steps_finetune/training_checkpoints/best_ckpt.tar')
302
+
303
+ params['best_checkpoint_path_atmos'] = os.path.join(expDir, 'training_checkpoints_atmos/best_ckpt.tar')
304
+
305
+ # set up logging
306
+ logging_utils.log_to_file(logger_name=None, log_filename=os.path.join(expDir, 'inference.log'))
307
+ logging_utils.log_versions()
308
+ params.log()
309
+
310
+ if params["ics_type"] == 'default':
311
+ ics = np.arange(0, 50, 1)
312
+ n_ics = len(ics)
313
+ print('init_condition:', ics)
314
+
315
+ logging.info("Inference for {} initial conditions".format(n_ics))
316
+
317
+ try:
318
+ autoregressive_inference_filetag = params["inference_file_tag"]
319
+ except:
320
+ autoregressive_inference_filetag = ""
321
+ if params.interp > 0:
322
+ autoregressive_inference_filetag = "_coarse"
323
+
324
+ valid_data_full, valid_data_full_atmos, model, model2, model_atmos = setup(params)
325
+
326
+
327
+ seq_pred = []
328
+ seq_real = []
329
+
330
+ # run autoregressive inference for multiple initial conditions
331
+ for i, ic_ in enumerate(ics):
332
+ logging.info("Initial condition {} of {}".format(i+1, n_ics))
333
+ seq_real, seq_pred = autoregressive_inference(params, ic_, valid_data_full, valid_data_full_atmos, model, model2, model_atmos)
334
+
335
+ prediction_length = seq_real[0].shape[0]
336
+ n_out_channels = seq_real[0].shape[1]
337
+ img_shape_x = seq_real[0].shape[2]
338
+ img_shape_y = seq_real[0].shape[3]
339
+
340
+ # save predictions and loss
341
+ save_path = os.path.join(params['experiment_dir'], 'results_forecasting.h5')
342
+ logging.info("Saving to {}".format(save_path))
343
+ print(f'saving to {save_path}')
344
+ if i==0:
345
+ f = h5py.File(save_path, 'w')
346
+ f.create_dataset(
347
+ "ground_truth",
348
+ data=seq_real,
349
+ maxshape=[None, prediction_length, n_out_channels, img_shape_x, img_shape_y],
350
+ dtype=np.float32)
351
+ f.create_dataset(
352
+ "predicted",
353
+ data=seq_pred,
354
+ maxshape=[None, prediction_length, n_out_channels, img_shape_x, img_shape_y],
355
+ dtype=np.float32)
356
+ f.close()
357
+ else:
358
+ f = h5py.File(save_path, 'a')
359
+
360
+ f["ground_truth"].resize((f["ground_truth"].shape[0] + 1), axis = 0)
361
+ f["ground_truth"][-1:] = seq_real
362
+
363
+ f["predicted"].resize((f["predicted"].shape[0] + 1), axis = 0)
364
+ f["predicted"][-1:] = seq_pred
365
+ f.close()
366
+
inference_forecasting.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ prediction_length=61 # 31
2
+
3
+ exp_dir='./exp'
4
+ config='NeuralOM'
5
+ run_num='20250309-195251'
6
+ finetune_dir='6_steps_finetune'
7
+
8
+ ics_type='default'
9
+
10
+ CUDA_VISIBLE_DEVICES=2 python inference_forecasting.py --exp_dir=${exp_dir} --config=${config} --run_num=${run_num} --finetune_dir=$finetune_dir --prediction_length=${prediction_length} --ics_type=${ics_type}
11
+
12
+
13
+
inference_simulation.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+ import glob
5
+ import h5py
6
+ import logging
7
+ import argparse
8
+ import numpy as np
9
+ from icecream import ic
10
+ from datetime import datetime
11
+ from collections import OrderedDict
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.cuda.amp as amp
15
+ import torch.distributed as dist
16
+ from torch.nn.parallel import DistributedDataParallel
17
+
18
+ sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../')
19
+ from my_utils.YParams import YParams
20
+ from my_utils.data_loader import get_data_loader
21
+ from my_utils import logging_utils
22
+ logging_utils.config_logger()
23
+
24
+
25
+ def load_model(model, params, checkpoint_file):
26
+ model.zero_grad()
27
+ checkpoint_fname = checkpoint_file
28
+ checkpoint = torch.load(checkpoint_fname)
29
+ try:
30
+ new_state_dict = OrderedDict()
31
+ for key, val in checkpoint['model_state'].items():
32
+ name = key[7:]
33
+ if name != 'ged':
34
+ new_state_dict[name] = val
35
+ model.load_state_dict(new_state_dict)
36
+ except:
37
+ model.load_state_dict(checkpoint['model_state'])
38
+ model.eval()
39
+ return model
40
+
41
+ def setup(params):
42
+ device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
43
+
44
+ # get data loader
45
+ valid_data_loader, valid_dataset = get_data_loader(params, params.test_data_path, dist.is_initialized(), train=False)
46
+
47
+ img_shape_x = valid_dataset.img_shape_x
48
+ img_shape_y = valid_dataset.img_shape_y
49
+ params.img_shape_x = img_shape_x
50
+ params.img_shape_y = img_shape_y
51
+
52
+ in_channels = np.array(params.in_channels)
53
+ out_channels = np.array(params.out_channels)
54
+ n_in_channels = len(in_channels)
55
+ n_out_channels = len(out_channels)
56
+
57
+ params['N_in_channels'] = n_in_channels
58
+ params['N_out_channels'] = n_out_channels
59
+
60
+ if params.normalization == 'zscore':
61
+ params.means = np.load(params.global_means_path)
62
+ params.stds = np.load(params.global_stds_path)
63
+
64
+ if params.nettype == 'NeuralOM':
65
+ from networks.MIGNN1 import MIGraph as model
66
+ from networks.MIGNN2 import MIGraph_stage2 as model2
67
+ else:
68
+ raise Exception("not implemented")
69
+
70
+ checkpoint_file = params['best_checkpoint_path']
71
+ checkpoint_file2 = params['best_checkpoint_path2']
72
+ logging.info('Loading trained model checkpoint from {}'.format(checkpoint_file))
73
+ logging.info('Loading trained model2 checkpoint from {}'.format(checkpoint_file2))
74
+
75
+ model = model(params).to(device)
76
+ model = load_model(model, params, checkpoint_file)
77
+ model = model.to(device)
78
+
79
+ print('model is ok')
80
+
81
+ model2 = model2(params).to(device)
82
+ model2 = load_model(model2, params, checkpoint_file2)
83
+ model2 = model2.to(device)
84
+
85
+ print('model2 is ok')
86
+
87
+ files_paths = glob.glob(params.test_data_path + "/*.h5")
88
+ files_paths.sort()
89
+
90
+ # which year
91
+ yr = 0
92
+ logging.info('Loading inference data')
93
+ logging.info('Inference data from {}'.format(files_paths[yr]))
94
+ climate_mean = np.load('./data/climate_mean_s_t_ssh.npy')
95
+ valid_data_full = h5py.File(files_paths[yr], 'r')['fields'][:365, :, :, :]
96
+ valid_data_full = valid_data_full - climate_mean
97
+
98
+ return valid_data_full, model, model2
99
+
100
+
101
+ def autoregressive_inference(params, init_condition, valid_data_full, model, model2):
102
+ device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
103
+
104
+ icd = int(init_condition)
105
+
106
+ exp_dir = params['experiment_dir']
107
+ dt = int(params.dt)
108
+ prediction_length = int(params.prediction_length/dt)
109
+ n_history = params.n_history
110
+ img_shape_x = params.img_shape_x
111
+ img_shape_y = params.img_shape_y
112
+ in_channels = np.array(params.in_channels)
113
+ out_channels = np.array(params.out_channels)
114
+ atmos_channels = np.array(params.atmos_channels)
115
+ n_in_channels = len(in_channels)
116
+ n_out_channels = len(out_channels)
117
+
118
+ seq_real = torch.zeros((prediction_length, n_out_channels, img_shape_x, img_shape_y))
119
+ seq_pred = torch.zeros((prediction_length, n_out_channels, img_shape_x, img_shape_y))
120
+
121
+
122
+ valid_data = valid_data_full[icd:(icd+prediction_length*dt+n_history*dt):dt][:, params.in_channels][:,:,0:360]
123
+ logging.info(f'valid_data_full: {valid_data_full.shape}')
124
+ logging.info(f'valid_data: {valid_data.shape}')
125
+
126
+ # normalize
127
+ if params.normalization == 'zscore':
128
+ valid_data = (valid_data - params.means[:,params.in_channels])/params.stds[:,params.in_channels]
129
+ valid_data = np.nan_to_num(valid_data, nan=0)
130
+
131
+ valid_data = torch.as_tensor(valid_data)
132
+
133
+ # autoregressive inference
134
+ logging.info('Begin autoregressive inference')
135
+
136
+
137
+ with torch.no_grad():
138
+ for i in range(valid_data.shape[0]):
139
+ if i==0: # start of sequence, t0 --> t0'
140
+ first = valid_data[0:n_history+1]
141
+ ic(valid_data.shape, first.shape)
142
+ future = valid_data[n_history+1]
143
+ ic(future.shape)
144
+
145
+ for h in range(n_history+1):
146
+
147
+ seq_real[h] = first[h*n_in_channels : (h+1)*n_in_channels, :93]
148
+
149
+ seq_pred[h] = seq_real[h]
150
+
151
+ first = first.to(device, dtype=torch.float)
152
+ first_ocean = first[:, params.ocean_channels, :, :]
153
+ ic(first_ocean.shape)
154
+ future_force0 = first[:, params.atmos_channels, :, :]
155
+
156
+ future_force = future[params.atmos_channels, :360, :720]
157
+ future_force = torch.unsqueeze(future_force, dim=0).to(device, dtype=torch.float)
158
+ model_input = torch.cat((first_ocean, future_force0, future_force.cuda()), axis=1)
159
+ ic(model_input.shape)
160
+ model1_future_pred = model(model_input)
161
+ with h5py.File(params.land_mask_path, 'r') as _f:
162
+ mask_data = torch.as_tensor(_f['fields'][:,out_channels, :360, :720], dtype=bool).to(device, dtype=torch.bool)
163
+ model1_future_pred = torch.masked_fill(input=model1_future_pred, mask=~mask_data, value=0)
164
+ future_pred = model2(model1_future_pred) + model1_future_pred
165
+
166
+ else:
167
+ if i < prediction_length-1:
168
+ future0 = valid_data[n_history+i]
169
+ future = valid_data[n_history+i+1]
170
+
171
+ inf_one_step_start = time.time()
172
+ future_force0 = future0[params.atmos_channels, :360, :720]
173
+ future_force = future[params.atmos_channels, :360, :720]
174
+ future_force0 = torch.unsqueeze(future_force0, dim=0).to(device, dtype=torch.float)
175
+ future_force = torch.unsqueeze(future_force, dim=0).to(device, dtype=torch.float)
176
+ model1_future_pred = model(torch.cat((future_pred.cuda(), future_force0, future_force), axis=1)) #autoregressive step
177
+ with h5py.File(params.land_mask_path, 'r') as _f:
178
+ mask_data = torch.as_tensor(_f['fields'][:,out_channels, :360, :720], dtype=bool).to(device, dtype=torch.bool)
179
+ model1_future_pred = torch.masked_fill(input=model1_future_pred, mask=~mask_data, value=0)
180
+ future_pred = model2(model1_future_pred) + model1_future_pred
181
+ inf_one_step_time = time.time() - inf_one_step_start
182
+
183
+ logging.info(f'inference one step time: {inf_one_step_time}')
184
+
185
+ if i < prediction_length - 1: # not on the last step
186
+ with h5py.File(params.land_mask_path, 'r') as _f:
187
+ mask_data = torch.as_tensor(_f['fields'][:,out_channels, :360, :720], dtype=bool)
188
+ seq_pred[n_history+i+1] = torch.masked_fill(input=future_pred.cpu(), mask=~mask_data, value=0)
189
+ seq_real[n_history+i+1] = future[:93]
190
+ history_stack = seq_pred[i+1:i+2+n_history]
191
+
192
+ future_pred = history_stack
193
+
194
+ pred = torch.unsqueeze(seq_pred[i], 0)
195
+ tar = torch.unsqueeze(seq_real[i], 0)
196
+
197
+ with h5py.File(params.land_mask_path, 'r') as _f:
198
+ mask_data = torch.as_tensor(_f['fields'][:,out_channels, :360, :720], dtype=bool)
199
+ ic(mask_data.shape, pred.shape, tar.shape)
200
+ pred = torch.masked_fill(input=pred, mask=~mask_data, value=0)
201
+ tar = torch.masked_fill(input=tar, mask=~mask_data, value=0)
202
+
203
+ print(torch.mean((pred-tar)**2))
204
+
205
+
206
+ seq_real = seq_real * params.stds[:,params.out_channels] + params.means[:,params.out_channels]
207
+ seq_real = seq_real.numpy()
208
+ seq_pred = seq_pred * params.stds[:,params.out_channels] + params.means[:,params.out_channels]
209
+ seq_pred = seq_pred.numpy()
210
+
211
+
212
+ return (np.expand_dims(seq_real[n_history:], 0),
213
+ np.expand_dims(seq_pred[n_history:], 0),
214
+ )
215
+
216
+
217
+ if __name__ == '__main__':
218
+ parser = argparse.ArgumentParser()
219
+ parser.add_argument("--exp_dir", default='../exp_15_levels', type=str)
220
+ parser.add_argument("--config", default='full_field', type=str)
221
+ parser.add_argument("--run_num", default='00', type=str)
222
+ parser.add_argument("--prediction_length", default=61, type=int)
223
+ parser.add_argument("--finetune_dir", default='', type=str)
224
+ parser.add_argument("--ics_type", default='default', type=str)
225
+ args = parser.parse_args()
226
+
227
+ config_path = os.path.join(args.exp_dir, args.config, args.run_num, 'config.yaml')
228
+ params = YParams(config_path, args.config)
229
+
230
+ params['resuming'] = False
231
+ params['interp'] = 0
232
+ params['world_size'] = 1
233
+ params['local_rank'] = 0
234
+ params['global_batch_size'] = params.batch_size
235
+ params['prediction_length'] = args.prediction_length
236
+ params['multi_steps_finetune'] = 1
237
+
238
+ torch.cuda.set_device(0)
239
+ torch.backends.cudnn.benchmark = True
240
+
241
+ # Set up directory
242
+ if args.finetune_dir == '':
243
+ expDir = os.path.join(params.exp_dir, args.config, str(args.run_num))
244
+ else:
245
+ expDir = os.path.join(params.exp_dir, args.config, str(args.run_num), args.finetune_dir)
246
+ logging.info(f'expDir: {expDir}')
247
+ params['experiment_dir'] = expDir
248
+ params['best_checkpoint_path'] = os.path.join(expDir, 'training_checkpoints/best_ckpt.tar')
249
+ params['best_checkpoint_path2'] = os.path.join(expDir, 'model2/10_steps_finetune/training_checkpoints/best_ckpt.tar')
250
+
251
+ # set up logging
252
+ logging_utils.log_to_file(logger_name=None, log_filename=os.path.join(expDir, 'inference.log'))
253
+ logging_utils.log_versions()
254
+ params.log()
255
+
256
+ if params["ics_type"] == 'default':
257
+ ics = np.arange(0, 240, 1)
258
+ n_ics = len(ics)
259
+ print('init_condition:', ics)
260
+
261
+ logging.info("Inference for {} initial conditions".format(n_ics))
262
+
263
+ try:
264
+ autoregressive_inference_filetag = params["inference_file_tag"]
265
+ except:
266
+ autoregressive_inference_filetag = ""
267
+ if params.interp > 0:
268
+ autoregressive_inference_filetag = "_coarse"
269
+
270
+ valid_data_full, model, model2 = setup(params)
271
+
272
+
273
+ seq_pred = []
274
+ seq_real = []
275
+
276
+ # run autoregressive inference for multiple initial conditions
277
+ for i, ic_ in enumerate(ics):
278
+ logging.info("Initial condition {} of {}".format(i+1, n_ics))
279
+ seq_real, seq_pred = autoregressive_inference(params, ic_, valid_data_full, model, model2)
280
+
281
+ prediction_length = seq_real[0].shape[0]
282
+ n_out_channels = seq_real[0].shape[1]
283
+ img_shape_x = seq_real[0].shape[2]
284
+ img_shape_y = seq_real[0].shape[3]
285
+
286
+ # save predictions and loss
287
+ save_path = os.path.join(params['experiment_dir'], 'results_simulation.h5')
288
+ logging.info("Saving to {}".format(save_path))
289
+ print(f'saving to {save_path}')
290
+ if i==0:
291
+ f = h5py.File(save_path, 'w')
292
+ f.create_dataset(
293
+ "ground_truth",
294
+ data=seq_real,
295
+ maxshape=[None, prediction_length, n_out_channels, img_shape_x, img_shape_y],
296
+ dtype=np.float32)
297
+ f.create_dataset(
298
+ "predicted",
299
+ data=seq_pred,
300
+ maxshape=[None, prediction_length, n_out_channels, img_shape_x, img_shape_y],
301
+ dtype=np.float32)
302
+ f.close()
303
+ else:
304
+ f = h5py.File(save_path, 'a')
305
+
306
+ f["ground_truth"].resize((f["ground_truth"].shape[0] + 1), axis = 0)
307
+ f["ground_truth"][-1:] = seq_real
308
+
309
+ f["predicted"].resize((f["predicted"].shape[0] + 1), axis = 0)
310
+ f["predicted"][-1:] = seq_pred
311
+ f.close()
312
+
inference_simulation.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ prediction_length=61 # 31
2
+
3
+ exp_dir='./exp'
4
+ config='NeuralOM'
5
+ run_num='20250309-195251'
6
+ finetune_dir='6_steps_finetune'
7
+
8
+ ics_type='default'
9
+
10
+ CUDA_VISIBLE_DEVICES=0 python inference_simulation.py --exp_dir=${exp_dir} --config=${config} --run_num=${run_num} --finetune_dir=$finetune_dir --prediction_length=${prediction_length} --ics_type=${ics_type}
11
+
12
+
13
+
my_utils/YParams.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+
3
+ import importlib
4
+ import sys
5
+ import os
6
+ importlib.reload(sys)
7
+
8
+ from ruamel.yaml import YAML
9
+ import logging
10
+
11
+ class YParams():
12
+ """ Yaml file parser """
13
+ def __init__(self, yaml_filename, config_name, print_params=False):
14
+ self._yaml_filename = yaml_filename
15
+ self._config_name = config_name
16
+ self.params = {}
17
+
18
+ if print_params:
19
+ print(os.system('hostname'))
20
+ print("------------------ Configuration ------------------ ", yaml_filename)
21
+
22
+ with open(yaml_filename, 'rb') as _file:
23
+ yaml = YAML().load(_file)
24
+ for key, val in yaml[config_name].items():
25
+ if print_params: print(key, val)
26
+ if val =='None': val = None
27
+
28
+ self.params[key] = val
29
+ self.__setattr__(key, val)
30
+
31
+ if print_params:
32
+ print("---------------------------------------------------")
33
+
34
+ def __getitem__(self, key):
35
+ return self.params[key]
36
+
37
+ def __setitem__(self, key, val):
38
+ self.params[key] = val
39
+ self.__setattr__(key, val)
40
+
41
+ def __contains__(self, key):
42
+ return (key in self.params)
43
+
44
+ def update_params(self, config):
45
+ for key, val in config.items():
46
+ self.params[key] = val
47
+ self.__setattr__(key, val)
48
+
49
+ def log(self):
50
+ logging.info("------------------ Configuration ------------------")
51
+ logging.info("Configuration file: "+str(self._yaml_filename))
52
+ logging.info("Configuration name: "+str(self._config_name))
53
+ for key, val in self.params.items():
54
+ logging.info(str(key) + ' ' + str(val))
55
+ logging.info("---------------------------------------------------")
my_utils/__pycache__/YParams.cpython-310.pyc ADDED
Binary file (2.1 kB). View file
 
my_utils/__pycache__/YParams.cpython-37.pyc ADDED
Binary file (2.11 kB). View file
 
my_utils/__pycache__/YParams.cpython-39.pyc ADDED
Binary file (2.08 kB). View file
 
my_utils/__pycache__/bicubic.cpython-310.pyc ADDED
Binary file (9.24 kB). View file
 
my_utils/__pycache__/bicubic.cpython-39.pyc ADDED
Binary file (9.2 kB). View file
 
my_utils/__pycache__/darcy_loss.cpython-310.pyc ADDED
Binary file (13.7 kB). View file
 
my_utils/__pycache__/darcy_loss.cpython-310.pyc.70370790180304 ADDED
Binary file (13.5 kB). View file
 
my_utils/__pycache__/darcy_loss.cpython-310.pyc.70373230085584 ADDED
Binary file (13.5 kB). View file
 
my_utils/__pycache__/darcy_loss.cpython-310.pyc.70384414393808 ADDED
Binary file (13.5 kB). View file
 
my_utils/__pycache__/darcy_loss.cpython-37.pyc ADDED
Binary file (9.02 kB). View file
 
my_utils/__pycache__/darcy_loss.cpython-39.pyc ADDED
Binary file (14.2 kB). View file
 
my_utils/__pycache__/data_loader.cpython-310.pyc ADDED
Binary file (5 kB). View file
 
my_utils/__pycache__/data_loader_multifiles.cpython-310.pyc ADDED
Binary file (4.94 kB). View file
 
my_utils/__pycache__/data_loader_multifiles.cpython-37.pyc ADDED
Binary file (3.69 kB). View file
 
my_utils/__pycache__/data_loader_multifiles.cpython-39.pyc ADDED
Binary file (16.5 kB). View file
 
my_utils/__pycache__/get_date.cpython-310.pyc ADDED
Binary file (167 Bytes). View file
 
my_utils/__pycache__/img_utils.cpython-310.pyc ADDED
Binary file (3.39 kB). View file
 
my_utils/__pycache__/img_utils.cpython-37.pyc ADDED
Binary file (3.81 kB). View file
 
my_utils/__pycache__/img_utils.cpython-39.pyc ADDED
Binary file (4.83 kB). View file
 
my_utils/__pycache__/logging_utils.cpython-310.pyc ADDED
Binary file (1 kB). View file
 
my_utils/__pycache__/logging_utils.cpython-37.pyc ADDED
Binary file (994 Bytes). View file
 
my_utils/__pycache__/logging_utils.cpython-39.pyc ADDED
Binary file (995 Bytes). View file
 
my_utils/__pycache__/norm.cpython-310.pyc ADDED
Binary file (3.37 kB). View file
 
my_utils/__pycache__/time_utils.cpython-310.pyc ADDED
Binary file (606 Bytes). View file
 
my_utils/__pycache__/time_utils.cpython-39.pyc ADDED
Binary file (578 Bytes). View file
 
my_utils/__pycache__/weighted_acc_rmse.cpython-310.pyc ADDED
Binary file (5.9 kB). View file
 
my_utils/__pycache__/weighted_acc_rmse.cpython-37.pyc ADDED
Binary file (6.27 kB). View file
 
my_utils/__pycache__/weighted_acc_rmse.cpython-39.pyc ADDED
Binary file (6 kB). View file
 
my_utils/data_loader.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import glob
3
+ import torch
4
+ import random
5
+ import numpy as np
6
+ from torch.utils.data import DataLoader, Dataset
7
+ from torch.utils.data.distributed import DistributedSampler
8
+ from torch import Tensor
9
+ import h5py
10
+ import math
11
+ from my_utils.norm import reshape_fields
12
+ import os
13
+
14
+
15
+ current_dir = os.path.dirname(os.path.abspath(__file__))
16
+ parent_dir = os.path.dirname(current_dir)
17
+ climate_mean_path = os.path.join(parent_dir, 'data/climate_mean_s_t_ssh.npy')
18
+
19
+ def get_data_loader(params, files_pattern, distributed, train):
20
+ dataset = GetDataset(params, files_pattern, train)
21
+ sampler = DistributedSampler(dataset, shuffle=train) if distributed else None
22
+
23
+
24
+ dataloader = DataLoader(dataset,
25
+ batch_size = int(params.batch_size),
26
+ num_workers = params.num_data_workers,
27
+ shuffle = False,
28
+ sampler = sampler if train else None,
29
+ drop_last = True,
30
+ pin_memory = True)
31
+
32
+ if train:
33
+ return dataloader, dataset, sampler
34
+ else:
35
+ return dataloader, dataset
36
+
37
+
38
+ class GetDataset(Dataset):
39
+ def __init__(self, params, location, train):
40
+ self.params = params
41
+ self.location = location
42
+ self.train = train
43
+ self.orography = params.orography
44
+ self.normalize = params.normalize
45
+ self.dt = params.dt
46
+ self.n_history = params.n_history
47
+ self.in_channels = np.array(params.in_channels)
48
+ self.out_channels = np.array(params.out_channels)
49
+ self.ocean_channels = np.array(params.ocean_channels)
50
+ self.atmos_channels = np.array(params.atmos_channels)
51
+ self.n_in_channels = len(self.in_channels)
52
+ self.n_out_channels = len(self.out_channels)
53
+
54
+ self._get_files_stats()
55
+ self.add_noise = params.add_noise if train else False
56
+ self.climate_mean = np.load(climate_mean_path, mmap_mode='r')
57
+
58
+
59
+ def _get_files_stats(self):
60
+ self.files_paths = glob.glob(self.location + "/*.h5")
61
+ self.files_paths.sort()
62
+ self.n_years = len(self.files_paths)
63
+
64
+ with h5py.File(self.files_paths[0], 'r') as _f:
65
+ logging.info("Getting file stats from {}".format(self.files_paths[0]))
66
+
67
+ self.n_samples_per_year = _f['fields'].shape[0] - self.params.multi_steps_finetune
68
+
69
+ self.img_shape_x = _f['fields'].shape[2] - 1
70
+ self.img_shape_y = _f['fields'].shape[3]
71
+
72
+ self.n_samples_total = self.n_years * self.n_samples_per_year
73
+ self.files = [None for _ in range(self.n_years)]
74
+
75
+ logging.info("Number of samples per year: {}".format(self.n_samples_per_year))
76
+ logging.info("Found data at path {}. Number of examples: {}. Image Shape: {} x {} x {}".format(self.location,
77
+ self.n_samples_total,
78
+ self.img_shape_x,
79
+ self.img_shape_y,
80
+ self.n_in_channels))
81
+ logging.info("Delta t: {} days".format(1 * self.dt))
82
+ logging.info("Including {} days of past history in training at a frequency of {} days".format(
83
+ 1 * self.dt * self.n_history, 1 * self.dt))
84
+
85
+ def _open_file(self, year_idx):
86
+ _file = h5py.File(self.files_paths[year_idx], 'r')
87
+ self.files[year_idx] = _file['fields']
88
+
89
+ if self.orography and self.params.normalization == 'zscore':
90
+ _orog_file = h5py.File(self.params.orography_norm_zscore_path, 'r')
91
+ if self.orography and self.params.normalization == 'maxmin':
92
+ _orog_file = h5py.File(self.params.orography_norm_maxmin_path, 'r')
93
+
94
+ def __len__(self):
95
+ return self.n_samples_total
96
+
97
+ def __getitem__(self, global_idx):
98
+ year_idx = int(global_idx / self.n_samples_per_year) # which year
99
+ local_idx = int(global_idx % self.n_samples_per_year) # which sample in a year
100
+
101
+ if self.files[year_idx] is None:
102
+ self._open_file(year_idx)
103
+
104
+ if local_idx < self.dt * self.n_history:
105
+ local_idx += self.dt * self.n_history
106
+
107
+ step = 0 if local_idx >= self.n_samples_per_year - self.dt else self.dt
108
+
109
+ orog = None
110
+
111
+
112
+ if self.params.multi_steps_finetune == 1:
113
+ if local_idx == 365:
114
+ local_idx = 364
115
+
116
+ climate_mean_ocean = self.climate_mean[(local_idx-self.dt*self.n_history):(local_idx+1):self.dt, self.ocean_channels, :360, :720]
117
+ ocean = reshape_fields(
118
+ self.files[year_idx][(local_idx-self.dt*self.n_history):(local_idx+1):self.dt, self.ocean_channels, :360, :720] - climate_mean_ocean,
119
+ 'ocean',
120
+ self.params,
121
+ self.train,
122
+ self.normalize,
123
+ orog,
124
+ self.add_noise
125
+ )
126
+
127
+ force_future0 = reshape_fields(
128
+ self.files[year_idx][local_idx, self.atmos_channels, :360, :720],
129
+ 'force',
130
+ self.params,
131
+ self.train,
132
+ self.normalize,
133
+ orog,
134
+ self.add_noise
135
+ )
136
+
137
+ force_future1 = reshape_fields(
138
+ self.files[year_idx][local_idx+step, self.atmos_channels, :360, :720],
139
+ 'force',
140
+ self.params,
141
+ self.train,
142
+ self.normalize,
143
+ orog,
144
+ self.add_noise
145
+ )
146
+
147
+ climate_mean_tar = self.climate_mean[local_idx+step, self.out_channels, :360, :720]
148
+ tar = reshape_fields(
149
+ self.files[year_idx][local_idx+step, self.out_channels, :360, :720] - climate_mean_tar,
150
+ 'tar',
151
+ self.params,
152
+ self.train,
153
+ self.normalize,
154
+ orog
155
+ )
156
+ else:
157
+ climate_mean_ocean = self.climate_mean[(local_idx-self.dt*self.n_history):(local_idx+1):self.dt, self.ocean_channels, :360, :720]
158
+ ocean = reshape_fields(
159
+ self.files[year_idx][(local_idx-self.dt*self.n_history):(local_idx+1):self.dt, self.ocean_channels, :360, :720] - climate_mean_ocean,
160
+ 'ocean',
161
+ self.params,
162
+ self.train,
163
+ self.normalize,
164
+ orog,
165
+ self.add_noise
166
+ )
167
+
168
+ force_future0 = reshape_fields(
169
+ self.files[year_idx][local_idx, self.atmos_channels, :360, :720],
170
+ 'force',
171
+ self.params,
172
+ self.train,
173
+ self.normalize,
174
+ orog,
175
+ self.add_noise
176
+ )
177
+
178
+ force_future1 = reshape_fields(
179
+ self.files[year_idx][local_idx+step, self.atmos_channels, :360, :720],
180
+ 'force',
181
+ self.params,
182
+ self.train,
183
+ self.normalize,
184
+ orog,
185
+ self.add_noise
186
+ )
187
+
188
+ climate_mean_tar = self.climate_mean[local_idx+step:local_idx+step+self.params.multi_steps_finetune, self.in_channels, :360, :720]
189
+ tar_data = self.files[year_idx][local_idx+step:local_idx+step+self.params.multi_steps_finetune, self.in_channels, :360, :720]
190
+ tar = reshape_fields(
191
+ tar_data - climate_mean_tar,
192
+ 'inp',
193
+ self.params,
194
+ self.train,
195
+ self.normalize,
196
+ orog
197
+ )
198
+
199
+ ocean = np.nan_to_num(ocean, nan=0)
200
+ force_future0 = np.nan_to_num(force_future0, nan=0)
201
+ force_future1 = np.nan_to_num(force_future1, nan=0)
202
+ tar = np.nan_to_num(tar, nan=0)
203
+
204
+
205
+ return np.concatenate((ocean, force_future0, force_future1), axis=0), tar
my_utils/logging_utils.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+
4
+ _format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
5
+
6
+ def config_logger(log_level=logging.INFO):
7
+ logging.basicConfig(format=_format, level=log_level)
8
+
9
+ def log_to_file(logger_name=None, log_level=logging.INFO, log_filename='tensorflow.log'):
10
+
11
+ if not os.path.exists(os.path.dirname(log_filename)):
12
+ os.makedirs(os.path.dirname(log_filename))
13
+
14
+ if logger_name is not None:
15
+ log = logging.getLogger(logger_name)
16
+ else:
17
+ log = logging.getLogger()
18
+
19
+ fh = logging.FileHandler(log_filename)
20
+ fh.setLevel(log_level)
21
+ fh.setFormatter(logging.Formatter(_format))
22
+ log.addHandler(fh)
23
+
24
+ def log_versions():
25
+ import torch
26
+ import subprocess
my_utils/norm.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import glob
3
+ from types import new_class
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import random
8
+ import numpy as np
9
+ import torch
10
+ from torch.utils.data import DataLoader, Dataset
11
+ from torch.utils.data.distributed import DistributedSampler
12
+ from torch import Tensor
13
+ import h5py
14
+ import math
15
+ import torchvision.transforms.functional as TF
16
+ # import matplotlib
17
+ # import matplotlib.pyplot as plt
18
+
19
+ class PeriodicPad2d(nn.Module):
20
+ """
21
+ pad longitudinal (left-right) circular
22
+ and pad latitude (top-bottom) with zeros
23
+ """
24
+ def __init__(self, pad_width):
25
+ super(PeriodicPad2d, self).__init__()
26
+ self.pad_width = pad_width
27
+
28
+ def forward(self, x):
29
+ # pad left and right circular
30
+ out = F.pad(x, (self.pad_width, self.pad_width, 0, 0), mode="circular")
31
+ # pad top and bottom zeros
32
+ out = F.pad(out, (0, 0, self.pad_width, self.pad_width), mode="constant", value=0)
33
+ return out
34
+
35
+ def reshape_fields(img, inp_or_tar, params, train, normalize=True, orog=None, add_noise=False):
36
+ # Takes in np array of size (n_history+1, c, h, w)
37
+ # returns torch tensor of size ((n_channels*(n_history+1), crop_size_x, crop_size_y)
38
+
39
+ if len(np.shape(img)) == 3:
40
+ img = np.expand_dims(img, 0)
41
+
42
+ if np.shape(img)[2] == 721:
43
+ img = img[:,:, 0:720, :] # remove last pixel
44
+
45
+ n_history = np.shape(img)[0] - 1
46
+ img_shape_x = np.shape(img)[-2]
47
+ img_shape_y = np.shape(img)[-1]
48
+ n_channels = np.shape(img)[1] # this will either be N_in_channels or N_out_channels
49
+
50
+ if inp_or_tar == 'inp':
51
+ channels = params.in_channels
52
+ elif inp_or_tar == 'ocean':
53
+ channels = params.ocean_channels
54
+ elif inp_or_tar == 'force':
55
+ channels = params.atmos_channels
56
+ else:
57
+ channels = params.out_channels
58
+
59
+ if normalize and params.normalization == 'minmax':
60
+ maxs = np.load(params.global_maxs_path)[:, channels]
61
+ mins = np.load(params.global_mins_path)[:, channels]
62
+ img = (img - mins) / (maxs - mins)
63
+
64
+ if normalize and params.normalization == 'zscore':
65
+ means = np.load(params.global_means_path)[:, channels]
66
+ stds = np.load(params.global_stds_path)[:, channels]
67
+ img -=means
68
+ img /=stds
69
+
70
+ if normalize and params.normalization == 'zscore_lat':
71
+ means = np.load(params.global_lat_means_path)[:, channels,:720]
72
+ stds = np.load(params.global_lat_stds_path)[:, channels,:720]
73
+ img -=means
74
+ img /=stds
75
+
76
+ if params.orography and inp_or_tar == 'inp':
77
+ # print('img:', img.shape, 'orog:', orog.shape)
78
+ orog = np.expand_dims(orog, axis = (0,1))
79
+ orog = np.repeat(orog, repeats=img.shape[0], axis=0)
80
+ # print('img:', img.shape, 'orog:', orog.shape)
81
+ img = np.concatenate((img, orog), axis = 1)
82
+ n_channels += 1
83
+
84
+ img = np.squeeze(img)
85
+ # if inp_or_tar == 'inp':
86
+ # img = np.reshape(img, (n_channels*(n_history+1))) # ??
87
+ # elif inp_or_tar == 'tar':
88
+ # img = np.reshape(img, (n_channels, crop_size_x, crop_size_y)) #??
89
+
90
+ if add_noise:
91
+ img = img + np.random.normal(0, scale=params.noise_std, size=img.shape)
92
+
93
+ return torch.as_tensor(img)
94
+
95
+ def vis_precip(fields):
96
+ pred, tar = fields
97
+ fig, ax = plt.subplots(1, 2, figsize=(24,12))
98
+ ax[0].imshow(pred, cmap="coolwarm")
99
+ ax[0].set_title("tp pred")
100
+ ax[1].imshow(tar, cmap="coolwarm")
101
+ ax[1].set_title("tp tar")
102
+ fig.tight_layout()
103
+ return fig
104
+
105
+ def read_max_min_value(min_max_val_file_path):
106
+ with h5py.File(min_max_val_file_path, 'r') as f:
107
+ max_values = f['max_values']
108
+ min_values = f['min_values']
109
+ return max_values, min_values
110
+
111
+
112
+
113
+
114
+
networks/.ipynb_checkpoints/CirT1-checkpoint.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import lru_cache
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from timm.models.vision_transformer import trunc_normal_, Block
7
+ from torch.jit import Final
8
+
9
+ import torch.nn.functional as F
10
+ from typing import Optional
11
+ from timm.layers import DropPath, use_fused_attn, Mlp
12
+
13
+ class PatchEmbed(nn.Module):
14
+ def __init__(
15
+ self,
16
+ img_size=[121, 240],
17
+ in_chans=63,
18
+ embed_dim=768,
19
+ norm_layer=None,
20
+ flatten=True,
21
+ bias=True,
22
+ ):
23
+ super().__init__()
24
+ self.img_size = img_size
25
+ self.num_patches = img_size[0]
26
+ self.flatten = flatten
27
+
28
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=[1, img_size[1]], stride=1, bias=bias)
29
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
30
+
31
+ def forward(self, x):
32
+ B, C, H, W = x.shape
33
+ x = self.proj(x)
34
+ if self.flatten:
35
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
36
+ x = self.norm(x)
37
+ return x
38
+
39
+ class Attention(nn.Module):
40
+ fused_attn: Final[bool]
41
+
42
+ def __init__(
43
+ self,
44
+ dim: int,
45
+ num_heads: int = 8,
46
+ qkv_bias: bool = False,
47
+ qk_norm: bool = False,
48
+ attn_drop: float = 0.,
49
+ proj_drop: float = 0.,
50
+ norm_layer: nn.Module = nn.LayerNorm,
51
+ ) -> None:
52
+ super().__init__()
53
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
54
+ self.num_heads = num_heads
55
+ self.head_dim = dim // self.num_heads
56
+ self.scale = self.head_dim ** -0.5
57
+ self.fused_attn = use_fused_attn()
58
+ self.dim = dim
59
+ self.attn_bias = nn.Parameter(torch.zeros(121, 121, 2), requires_grad=True)
60
+
61
+ # self.qkv = CLinear(dim, dim * 3, bias=qkv_bias)
62
+ self.qkv = nn.Linear((dim // 2 + 1) * 2, dim * 3, bias=qkv_bias)
63
+ self.q = nn.Linear((dim // 2 + 1) * 2, dim, bias=qkv_bias)
64
+ self.k = nn.Linear((dim // 2 + 1) * 2, dim, bias=qkv_bias)
65
+ self.v = nn.Linear((dim // 2 + 1) * 2, dim, bias=qkv_bias)
66
+
67
+
68
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
69
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
70
+ self.attn_drop = nn.Dropout(attn_drop)
71
+ # self.proj = CLinear(dim, dim)
72
+ self.proj_drop = nn.Dropout(proj_drop)
73
+ self.proj = nn.Linear(dim, dim)
74
+ # self.proj_drop = nn.Dropout(proj_drop)
75
+
76
+
77
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
78
+ B, N, C = x.shape
79
+
80
+ x = torch.fft.rfft(x, norm="forward")
81
+ x = torch.view_as_real(x)
82
+ x = torch.cat((x[:, :, :, 0], -x[:, :, :, 1]), dim=-1)
83
+
84
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
85
+ # q, k, v = qkv.unbind(0)
86
+ q = self.q(x).reshape(B, self.num_heads, N, self.head_dim)
87
+ k = self.k(x).reshape(B, self.num_heads, N, self.head_dim)
88
+ v = self.v(x).reshape(B, self.num_heads, N, self.head_dim)
89
+ q = q * self.scale
90
+ attn = q @ k.transpose(-2, -1)
91
+
92
+
93
+ attn = attn.softmax(dim=-1)
94
+ attn = self.attn_drop(attn)
95
+ x = attn @ v
96
+ x = x.transpose(1, 2).reshape(B, N, C)
97
+ x = self.proj(x)
98
+ x = self.proj_drop(x)
99
+
100
+ real, img = torch.split(x, x.shape[-1] // 2, dim=-1)
101
+ x = torch.stack([real,-img], dim=-1)
102
+ x = torch.view_as_complex(x)
103
+ x = torch.fft.irfft(x, self.dim, norm="forward")
104
+
105
+ return x
106
+
107
+
108
+ class LayerScale(nn.Module):
109
+ def __init__(
110
+ self,
111
+ dim: int,
112
+ init_values: float = 1e-5,
113
+ inplace: bool = False,
114
+ ) -> None:
115
+ super().__init__()
116
+ self.inplace = inplace
117
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
118
+
119
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
120
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
121
+
122
+
123
+
124
+ class Block(nn.Module):
125
+ def __init__(
126
+ self,
127
+ dim: int,
128
+ num_heads: int,
129
+ mlp_ratio: float = 4.,
130
+ qkv_bias: bool = False,
131
+ qk_norm: bool = False,
132
+ proj_drop: float = 0.,
133
+ attn_drop: float = 0.,
134
+ init_values: Optional[float] = None,
135
+ drop_path: float = 0.,
136
+ act_layer: nn.Module = nn.GELU,
137
+ norm_layer: nn.Module = nn.LayerNorm,
138
+ mlp_layer: nn.Module = Mlp,
139
+ ) -> None:
140
+ super().__init__()
141
+ self.norm1 = norm_layer(dim)
142
+ self.attn = Attention(
143
+ dim,
144
+ num_heads=num_heads,
145
+ qkv_bias=qkv_bias,
146
+ qk_norm=qk_norm,
147
+ attn_drop=attn_drop,
148
+ proj_drop=proj_drop,
149
+ norm_layer=norm_layer,
150
+ )
151
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
152
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
153
+
154
+ self.norm2 = norm_layer(dim)
155
+ self.mlp = mlp_layer(
156
+ in_features=dim,
157
+ hidden_features=int(dim * mlp_ratio),
158
+ act_layer=act_layer,
159
+ drop=proj_drop,
160
+ )
161
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
162
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
163
+
164
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
165
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
166
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
167
+ return x
168
+
169
+
170
+ class CirT(nn.Module):
171
+ def __init__(
172
+ self,
173
+ params,
174
+ img_size=[360, 720],
175
+ input_size=101,
176
+ output_size=93,
177
+ patch_size=124, #124
178
+ embed_dim=256,
179
+ depth=8,
180
+ decoder_depth=2,
181
+ num_heads=16,
182
+ mlp_ratio=4.0,
183
+ drop_path=0.1,
184
+ drop_rate=0.1
185
+ ):
186
+ super().__init__()
187
+
188
+ # TODO: remove time_history parameter
189
+ self.img_size = img_size
190
+ self.patch_size = img_size[1]
191
+ self.input_size = input_size
192
+ self.output_size = output_size
193
+ self.token_embeds = PatchEmbed(img_size, input_size, embed_dim)
194
+ # self.token_embeds = nn.Linear(img_size[0] * 2, embed_dim)
195
+ self.num_patches = self.token_embeds.num_patches
196
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim), requires_grad=True)
197
+ # self.pos_embed = PosEmbed(embed_dim=embed_dim)
198
+
199
+
200
+ # --------------------------------------------------------------------------
201
+
202
+ # ViT backbone
203
+ self.pos_drop = nn.Dropout(p=drop_rate)
204
+ dpr = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
205
+ self.blocks = nn.ModuleList(
206
+ [
207
+ Block(
208
+ embed_dim,
209
+ num_heads,
210
+ mlp_ratio,
211
+ qkv_bias=True,
212
+ drop_path=dpr[i],
213
+ norm_layer=nn.LayerNorm,
214
+ # drop=drop_rate,
215
+ )
216
+ for i in range(depth)
217
+ ]
218
+ )
219
+ self.norm = nn.LayerNorm(embed_dim)
220
+
221
+ # --------------------------------------------------------------------------
222
+
223
+ # prediction head
224
+ self.head = nn.ModuleList()
225
+ for _ in range(decoder_depth):
226
+ self.head.append(nn.Linear(embed_dim, embed_dim))
227
+ self.head.append(nn.GELU())
228
+ self.head.append(nn.Linear(embed_dim, output_size * self.img_size[1]))
229
+ self.head = nn.Sequential(*self.head)
230
+
231
+ # --------------------------------------------------------------------------
232
+
233
+ self.initialize_weights()
234
+
235
+ def initialize_weights(self):
236
+ # token embedding layer
237
+ w = self.token_embeds.proj.weight.data
238
+ trunc_normal_(w.view([w.shape[0], -1]), std=0.02)
239
+
240
+ # initialize nn.Linear and nn.LayerNorm
241
+ self.apply(self._init_weights)
242
+
243
+ def _init_weights(self, m):
244
+ if isinstance(m, nn.Linear):
245
+ trunc_normal_(m.weight, std=0.02)
246
+ if m.bias is not None:
247
+ nn.init.constant_(m.bias, 0)
248
+ elif isinstance(m, nn.LayerNorm):
249
+ nn.init.constant_(m.bias, 0)
250
+ nn.init.constant_(m.weight, 1.0)
251
+
252
+
253
+ def unpatchify(self, x: torch.Tensor, h=None, w=None):
254
+ """
255
+ x: (B, L, V * patch_size)
256
+ return imgs: (B, V, H, W)
257
+ """
258
+ p = self.patch_size
259
+ c_out = self.output_size
260
+ h = self.img_size[0] // 1
261
+ w = self.img_size[1] // p
262
+ assert h * w == x.shape[1]
263
+
264
+ x = x.reshape(shape=(x.shape[0], h, w, p, c_out))
265
+ x = torch.einsum("nhwpc->nchpw", x)
266
+ imgs = x.reshape(shape=(x.shape[0], c_out, h, w * p))
267
+ return imgs
268
+
269
+ def forward_encoder(self, x: torch.Tensor):
270
+ # x: `[B, V, H, W]` shape.
271
+
272
+ # tokenize each variable separately
273
+ # x = torch.fft.rfft(x, norm="forward")
274
+ # x = torch.view_as_real(x)
275
+ # x = torch.cat((x[:, :, :, :, 0], -x[:, :, :, :, 1]), dim=-1)
276
+
277
+ x = self.token_embeds(x)
278
+
279
+ # pos_embed = self.pos_embed()
280
+ # add pos embedding
281
+ x = x + self.pos_embed
282
+ x = self.pos_drop(x)
283
+
284
+ # apply Transformer blocks
285
+ for blk in self.blocks:
286
+ x = blk(x)
287
+ x = self.norm(x)
288
+
289
+ return x
290
+
291
+ def forward(self, x):
292
+ B, V, H, W = x.shape
293
+ # print(x.shape)
294
+ out_transformers = self.forward_encoder(x) # B, L, D
295
+ preds = self.head(out_transformers) # B, L, V*p*p
296
+ preds = self.unpatchify(preds)
297
+
298
+ # real, img = torch.split(preds, preds.shape[-1] // 2, dim=-1)
299
+ # preds = torch.cat([real, -img], dim=-1)
300
+ # preds = torch.fft.irfft(preds, W, norm="forward")
301
+ return preds
networks/.ipynb_checkpoints/CirT2-checkpoint.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import lru_cache
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from timm.models.vision_transformer import trunc_normal_, Block
7
+ from torch.jit import Final
8
+
9
+ import torch.nn.functional as F
10
+ from typing import Optional
11
+ from timm.layers import DropPath, use_fused_attn, Mlp
12
+
13
+ class PatchEmbed(nn.Module):
14
+ def __init__(
15
+ self,
16
+ img_size=[121, 240],
17
+ in_chans=63,
18
+ embed_dim=768,
19
+ norm_layer=None,
20
+ flatten=True,
21
+ bias=True,
22
+ ):
23
+ super().__init__()
24
+ self.img_size = img_size
25
+ self.num_patches = img_size[0]
26
+ self.flatten = flatten
27
+
28
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=[1, img_size[1]], stride=1, bias=bias)
29
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
30
+
31
+ def forward(self, x):
32
+ B, C, H, W = x.shape
33
+ x = self.proj(x)
34
+ if self.flatten:
35
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
36
+ x = self.norm(x)
37
+ return x
38
+
39
+ class Attention(nn.Module):
40
+ fused_attn: Final[bool]
41
+
42
+ def __init__(
43
+ self,
44
+ dim: int,
45
+ num_heads: int = 8,
46
+ qkv_bias: bool = False,
47
+ qk_norm: bool = False,
48
+ attn_drop: float = 0.,
49
+ proj_drop: float = 0.,
50
+ norm_layer: nn.Module = nn.LayerNorm,
51
+ ) -> None:
52
+ super().__init__()
53
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
54
+ self.num_heads = num_heads
55
+ self.head_dim = dim // self.num_heads
56
+ self.scale = self.head_dim ** -0.5
57
+ self.fused_attn = use_fused_attn()
58
+ self.dim = dim
59
+ self.attn_bias = nn.Parameter(torch.zeros(121, 121, 2), requires_grad=True)
60
+
61
+ # self.qkv = CLinear(dim, dim * 3, bias=qkv_bias)
62
+ self.qkv = nn.Linear((dim // 2 + 1) * 2, dim * 3, bias=qkv_bias)
63
+ self.q = nn.Linear((dim // 2 + 1) * 2, dim, bias=qkv_bias)
64
+ self.k = nn.Linear((dim // 2 + 1) * 2, dim, bias=qkv_bias)
65
+ self.v = nn.Linear((dim // 2 + 1) * 2, dim, bias=qkv_bias)
66
+
67
+
68
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
69
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
70
+ self.attn_drop = nn.Dropout(attn_drop)
71
+ # self.proj = CLinear(dim, dim)
72
+ self.proj_drop = nn.Dropout(proj_drop)
73
+ self.proj = nn.Linear(dim, dim)
74
+ # self.proj_drop = nn.Dropout(proj_drop)
75
+
76
+
77
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
78
+ B, N, C = x.shape
79
+
80
+ x = torch.fft.rfft(x, norm="forward")
81
+ x = torch.view_as_real(x)
82
+ x = torch.cat((x[:, :, :, 0], -x[:, :, :, 1]), dim=-1)
83
+
84
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
85
+ # q, k, v = qkv.unbind(0)
86
+ q = self.q(x).reshape(B, self.num_heads, N, self.head_dim)
87
+ k = self.k(x).reshape(B, self.num_heads, N, self.head_dim)
88
+ v = self.v(x).reshape(B, self.num_heads, N, self.head_dim)
89
+ q = q * self.scale
90
+ attn = q @ k.transpose(-2, -1)
91
+
92
+
93
+ attn = attn.softmax(dim=-1)
94
+ attn = self.attn_drop(attn)
95
+ x = attn @ v
96
+ x = x.transpose(1, 2).reshape(B, N, C)
97
+ x = self.proj(x)
98
+ x = self.proj_drop(x)
99
+
100
+ real, img = torch.split(x, x.shape[-1] // 2, dim=-1)
101
+ x = torch.stack([real,-img], dim=-1)
102
+ x = torch.view_as_complex(x)
103
+ x = torch.fft.irfft(x, self.dim, norm="forward")
104
+
105
+ return x
106
+
107
+
108
+ class LayerScale(nn.Module):
109
+ def __init__(
110
+ self,
111
+ dim: int,
112
+ init_values: float = 1e-5,
113
+ inplace: bool = False,
114
+ ) -> None:
115
+ super().__init__()
116
+ self.inplace = inplace
117
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
118
+
119
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
120
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
121
+
122
+
123
+
124
+ class Block(nn.Module):
125
+ def __init__(
126
+ self,
127
+ dim: int,
128
+ num_heads: int,
129
+ mlp_ratio: float = 4.,
130
+ qkv_bias: bool = False,
131
+ qk_norm: bool = False,
132
+ proj_drop: float = 0.,
133
+ attn_drop: float = 0.,
134
+ init_values: Optional[float] = None,
135
+ drop_path: float = 0.,
136
+ act_layer: nn.Module = nn.GELU,
137
+ norm_layer: nn.Module = nn.LayerNorm,
138
+ mlp_layer: nn.Module = Mlp,
139
+ ) -> None:
140
+ super().__init__()
141
+ self.norm1 = norm_layer(dim)
142
+ self.attn = Attention(
143
+ dim,
144
+ num_heads=num_heads,
145
+ qkv_bias=qkv_bias,
146
+ qk_norm=qk_norm,
147
+ attn_drop=attn_drop,
148
+ proj_drop=proj_drop,
149
+ norm_layer=norm_layer,
150
+ )
151
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
152
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
153
+
154
+ self.norm2 = norm_layer(dim)
155
+ self.mlp = mlp_layer(
156
+ in_features=dim,
157
+ hidden_features=int(dim * mlp_ratio),
158
+ act_layer=act_layer,
159
+ drop=proj_drop,
160
+ )
161
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
162
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
163
+
164
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
165
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
166
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
167
+ return x
168
+
169
+
170
+ class CirT_stage2(nn.Module):
171
+ def __init__(
172
+ self,
173
+ params,
174
+ img_size=[360, 720],
175
+ input_size=93,
176
+ output_size=93,
177
+ patch_size=124, #124
178
+ embed_dim=256,
179
+ depth=8,
180
+ decoder_depth=2,
181
+ num_heads=16,
182
+ mlp_ratio=4.0,
183
+ drop_path=0.1,
184
+ drop_rate=0.1
185
+ ):
186
+ super().__init__()
187
+
188
+ # TODO: remove time_history parameter
189
+ self.img_size = img_size
190
+ self.patch_size = img_size[1]
191
+ self.input_size = input_size
192
+ self.output_size = output_size
193
+ self.token_embeds = PatchEmbed(img_size, input_size, embed_dim)
194
+ # self.token_embeds = nn.Linear(img_size[0] * 2, embed_dim)
195
+ self.num_patches = self.token_embeds.num_patches
196
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim), requires_grad=True)
197
+ # self.pos_embed = PosEmbed(embed_dim=embed_dim)
198
+
199
+
200
+ # --------------------------------------------------------------------------
201
+
202
+ # ViT backbone
203
+ self.pos_drop = nn.Dropout(p=drop_rate)
204
+ dpr = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
205
+ self.blocks = nn.ModuleList(
206
+ [
207
+ Block(
208
+ embed_dim,
209
+ num_heads,
210
+ mlp_ratio,
211
+ qkv_bias=True,
212
+ drop_path=dpr[i],
213
+ norm_layer=nn.LayerNorm,
214
+ # drop=drop_rate,
215
+ )
216
+ for i in range(depth)
217
+ ]
218
+ )
219
+ self.norm = nn.LayerNorm(embed_dim)
220
+
221
+ # --------------------------------------------------------------------------
222
+
223
+ # prediction head
224
+ self.head = nn.ModuleList()
225
+ for _ in range(decoder_depth):
226
+ self.head.append(nn.Linear(embed_dim, embed_dim))
227
+ self.head.append(nn.GELU())
228
+ self.head.append(nn.Linear(embed_dim, output_size * self.img_size[1]))
229
+ self.head = nn.Sequential(*self.head)
230
+
231
+ # --------------------------------------------------------------------------
232
+
233
+ self.initialize_weights()
234
+
235
+ def initialize_weights(self):
236
+ # token embedding layer
237
+ w = self.token_embeds.proj.weight.data
238
+ trunc_normal_(w.view([w.shape[0], -1]), std=0.02)
239
+
240
+ # initialize nn.Linear and nn.LayerNorm
241
+ self.apply(self._init_weights)
242
+
243
+ def _init_weights(self, m):
244
+ if isinstance(m, nn.Linear):
245
+ trunc_normal_(m.weight, std=0.02)
246
+ if m.bias is not None:
247
+ nn.init.constant_(m.bias, 0)
248
+ elif isinstance(m, nn.LayerNorm):
249
+ nn.init.constant_(m.bias, 0)
250
+ nn.init.constant_(m.weight, 1.0)
251
+
252
+
253
+ def unpatchify(self, x: torch.Tensor, h=None, w=None):
254
+ """
255
+ x: (B, L, V * patch_size)
256
+ return imgs: (B, V, H, W)
257
+ """
258
+ p = self.patch_size
259
+ c_out = self.output_size
260
+ h = self.img_size[0] // 1
261
+ w = self.img_size[1] // p
262
+ assert h * w == x.shape[1]
263
+
264
+ x = x.reshape(shape=(x.shape[0], h, w, p, c_out))
265
+ x = torch.einsum("nhwpc->nchpw", x)
266
+ imgs = x.reshape(shape=(x.shape[0], c_out, h, w * p))
267
+ return imgs
268
+
269
+ def forward_encoder(self, x: torch.Tensor):
270
+ # x: `[B, V, H, W]` shape.
271
+
272
+ # tokenize each variable separately
273
+ # x = torch.fft.rfft(x, norm="forward")
274
+ # x = torch.view_as_real(x)
275
+ # x = torch.cat((x[:, :, :, :, 0], -x[:, :, :, :, 1]), dim=-1)
276
+
277
+ x = self.token_embeds(x)
278
+
279
+ # pos_embed = self.pos_embed()
280
+ # add pos embedding
281
+ x = x + self.pos_embed
282
+ x = self.pos_drop(x)
283
+
284
+ # apply Transformer blocks
285
+ for blk in self.blocks:
286
+ x = blk(x)
287
+ x = self.norm(x)
288
+
289
+ return x
290
+
291
+ def forward(self, x):
292
+ B, V, H, W = x.shape
293
+ # print(x.shape)
294
+ out_transformers = self.forward_encoder(x) # B, L, D
295
+ preds = self.head(out_transformers) # B, L, V*p*p
296
+ preds = self.unpatchify(preds)
297
+
298
+ # real, img = torch.split(preds, preds.shape[-1] // 2, dim=-1)
299
+ # preds = torch.cat([real, -img], dim=-1)
300
+ # preds = torch.fft.irfft(preds, W, norm="forward")
301
+ return preds
networks/CirT1.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import lru_cache
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from timm.models.vision_transformer import trunc_normal_, Block
7
+ from torch.jit import Final
8
+
9
+ import torch.nn.functional as F
10
+ from typing import Optional
11
+ from timm.layers import DropPath, use_fused_attn, Mlp
12
+
13
+ class PatchEmbed(nn.Module):
14
+ def __init__(
15
+ self,
16
+ img_size=[121, 240],
17
+ in_chans=63,
18
+ embed_dim=768,
19
+ norm_layer=None,
20
+ flatten=True,
21
+ bias=True,
22
+ ):
23
+ super().__init__()
24
+ self.img_size = img_size
25
+ self.num_patches = img_size[0]
26
+ self.flatten = flatten
27
+
28
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=[1, img_size[1]], stride=1, bias=bias)
29
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
30
+
31
+ def forward(self, x):
32
+ B, C, H, W = x.shape
33
+ x = self.proj(x)
34
+ if self.flatten:
35
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
36
+ x = self.norm(x)
37
+ return x
38
+
39
+ class Attention(nn.Module):
40
+ fused_attn: Final[bool]
41
+
42
+ def __init__(
43
+ self,
44
+ dim: int,
45
+ num_heads: int = 8,
46
+ qkv_bias: bool = False,
47
+ qk_norm: bool = False,
48
+ attn_drop: float = 0.,
49
+ proj_drop: float = 0.,
50
+ norm_layer: nn.Module = nn.LayerNorm,
51
+ ) -> None:
52
+ super().__init__()
53
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
54
+ self.num_heads = num_heads
55
+ self.head_dim = dim // self.num_heads
56
+ self.scale = self.head_dim ** -0.5
57
+ self.fused_attn = use_fused_attn()
58
+ self.dim = dim
59
+ self.attn_bias = nn.Parameter(torch.zeros(121, 121, 2), requires_grad=True)
60
+
61
+ # self.qkv = CLinear(dim, dim * 3, bias=qkv_bias)
62
+ self.qkv = nn.Linear((dim // 2 + 1) * 2, dim * 3, bias=qkv_bias)
63
+ self.q = nn.Linear((dim // 2 + 1) * 2, dim, bias=qkv_bias)
64
+ self.k = nn.Linear((dim // 2 + 1) * 2, dim, bias=qkv_bias)
65
+ self.v = nn.Linear((dim // 2 + 1) * 2, dim, bias=qkv_bias)
66
+
67
+
68
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
69
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
70
+ self.attn_drop = nn.Dropout(attn_drop)
71
+ # self.proj = CLinear(dim, dim)
72
+ self.proj_drop = nn.Dropout(proj_drop)
73
+ self.proj = nn.Linear(dim, dim)
74
+ # self.proj_drop = nn.Dropout(proj_drop)
75
+
76
+
77
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
78
+ B, N, C = x.shape
79
+
80
+ x = torch.fft.rfft(x, norm="forward")
81
+ x = torch.view_as_real(x)
82
+ x = torch.cat((x[:, :, :, 0], -x[:, :, :, 1]), dim=-1)
83
+
84
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
85
+ # q, k, v = qkv.unbind(0)
86
+ q = self.q(x).reshape(B, self.num_heads, N, self.head_dim)
87
+ k = self.k(x).reshape(B, self.num_heads, N, self.head_dim)
88
+ v = self.v(x).reshape(B, self.num_heads, N, self.head_dim)
89
+ q = q * self.scale
90
+ attn = q @ k.transpose(-2, -1)
91
+
92
+
93
+ attn = attn.softmax(dim=-1)
94
+ attn = self.attn_drop(attn)
95
+ x = attn @ v
96
+ x = x.transpose(1, 2).reshape(B, N, C)
97
+ x = self.proj(x)
98
+ x = self.proj_drop(x)
99
+
100
+ real, img = torch.split(x, x.shape[-1] // 2, dim=-1)
101
+ x = torch.stack([real,-img], dim=-1)
102
+ x = torch.view_as_complex(x)
103
+ x = torch.fft.irfft(x, self.dim, norm="forward")
104
+
105
+ return x
106
+
107
+
108
+ class LayerScale(nn.Module):
109
+ def __init__(
110
+ self,
111
+ dim: int,
112
+ init_values: float = 1e-5,
113
+ inplace: bool = False,
114
+ ) -> None:
115
+ super().__init__()
116
+ self.inplace = inplace
117
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
118
+
119
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
120
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
121
+
122
+
123
+
124
+ class Block(nn.Module):
125
+ def __init__(
126
+ self,
127
+ dim: int,
128
+ num_heads: int,
129
+ mlp_ratio: float = 4.,
130
+ qkv_bias: bool = False,
131
+ qk_norm: bool = False,
132
+ proj_drop: float = 0.,
133
+ attn_drop: float = 0.,
134
+ init_values: Optional[float] = None,
135
+ drop_path: float = 0.,
136
+ act_layer: nn.Module = nn.GELU,
137
+ norm_layer: nn.Module = nn.LayerNorm,
138
+ mlp_layer: nn.Module = Mlp,
139
+ ) -> None:
140
+ super().__init__()
141
+ self.norm1 = norm_layer(dim)
142
+ self.attn = Attention(
143
+ dim,
144
+ num_heads=num_heads,
145
+ qkv_bias=qkv_bias,
146
+ qk_norm=qk_norm,
147
+ attn_drop=attn_drop,
148
+ proj_drop=proj_drop,
149
+ norm_layer=norm_layer,
150
+ )
151
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
152
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
153
+
154
+ self.norm2 = norm_layer(dim)
155
+ self.mlp = mlp_layer(
156
+ in_features=dim,
157
+ hidden_features=int(dim * mlp_ratio),
158
+ act_layer=act_layer,
159
+ drop=proj_drop,
160
+ )
161
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
162
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
163
+
164
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
165
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
166
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
167
+ return x
168
+
169
+
170
+ class CirT(nn.Module):
171
+ def __init__(
172
+ self,
173
+ params,
174
+ img_size=[360, 720],
175
+ input_size=101,
176
+ output_size=93,
177
+ patch_size=124, #124
178
+ embed_dim=256,
179
+ depth=8,
180
+ decoder_depth=2,
181
+ num_heads=16,
182
+ mlp_ratio=4.0,
183
+ drop_path=0.1,
184
+ drop_rate=0.1
185
+ ):
186
+ super().__init__()
187
+
188
+ # TODO: remove time_history parameter
189
+ self.img_size = img_size
190
+ self.patch_size = img_size[1]
191
+ self.input_size = input_size
192
+ self.output_size = output_size
193
+ self.token_embeds = PatchEmbed(img_size, input_size, embed_dim)
194
+ # self.token_embeds = nn.Linear(img_size[0] * 2, embed_dim)
195
+ self.num_patches = self.token_embeds.num_patches
196
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim), requires_grad=True)
197
+ # self.pos_embed = PosEmbed(embed_dim=embed_dim)
198
+
199
+
200
+ # --------------------------------------------------------------------------
201
+
202
+ # ViT backbone
203
+ self.pos_drop = nn.Dropout(p=drop_rate)
204
+ dpr = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
205
+ self.blocks = nn.ModuleList(
206
+ [
207
+ Block(
208
+ embed_dim,
209
+ num_heads,
210
+ mlp_ratio,
211
+ qkv_bias=True,
212
+ drop_path=dpr[i],
213
+ norm_layer=nn.LayerNorm,
214
+ # drop=drop_rate,
215
+ )
216
+ for i in range(depth)
217
+ ]
218
+ )
219
+ self.norm = nn.LayerNorm(embed_dim)
220
+
221
+ # --------------------------------------------------------------------------
222
+
223
+ # prediction head
224
+ self.head = nn.ModuleList()
225
+ for _ in range(decoder_depth):
226
+ self.head.append(nn.Linear(embed_dim, embed_dim))
227
+ self.head.append(nn.GELU())
228
+ self.head.append(nn.Linear(embed_dim, output_size * self.img_size[1]))
229
+ self.head = nn.Sequential(*self.head)
230
+
231
+ # --------------------------------------------------------------------------
232
+
233
+ self.initialize_weights()
234
+
235
+ def initialize_weights(self):
236
+ # token embedding layer
237
+ w = self.token_embeds.proj.weight.data
238
+ trunc_normal_(w.view([w.shape[0], -1]), std=0.02)
239
+
240
+ # initialize nn.Linear and nn.LayerNorm
241
+ self.apply(self._init_weights)
242
+
243
+ def _init_weights(self, m):
244
+ if isinstance(m, nn.Linear):
245
+ trunc_normal_(m.weight, std=0.02)
246
+ if m.bias is not None:
247
+ nn.init.constant_(m.bias, 0)
248
+ elif isinstance(m, nn.LayerNorm):
249
+ nn.init.constant_(m.bias, 0)
250
+ nn.init.constant_(m.weight, 1.0)
251
+
252
+
253
+ def unpatchify(self, x: torch.Tensor, h=None, w=None):
254
+ """
255
+ x: (B, L, V * patch_size)
256
+ return imgs: (B, V, H, W)
257
+ """
258
+ p = self.patch_size
259
+ c_out = self.output_size
260
+ h = self.img_size[0] // 1
261
+ w = self.img_size[1] // p
262
+ assert h * w == x.shape[1]
263
+
264
+ x = x.reshape(shape=(x.shape[0], h, w, p, c_out))
265
+ x = torch.einsum("nhwpc->nchpw", x)
266
+ imgs = x.reshape(shape=(x.shape[0], c_out, h, w * p))
267
+ return imgs
268
+
269
+ def forward_encoder(self, x: torch.Tensor):
270
+ # x: `[B, V, H, W]` shape.
271
+
272
+ # tokenize each variable separately
273
+ # x = torch.fft.rfft(x, norm="forward")
274
+ # x = torch.view_as_real(x)
275
+ # x = torch.cat((x[:, :, :, :, 0], -x[:, :, :, :, 1]), dim=-1)
276
+
277
+ x = self.token_embeds(x)
278
+
279
+ # pos_embed = self.pos_embed()
280
+ # add pos embedding
281
+ x = x + self.pos_embed
282
+ x = self.pos_drop(x)
283
+
284
+ # apply Transformer blocks
285
+ for blk in self.blocks:
286
+ x = blk(x)
287
+ x = self.norm(x)
288
+
289
+ return x
290
+
291
+ def forward(self, x):
292
+ B, V, H, W = x.shape
293
+ # print(x.shape)
294
+ out_transformers = self.forward_encoder(x) # B, L, D
295
+ preds = self.head(out_transformers) # B, L, V*p*p
296
+ preds = self.unpatchify(preds)
297
+
298
+ # real, img = torch.split(preds, preds.shape[-1] // 2, dim=-1)
299
+ # preds = torch.cat([real, -img], dim=-1)
300
+ # preds = torch.fft.irfft(preds, W, norm="forward")
301
+ return preds
networks/CirT2.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import lru_cache
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from timm.models.vision_transformer import trunc_normal_, Block
7
+ from torch.jit import Final
8
+
9
+ import torch.nn.functional as F
10
+ from typing import Optional
11
+ from timm.layers import DropPath, use_fused_attn, Mlp
12
+
13
+ class PatchEmbed(nn.Module):
14
+ def __init__(
15
+ self,
16
+ img_size=[121, 240],
17
+ in_chans=63,
18
+ embed_dim=768,
19
+ norm_layer=None,
20
+ flatten=True,
21
+ bias=True,
22
+ ):
23
+ super().__init__()
24
+ self.img_size = img_size
25
+ self.num_patches = img_size[0]
26
+ self.flatten = flatten
27
+
28
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=[1, img_size[1]], stride=1, bias=bias)
29
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
30
+
31
+ def forward(self, x):
32
+ B, C, H, W = x.shape
33
+ x = self.proj(x)
34
+ if self.flatten:
35
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
36
+ x = self.norm(x)
37
+ return x
38
+
39
+ class Attention(nn.Module):
40
+ fused_attn: Final[bool]
41
+
42
+ def __init__(
43
+ self,
44
+ dim: int,
45
+ num_heads: int = 8,
46
+ qkv_bias: bool = False,
47
+ qk_norm: bool = False,
48
+ attn_drop: float = 0.,
49
+ proj_drop: float = 0.,
50
+ norm_layer: nn.Module = nn.LayerNorm,
51
+ ) -> None:
52
+ super().__init__()
53
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
54
+ self.num_heads = num_heads
55
+ self.head_dim = dim // self.num_heads
56
+ self.scale = self.head_dim ** -0.5
57
+ self.fused_attn = use_fused_attn()
58
+ self.dim = dim
59
+ self.attn_bias = nn.Parameter(torch.zeros(121, 121, 2), requires_grad=True)
60
+
61
+ # self.qkv = CLinear(dim, dim * 3, bias=qkv_bias)
62
+ self.qkv = nn.Linear((dim // 2 + 1) * 2, dim * 3, bias=qkv_bias)
63
+ self.q = nn.Linear((dim // 2 + 1) * 2, dim, bias=qkv_bias)
64
+ self.k = nn.Linear((dim // 2 + 1) * 2, dim, bias=qkv_bias)
65
+ self.v = nn.Linear((dim // 2 + 1) * 2, dim, bias=qkv_bias)
66
+
67
+
68
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
69
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
70
+ self.attn_drop = nn.Dropout(attn_drop)
71
+ # self.proj = CLinear(dim, dim)
72
+ self.proj_drop = nn.Dropout(proj_drop)
73
+ self.proj = nn.Linear(dim, dim)
74
+ # self.proj_drop = nn.Dropout(proj_drop)
75
+
76
+
77
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
78
+ B, N, C = x.shape
79
+
80
+ x = torch.fft.rfft(x, norm="forward")
81
+ x = torch.view_as_real(x)
82
+ x = torch.cat((x[:, :, :, 0], -x[:, :, :, 1]), dim=-1)
83
+
84
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
85
+ # q, k, v = qkv.unbind(0)
86
+ q = self.q(x).reshape(B, self.num_heads, N, self.head_dim)
87
+ k = self.k(x).reshape(B, self.num_heads, N, self.head_dim)
88
+ v = self.v(x).reshape(B, self.num_heads, N, self.head_dim)
89
+ q = q * self.scale
90
+ attn = q @ k.transpose(-2, -1)
91
+
92
+
93
+ attn = attn.softmax(dim=-1)
94
+ attn = self.attn_drop(attn)
95
+ x = attn @ v
96
+ x = x.transpose(1, 2).reshape(B, N, C)
97
+ x = self.proj(x)
98
+ x = self.proj_drop(x)
99
+
100
+ real, img = torch.split(x, x.shape[-1] // 2, dim=-1)
101
+ x = torch.stack([real,-img], dim=-1)
102
+ x = torch.view_as_complex(x)
103
+ x = torch.fft.irfft(x, self.dim, norm="forward")
104
+
105
+ return x
106
+
107
+
108
+ class LayerScale(nn.Module):
109
+ def __init__(
110
+ self,
111
+ dim: int,
112
+ init_values: float = 1e-5,
113
+ inplace: bool = False,
114
+ ) -> None:
115
+ super().__init__()
116
+ self.inplace = inplace
117
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
118
+
119
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
120
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
121
+
122
+
123
+
124
+ class Block(nn.Module):
125
+ def __init__(
126
+ self,
127
+ dim: int,
128
+ num_heads: int,
129
+ mlp_ratio: float = 4.,
130
+ qkv_bias: bool = False,
131
+ qk_norm: bool = False,
132
+ proj_drop: float = 0.,
133
+ attn_drop: float = 0.,
134
+ init_values: Optional[float] = None,
135
+ drop_path: float = 0.,
136
+ act_layer: nn.Module = nn.GELU,
137
+ norm_layer: nn.Module = nn.LayerNorm,
138
+ mlp_layer: nn.Module = Mlp,
139
+ ) -> None:
140
+ super().__init__()
141
+ self.norm1 = norm_layer(dim)
142
+ self.attn = Attention(
143
+ dim,
144
+ num_heads=num_heads,
145
+ qkv_bias=qkv_bias,
146
+ qk_norm=qk_norm,
147
+ attn_drop=attn_drop,
148
+ proj_drop=proj_drop,
149
+ norm_layer=norm_layer,
150
+ )
151
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
152
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
153
+
154
+ self.norm2 = norm_layer(dim)
155
+ self.mlp = mlp_layer(
156
+ in_features=dim,
157
+ hidden_features=int(dim * mlp_ratio),
158
+ act_layer=act_layer,
159
+ drop=proj_drop,
160
+ )
161
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
162
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
163
+
164
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
165
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
166
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
167
+ return x
168
+
169
+
170
+ class CirT_stage2(nn.Module):
171
+ def __init__(
172
+ self,
173
+ params,
174
+ img_size=[360, 720],
175
+ input_size=93,
176
+ output_size=93,
177
+ patch_size=124, #124
178
+ embed_dim=256,
179
+ depth=8,
180
+ decoder_depth=2,
181
+ num_heads=16,
182
+ mlp_ratio=4.0,
183
+ drop_path=0.1,
184
+ drop_rate=0.1
185
+ ):
186
+ super().__init__()
187
+
188
+ # TODO: remove time_history parameter
189
+ self.img_size = img_size
190
+ self.patch_size = img_size[1]
191
+ self.input_size = input_size
192
+ self.output_size = output_size
193
+ self.token_embeds = PatchEmbed(img_size, input_size, embed_dim)
194
+ # self.token_embeds = nn.Linear(img_size[0] * 2, embed_dim)
195
+ self.num_patches = self.token_embeds.num_patches
196
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim), requires_grad=True)
197
+ # self.pos_embed = PosEmbed(embed_dim=embed_dim)
198
+
199
+
200
+ # --------------------------------------------------------------------------
201
+
202
+ # ViT backbone
203
+ self.pos_drop = nn.Dropout(p=drop_rate)
204
+ dpr = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
205
+ self.blocks = nn.ModuleList(
206
+ [
207
+ Block(
208
+ embed_dim,
209
+ num_heads,
210
+ mlp_ratio,
211
+ qkv_bias=True,
212
+ drop_path=dpr[i],
213
+ norm_layer=nn.LayerNorm,
214
+ # drop=drop_rate,
215
+ )
216
+ for i in range(depth)
217
+ ]
218
+ )
219
+ self.norm = nn.LayerNorm(embed_dim)
220
+
221
+ # --------------------------------------------------------------------------
222
+
223
+ # prediction head
224
+ self.head = nn.ModuleList()
225
+ for _ in range(decoder_depth):
226
+ self.head.append(nn.Linear(embed_dim, embed_dim))
227
+ self.head.append(nn.GELU())
228
+ self.head.append(nn.Linear(embed_dim, output_size * self.img_size[1]))
229
+ self.head = nn.Sequential(*self.head)
230
+
231
+ # --------------------------------------------------------------------------
232
+
233
+ self.initialize_weights()
234
+
235
+ def initialize_weights(self):
236
+ # token embedding layer
237
+ w = self.token_embeds.proj.weight.data
238
+ trunc_normal_(w.view([w.shape[0], -1]), std=0.02)
239
+
240
+ # initialize nn.Linear and nn.LayerNorm
241
+ self.apply(self._init_weights)
242
+
243
+ def _init_weights(self, m):
244
+ if isinstance(m, nn.Linear):
245
+ trunc_normal_(m.weight, std=0.02)
246
+ if m.bias is not None:
247
+ nn.init.constant_(m.bias, 0)
248
+ elif isinstance(m, nn.LayerNorm):
249
+ nn.init.constant_(m.bias, 0)
250
+ nn.init.constant_(m.weight, 1.0)
251
+
252
+
253
+ def unpatchify(self, x: torch.Tensor, h=None, w=None):
254
+ """
255
+ x: (B, L, V * patch_size)
256
+ return imgs: (B, V, H, W)
257
+ """
258
+ p = self.patch_size
259
+ c_out = self.output_size
260
+ h = self.img_size[0] // 1
261
+ w = self.img_size[1] // p
262
+ assert h * w == x.shape[1]
263
+
264
+ x = x.reshape(shape=(x.shape[0], h, w, p, c_out))
265
+ x = torch.einsum("nhwpc->nchpw", x)
266
+ imgs = x.reshape(shape=(x.shape[0], c_out, h, w * p))
267
+ return imgs
268
+
269
+ def forward_encoder(self, x: torch.Tensor):
270
+ # x: `[B, V, H, W]` shape.
271
+
272
+ # tokenize each variable separately
273
+ # x = torch.fft.rfft(x, norm="forward")
274
+ # x = torch.view_as_real(x)
275
+ # x = torch.cat((x[:, :, :, :, 0], -x[:, :, :, :, 1]), dim=-1)
276
+
277
+ x = self.token_embeds(x)
278
+
279
+ # pos_embed = self.pos_embed()
280
+ # add pos embedding
281
+ x = x + self.pos_embed
282
+ x = self.pos_drop(x)
283
+
284
+ # apply Transformer blocks
285
+ for blk in self.blocks:
286
+ x = blk(x)
287
+ x = self.norm(x)
288
+
289
+ return x
290
+
291
+ def forward(self, x):
292
+ B, V, H, W = x.shape
293
+ # print(x.shape)
294
+ out_transformers = self.forward_encoder(x) # B, L, D
295
+ preds = self.head(out_transformers) # B, L, V*p*p
296
+ preds = self.unpatchify(preds)
297
+
298
+ # real, img = torch.split(preds, preds.shape[-1] // 2, dim=-1)
299
+ # preds = torch.cat([real, -img], dim=-1)
300
+ # preds = torch.fft.irfft(preds, W, norm="forward")
301
+ return preds