jbilcke-hf's picture
jbilcke-hf HF Staff
Upload 210 files
4bf9661 verified
raw
history blame
4.51 kB
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModel
from typing import List, Union
import os
from .config import MODEL_PATHS
class PickScore(torch.nn.Module):
def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS):
super().__init__()
"""Initialize the Selector with a processor and model.
Args:
device (Union[str, torch.device]): The device to load the model on.
"""
self.device = device if isinstance(device, torch.device) else torch.device(device)
processor_name_or_path = path.get("clip")
model_pretrained_name_or_path = path.get("pickscore")
self.processor = AutoProcessor.from_pretrained(processor_name_or_path)
self.model = AutoModel.from_pretrained(model_pretrained_name_or_path).eval().to(self.device)
def _calculate_score(self, image: torch.Tensor, prompt: str, softmax: bool = False) -> float:
"""Calculate the score for a single image and prompt.
Args:
image (torch.Tensor): The processed image tensor.
prompt (str): The prompt text.
softmax (bool): Whether to apply softmax to the scores.
Returns:
float: The score for the image.
"""
with torch.no_grad():
# Prepare text inputs
text_inputs = self.processor(
text=prompt,
padding=True,
truncation=True,
max_length=77,
return_tensors="pt",
).to(self.device)
# Embed images and text
image_embs = self.model.get_image_features(pixel_values=image)
image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
text_embs = self.model.get_text_features(**text_inputs)
text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)
# Compute score
score = (text_embs @ image_embs.T)[0]
if softmax:
# Apply logit scale and softmax
score = torch.softmax(self.model.logit_scale.exp() * score, dim=-1)
return score.cpu().item()
@torch.no_grad()
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str, softmax: bool = False) -> List[float]:
"""Score the images based on the prompt.
Args:
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
prompt (str): The prompt text.
softmax (bool): Whether to apply softmax to the scores.
Returns:
List[float]: List of scores for the images.
"""
try:
if isinstance(images, (str, Image.Image)):
# Single image
if isinstance(images, str):
pil_image = Image.open(images)
else:
pil_image = images
# Prepare image inputs
image_inputs = self.processor(
images=pil_image,
padding=True,
truncation=True,
max_length=77,
return_tensors="pt",
).to(self.device)
return [self._calculate_score(image_inputs["pixel_values"], prompt, softmax)]
elif isinstance(images, list):
# Multiple images
scores = []
for one_image in images:
if isinstance(one_image, str):
pil_image = Image.open(one_image)
elif isinstance(one_image, Image.Image):
pil_image = one_image
else:
raise TypeError("The type of parameter images is illegal.")
# Prepare image inputs
image_inputs = self.processor(
images=pil_image,
padding=True,
truncation=True,
max_length=77,
return_tensors="pt",
).to(self.device)
scores.append(self._calculate_score(image_inputs["pixel_values"], prompt, softmax))
return scores
else:
raise TypeError("The type of parameter images is illegal.")
except Exception as e:
raise RuntimeError(f"Error in scoring images: {e}")