import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
classifier_path = "KomeijiForce/deberta-v3-base-check-conversational-scene"
classifier_tokenizer = AutoTokenizer.from_pretrained(classifier_path)
classifier = AutoModelForSequenceClassification.from_pretrained(classifier_path)
scene = '''
Arisa: I know you're on a roll right now Kasumi, but just calm down...
Kasumi: I can't! I can't calm down!
Arisa: I'm saying you're losing it, Kasumi.
Saaya: We know exactly what you mean, Kasumi.
Kasumi: Right?!
Rimi: Of course. We all felt it, singing together like that♪
Kasumi: Yay! I'm so glad you understand, Rimi-rin♪
Tae: Sharing moments like this is what's really important.
Kasumi: Sharing...! I wanna share this feeling with everybody! This sparkling, heart-pounding feeling!
'''.strip()
question = "Is Kasumi inspired by something exciting?"
prompt = f'''Scene: {scene}
Question: {question}
Directly answer only yes/no/unknown.'''
with torch.no_grad():
logits = classifier(**classifier_tokenizer(prompt, return_tensors="pt")).logits[0]
choice = logits.argmax(-1).item()
answer = [False, None, True][choice]
print(answer)