--- library_name: transformers tags: [] --- ```python import torch from transformers import AutoTokenizer, AutoModel # Pick one sentence sentence = "The patient has a right pneumothorax." # Load pretrained model and tokenizer model_name = "IAMJB/RadEvalModernBERT" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModel.from_pretrained(model_name) # Put model in eval mode and set device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) model.eval() # Tokenize input inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True).to(device) # Get embeddings with torch.no_grad(): outputs = model(**inputs, output_hidden_states=True) last_hidden_state = outputs.hidden_states[-1] cls_embedding = last_hidden_state[:, 0, :] # CLS token print("Sentence:", sentence) print("Embedding shape:", cls_embedding.shape) ``` ### Similarity heatmap example ```python import argparse import numpy as np import matplotlib.pyplot as plt import torch import seaborn as sns from transformers import AutoTokenizer, AutoModel def get_cls_embeddings(model, tokenizer, texts, device): """Get CLS token embeddings for a list of texts.""" embeddings = [] for text in texts: # Tokenize the text inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True) inputs = {k: v.to(device) for k, v in inputs.items()} # Get the embeddings (use CLS token) with torch.no_grad(): outputs = model(**inputs, output_hidden_states=True) # Use the last hidden state last_hidden_state = outputs.hidden_states[-1] # Extract CLS token (first token) embedding cls_embedding = last_hidden_state[:, 0, :] embeddings.append(cls_embedding.cpu().numpy()[0]) return np.array(embeddings) def compute_similarities(embeddings): """Compute cosine similarity between embeddings.""" # Normalize embeddings normalized_embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True) # Compute similarity matrix similarity_matrix = np.matmul(normalized_embeddings, normalized_embeddings.T) return similarity_matrix def plot_heatmap(similarity_matrix, labels, output_path="cls_embedding_similarities.png"): """Generate a heatmap visualization of the similarity matrix.""" plt.figure(figsize=(10, 8)) # Find min value to set as vmin (or use 0.6 as a reasonable value) min_val = max(0.0, np.min(similarity_matrix)) # Create the heatmap with adjusted color scale ax = sns.heatmap( similarity_matrix, annot=True, fmt=".3f", cmap="viridis", # Better colormap for distinguishing high values vmin=min_val, # Start from minimum value or 0.6 vmax=1.0, xticklabels=labels, yticklabels=labels, cbar_kws={"label": "Similarity"} ) # Add title and adjust layout plt.title("CLS Token Embedding Similarities") plt.tight_layout() # Rotate x-axis labels for better readability plt.xticks(rotation=90) # Save the figure plt.savefig(output_path, dpi=300, bbox_inches="tight") print(f"Heatmap saved to {output_path}") # Show the plot plt.show() def main(): # Medical terms to compare medical_terms = [ "large right pneumothorax", "right pneumothorax", "pneumonia in the right lower lobe", "consolidation in the right lower lobe", "right 9th rib fracture", "left 9th rib fracture", "left 5th rib fracture", "5th metatarsal fracture", "no pneumothorax is present", "prior consolidation has cleared", "no rib fractures" ] # Set the device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Load the tokenizer tokenizer = AutoTokenizer.from_pretrained(IAMJB/RadEvalModernBERT) # Load the model model = AutoModel.from_pretrained(IAMJB/RadEvalModernBERT) model.to(device) model.eval() # Get CLS token embeddings for the medical terms print("Generating CLS token embeddings...") embeddings = get_cls_embeddings(model, tokenizer, medical_terms, device) # Compute similarities print("Computing similarity matrix...") similarity_matrix = compute_similarities(embeddings) # Plot and save the heatmap print("Generating heatmap...") plot_heatmap(similarity_matrix, medical_terms, "cls_embedding_similarities.png") print("Done!") if __name__ == "__main__": main() ``` ![image/png](https://cdn-uploads.huggingface.co/production/uploads/62716952bcef985363db8485/6mzZ5_Xz2ovl3a6TlAzxo.png)