import base64 from io import BytesIO from pathlib import Path import glob import numpy as np import gradio as gr import rasterio as rio import matplotlib.pyplot as plt import matplotlib as mpl from PIL import Image from matplotlib import rcParams from msclip.inference import run_inference_classification rcParams["font.size"] = 9 IMG_PX = 300 EXAMPLES = { "EuroSAT": { "images": glob.glob("examples/eurosat/*.tif"), "classes": [ "AnnualCrop","Forest","HerbaceousVegetation","Highway","Industrial", "Pasture","PermanentCrop","Residential","River","SeaLake" ] }, "Meter-ML": { "images": glob.glob("examples/meterml/*.tif"), "classes": [ "Todo" ] }, "TerraMesh": { "images": glob.glob("examples/terramesh/*.tif"), "classes": [ "Agriculture", "Beach", "River", "Ice", "Fields" ] }, } def load_eurosat_example(): return EXAMPLES["EuroSAT"]["images"], ",".join(EXAMPLES["EuroSAT"]["classes"]) def load_meterml_example(): return EXAMPLES["Meter-ML"]["images"], ",".join(EXAMPLES["Meter-ML"]["classes"]) def load_terramesh_example(): return EXAMPLES["TerraMesh"]["images"], ",".join(EXAMPLES["TerraMesh"]["classes"]) pastel1_hex = [mpl.colors.to_hex(c) for c in mpl.colormaps["Pastel1"].colors] def build_colormap(class_names): return {c: pastel1_hex[i % len(pastel1_hex)] for i, c in enumerate(sorted(class_names))} def _rgb_smooth_quantiles(array, tolerance=0.02, scaling=0.5, default=2000): """ array: numpy array with dimensions [C, H, W] returns 0-1 scaled array """ # Get scaling thresholds for smoothing the brightness limit_low, median, limit_high = np.quantile(array, q=[tolerance, 0.5, 1. - tolerance]) limit_high = limit_high.clip(default) # Scale only pixels above default value limit_low = limit_low.clip(0, 1000) # Scale only pixels below 1000 limit_low = np.where(median > default / 2, limit_low, 0) # Make image only darker if it is not dark already # Smooth very dark and bright values using linear scaling array = np.where(array >= limit_low, array, limit_low + (array - limit_low) * scaling) array = np.where(array <= limit_high, array, limit_high + (array - limit_high) * scaling) # Update scaling params using a 10th of the tolerance for max value limit_low, limit_high = np.quantile(array, q=[tolerance/10, 1. - tolerance/10]) limit_high = limit_high.clip(default, 20000) # Scale only pixels above default value limit_low = limit_low.clip(0, 500) # Scale only pixels below 500 limit_low = np.where(median > default / 2, limit_low, 0) # Make image only darker if it is not dark already # Scale data to 0-255 array = (array - limit_low) / (limit_high - limit_low) return array def _s2_to_rgb(data, smooth_quantiles=True): # Select if data.shape[0] > 13: # assuming channel last rgb = data[:, :, [3, 2, 1]] else: # assuming channel first rgb = data[[3, 2, 1]].transpose((1, 2, 0)) if smooth_quantiles: rgb = _rgb_smooth_quantiles(rgb) else: rgb = rgb / 2000 # to uint8 rgb = (rgb * 255).round().clip(0, 255).astype(np.uint8) return rgb def _img_to_b64(path: str | Path) -> str: """Encode image as base64 (optionally downsized).""" with rio.open(path) as src: data = src.read() rgb = _s2_to_rgb(data) img = Image.fromarray(rgb) side = max(img.size) # create square canvas, paste centred, then resize canvas = Image.new("RGB", (side, side), (255, 255, 255)) canvas.paste(img, ((side - img.width) // 2, (side - img.height) // 2)) canvas = canvas.resize((IMG_PX, IMG_PX)) buf = BytesIO() canvas.save(buf, format="PNG") return base64.b64encode(buf.getvalue()).decode() def _bar_chart(top_scores, cmap) -> str: scores = top_scores.values.tolist() labels = top_scores.index.tolist() while len(scores) < 3: scores.append(0) labels.append("") fig, ax = plt.subplots(figsize=(3, 1)) y_pos = np.arange(3) colors = [cmap.get(cls, "none") if val > 0 else (0, 0, 0, 0) for cls, val in zip(labels, scores)] ax.barh(y_pos, scores, height=0.7, color=colors) ax.set_xlim(0, 1) ax.invert_yaxis() ax.axis("off") for i, (cls, val) in enumerate(zip(labels, scores)): if val > 0: # skip padded rows ax.text(0.02, i+0.03, f"{cls} ({round(val * 100)}%)", ha="left", va="center") buf = BytesIO() fig.savefig(buf, format="png", dpi=300, bbox_inches="tight", transparent=True) plt.close(fig) b64 = base64.b64encode(buf.getvalue()).decode() return f'' def classify(images, class_text): class_names = [c.strip() for c in class_text.split(",") if c.strip()] cards = [] df = run_inference_classification(image_path=images, class_names=class_names) # one row per call for img_path, (id, row) in zip(images, df.iterrows()): scores = row[2:].astype(float) # drop filename column top = scores.sort_values(ascending=False)[:3] top = top[top > 0.01] # filter low scores cmap = build_colormap(class_names) cards.append(f"""
{_bar_chart(top, cmap)}
""") return ( "
" + "".join(cards) + "
" ) # UI DEFAULT_CLASSES = ["Forest", "River", "Buildings", "Agriculture", "Mountain", "Snow"] # with gr.Blocks(css=".gradio-container") as demo: with gr.Blocks( css=""" .gradio-container #result_box, #result_box.gr-skeleton {min-height:280px !important;} """) as demo: gr.Markdown("## Zero‑shot Classification with Llama3-MS‑CLIP") gr.Markdown("Provide Sentinel-2 tif files with all 12 or 13 bands and define the class names. " "You can also load one of the three provided example sets with class names that you can modify. The example images are comming from [EuroSAT](), [Meter-ML](), and [TerraMesh](). " "The images are classified based on the similarity between the images embeddings and text embeddings. " "You find more information in the [model card](https://huggingface.co/ibm-esa-geospatial/Llama3-MS-CLIP-base) and the [paper](https://arxiv.org/abs/2503.15969). ") with gr.Row(): img_in = gr.File( label="Upload S-2 images", file_count="multiple", type="filepath" ) cls_in = gr.Textbox( value=", ".join(DEFAULT_CLASSES), label="Class names (comma‑separated)", ) run_btn = gr.Button("Classify", variant="primary") # Examples gr.Markdown("#### Load examples") with gr.Row(): btn_terramesh = gr.Button("TerraMesh") btn_eurosat = gr.Button("EuroSAT") btn_meterml = gr.Button("Meter-ML") out_html = gr.HTML(label="Results", elem_id="result_box", min_height=280) run_btn.click(classify, inputs=[img_in, cls_in], outputs=out_html) btn_terramesh.click( load_terramesh_example, outputs=[img_in, cls_in], ).then( classify, inputs=[img_in, cls_in], outputs=out_html, ) btn_eurosat.click( load_eurosat_example, outputs=[img_in, cls_in], ).then( classify, inputs=[img_in, cls_in], outputs=out_html, ) btn_meterml.click( load_meterml_example, outputs=[img_in, cls_in], ).then( classify, inputs=[img_in, cls_in], outputs=out_html, ) if __name__ == "__main__": demo.launch()