import base64
from io import BytesIO
from pathlib import Path
import glob
import numpy as np
import gradio as gr
import rasterio as rio
import matplotlib.pyplot as plt
import matplotlib as mpl
from PIL import Image
from matplotlib import rcParams
from msclip.inference import run_inference_classification
rcParams["font.size"] = 9
IMG_PX = 300
EXAMPLES = {
"EuroSAT": {
"images": glob.glob("examples/eurosat/*.tif"),
"classes": [
"AnnualCrop","Forest","HerbaceousVegetation","Highway","Industrial",
"Pasture","PermanentCrop","Residential","River","SeaLake"
]
},
"Meter-ML": {
"images": glob.glob("examples/meterml/*.tif"),
"classes": [
"Todo"
]
},
"TerraMesh": {
"images": glob.glob("examples/terramesh/*.tif"),
"classes": [
"Agriculture", "Beach", "River", "Ice", "Fields"
]
},
}
def load_eurosat_example():
return EXAMPLES["EuroSAT"]["images"], ",".join(EXAMPLES["EuroSAT"]["classes"])
def load_meterml_example():
return EXAMPLES["Meter-ML"]["images"], ",".join(EXAMPLES["Meter-ML"]["classes"])
def load_terramesh_example():
return EXAMPLES["TerraMesh"]["images"], ",".join(EXAMPLES["TerraMesh"]["classes"])
pastel1_hex = [mpl.colors.to_hex(c) for c in mpl.colormaps["Pastel1"].colors]
def build_colormap(class_names):
return {c: pastel1_hex[i % len(pastel1_hex)] for i, c in enumerate(sorted(class_names))}
def _rgb_smooth_quantiles(array, tolerance=0.02, scaling=0.5, default=2000):
"""
array: numpy array with dimensions [C, H, W]
returns 0-1 scaled array
"""
# Get scaling thresholds for smoothing the brightness
limit_low, median, limit_high = np.quantile(array, q=[tolerance, 0.5, 1. - tolerance])
limit_high = limit_high.clip(default) # Scale only pixels above default value
limit_low = limit_low.clip(0, 1000) # Scale only pixels below 1000
limit_low = np.where(median > default / 2, limit_low, 0) # Make image only darker if it is not dark already
# Smooth very dark and bright values using linear scaling
array = np.where(array >= limit_low, array, limit_low + (array - limit_low) * scaling)
array = np.where(array <= limit_high, array, limit_high + (array - limit_high) * scaling)
# Update scaling params using a 10th of the tolerance for max value
limit_low, limit_high = np.quantile(array, q=[tolerance/10, 1. - tolerance/10])
limit_high = limit_high.clip(default, 20000) # Scale only pixels above default value
limit_low = limit_low.clip(0, 500) # Scale only pixels below 500
limit_low = np.where(median > default / 2, limit_low, 0) # Make image only darker if it is not dark already
# Scale data to 0-255
array = (array - limit_low) / (limit_high - limit_low)
return array
def _s2_to_rgb(data, smooth_quantiles=True):
# Select
if data.shape[0] > 13:
# assuming channel last
rgb = data[:, :, [3, 2, 1]]
else:
# assuming channel first
rgb = data[[3, 2, 1]].transpose((1, 2, 0))
if smooth_quantiles:
rgb = _rgb_smooth_quantiles(rgb)
else:
rgb = rgb / 2000
# to uint8
rgb = (rgb * 255).round().clip(0, 255).astype(np.uint8)
return rgb
def _img_to_b64(path: str | Path) -> str:
"""Encode image as base64 (optionally downsized)."""
with rio.open(path) as src:
data = src.read()
rgb = _s2_to_rgb(data)
img = Image.fromarray(rgb)
side = max(img.size)
# create square canvas, paste centred, then resize
canvas = Image.new("RGB", (side, side), (255, 255, 255))
canvas.paste(img, ((side - img.width) // 2, (side - img.height) // 2))
canvas = canvas.resize((IMG_PX, IMG_PX))
buf = BytesIO()
canvas.save(buf, format="PNG")
return base64.b64encode(buf.getvalue()).decode()
def _bar_chart(top_scores, cmap) -> str:
scores = top_scores.values.tolist()
labels = top_scores.index.tolist()
while len(scores) < 3:
scores.append(0)
labels.append("")
fig, ax = plt.subplots(figsize=(3, 1))
y_pos = np.arange(3)
colors = [cmap.get(cls, "none") if val > 0 else (0, 0, 0, 0)
for cls, val in zip(labels, scores)]
ax.barh(y_pos, scores, height=0.7, color=colors)
ax.set_xlim(0, 1)
ax.invert_yaxis()
ax.axis("off")
for i, (cls, val) in enumerate(zip(labels, scores)):
if val > 0: # skip padded rows
ax.text(0.02, i+0.03, f"{cls} ({round(val * 100)}%)", ha="left", va="center")
buf = BytesIO()
fig.savefig(buf, format="png", dpi=300, bbox_inches="tight", transparent=True)
plt.close(fig)
b64 = base64.b64encode(buf.getvalue()).decode()
return f''
def classify(images, class_text):
class_names = [c.strip() for c in class_text.split(",") if c.strip()]
cards = []
df = run_inference_classification(image_path=images, class_names=class_names) # one row per call
for img_path, (id, row) in zip(images, df.iterrows()):
scores = row[2:].astype(float) # drop filename column
top = scores.sort_values(ascending=False)[:3]
top = top[top > 0.01] # filter low scores
cmap = build_colormap(class_names)
cards.append(f"""