Maximilian Werk commited on
Commit
cf456d3
·
1 Parent(s): b7707d5

feat: reduced default noise of the model

Browse files
configuration_jina_embeddings_v4.py CHANGED
@@ -2,6 +2,7 @@ from transformers.models.qwen2_5_vl import Qwen2_5_VLConfig
2
 
3
  from typing import Optional
4
 
 
5
  class JinaEmbeddingsV4Config(Qwen2_5_VLConfig):
6
  """
7
  Configuration for the JinaEmbeddingsV4 model.
@@ -12,10 +13,11 @@ class JinaEmbeddingsV4Config(Qwen2_5_VLConfig):
12
  single_vector_pool_strategy: str = "mean",
13
  multi_vector_projector_dim: int = 128,
14
  pretrained_peft_model_name_or_path: Optional[str] = None,
 
15
  **kwargs,
16
  ):
17
  super().__init__(**kwargs)
18
  self.single_vector_pool_strategy = single_vector_pool_strategy
19
  self.multi_vector_projector_dim = multi_vector_projector_dim
20
  self.pretrained_peft_model_name_or_path = pretrained_peft_model_name_or_path
21
-
 
2
 
3
  from typing import Optional
4
 
5
+
6
  class JinaEmbeddingsV4Config(Qwen2_5_VLConfig):
7
  """
8
  Configuration for the JinaEmbeddingsV4 model.
 
13
  single_vector_pool_strategy: str = "mean",
14
  multi_vector_projector_dim: int = 128,
15
  pretrained_peft_model_name_or_path: Optional[str] = None,
16
+ verbosity: int = 0,
17
  **kwargs,
18
  ):
19
  super().__init__(**kwargs)
20
  self.single_vector_pool_strategy = single_vector_pool_strategy
21
  self.multi_vector_projector_dim = multi_vector_projector_dim
22
  self.pretrained_peft_model_name_or_path = pretrained_peft_model_name_or_path
23
+ self.verbosity = verbosity
modeling_jina_embeddings_v4.py CHANGED
@@ -146,6 +146,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
146
  self.name_or_path, trust_remote_code=True, use_fast=True
147
  )
148
  self.multi_vector_projector_dim = config.multi_vector_projector_dim
 
149
  self._task = None
150
 
151
  @property
@@ -335,7 +336,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
335
  assert not return_numpy, "`return_numpy` is not supported when `return_multivector=True` and more than one data is encoded"
336
  results = []
337
  self.eval()
338
- for batch in tqdm(dataloader, desc=desc):
339
  with torch.no_grad():
340
  batch = {k: v.to(self.device) for k, v in batch.items()}
341
  with torch.autocast(
@@ -349,7 +350,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
349
  embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=-1)
350
  else:
351
  embeddings = embeddings.multi_vec_emb
352
-
353
  if return_multivector and not return_numpy:
354
  valid_tokens = batch["attention_mask"].bool()
355
  embeddings = [
@@ -453,7 +454,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
453
  if return_numpy:
454
  print("Warning: `return_numpy` is ignored when `return_multivector=True` and `len(texts) > 1`")
455
  return_numpy = False
456
-
457
  if isinstance(texts, str):
458
  texts = [texts]
459
 
@@ -468,7 +469,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
468
  **encode_kwargs,
469
  )
470
 
471
- return embeddings if return_list else embeddings[0]
472
 
473
  def _load_images_if_needed(
474
  self, images: List[Union[str, Image.Image]]
@@ -515,7 +516,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
515
  )
516
  encode_kwargs = self._validate_encoding_params(truncate_dim=truncate_dim)
517
  task = self._validate_task(task)
518
-
519
  return_list = isinstance(images, list)
520
 
521
  # If return_multivector is True and encoding multiple images, ignore return_numpy
@@ -527,7 +528,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
527
  # Convert single image to list
528
  if isinstance(images, (str, Image.Image)):
529
  images = [images]
530
-
531
  images = self._load_images_if_needed(images)
532
  embeddings = self._process_batches(
533
  data=images,
 
146
  self.name_or_path, trust_remote_code=True, use_fast=True
147
  )
148
  self.multi_vector_projector_dim = config.multi_vector_projector_dim
149
+ self.verbosity = config.verbosity
150
  self._task = None
151
 
152
  @property
 
336
  assert not return_numpy, "`return_numpy` is not supported when `return_multivector=True` and more than one data is encoded"
337
  results = []
338
  self.eval()
339
+ for batch in tqdm(dataloader, desc=desc, disable=self.verbosity == 0):
340
  with torch.no_grad():
341
  batch = {k: v.to(self.device) for k, v in batch.items()}
342
  with torch.autocast(
 
350
  embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=-1)
351
  else:
352
  embeddings = embeddings.multi_vec_emb
353
+
354
  if return_multivector and not return_numpy:
355
  valid_tokens = batch["attention_mask"].bool()
356
  embeddings = [
 
454
  if return_numpy:
455
  print("Warning: `return_numpy` is ignored when `return_multivector=True` and `len(texts) > 1`")
456
  return_numpy = False
457
+
458
  if isinstance(texts, str):
459
  texts = [texts]
460
 
 
469
  **encode_kwargs,
470
  )
471
 
472
+ return embeddings if return_list else embeddings[0]
473
 
474
  def _load_images_if_needed(
475
  self, images: List[Union[str, Image.Image]]
 
516
  )
517
  encode_kwargs = self._validate_encoding_params(truncate_dim=truncate_dim)
518
  task = self._validate_task(task)
519
+
520
  return_list = isinstance(images, list)
521
 
522
  # If return_multivector is True and encoding multiple images, ignore return_numpy
 
528
  # Convert single image to list
529
  if isinstance(images, (str, Image.Image)):
530
  images = [images]
531
+
532
  images = self._load_images_if_needed(images)
533
  embeddings = self._process_batches(
534
  data=images,