Alexey Kalinin commited on
Commit
44ee98d
·
1 Parent(s): eda7916

add trained model

Browse files
Files changed (4) hide show
  1. app.py +3 -3
  2. class2name.joblib +0 -0
  3. inference.py +22 -4
  4. requirements.txt +2 -1
app.py CHANGED
@@ -4,7 +4,7 @@ from inference import classify, load_model # Replace with your actual module
4
 
5
  @st.cache_resource
6
  def get_model():
7
- return load_model()
8
 
9
  def get_arxiv_article_info(url):
10
  """Extracts title and abstract from an arXiv article link."""
@@ -16,7 +16,7 @@ def get_arxiv_article_info(url):
16
  return None, None
17
 
18
  # Load model once
19
- model = get_model()
20
 
21
  st.title("ArXiv Article Classifier")
22
 
@@ -41,6 +41,6 @@ elif input_method == "Manual Input":
41
 
42
  # Classification and output
43
  if title and abstract:
44
- category = classify(title, abstract)
45
  st.write(f"### Title: {title}")
46
  st.write(f"**Predicted Category:** {category}")
 
4
 
5
  @st.cache_resource
6
  def get_model():
7
+ return load_model("kalinin-a-i/ml2-hw4", "class2name.joblib")
8
 
9
  def get_arxiv_article_info(url):
10
  """Extracts title and abstract from an arXiv article link."""
 
16
  return None, None
17
 
18
  # Load model once
19
+ pipe, class2name = get_model()
20
 
21
  st.title("ArXiv Article Classifier")
22
 
 
41
 
42
  # Classification and output
43
  if title and abstract:
44
+ category = classify(pipe, class2name, title, abstract)
45
  st.write(f"### Title: {title}")
46
  st.write(f"**Predicted Category:** {category}")
class2name.joblib ADDED
Binary file (400 Bytes). View file
 
inference.py CHANGED
@@ -1,6 +1,24 @@
1
- def load_model():
2
- pass
 
 
3
 
4
- def classify(title, abstract):
5
- return "physics"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
 
1
+ from transformers import AutoTokenizer
2
+ from transformers import AutoModelForSequenceClassification
3
+ from transformers import pipeline, Pipeline
4
+ from joblib import load
5
 
6
+
7
+ def load_model(path2chkpt: str, path2mapping: str):
8
+ model = AutoModelForSequenceClassification.from_pretrained("/home/jupyter/datasphere/project/hw4_nlp_ops/weights_20_classes/checkpoint-4500")
9
+ tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-cased")
10
+
11
+ pipe = pipeline("text-classification",
12
+ model=model,
13
+ tokenizer=tokenizer)
14
+
15
+ class2name = load(path2mapping)
16
+ return pipe, class2name
17
+
18
+
19
+ def classify(pipe: Pipeline, class2name: dict[str, str], title: str, abstract: str):
20
+ inputs = ".".join([title, abstract])
21
+ class_code = pipe(inputs)[0]["label"]
22
+
23
+ return class2name[class_code]
24
 
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
  streamlit
2
- arxiv
 
 
1
  streamlit
2
+ arxiv
3
+ joblib