PD03 commited on
Commit
6a97111
·
verified ·
1 Parent(s): 887b999

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -16
app.py CHANGED
@@ -1,29 +1,67 @@
1
- # app.py
2
- import pandas as pd
3
- from transformers import pipeline
4
- import gradio as gr
5
 
6
- # Load synthetic data
7
- df = pd.read_csv("synthetic_profit.csv")
8
 
9
- # Initialize TAPAS QA pipeline
 
10
  qa = pipeline(
11
  "table-question-answering",
12
  model="google/tapas-base-finetuned-sqa",
13
  tokenizer="google/tapas-base-finetuned-sqa"
14
  )
15
 
16
- def answer(query: str) -> str:
17
- res = qa(table=df, query=query)
18
- return f"**Answer:** {res['answer']} _(agg: {res['aggregate']})_"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  demo = gr.Interface(
21
  fn=answer,
22
- inputs=gr.Textbox(lines=2, placeholder="e.g. What was profit margin for Product B in EMEA Q2 2024?"),
23
- outputs="markdown",
24
  title="S/4HANA Profitability Chat",
25
- description="Ask questions of synthetic S/4HANA data using TAPAS"
26
  )
27
-
28
- if __name__ == "__main__":
29
- demo.launch()
 
 
 
 
 
1
 
 
 
2
 
3
+ # 3) load TAPAS
4
+ from transformers import pipeline
5
  qa = pipeline(
6
  "table-question-answering",
7
  model="google/tapas-base-finetuned-sqa",
8
  tokenizer="google/tapas-base-finetuned-sqa"
9
  )
10
 
11
+ # 4) cast to strings to avoid the regex bug
12
+ df_str = df.astype(str)
13
+
14
+ # 5) sanity check
15
+ print( qa(table=df_str, query="What was the ProfitMargin for Product B in EMEA Q2 2024?") )
16
+
17
+ # 6) launch Gradio
18
+ import gradio as gr
19
+
20
+ import re
21
+
22
+ def answer(q: str) -> str:
23
+ # --- 1. try to parse explicit total/average queries ---
24
+ m = re.search(r"\b(total|average)\s+(ProfitMargin|Profit|Revenue|Cost)\b", q, re.IGNORECASE)
25
+ p = re.search(r"\bProduct\s*([A-D])\b", q, re.IGNORECASE)
26
+ t = re.search(r"\b(Q[1-4])\s*(\d{4})\b", q, re.IGNORECASE)
27
+
28
+ if m and p and t:
29
+ agg_type = m.group(1).lower() # "total" or "average"
30
+ metric = m.group(2) # column name
31
+ product = f"Product {p.group(1).upper()}"
32
+ quarter = t.group(1)
33
+ year = int(t.group(2))
34
+
35
+ # filter the *numeric* DataFrame
36
+ subset = df[
37
+ (df["Product"] == product) &
38
+ (df["FiscalQuarter"] == quarter) &
39
+ (df["FiscalYear"] == year)
40
+ ]
41
+
42
+ if not subset.empty:
43
+ if agg_type == "total":
44
+ val = subset[metric].sum()
45
+ return f"Total {metric} for {product} in {quarter} {year}: {val:,.2f}"
46
+ else: # average
47
+ val = subset[metric].mean()
48
+ # show 3 decimal places for margins, 2 for currency
49
+ fmt = "{:,.3f}" if metric=="ProfitMargin" else "{:,.2f}"
50
+ return f"Average {metric} for {product} in {quarter} {year}: " + fmt.format(val)
51
+
52
+ # --- 2. fallback to TAPAS for everything else ---
53
+ res = qa(table=df_str, query=q)
54
+ agg = res.get("aggregator","")
55
+ if agg and agg != "NONE":
56
+ return f"Answer: {res['answer']} (agg: {agg})"
57
+ # last-resort: raw answer
58
+ return f"Answer: {res['answer']}"
59
+
60
 
61
  demo = gr.Interface(
62
  fn=answer,
63
+ inputs=gr.Textbox(lines=2, placeholder="e.g. Profit for Product A in Q1 2023?"),
64
+ outputs="text",
65
  title="S/4HANA Profitability Chat",
 
66
  )
67
+ demo.launch(share=True, debug=True)