SmilingWolf's picture
Update and trim dependencies, isort imports
d06266d
import json
import os
from collections import defaultdict
from typing import Dict, List
import faiss
import gradio as gr
import numpy as np
from cheesechaser.datapool import (
AnimePicturesWebpDataPool,
DanbooruNewestWebpDataPool,
GelbooruWebpDataPool,
KonachanWebpDataPool,
Rule34WebpDataPool,
YandeWebpDataPool,
ZerochanWebpDataPool,
)
from hfutils.operate import get_hf_client, get_hf_fs
from hfutils.utils import TemporaryDirectory
from imgutils.generic import siglip
from imgutils.utils import ts_lru_cache
from PIL import Image
from pools import quick_webp_pool
_SIGLIP_REPO_ID = "deepghs/siglip_beta"
_INDEX_REPO_ID = 'deepghs/anime_sites_indices'
hf_fs = get_hf_fs()
hf_client = get_hf_client()
_DEFAULT_MODEL_NAME = 'SwinV2_v3_danbooru_8005009_4GB'
_ALL_MODEL_NAMES = [
os.path.dirname(os.path.relpath(path, _INDEX_REPO_ID))
for path in hf_fs.glob(f'{_INDEX_REPO_ID}/*/knn.index')
]
_SITE_CLS = {
'danbooru': DanbooruNewestWebpDataPool,
'yandere': YandeWebpDataPool,
'zerochan': ZerochanWebpDataPool,
'gelbooru': GelbooruWebpDataPool,
'konachan': KonachanWebpDataPool,
'anime_pictures': AnimePicturesWebpDataPool,
'rule34': Rule34WebpDataPool,
}
def _get_from_ids(site_name: str, ids: List[int]) -> Dict[int, Image.Image]:
with TemporaryDirectory() as td:
site_cls = _SITE_CLS.get(site_name) or quick_webp_pool(site_name, 3)
datapool = site_cls()
datapool.batch_download_to_directory(
resource_ids=ids,
dst_dir=td,
)
retval = {}
for file in os.listdir(td):
id_ = int(os.path.splitext(file)[0])
image = Image.open(os.path.join(td, file))
image.load()
retval[id_] = image
return retval
def _get_from_raw_ids(ids: List[str]) -> Dict[str, Image.Image]:
_sites = defaultdict(list)
for id_ in ids:
site_name, num_id = id_.rsplit('_', maxsplit=1)
num_id = int(num_id)
_sites[site_name].append(num_id)
_retval = {}
for site_name, site_ids in _sites.items():
_retval.update({
f'{site_name}_{id_}': image
for id_, image in _get_from_ids(site_name, site_ids).items()
})
return _retval
@ts_lru_cache(maxsize=3)
def _get_index_info(repo_id: str, model_name: str):
image_ids = np.load(hf_client.hf_hub_download(
repo_id=repo_id,
repo_type='model',
filename=f'{model_name}/ids.npy',
))
knn_index = faiss.read_index(hf_client.hf_hub_download(
repo_id=repo_id,
repo_type='model',
filename=f'{model_name}/knn.index',
))
config = json.loads(open(hf_client.hf_hub_download(
repo_id=repo_id,
repo_type='model',
filename=f'{model_name}/infos.json',
)).read())["index_param"]
faiss.ParameterSpace().set_index_parameters(knn_index, config)
return image_ids, knn_index
def search(model_name: str, img_input, str_input: str, n_neighbours: int):
images_ids, knn_index = _get_index_info(_INDEX_REPO_ID, model_name)
if str_input == "":
embeddings = siglip.siglip_image_encode(
img_input,
repo_id=_SIGLIP_REPO_ID,
model_name="smilingwolf/siglip_swinv2_base_2025_02_22_18h56m54s",
fmt="embeddings",
)
else:
embeddings = siglip.siglip_text_encode(
str_input,
repo_id=_SIGLIP_REPO_ID,
model_name="smilingwolf/siglip_swinv2_base_2025_02_22_18h56m54s",
fmt="embeddings",
)
# In the model, the "embeddings" output node is already normalized.
# Ask for the "encodings" output if you want the raw logits
dists, indexes = knn_index.search(embeddings, k=n_neighbours)
neighbours_ids = images_ids[indexes][0]
captions = []
images = []
ids_to_images = _get_from_raw_ids(neighbours_ids)
for image_id, dist in zip(neighbours_ids, dists[0]):
if image_id in ids_to_images:
images.append(ids_to_images[image_id])
captions.append(f"{image_id}/{dist:.2f}")
return list(zip(images, captions))
if __name__ == "__main__":
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
img_input = gr.Image(type="pil", image_mode="RGBA", label="Image input")
str_input = gr.Textbox(label="Text input (leave empty to use image input)")
with gr.Column():
with gr.Row():
n_model = gr.Dropdown(
choices=_ALL_MODEL_NAMES,
value=_DEFAULT_MODEL_NAME,
label='Index to Use',
)
with gr.Row():
n_neighbours = gr.Slider(
minimum=1,
maximum=50,
value=20,
step=1,
label="# of images",
)
find_btn = gr.Button("Find similar images")
with gr.Row():
similar_images = gr.Gallery(label="Similar images", columns=[5])
find_btn.click(
fn=search,
inputs=[
n_model,
img_input,
str_input,
n_neighbours,
],
outputs=[similar_images],
)
demo.queue().launch()