description
This Peft model is trained on the NER dataset.
The goal is for the base model to be able to better recognize entities using classification.
base model accuracy on validation of NER dataset: 0.3777
PEFT model accuracy on validation of NER dataset: 0.7657
you can see the training process of PEFT from here
using
# import libraries
import numpy as np
import torch
from transformers import (
AutoTokenizer,
BitsAndBytesConfig,
TrainingArguments,
AutoModelForSequenceClassification
)
# Load model
model_id = "MVesalA/peft-NER-existenceCLS-HooshvareLab"
tokenizer = AutoTokenizer.from_pretrained(model_id)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True, # Enables 4-bit quantization
bnb_4bit_use_double_quant=True, # Use double quantization for potentially higher accuracy (optional)
bnb_4bit_quant_type="nf4", # Quantization type (specifics depend on hardware and library)
bnb_4bit_compute_dtype=torch.bfloat16 # Compute dtype for improved efficiency (optional)
)
id2label = {0: "NEGATIVE", 1: "POSITIVE"}
label2id = {"NEGATIVE": 0, "POSITIVE": 1}
model = AutoModelForSequenceClassification.from_pretrained(
model_id, # "MVesalA/peft-NER-existenceCLS-HooshvareLab"
num_labels=2, # Number of output labels (2 for binary sentiment classification)
id2label=id2label, # {0: "NEGATIVE", 1: "POSITIVE"}
label2id=label2id, # {"NEGATIVE": 0, "POSITIVE": 1}
quantization_config=bnb_config # configuration for quantization
)
# predict entity
def predict(input_text, model=model):
"""
Predicts the sentiment label for a given text input.
Args:
input_text (str): The text to predict the sentiment for.
Returns:
float: The predicted probability of the text being positive sentiment.
"""
inputs = tokenizer(input_text, return_tensors="pt").to("cuda") # Convert to PyTorch tensors and move to GPU (if available)
with torch.no_grad():
outputs = model(**inputs).logits # Get the model's output logits
y_prob = torch.sigmoid(outputs).tolist()[0] # Apply sigmoid activation and convert to list
return np.round(y_prob, 5) # Round the predicted probability to 5 decimal places
predict("input_text") # ["Negative_Prob", "Positive_prob"]
- Downloads last month
- 38