blumenstiel commited on
Commit
cbc0399
·
1 Parent(s): 316d81d

Update app

Browse files
Files changed (2) hide show
  1. app.py +30 -33
  2. requirements.txt +0 -2
app.py CHANGED
@@ -4,7 +4,6 @@ from io import BytesIO
4
  from pathlib import Path
5
 
6
  import glob
7
- import spaces
8
  import numpy as np
9
  import gradio as gr
10
  import rasterio as rio
@@ -18,6 +17,11 @@ rcParams["font.size"] = 9
18
  rcParams["axes.titlesize"] = 9
19
  IMG_PX = 300
20
 
 
 
 
 
 
21
  EXAMPLES = {
22
  "EuroSAT": {
23
  "images": glob.glob("examples/eurosat/*.tif"),
@@ -47,18 +51,6 @@ EXAMPLES = {
47
  }
48
 
49
 
50
- def load_eurosat_example():
51
- return EXAMPLES["EuroSAT"]["images"], ", ".join(EXAMPLES["EuroSAT"]["classes"])
52
-
53
-
54
- def load_meterml_example():
55
- return EXAMPLES["Meter-ML"]["images"], ", ".join(EXAMPLES["Meter-ML"]["classes"])
56
-
57
-
58
- def load_terramesh_example():
59
- return EXAMPLES["TerraMesh"]["images"], ", ".join(EXAMPLES["TerraMesh"]["classes"])
60
-
61
-
62
  pastel1_hex = [mpl.colors.to_hex(c) for c in mpl.colormaps["Pastel1"].colors]
63
 
64
 
@@ -164,8 +156,8 @@ def _bar_chart(top_scores, img_name, cmap) -> str:
164
  b64 = base64.b64encode(buf.getvalue()).decode()
165
  return f'<img src="data:image/png;base64,{b64}" style="display:block;margin:auto;width:{IMG_PX}px;" />'
166
 
167
-
168
- @spaces.GPU
169
  def classify(images, class_text):
170
  class_names = [c.strip() for c in class_text.split(",") if c.strip()]
171
  cards = []
@@ -192,8 +184,25 @@ def classify(images, class_text):
192
  )
193
 
194
 
195
- # UI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
 
197
  with gr.Blocks(
198
  css="""
199
  .gradio-container
@@ -203,7 +212,8 @@ with gr.Blocks(
203
  gr.Markdown("## Zero‑shot Classification with Llama3-MS‑CLIP")
204
  gr.Markdown("Provide Sentinel-2 L2A tif files with all 12 bands and define the class names for running zero-shot classification. "
205
  "You can also use S-2 L1C files with 13 bands but the model might not work as well (e.g., misclassifing forests as sea because of the differrently scaled values). "
206
- "We provide three sets of example images with class names that you can modify. The examples are from [EuroSAT](https://arxiv.org/abs/1709.00029), [Meter-ML](https://arxiv.org/abs/2207.11166), and [TerraMesh](https://arxiv.org/abs/2504.11172) (We downloaded S-2 L2A images for the same locations). "
 
207
  "The images are classified based on the similarity between the images embeddings and text embeddings. "
208
  "You find more information in the [model card](https://huggingface.co/ibm-esa-geospatial/Llama3-MS-CLIP-base) and the [paper](https://arxiv.org/abs/2503.15969). ")
209
  with gr.Row():
@@ -212,7 +222,6 @@ with gr.Blocks(
212
  )
213
  cls_in = gr.Textbox(
214
  value=", ".join(["Forest", "River", "Buildings", "Agriculture", "Mountain", "Snow"]),
215
- # some default classes
216
  label="Class names (comma‑separated)",
217
  )
218
 
@@ -233,29 +242,17 @@ with gr.Blocks(
233
 
234
  btn_terramesh.click(
235
  load_terramesh_example,
236
- outputs=[img_in, cls_in],
237
- ).then(
238
- classify,
239
- inputs=[img_in, cls_in],
240
- outputs=out_html,
241
  )
242
 
243
  btn_eurosat.click(
244
  load_eurosat_example,
245
- outputs=[img_in, cls_in],
246
- ).then(
247
- classify,
248
- inputs=[img_in, cls_in],
249
- outputs=out_html,
250
  )
251
 
252
  btn_meterml.click(
253
  load_meterml_example,
254
- outputs=[img_in, cls_in],
255
- ).then(
256
- classify,
257
- inputs=[img_in, cls_in],
258
- outputs=out_html,
259
  )
260
 
261
  if __name__ == "__main__":
 
4
  from pathlib import Path
5
 
6
  import glob
 
7
  import numpy as np
8
  import gradio as gr
9
  import rasterio as rio
 
17
  rcParams["axes.titlesize"] = 9
18
  IMG_PX = 300
19
 
20
+ import sys
21
+ import csv
22
+
23
+ csv.field_size_limit(sys.maxsize)
24
+
25
  EXAMPLES = {
26
  "EuroSAT": {
27
  "images": glob.glob("examples/eurosat/*.tif"),
 
51
  }
52
 
53
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  pastel1_hex = [mpl.colors.to_hex(c) for c in mpl.colormaps["Pastel1"].colors]
55
 
56
 
 
156
  b64 = base64.b64encode(buf.getvalue()).decode()
157
  return f'<img src="data:image/png;base64,{b64}" style="display:block;margin:auto;width:{IMG_PX}px;" />'
158
 
159
+ # import spaces
160
+ # @spaces.GPU # ZeroGPU does not seem to be working
161
  def classify(images, class_text):
162
  class_names = [c.strip() for c in class_text.split(",") if c.strip()]
163
  cards = []
 
184
  )
185
 
186
 
187
+ # Cache examples
188
+ terramesh_html = classify(EXAMPLES["TerraMesh"]["images"], ", ".join(EXAMPLES["TerraMesh"]["classes"]))
189
+ eurosat_html = classify(EXAMPLES["EuroSAT"]["images"], ", ".join(EXAMPLES["EuroSAT"]["classes"]))
190
+ meterml_html = classify(EXAMPLES["Meter-ML"]["images"], ", ".join(EXAMPLES["Meter-ML"]["classes"]))
191
+
192
+
193
+ def load_eurosat_example():
194
+ return EXAMPLES["EuroSAT"]["images"], ", ".join(EXAMPLES["EuroSAT"]["classes"]), eurosat_html
195
+
196
+
197
+ def load_meterml_example():
198
+ return EXAMPLES["Meter-ML"]["images"], ", ".join(EXAMPLES["Meter-ML"]["classes"]), meterml_html
199
+
200
+
201
+ def load_terramesh_example():
202
+ return EXAMPLES["TerraMesh"]["images"], ", ".join(EXAMPLES["TerraMesh"]["classes"]), terramesh_html
203
+
204
 
205
+ # UI
206
  with gr.Blocks(
207
  css="""
208
  .gradio-container
 
212
  gr.Markdown("## Zero‑shot Classification with Llama3-MS‑CLIP")
213
  gr.Markdown("Provide Sentinel-2 L2A tif files with all 12 bands and define the class names for running zero-shot classification. "
214
  "You can also use S-2 L1C files with 13 bands but the model might not work as well (e.g., misclassifing forests as sea because of the differrently scaled values). "
215
+ "We provide three sets of example images with class names and cached outputs. "
216
+ "The examples are from [EuroSAT](https://arxiv.org/abs/1709.00029), [Meter-ML](https://arxiv.org/abs/2207.11166), and [TerraMesh](https://arxiv.org/abs/2504.11172) (We downloaded S-2 L2A images for the same locations). "
217
  "The images are classified based on the similarity between the images embeddings and text embeddings. "
218
  "You find more information in the [model card](https://huggingface.co/ibm-esa-geospatial/Llama3-MS-CLIP-base) and the [paper](https://arxiv.org/abs/2503.15969). ")
219
  with gr.Row():
 
222
  )
223
  cls_in = gr.Textbox(
224
  value=", ".join(["Forest", "River", "Buildings", "Agriculture", "Mountain", "Snow"]),
 
225
  label="Class names (comma‑separated)",
226
  )
227
 
 
242
 
243
  btn_terramesh.click(
244
  load_terramesh_example,
245
+ outputs=[img_in, cls_in, out_html],
 
 
 
 
246
  )
247
 
248
  btn_eurosat.click(
249
  load_eurosat_example,
250
+ outputs=[img_in, cls_in, out_html],
 
 
 
 
251
  )
252
 
253
  btn_meterml.click(
254
  load_meterml_example,
255
+ outputs=[img_in, cls_in, out_html],
 
 
 
 
256
  )
257
 
258
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -1,5 +1,3 @@
1
- --extra-index-url https://download.pytorch.org/whl/cu113
2
- torch
3
  gradio>=4.31.0
4
  plotly
5
  rasterio
 
 
 
1
  gradio>=4.31.0
2
  plotly
3
  rasterio