|
import streamlit as st |
|
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast |
|
|
|
@st.cache_resource |
|
def load_model(): |
|
tokenizer = MBart50TokenizerFast.from_pretrained("MahmutCanBoran/mbart-audi-diagnosis-agent") |
|
model = MBartForConditionalGeneration.from_pretrained("MahmutCanBoran/mbart-audi-diagnosis-agent") |
|
return tokenizer, model |
|
|
|
tokenizer, model = load_model() |
|
|
|
st.title("π§ Audi AI Diagnosis Agent") |
|
st.markdown("Enter your Audi issue in **English**, and the AI will try to diagnose it.") |
|
|
|
text = st.text_area("π What's the problem?", "") |
|
|
|
if st.button("π§ Diagnose"): |
|
if not text.strip(): |
|
st.warning("Please enter a vehicle problem.") |
|
else: |
|
inputs = tokenizer( |
|
text, |
|
return_tensors="pt", |
|
truncation=True, |
|
padding="max_length", |
|
max_length=128 |
|
).to(model.device) |
|
|
|
output_ids = model.generate( |
|
**inputs, |
|
max_length=128, |
|
num_beams=4, |
|
early_stopping=True, |
|
forced_bos_token_id=tokenizer.lang_code_to_id["en_XX"] |
|
) |
|
|
|
result = tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
st.success("π§ AI Diagnosis:") |
|
st.markdown(f"**{result}**") |