llama-nemoretriever-colembed-1b-v1 / modeling_llama_nemoretrievercolembed.py
nv-bschifferer's picture
adding license
780d274
# --------------------------------------------------------
# Copyright (c) 2025 NVIDIA
# Licensed under customized NSCLv1 [see LICENSE.md for details]
# --------------------------------------------------------
# Based on https://github.com/OpenGVLab/InternVL/blob/main/streamlit_demo/model_worker.py
# https://github.com/OpenGVLab/InternVL/?tab=MIT-1-ov-file#readme
# Importing torch before transformers can cause `segmentation fault`
from transformers import AutoTokenizer, AutoConfig
from transformers.modeling_outputs import SequenceClassifierOutputWithPast
import base64
import os
from io import BytesIO
from typing import Tuple
import math
import requests
import torch
from torch import Tensor
import torchvision.transforms as T
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
from typing import Optional, Any, Union, Dict, List
from tqdm import tqdm
import torch.nn.functional as F
from datasets import Dataset
from torch.utils.data import DataLoader
from .modeling_eagle_chat import Eagle2ChatModel
from .configuration_eagle_chat import Eagle2ChatConfig
from .conversation import get_conv_template
from .configuration_siglip import SiglipVisionConfig
from .modeling_siglip import SiglipVisionModel
from .flash_attention import *
from .llama_bidirectional_model import LlamaBidirectionalModel
from transformers import PreTrainedModel
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
SIGLIP_MEAN = (0.5, 0.5, 0.5)
SIGLIP_STD = (0.5, 0.5, 0.5)
def load_image(image):
if isinstance(image, Image.Image):
return image
elif isinstance(image, str) and os.path.exists(image):
return Image.open(image)
elif isinstance(image, dict):
if 'disk_path' in image:
return Image.open(image['disk_path'])
elif 'base64' in image:
return Image.open(BytesIO(base64.b64decode(image['base64'])))
elif 'url' in image:
response = requests.get(image['url'])
return Image.open(BytesIO(response.content))
elif 'bytes' in image:
return Image.open(BytesIO(image['bytes']))
else:
raise ValueError(f'Invalid image: {image}')
else:
raise ValueError(f'Invalid image: {image}')
def build_transform(input_size, norm_type='imagenet'):
if norm_type == 'imagenet':
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
elif norm_type == 'siglip':
MEAN, STD = SIGLIP_MEAN, SIGLIP_STD
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
return transform
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
"""
previous version mainly foucs on ratio.
We also consider area ratio here.
"""
best_factor = float('-inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
area_ratio = (ratio[0]*ratio[1]*image_size*image_size)/ area
"""
new area > 60% of original image area is enough.
"""
factor_based_on_area_n_ratio = min((ratio[0]*ratio[1]*image_size*image_size)/ area, 0.6)* \
min(target_aspect_ratio/aspect_ratio, aspect_ratio/target_aspect_ratio)
if factor_based_on_area_n_ratio > best_factor:
best_factor = factor_based_on_area_n_ratio
best_ratio = ratio
return best_ratio
def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
i * j <= max_num and i * j >= min_num)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images
def split_model(model_path, device):
device_map = {}
world_size = torch.cuda.device_count()
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
num_layers = config.llm_config.num_hidden_layers
print('world_size', world_size)
num_layers_per_gpu_ = math.floor(num_layers / (world_size - 1))
num_layers_per_gpu = [num_layers_per_gpu_] * world_size
num_layers_per_gpu[device] = num_layers - num_layers_per_gpu_ * (world_size-1)
print(num_layers_per_gpu)
layer_cnt = 0
for i, num_layer in enumerate(num_layers_per_gpu):
for j in range(num_layer):
device_map[f'language_model.model.layers.{layer_cnt}'] = i
layer_cnt += 1
device_map['vision_model'] = device
device_map['mlp1'] = device
device_map['language_model.model.tok_embeddings'] = device
device_map['language_model.model.embed_tokens'] = device
device_map['language_model.output'] = device
device_map['language_model.model.norm'] = device
device_map['language_model.lm_head'] = device
device_map['language_model.model.rotary_emb'] = device
device_map[f'language_model.model.layers.{num_layers - 1}'] = device
return device_map
class llama_NemoRetrieverColEmbedConfig(Eagle2ChatConfig):
model_type = "llama_nemoretrievercolembed"
q_max_length: Optional[int]
p_max_length: Optional[int]
query_prefix: str
passage_prefix: str
pooling: str
bidirectional_attention: bool
def __init__(
self,
q_max_length: Optional[int] = 512,
p_max_length: Optional[int] = 10240,
query_prefix: str = "query:",
passage_prefix: str = "passage:",
pooling: str = "last",
bidirectional_attention: bool = False,
max_input_tiles: int = 2,
img_context_token_id: int = 128258, #tokenizer.convert_tokens_to_ids("<IMG_CONTEXT>")
out_dimension: int = -1,
**kwargs,
):
self.q_max_length = q_max_length
self.p_max_length = p_max_length
self.query_prefix = query_prefix
self.passage_prefix = passage_prefix
self.pooling = pooling
self.bidirectional_attention = bidirectional_attention
self.img_context_token_id = img_context_token_id
self.max_input_tiles = max_input_tiles
self.out_dimension = out_dimension
super().__init__(**kwargs)
class llama_NemoRetrieverColEmbed(Eagle2ChatModel):
config_class = llama_NemoRetrieverColEmbedConfig
_supports_flash_attn_2 = True
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.padding = True
self.q_max_length = 512
self.p_max_length = 10240
self.pad_to_multiple_of = None
self.query_prefix = 'query:'
self.passage_prefix = 'passage:'
if isinstance(args[0], llama_NemoRetrieverColEmbedConfig):
tokenizer = AutoTokenizer.from_pretrained(args[0]._name_or_path, trust_remote_code=True)
tokens_to_keep = ['<box>', '</box>', '<ref>', '</ref>']
tokenizer.additional_special_tokens = [item for item in tokenizer.additional_special_tokens if item not in tokens_to_keep]
tokenizer.padding_side = 'left'
self.tokenizer = tokenizer
self.norm_type = 'siglip'
self.image_size = self.config.force_image_size
self.max_input_tiles = 6
self.system_message = ""
self.use_visual_embedding = True
def process_documents(self, documents: Union[Dict,List[Dict]], **kwargs):
if isinstance(documents, dict):
images = documents["images"]
texts = documents["texts"]
assert len(texts) == len(images)
elif isinstance(documents, list):
images = [pair['image'] for pair in documents ]
texts = [pair['text'] for pair in documents ]
else:
raise ValueError("The documents need to be a dict or list of dicts")
if self.passage_prefix:
texts = [self.passage_prefix + ' ' + t for t in texts]
contents, pil_images, max_input_tile_list, llm_onlys = [], [], [], []
for image, text in zip(images, texts):
prefix = ''
llm_only = True
if image != '':
pil_images.append(load_image(image))
prefix = '<image>'
max_input_tile_list.append(self.max_input_tiles)
llm_only = False
else:
pil_images.append(None)
max_input_tile_list.append(self.max_input_tiles)
llm_onlys.append(llm_only)
content = text
if prefix!='':
content = prefix + ' ' + content
if self.passage_prefix:
content = self.passage_prefix + ' ' + content
contents.append(content)
transform = build_transform(input_size=self.image_size, norm_type=self.norm_type)
template = get_conv_template(self.config.template)
template.system_message = self.system_message
content_prompts = []
pixel_values_list = []
for content, pil_image, max_input_tiles, llm_only in zip(contents, pil_images, max_input_tile_list, llm_onlys):
if pil_image is not None:
if self.config.dynamic_image_size:
image_tiles = dynamic_preprocess(
pil_image, image_size=self.image_size, max_num=max_input_tiles,
use_thumbnail=self.config.use_thumbnail)
else:
image_tiles = [pil_image]
pixel_values = [transform(item) for item in image_tiles]
pixel_values = torch.stack(pixel_values).to(dtype=torch.bfloat16)
pixel_values_list.append(pixel_values)
else:
pixel_values = None
IMG_START_TOKEN='<img>'
IMG_END_TOKEN='</img>'
IMG_CONTEXT_TOKEN='<IMG_CONTEXT>'
if pixel_values is not None and '<image>' not in content and not llm_only:
content = '<image> ' + content
# Reseting conversation messages
template.messages.clear()
# TODO: do we need this template?
template.append_message(template.roles[0], content) # user
template.append_message(template.roles[1], None) # assistant
content_prompt = template.get_prompt()
if '<image>' not in content:
content_prompt = content_prompt
else:
num_patches = pixel_values.shape[0]
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
content_prompt = content_prompt.replace('<image>', image_tokens, 1)
content_prompts.append(content_prompt)
model_inputs = self.tokenizer(content_prompts,
truncation=True,
max_length=self.p_max_length,
padding=self.padding,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors='pt')
if len(pixel_values_list)>1:
pixel_values_squeezed = torch.concat(pixel_values_list, axis=0)
elif len(pixel_values_list)==1:
pixel_values_squeezed = pixel_values_list[0]
else:
pixel_values_squeezed = None
batch_docs = {
"input_ids": model_inputs['input_ids'],
"attention_mask": model_inputs['attention_mask'],
"pixel_values": None
}
if pixel_values_squeezed is not None:
batch_docs["pixel_values"] = pixel_values_squeezed
return batch_docs
def process_queries(self, queries: List[str], **kwargs):
template = get_conv_template(self.config.template)
template.system_message = self.system_message
query_prompts = []
for query in queries:
if self.query_prefix:
query = f"{self.query_prefix} {query}"
# Reseting conversation messages
template.messages.clear()
template.append_message(template.roles[0], query) # user
template.append_message(template.roles[1], None) # assistant
query_prompt = template.get_prompt()
query_prompts.append(query_prompt)
batch_query = self.tokenizer(
query_prompts,
truncation=True,
max_length=self.q_max_length,
padding=self.padding,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors='pt'
)
return batch_query
def get_scores(
self,
query_embeddings: Union[torch.Tensor, List[torch.Tensor]],
passage_embeddings: Union[torch.Tensor, List[torch.Tensor]],
batch_size: Optional[int] = 8,
) -> torch.Tensor:
"""Dot-product similarity between queries and passages."""
if isinstance(query_embeddings, list):
if len(query_embeddings[0].shape)==2:
# Expend Batch Dimension as ViDoRe Framework remove it
query_embeddings = [q.unsqueeze(0) for q in query_embeddings]
query_embeddings = self.padding_various_shape_tensor(query_embeddings)
if isinstance(passage_embeddings, list):
if len(passage_embeddings[0].shape)==2:
# Expend Batch Dimension as ViDoRe Framework remove it
passage_embeddings = [p.unsqueeze(0) for p in passage_embeddings]
passage_embeddings = self.padding_various_shape_tensor(passage_embeddings)
return self.colbert_score(query_embeddings, passage_embeddings, batch_size)
def colbert_score(
self,
qs: Union[torch.Tensor, List[torch.Tensor]],
ps: Union[torch.Tensor, List[torch.Tensor]],
batch_size: int = 128,
device: Optional[Union[str, torch.device]] = None,
) -> torch.Tensor:
if batch_size is None:
batch_size = 128
if len(qs) == 0:
raise ValueError("No queries provided")
if len(ps) == 0:
raise ValueError("No passages provided")
scores_list: List[torch.Tensor] = []
for i in range(0, len(qs), batch_size):
scores_batch = []
qs_batch = torch.nn.utils.rnn.pad_sequence(qs[i : i + batch_size].cuda(), batch_first=True, padding_value=0)
for j in range(0, len(ps), batch_size):
ps_batch = torch.nn.utils.rnn.pad_sequence(
ps[j : j + batch_size].cuda(), batch_first=True, padding_value=0
)
scores_batch.append(torch.einsum("bnd,csd->bcns", qs_batch, ps_batch).max(dim=3)[0].sum(dim=2))
# Keep scores_batch on the GPU
scores_batch = torch.cat(scores_batch, dim=1)
scores_list.append(scores_batch)
scores = torch.cat(scores_list, dim=0)
return(scores)
def _extract_embeddings(self, dataloader: DataLoader, is_query: bool) -> List[torch.Tensor]:
qs = []
message = "query" if is_query else "document"
for batch in tqdm(dataloader, desc=f"Extracting {message} embeddings..."):
with torch.inference_mode():
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
if 'pixel_values' in batch and batch['pixel_values'] is None:
batch.pop('pixel_values')
batch = {k: v.to(self.device) for k, v in batch.items()}
embeddings = self(**batch, output_hidden_states=True).hidden_states[-1]
embeddings = embeddings*batch['attention_mask'].unsqueeze(-1)
embeddings = F.normalize(embeddings, dim=-1)
# Detecting abnormal outputs
assert torch.sum(embeddings).float().item() not in [float(0.), float("inf")]
qs.append(embeddings.contiguous())
qs_tensor = self.padding_various_shape_tensor(qs)
all_embeddings_tensor = qs_tensor.detach().cpu()
return all_embeddings_tensor
def forward_passages(self, passages, batch_size=8, **kwargs) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Forward passages as image-only documents."""
corpus = []
for image in passages:
corpus.append({
"image": image,
"text": ''
})
return self.forward_documents(corpus, batch_size)
def forward_queries(self, queries: List, batch_size=8) -> List[torch.Tensor]:
dataset = ListDataset[str](queries)
dataloader = DataLoader(
dataset=dataset,
batch_size=batch_size,
collate_fn=self.process_queries,
shuffle=False,
num_workers=8,
pin_memory=True,
drop_last=False,
)
return self._extract_embeddings(dataloader=dataloader, is_query=True)
def forward_documents(self, corpus: List, batch_size=8) -> List[torch.Tensor]:
images = []
texts = []
for doc in corpus:
text = doc["text"]
image = doc.get("image", "")
if image.mode != "RGB":
image = image.convert("RGB")
images.append(image)
texts.append(text)
dataset = Dataset.from_dict({"image": images, "text": texts})
dataloader = DataLoader(
dataset=dataset,
batch_size=batch_size,
collate_fn=self.process_documents,
shuffle=False,
num_workers=8,
pin_memory=True,
drop_last=False,
)
return self._extract_embeddings(dataloader=dataloader, is_query=False)
def padding_various_shape_tensor(self, tensors: List[torch.Tensor]) -> torch.Tensor:
"""Pad tensors of various shapes for colbert-like scoring"""
max_seq_len = max(t.shape[1] for t in tensors)
padded_tensors = [F.pad(t, (0, 0, 0, max_seq_len - t.shape[1]), mode="constant", value=0) for t in tensors]
return torch.cat(padded_tensors, dim=0)
from typing import TypeVar
from torch.utils.data import Dataset as TorchDataset
TV = TypeVar("T")
class ListDataset(TorchDataset[TV]):
def __init__(self, elements: List[TV]):
self.elements = elements
def __len__(self) -> int:
return len(self.elements)
def __getitem__(self, idx: int) -> TV:
return self.elements[idx]