|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from transformers import AutoTokenizer, AutoConfig |
|
from transformers.modeling_outputs import SequenceClassifierOutputWithPast |
|
import base64 |
|
import os |
|
from io import BytesIO |
|
from typing import Tuple |
|
import math |
|
import requests |
|
import torch |
|
from torch import Tensor |
|
import torchvision.transforms as T |
|
from PIL import Image |
|
from torchvision.transforms.functional import InterpolationMode |
|
from typing import Optional, Any, Union, Dict, List |
|
|
|
from tqdm import tqdm |
|
import torch.nn.functional as F |
|
from datasets import Dataset |
|
from torch.utils.data import DataLoader |
|
|
|
from .modeling_eagle_chat import Eagle2ChatModel |
|
from .configuration_eagle_chat import Eagle2ChatConfig |
|
from .conversation import get_conv_template |
|
|
|
from .configuration_siglip import SiglipVisionConfig |
|
from .modeling_siglip import SiglipVisionModel |
|
from .flash_attention import * |
|
|
|
from .llama_bidirectional_model import LlamaBidirectionalModel |
|
from transformers import PreTrainedModel |
|
|
|
IMAGENET_MEAN = (0.485, 0.456, 0.406) |
|
IMAGENET_STD = (0.229, 0.224, 0.225) |
|
|
|
SIGLIP_MEAN = (0.5, 0.5, 0.5) |
|
SIGLIP_STD = (0.5, 0.5, 0.5) |
|
|
|
def load_image(image): |
|
if isinstance(image, Image.Image): |
|
return image |
|
elif isinstance(image, str) and os.path.exists(image): |
|
return Image.open(image) |
|
elif isinstance(image, dict): |
|
if 'disk_path' in image: |
|
return Image.open(image['disk_path']) |
|
elif 'base64' in image: |
|
return Image.open(BytesIO(base64.b64decode(image['base64']))) |
|
elif 'url' in image: |
|
response = requests.get(image['url']) |
|
return Image.open(BytesIO(response.content)) |
|
elif 'bytes' in image: |
|
return Image.open(BytesIO(image['bytes'])) |
|
else: |
|
raise ValueError(f'Invalid image: {image}') |
|
else: |
|
raise ValueError(f'Invalid image: {image}') |
|
|
|
def build_transform(input_size, norm_type='imagenet'): |
|
if norm_type == 'imagenet': |
|
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD |
|
elif norm_type == 'siglip': |
|
MEAN, STD = SIGLIP_MEAN, SIGLIP_STD |
|
|
|
transform = T.Compose([ |
|
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), |
|
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), |
|
T.ToTensor(), |
|
T.Normalize(mean=MEAN, std=STD) |
|
]) |
|
return transform |
|
|
|
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): |
|
""" |
|
previous version mainly foucs on ratio. |
|
We also consider area ratio here. |
|
""" |
|
best_factor = float('-inf') |
|
best_ratio = (1, 1) |
|
area = width * height |
|
for ratio in target_ratios: |
|
target_aspect_ratio = ratio[0] / ratio[1] |
|
ratio_diff = abs(aspect_ratio - target_aspect_ratio) |
|
area_ratio = (ratio[0]*ratio[1]*image_size*image_size)/ area |
|
""" |
|
new area > 60% of original image area is enough. |
|
""" |
|
factor_based_on_area_n_ratio = min((ratio[0]*ratio[1]*image_size*image_size)/ area, 0.6)* \ |
|
min(target_aspect_ratio/aspect_ratio, aspect_ratio/target_aspect_ratio) |
|
|
|
if factor_based_on_area_n_ratio > best_factor: |
|
best_factor = factor_based_on_area_n_ratio |
|
best_ratio = ratio |
|
|
|
return best_ratio |
|
|
|
def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False): |
|
orig_width, orig_height = image.size |
|
aspect_ratio = orig_width / orig_height |
|
|
|
|
|
target_ratios = set( |
|
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if |
|
i * j <= max_num and i * j >= min_num) |
|
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) |
|
|
|
|
|
target_aspect_ratio = find_closest_aspect_ratio( |
|
aspect_ratio, target_ratios, orig_width, orig_height, image_size) |
|
|
|
|
|
target_width = image_size * target_aspect_ratio[0] |
|
target_height = image_size * target_aspect_ratio[1] |
|
blocks = target_aspect_ratio[0] * target_aspect_ratio[1] |
|
|
|
|
|
resized_img = image.resize((target_width, target_height)) |
|
processed_images = [] |
|
for i in range(blocks): |
|
box = ( |
|
(i % (target_width // image_size)) * image_size, |
|
(i // (target_width // image_size)) * image_size, |
|
((i % (target_width // image_size)) + 1) * image_size, |
|
((i // (target_width // image_size)) + 1) * image_size |
|
) |
|
|
|
split_img = resized_img.crop(box) |
|
processed_images.append(split_img) |
|
assert len(processed_images) == blocks |
|
if use_thumbnail and len(processed_images) != 1: |
|
thumbnail_img = image.resize((image_size, image_size)) |
|
processed_images.append(thumbnail_img) |
|
return processed_images |
|
|
|
def split_model(model_path, device): |
|
|
|
device_map = {} |
|
world_size = torch.cuda.device_count() |
|
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) |
|
num_layers = config.llm_config.num_hidden_layers |
|
|
|
print('world_size', world_size) |
|
num_layers_per_gpu_ = math.floor(num_layers / (world_size - 1)) |
|
num_layers_per_gpu = [num_layers_per_gpu_] * world_size |
|
num_layers_per_gpu[device] = num_layers - num_layers_per_gpu_ * (world_size-1) |
|
print(num_layers_per_gpu) |
|
layer_cnt = 0 |
|
for i, num_layer in enumerate(num_layers_per_gpu): |
|
for j in range(num_layer): |
|
device_map[f'language_model.model.layers.{layer_cnt}'] = i |
|
layer_cnt += 1 |
|
device_map['vision_model'] = device |
|
device_map['mlp1'] = device |
|
device_map['language_model.model.tok_embeddings'] = device |
|
device_map['language_model.model.embed_tokens'] = device |
|
device_map['language_model.output'] = device |
|
device_map['language_model.model.norm'] = device |
|
device_map['language_model.lm_head'] = device |
|
device_map['language_model.model.rotary_emb'] = device |
|
device_map[f'language_model.model.layers.{num_layers - 1}'] = device |
|
return device_map |
|
|
|
class llama_NemoRetrieverColEmbedConfig(Eagle2ChatConfig): |
|
model_type = "llama_nemoretrievercolembed" |
|
|
|
q_max_length: Optional[int] |
|
p_max_length: Optional[int] |
|
query_prefix: str |
|
passage_prefix: str |
|
pooling: str |
|
bidirectional_attention: bool |
|
|
|
def __init__( |
|
self, |
|
q_max_length: Optional[int] = 512, |
|
p_max_length: Optional[int] = 10240, |
|
query_prefix: str = "query:", |
|
passage_prefix: str = "passage:", |
|
pooling: str = "last", |
|
bidirectional_attention: bool = False, |
|
max_input_tiles: int = 2, |
|
img_context_token_id: int = 128258, |
|
out_dimension: int = -1, |
|
**kwargs, |
|
): |
|
self.q_max_length = q_max_length |
|
self.p_max_length = p_max_length |
|
self.query_prefix = query_prefix |
|
self.passage_prefix = passage_prefix |
|
self.pooling = pooling |
|
self.bidirectional_attention = bidirectional_attention |
|
self.img_context_token_id = img_context_token_id |
|
self.max_input_tiles = max_input_tiles |
|
self.out_dimension = out_dimension |
|
super().__init__(**kwargs) |
|
|
|
class llama_NemoRetrieverColEmbed(Eagle2ChatModel): |
|
|
|
config_class = llama_NemoRetrieverColEmbedConfig |
|
_supports_flash_attn_2 = True |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.padding = True |
|
self.q_max_length = 512 |
|
self.p_max_length = 10240 |
|
self.pad_to_multiple_of = None |
|
self.query_prefix = 'query:' |
|
self.passage_prefix = 'passage:' |
|
|
|
if isinstance(args[0], llama_NemoRetrieverColEmbedConfig): |
|
tokenizer = AutoTokenizer.from_pretrained(args[0]._name_or_path, trust_remote_code=True) |
|
tokens_to_keep = ['<box>', '</box>', '<ref>', '</ref>'] |
|
tokenizer.additional_special_tokens = [item for item in tokenizer.additional_special_tokens if item not in tokens_to_keep] |
|
tokenizer.padding_side = 'left' |
|
self.tokenizer = tokenizer |
|
|
|
self.norm_type = 'siglip' |
|
self.image_size = self.config.force_image_size |
|
self.max_input_tiles = 6 |
|
self.system_message = "" |
|
self.use_visual_embedding = True |
|
|
|
def process_documents(self, documents: Union[Dict,List[Dict]], **kwargs): |
|
if isinstance(documents, dict): |
|
images = documents["images"] |
|
texts = documents["texts"] |
|
assert len(texts) == len(images) |
|
elif isinstance(documents, list): |
|
images = [pair['image'] for pair in documents ] |
|
texts = [pair['text'] for pair in documents ] |
|
else: |
|
raise ValueError("The documents need to be a dict or list of dicts") |
|
|
|
if self.passage_prefix: |
|
texts = [self.passage_prefix + ' ' + t for t in texts] |
|
|
|
contents, pil_images, max_input_tile_list, llm_onlys = [], [], [], [] |
|
for image, text in zip(images, texts): |
|
prefix = '' |
|
llm_only = True |
|
if image != '': |
|
pil_images.append(load_image(image)) |
|
prefix = '<image>' |
|
max_input_tile_list.append(self.max_input_tiles) |
|
llm_only = False |
|
else: |
|
pil_images.append(None) |
|
max_input_tile_list.append(self.max_input_tiles) |
|
|
|
llm_onlys.append(llm_only) |
|
|
|
content = text |
|
if prefix!='': |
|
content = prefix + ' ' + content |
|
if self.passage_prefix: |
|
content = self.passage_prefix + ' ' + content |
|
contents.append(content) |
|
|
|
transform = build_transform(input_size=self.image_size, norm_type=self.norm_type) |
|
|
|
template = get_conv_template(self.config.template) |
|
template.system_message = self.system_message |
|
|
|
content_prompts = [] |
|
pixel_values_list = [] |
|
for content, pil_image, max_input_tiles, llm_only in zip(contents, pil_images, max_input_tile_list, llm_onlys): |
|
if pil_image is not None: |
|
if self.config.dynamic_image_size: |
|
image_tiles = dynamic_preprocess( |
|
pil_image, image_size=self.image_size, max_num=max_input_tiles, |
|
use_thumbnail=self.config.use_thumbnail) |
|
else: |
|
image_tiles = [pil_image] |
|
|
|
pixel_values = [transform(item) for item in image_tiles] |
|
pixel_values = torch.stack(pixel_values).to(dtype=torch.bfloat16) |
|
pixel_values_list.append(pixel_values) |
|
else: |
|
pixel_values = None |
|
|
|
IMG_START_TOKEN='<img>' |
|
IMG_END_TOKEN='</img>' |
|
IMG_CONTEXT_TOKEN='<IMG_CONTEXT>' |
|
|
|
if pixel_values is not None and '<image>' not in content and not llm_only: |
|
content = '<image> ' + content |
|
|
|
|
|
template.messages.clear() |
|
|
|
|
|
template.append_message(template.roles[0], content) |
|
template.append_message(template.roles[1], None) |
|
content_prompt = template.get_prompt() |
|
|
|
if '<image>' not in content: |
|
content_prompt = content_prompt |
|
else: |
|
num_patches = pixel_values.shape[0] |
|
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN |
|
content_prompt = content_prompt.replace('<image>', image_tokens, 1) |
|
|
|
content_prompts.append(content_prompt) |
|
|
|
model_inputs = self.tokenizer(content_prompts, |
|
truncation=True, |
|
max_length=self.p_max_length, |
|
padding=self.padding, |
|
pad_to_multiple_of=self.pad_to_multiple_of, |
|
return_tensors='pt') |
|
|
|
if len(pixel_values_list)>1: |
|
pixel_values_squeezed = torch.concat(pixel_values_list, axis=0) |
|
elif len(pixel_values_list)==1: |
|
pixel_values_squeezed = pixel_values_list[0] |
|
else: |
|
pixel_values_squeezed = None |
|
|
|
batch_docs = { |
|
"input_ids": model_inputs['input_ids'], |
|
"attention_mask": model_inputs['attention_mask'], |
|
"pixel_values": None |
|
} |
|
if pixel_values_squeezed is not None: |
|
batch_docs["pixel_values"] = pixel_values_squeezed |
|
|
|
return batch_docs |
|
|
|
def process_queries(self, queries: List[str], **kwargs): |
|
|
|
template = get_conv_template(self.config.template) |
|
template.system_message = self.system_message |
|
|
|
query_prompts = [] |
|
for query in queries: |
|
if self.query_prefix: |
|
query = f"{self.query_prefix} {query}" |
|
|
|
|
|
template.messages.clear() |
|
|
|
template.append_message(template.roles[0], query) |
|
template.append_message(template.roles[1], None) |
|
query_prompt = template.get_prompt() |
|
|
|
query_prompts.append(query_prompt) |
|
|
|
|
|
batch_query = self.tokenizer( |
|
query_prompts, |
|
truncation=True, |
|
max_length=self.q_max_length, |
|
padding=self.padding, |
|
pad_to_multiple_of=self.pad_to_multiple_of, |
|
return_tensors='pt' |
|
) |
|
|
|
return batch_query |
|
|
|
def get_scores( |
|
self, |
|
query_embeddings: Union[torch.Tensor, List[torch.Tensor]], |
|
passage_embeddings: Union[torch.Tensor, List[torch.Tensor]], |
|
batch_size: Optional[int] = 8, |
|
) -> torch.Tensor: |
|
"""Dot-product similarity between queries and passages.""" |
|
if isinstance(query_embeddings, list): |
|
if len(query_embeddings[0].shape)==2: |
|
|
|
query_embeddings = [q.unsqueeze(0) for q in query_embeddings] |
|
query_embeddings = self.padding_various_shape_tensor(query_embeddings) |
|
if isinstance(passage_embeddings, list): |
|
if len(passage_embeddings[0].shape)==2: |
|
|
|
passage_embeddings = [p.unsqueeze(0) for p in passage_embeddings] |
|
passage_embeddings = self.padding_various_shape_tensor(passage_embeddings) |
|
|
|
return self.colbert_score(query_embeddings, passage_embeddings, batch_size) |
|
|
|
def colbert_score( |
|
self, |
|
qs: Union[torch.Tensor, List[torch.Tensor]], |
|
ps: Union[torch.Tensor, List[torch.Tensor]], |
|
batch_size: int = 128, |
|
device: Optional[Union[str, torch.device]] = None, |
|
) -> torch.Tensor: |
|
if batch_size is None: |
|
batch_size = 128 |
|
if len(qs) == 0: |
|
raise ValueError("No queries provided") |
|
if len(ps) == 0: |
|
raise ValueError("No passages provided") |
|
|
|
scores_list: List[torch.Tensor] = [] |
|
for i in range(0, len(qs), batch_size): |
|
scores_batch = [] |
|
qs_batch = torch.nn.utils.rnn.pad_sequence(qs[i : i + batch_size].cuda(), batch_first=True, padding_value=0) |
|
for j in range(0, len(ps), batch_size): |
|
ps_batch = torch.nn.utils.rnn.pad_sequence( |
|
ps[j : j + batch_size].cuda(), batch_first=True, padding_value=0 |
|
) |
|
scores_batch.append(torch.einsum("bnd,csd->bcns", qs_batch, ps_batch).max(dim=3)[0].sum(dim=2)) |
|
|
|
scores_batch = torch.cat(scores_batch, dim=1) |
|
scores_list.append(scores_batch) |
|
|
|
scores = torch.cat(scores_list, dim=0) |
|
return(scores) |
|
|
|
def _extract_embeddings(self, dataloader: DataLoader, is_query: bool) -> List[torch.Tensor]: |
|
qs = [] |
|
message = "query" if is_query else "document" |
|
for batch in tqdm(dataloader, desc=f"Extracting {message} embeddings..."): |
|
with torch.inference_mode(): |
|
with torch.autocast(device_type="cuda", dtype=torch.bfloat16): |
|
if 'pixel_values' in batch and batch['pixel_values'] is None: |
|
batch.pop('pixel_values') |
|
batch = {k: v.to(self.device) for k, v in batch.items()} |
|
embeddings = self(**batch, output_hidden_states=True).hidden_states[-1] |
|
embeddings = embeddings*batch['attention_mask'].unsqueeze(-1) |
|
embeddings = F.normalize(embeddings, dim=-1) |
|
|
|
|
|
assert torch.sum(embeddings).float().item() not in [float(0.), float("inf")] |
|
qs.append(embeddings.contiguous()) |
|
|
|
qs_tensor = self.padding_various_shape_tensor(qs) |
|
all_embeddings_tensor = qs_tensor.detach().cpu() |
|
return all_embeddings_tensor |
|
|
|
def forward_passages(self, passages, batch_size=8, **kwargs) -> Union[torch.Tensor, List[torch.Tensor]]: |
|
"""Forward passages as image-only documents.""" |
|
corpus = [] |
|
for image in passages: |
|
corpus.append({ |
|
"image": image, |
|
"text": '' |
|
}) |
|
return self.forward_documents(corpus, batch_size) |
|
|
|
def forward_queries(self, queries: List, batch_size=8) -> List[torch.Tensor]: |
|
dataset = ListDataset[str](queries) |
|
dataloader = DataLoader( |
|
dataset=dataset, |
|
batch_size=batch_size, |
|
collate_fn=self.process_queries, |
|
shuffle=False, |
|
num_workers=8, |
|
pin_memory=True, |
|
drop_last=False, |
|
) |
|
return self._extract_embeddings(dataloader=dataloader, is_query=True) |
|
|
|
def forward_documents(self, corpus: List, batch_size=8) -> List[torch.Tensor]: |
|
images = [] |
|
texts = [] |
|
for doc in corpus: |
|
text = doc["text"] |
|
image = doc.get("image", "") |
|
if image.mode != "RGB": |
|
image = image.convert("RGB") |
|
images.append(image) |
|
texts.append(text) |
|
|
|
dataset = Dataset.from_dict({"image": images, "text": texts}) |
|
dataloader = DataLoader( |
|
dataset=dataset, |
|
batch_size=batch_size, |
|
collate_fn=self.process_documents, |
|
shuffle=False, |
|
num_workers=8, |
|
pin_memory=True, |
|
drop_last=False, |
|
) |
|
return self._extract_embeddings(dataloader=dataloader, is_query=False) |
|
|
|
def padding_various_shape_tensor(self, tensors: List[torch.Tensor]) -> torch.Tensor: |
|
"""Pad tensors of various shapes for colbert-like scoring""" |
|
max_seq_len = max(t.shape[1] for t in tensors) |
|
padded_tensors = [F.pad(t, (0, 0, 0, max_seq_len - t.shape[1]), mode="constant", value=0) for t in tensors] |
|
return torch.cat(padded_tensors, dim=0) |
|
|
|
|
|
from typing import TypeVar |
|
from torch.utils.data import Dataset as TorchDataset |
|
TV = TypeVar("T") |
|
class ListDataset(TorchDataset[TV]): |
|
def __init__(self, elements: List[TV]): |
|
self.elements = elements |
|
|
|
def __len__(self) -> int: |
|
return len(self.elements) |
|
|
|
def __getitem__(self, idx: int) -> TV: |
|
return self.elements[idx] |