# -------------------------------------------------------- # Copyright (c) 2025 NVIDIA # Licensed under customized NSCLv1 [see LICENSE.md for details] # -------------------------------------------------------- # Based on https://github.com/OpenGVLab/InternVL/blob/main/streamlit_demo/model_worker.py # https://github.com/OpenGVLab/InternVL/?tab=MIT-1-ov-file#readme # Importing torch before transformers can cause `segmentation fault` from transformers import AutoTokenizer, AutoConfig from transformers.modeling_outputs import SequenceClassifierOutputWithPast import base64 import os from io import BytesIO from typing import Tuple import math import requests import torch from torch import Tensor import torchvision.transforms as T from PIL import Image from torchvision.transforms.functional import InterpolationMode from typing import Optional, Any, Union, Dict, List from tqdm import tqdm import torch.nn.functional as F from datasets import Dataset from torch.utils.data import DataLoader from .modeling_eagle_chat import Eagle2ChatModel from .configuration_eagle_chat import Eagle2ChatConfig from .conversation import get_conv_template from .configuration_siglip import SiglipVisionConfig from .modeling_siglip import SiglipVisionModel from .flash_attention import * from .llama_bidirectional_model import LlamaBidirectionalModel from transformers import PreTrainedModel IMAGENET_MEAN = (0.485, 0.456, 0.406) IMAGENET_STD = (0.229, 0.224, 0.225) SIGLIP_MEAN = (0.5, 0.5, 0.5) SIGLIP_STD = (0.5, 0.5, 0.5) def load_image(image): if isinstance(image, Image.Image): return image elif isinstance(image, str) and os.path.exists(image): return Image.open(image) elif isinstance(image, dict): if 'disk_path' in image: return Image.open(image['disk_path']) elif 'base64' in image: return Image.open(BytesIO(base64.b64decode(image['base64']))) elif 'url' in image: response = requests.get(image['url']) return Image.open(BytesIO(response.content)) elif 'bytes' in image: return Image.open(BytesIO(image['bytes'])) else: raise ValueError(f'Invalid image: {image}') else: raise ValueError(f'Invalid image: {image}') def build_transform(input_size, norm_type='imagenet'): if norm_type == 'imagenet': MEAN, STD = IMAGENET_MEAN, IMAGENET_STD elif norm_type == 'siglip': MEAN, STD = SIGLIP_MEAN, SIGLIP_STD transform = T.Compose([ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=MEAN, std=STD) ]) return transform def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): """ previous version mainly foucs on ratio. We also consider area ratio here. """ best_factor = float('-inf') best_ratio = (1, 1) area = width * height for ratio in target_ratios: target_aspect_ratio = ratio[0] / ratio[1] ratio_diff = abs(aspect_ratio - target_aspect_ratio) area_ratio = (ratio[0]*ratio[1]*image_size*image_size)/ area """ new area > 60% of original image area is enough. """ factor_based_on_area_n_ratio = min((ratio[0]*ratio[1]*image_size*image_size)/ area, 0.6)* \ min(target_aspect_ratio/aspect_ratio, aspect_ratio/target_aspect_ratio) if factor_based_on_area_n_ratio > best_factor: best_factor = factor_based_on_area_n_ratio best_ratio = ratio return best_ratio def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False): orig_width, orig_height = image.size aspect_ratio = orig_width / orig_height # calculate the existing image aspect ratio target_ratios = set( (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num) target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) # find the closest aspect ratio to the target target_aspect_ratio = find_closest_aspect_ratio( aspect_ratio, target_ratios, orig_width, orig_height, image_size) # calculate the target width and height target_width = image_size * target_aspect_ratio[0] target_height = image_size * target_aspect_ratio[1] blocks = target_aspect_ratio[0] * target_aspect_ratio[1] # resize the image resized_img = image.resize((target_width, target_height)) processed_images = [] for i in range(blocks): box = ( (i % (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size, ((i % (target_width // image_size)) + 1) * image_size, ((i // (target_width // image_size)) + 1) * image_size ) # split the image split_img = resized_img.crop(box) processed_images.append(split_img) assert len(processed_images) == blocks if use_thumbnail and len(processed_images) != 1: thumbnail_img = image.resize((image_size, image_size)) processed_images.append(thumbnail_img) return processed_images def split_model(model_path, device): device_map = {} world_size = torch.cuda.device_count() config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) num_layers = config.llm_config.num_hidden_layers print('world_size', world_size) num_layers_per_gpu_ = math.floor(num_layers / (world_size - 1)) num_layers_per_gpu = [num_layers_per_gpu_] * world_size num_layers_per_gpu[device] = num_layers - num_layers_per_gpu_ * (world_size-1) print(num_layers_per_gpu) layer_cnt = 0 for i, num_layer in enumerate(num_layers_per_gpu): for j in range(num_layer): device_map[f'language_model.model.layers.{layer_cnt}'] = i layer_cnt += 1 device_map['vision_model'] = device device_map['mlp1'] = device device_map['language_model.model.tok_embeddings'] = device device_map['language_model.model.embed_tokens'] = device device_map['language_model.output'] = device device_map['language_model.model.norm'] = device device_map['language_model.lm_head'] = device device_map['language_model.model.rotary_emb'] = device device_map[f'language_model.model.layers.{num_layers - 1}'] = device return device_map class llama_NemoRetrieverColEmbedConfig(Eagle2ChatConfig): model_type = "llama_nemoretrievercolembed" q_max_length: Optional[int] p_max_length: Optional[int] query_prefix: str passage_prefix: str pooling: str bidirectional_attention: bool def __init__( self, q_max_length: Optional[int] = 512, p_max_length: Optional[int] = 10240, query_prefix: str = "query:", passage_prefix: str = "passage:", pooling: str = "last", bidirectional_attention: bool = False, max_input_tiles: int = 2, img_context_token_id: int = 128258, #tokenizer.convert_tokens_to_ids("") out_dimension: int = -1, **kwargs, ): self.q_max_length = q_max_length self.p_max_length = p_max_length self.query_prefix = query_prefix self.passage_prefix = passage_prefix self.pooling = pooling self.bidirectional_attention = bidirectional_attention self.img_context_token_id = img_context_token_id self.max_input_tiles = max_input_tiles self.out_dimension = out_dimension super().__init__(**kwargs) class llama_NemoRetrieverColEmbed(Eagle2ChatModel): config_class = llama_NemoRetrieverColEmbedConfig _supports_flash_attn_2 = True def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.padding = True self.q_max_length = 512 self.p_max_length = 10240 self.pad_to_multiple_of = None self.query_prefix = 'query:' self.passage_prefix = 'passage:' if isinstance(args[0], llama_NemoRetrieverColEmbedConfig): tokenizer = AutoTokenizer.from_pretrained(args[0]._name_or_path, trust_remote_code=True) tokens_to_keep = ['', '', '', ''] tokenizer.additional_special_tokens = [item for item in tokenizer.additional_special_tokens if item not in tokens_to_keep] tokenizer.padding_side = 'left' self.tokenizer = tokenizer self.norm_type = 'siglip' self.image_size = self.config.force_image_size self.max_input_tiles = 6 self.system_message = "" self.use_visual_embedding = True def process_documents(self, documents: Union[Dict,List[Dict]], **kwargs): if isinstance(documents, dict): images = documents["images"] texts = documents["texts"] assert len(texts) == len(images) elif isinstance(documents, list): images = [pair['image'] for pair in documents ] texts = [pair['text'] for pair in documents ] else: raise ValueError("The documents need to be a dict or list of dicts") if self.passage_prefix: texts = [self.passage_prefix + ' ' + t for t in texts] contents, pil_images, max_input_tile_list, llm_onlys = [], [], [], [] for image, text in zip(images, texts): prefix = '' llm_only = True if image != '': pil_images.append(load_image(image)) prefix = '' max_input_tile_list.append(self.max_input_tiles) llm_only = False else: pil_images.append(None) max_input_tile_list.append(self.max_input_tiles) llm_onlys.append(llm_only) content = text if prefix!='': content = prefix + ' ' + content if self.passage_prefix: content = self.passage_prefix + ' ' + content contents.append(content) transform = build_transform(input_size=self.image_size, norm_type=self.norm_type) template = get_conv_template(self.config.template) template.system_message = self.system_message content_prompts = [] pixel_values_list = [] for content, pil_image, max_input_tiles, llm_only in zip(contents, pil_images, max_input_tile_list, llm_onlys): if pil_image is not None: if self.config.dynamic_image_size: image_tiles = dynamic_preprocess( pil_image, image_size=self.image_size, max_num=max_input_tiles, use_thumbnail=self.config.use_thumbnail) else: image_tiles = [pil_image] pixel_values = [transform(item) for item in image_tiles] pixel_values = torch.stack(pixel_values).to(dtype=torch.bfloat16) pixel_values_list.append(pixel_values) else: pixel_values = None IMG_START_TOKEN='' IMG_END_TOKEN='' IMG_CONTEXT_TOKEN='' if pixel_values is not None and '' not in content and not llm_only: content = ' ' + content # Reseting conversation messages template.messages.clear() # TODO: do we need this template? template.append_message(template.roles[0], content) # user template.append_message(template.roles[1], None) # assistant content_prompt = template.get_prompt() if '' not in content: content_prompt = content_prompt else: num_patches = pixel_values.shape[0] image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN content_prompt = content_prompt.replace('', image_tokens, 1) content_prompts.append(content_prompt) model_inputs = self.tokenizer(content_prompts, truncation=True, max_length=self.p_max_length, padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors='pt') if len(pixel_values_list)>1: pixel_values_squeezed = torch.concat(pixel_values_list, axis=0) elif len(pixel_values_list)==1: pixel_values_squeezed = pixel_values_list[0] else: pixel_values_squeezed = None batch_docs = { "input_ids": model_inputs['input_ids'], "attention_mask": model_inputs['attention_mask'], "pixel_values": None } if pixel_values_squeezed is not None: batch_docs["pixel_values"] = pixel_values_squeezed return batch_docs def process_queries(self, queries: List[str], **kwargs): template = get_conv_template(self.config.template) template.system_message = self.system_message query_prompts = [] for query in queries: if self.query_prefix: query = f"{self.query_prefix} {query}" # Reseting conversation messages template.messages.clear() template.append_message(template.roles[0], query) # user template.append_message(template.roles[1], None) # assistant query_prompt = template.get_prompt() query_prompts.append(query_prompt) batch_query = self.tokenizer( query_prompts, truncation=True, max_length=self.q_max_length, padding=self.padding, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors='pt' ) return batch_query def get_scores( self, query_embeddings: Union[torch.Tensor, List[torch.Tensor]], passage_embeddings: Union[torch.Tensor, List[torch.Tensor]], batch_size: Optional[int] = 8, ) -> torch.Tensor: """Dot-product similarity between queries and passages.""" if isinstance(query_embeddings, list): if len(query_embeddings[0].shape)==2: # Expend Batch Dimension as ViDoRe Framework remove it query_embeddings = [q.unsqueeze(0) for q in query_embeddings] query_embeddings = self.padding_various_shape_tensor(query_embeddings) if isinstance(passage_embeddings, list): if len(passage_embeddings[0].shape)==2: # Expend Batch Dimension as ViDoRe Framework remove it passage_embeddings = [p.unsqueeze(0) for p in passage_embeddings] passage_embeddings = self.padding_various_shape_tensor(passage_embeddings) return self.colbert_score(query_embeddings, passage_embeddings, batch_size) def colbert_score( self, qs: Union[torch.Tensor, List[torch.Tensor]], ps: Union[torch.Tensor, List[torch.Tensor]], batch_size: int = 128, device: Optional[Union[str, torch.device]] = None, ) -> torch.Tensor: if batch_size is None: batch_size = 128 if len(qs) == 0: raise ValueError("No queries provided") if len(ps) == 0: raise ValueError("No passages provided") scores_list: List[torch.Tensor] = [] for i in range(0, len(qs), batch_size): scores_batch = [] qs_batch = torch.nn.utils.rnn.pad_sequence(qs[i : i + batch_size].cuda(), batch_first=True, padding_value=0) for j in range(0, len(ps), batch_size): ps_batch = torch.nn.utils.rnn.pad_sequence( ps[j : j + batch_size].cuda(), batch_first=True, padding_value=0 ) scores_batch.append(torch.einsum("bnd,csd->bcns", qs_batch, ps_batch).max(dim=3)[0].sum(dim=2)) # Keep scores_batch on the GPU scores_batch = torch.cat(scores_batch, dim=1) scores_list.append(scores_batch) scores = torch.cat(scores_list, dim=0) return(scores) def _extract_embeddings(self, dataloader: DataLoader, is_query: bool) -> List[torch.Tensor]: qs = [] message = "query" if is_query else "document" for batch in tqdm(dataloader, desc=f"Extracting {message} embeddings..."): with torch.inference_mode(): with torch.autocast(device_type="cuda", dtype=torch.bfloat16): if 'pixel_values' in batch and batch['pixel_values'] is None: batch.pop('pixel_values') batch = {k: v.to(self.device) for k, v in batch.items()} embeddings = self(**batch, output_hidden_states=True).hidden_states[-1] embeddings = embeddings*batch['attention_mask'].unsqueeze(-1) embeddings = F.normalize(embeddings, dim=-1) # Detecting abnormal outputs assert torch.sum(embeddings).float().item() not in [float(0.), float("inf")] qs.append(embeddings.contiguous()) qs_tensor = self.padding_various_shape_tensor(qs) all_embeddings_tensor = qs_tensor.detach().cpu() return all_embeddings_tensor def forward_passages(self, passages, batch_size=8, **kwargs) -> Union[torch.Tensor, List[torch.Tensor]]: """Forward passages as image-only documents.""" corpus = [] for image in passages: corpus.append({ "image": image, "text": '' }) return self.forward_documents(corpus, batch_size) def forward_queries(self, queries: List, batch_size=8) -> List[torch.Tensor]: dataset = ListDataset[str](queries) dataloader = DataLoader( dataset=dataset, batch_size=batch_size, collate_fn=self.process_queries, shuffle=False, num_workers=8, pin_memory=True, drop_last=False, ) return self._extract_embeddings(dataloader=dataloader, is_query=True) def forward_documents(self, corpus: List, batch_size=8) -> List[torch.Tensor]: images = [] texts = [] for doc in corpus: text = doc["text"] image = doc.get("image", "") if image.mode != "RGB": image = image.convert("RGB") images.append(image) texts.append(text) dataset = Dataset.from_dict({"image": images, "text": texts}) dataloader = DataLoader( dataset=dataset, batch_size=batch_size, collate_fn=self.process_documents, shuffle=False, num_workers=8, pin_memory=True, drop_last=False, ) return self._extract_embeddings(dataloader=dataloader, is_query=False) def padding_various_shape_tensor(self, tensors: List[torch.Tensor]) -> torch.Tensor: """Pad tensors of various shapes for colbert-like scoring""" max_seq_len = max(t.shape[1] for t in tensors) padded_tensors = [F.pad(t, (0, 0, 0, max_seq_len - t.shape[1]), mode="constant", value=0) for t in tensors] return torch.cat(padded_tensors, dim=0) from typing import TypeVar from torch.utils.data import Dataset as TorchDataset TV = TypeVar("T") class ListDataset(TorchDataset[TV]): def __init__(self, elements: List[TV]): self.elements = elements def __len__(self) -> int: return len(self.elements) def __getitem__(self, idx: int) -> TV: return self.elements[idx]