Alexey Kalinin commited on
Commit
5a69625
·
1 Parent(s): 84152d3

update inference

Browse files
Files changed (2) hide show
  1. app.py +5 -3
  2. inference.py +10 -3
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import streamlit as st
2
  import arxiv
3
- from inference import classify, load_model # Replace with your actual module
4
 
5
  @st.cache_resource
6
  def get_model():
@@ -41,6 +41,8 @@ elif input_method == "Manual Input":
41
 
42
  # Classification and output
43
  if title and abstract:
44
- category = classify(pipe, class2name, title, abstract)
45
  st.write(f"### Title: {title}")
46
- st.write(f"**Predicted Category:** {category}")
 
 
 
1
  import streamlit as st
2
  import arxiv
3
+ from inference import top_95_labels, load_model # Replace with your actual module
4
 
5
  @st.cache_resource
6
  def get_model():
 
41
 
42
  # Classification and output
43
  if title and abstract:
44
+ categories = top_95_labels(title, abstract) # Assuming classify returns a list of labels
45
  st.write(f"### Title: {title}")
46
+ st.write("**Predicted Categories:**")
47
+ for category in categories:
48
+ st.write(f"- {category}")
inference.py CHANGED
@@ -16,9 +16,16 @@ def load_model(path2chkpt: str, path2mapping: str):
16
  return pipe, class2name
17
 
18
 
19
- def classify(pipe: Pipeline, class2name: dict[str, str], title: str, abstract: str):
20
  inputs = ".".join([title, abstract])
21
- class_code = pipe(inputs)[0]["label"]
22
 
23
- return class2name[class_code]
 
 
 
 
 
 
 
24
 
 
16
  return pipe, class2name
17
 
18
 
19
+ def top_95_labels(pipe: Pipeline, class2name: dict[str, str], title: str, abstract: str):
20
  inputs = ".".join([title, abstract])
21
+ result = pipe(inputs, top_k=20)
22
 
23
+ proba = 0
24
+ labels = []
25
+ i = 0
26
+ while proba < 0.95:
27
+ proba += result[i]["score"]
28
+ labels.append(result[i]["label"])
29
+ i += 1
30
+ return [class2name[label] for label in labels]
31