Chamin09 commited on
Commit
5aea4f1
·
verified ·
1 Parent(s): 71aaa5d

Update models/image_models.py

Browse files
Files changed (1) hide show
  1. models/image_models.py +2 -1
models/image_models.py CHANGED
@@ -78,7 +78,7 @@ class ImageModelManager:
78
  self.logger.info(f"Loading advanced image model: {self.advanced_model_name}")
79
  self.advanced_processor = Blip2Processor.from_pretrained(self.advanced_model_name)
80
  self.advanced_model = Blip2ForConditionalGeneration.from_pretrained(
81
- self.advanced_model_name, torch_dtype=torch.float16)
82
 
83
  self.initialized["advanced"] = True
84
  self.logger.info("Advanced image model initialized successfully")
@@ -211,6 +211,7 @@ class ImageModelManager:
211
  # Generate caption
212
  with torch.no_grad():
213
  if model_type == "advanced":
 
214
  generated_ids = model.generate(
215
  pixel_values=inputs.pixel_values,
216
  max_new_tokens=50, # Using max_new_tokens instead of max_length
 
78
  self.logger.info(f"Loading advanced image model: {self.advanced_model_name}")
79
  self.advanced_processor = Blip2Processor.from_pretrained(self.advanced_model_name)
80
  self.advanced_model = Blip2ForConditionalGeneration.from_pretrained(
81
+ self.advanced_model_name, torch_dtype=torch.float32)
82
 
83
  self.initialized["advanced"] = True
84
  self.logger.info("Advanced image model initialized successfully")
 
211
  # Generate caption
212
  with torch.no_grad():
213
  if model_type == "advanced":
214
+ pixel_values = inputs.pixel_values.to(torch.float32)
215
  generated_ids = model.generate(
216
  pixel_values=inputs.pixel_values,
217
  max_new_tokens=50, # Using max_new_tokens instead of max_length