PD03 commited on
Commit
d162c32
·
verified ·
1 Parent(s): b1faf3e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -48
app.py CHANGED
@@ -2,79 +2,102 @@ import os
2
  import gradio as gr
3
  import pandas as pd
4
  import tensorflow as tf
5
- from tapas.scripts import prediction_utils
6
- from tapas.utils import number_annotation_utils
7
- from tapas.protos import interaction_pb2
8
 
9
- # 1) Read CSV and build list-of-lists table
10
- import pandas as pd
 
 
11
 
 
12
  df = pd.read_csv("synthetic_profit.csv")
13
- # Ensure all values are strings
14
  df = df.astype(str)
15
- # Build TAPAS-style table: header row + data rows
16
- table = [list(df.columns)] + df.values.tolist()
17
 
18
- # 2) Configure TAPAS conversion with aggregation support
19
- from tapas.utils import example_utils as tf_example_utils
 
 
 
 
 
 
20
  config = tf_example_utils.ClassifierConversionConfig(
21
  vocab_file="tapas_sqa_base/vocab.txt",
22
  max_seq_length=512,
23
  max_column_id=512,
24
  max_row_id=512,
25
- strip_column_names=False, # Keep header names
26
- add_aggregation_candidates=True, # Propose SUM/AVERAGE operations
27
  )
28
  converter = tf_example_utils.ToClassifierTensorflowExample(config)
29
 
30
- # 3) Helper: convert one interaction to model input
31
- def interaction_from_query(question: str):
 
 
 
 
 
 
 
 
 
 
 
32
  interaction = interaction_pb2.Interaction()
33
- # Add question
34
  q = interaction.questions.add()
35
- q.original_text = question
36
- # Add table columns
37
  for col in table[0]:
38
  interaction.table.columns.add().text = col
39
- # Add table rows/cells
40
- for row in table[1:]:
41
- r = interaction.table.rows.add()
42
- for cell in row:
43
- r.cells.add().text = cell
44
- # Annotate numeric values
45
  number_annotation_utils.add_numeric_values(interaction)
46
- return interaction
 
 
47
 
48
- # 4) Instantiate TAPAS model and tokenizer
49
- from transformers import TFAutoModelForSequenceClassification, AutoTokenizer
50
- MODEL = "google/tapas-base-finetuned-wtq"
51
- tokenizer = AutoTokenizer.from_pretrained(MODEL)
52
- model = TFAutoModelForSequenceClassification.from_pretrained(MODEL)
53
-
54
- # 5) Prediction helper
55
- def predict_answer(question: str):
56
- interaction = interaction_from_query(question)
57
- # Convert to TensorFlowExample
58
- tf_example = converter.convert(interaction)
59
- # Run prediction
60
- result = model(tf_example.features)
61
- # Parse answer coordinates
62
- coords = prediction_utils.parse_coordinates(result.logits)
63
- # Map coordinates back to table cells
64
  answers = []
65
- for r, c in coords:
 
66
  answers.append(table[r+1][c])
67
- return ", ".join(answers)
 
 
 
 
 
 
 
68
 
69
- # 6) Gradio interface
70
  iface = gr.Interface(
71
- fn=predict_answer,
72
- inputs=gr.Textbox(lines=2, placeholder="Ask a question"),
73
- outputs=gr.Textbox(lines=3),
74
  title="SAP Profitability Q&A (TAPAS Low-Level)",
75
  description=(
76
- "Low-level TAPAS: list-of-lists input, numeric annotations, "
77
- "aggregation candidates, and coordinate post-processing."
78
  ),
79
  allow_flagging="never",
80
  )
 
2
  import gradio as gr
3
  import pandas as pd
4
  import tensorflow as tf
 
 
 
5
 
6
+ # TAPAS imports
7
+ from tapas.protos import interaction_pb2
8
+ from tapas.utils import number_annotation_utils, tf_example_utils, prediction_utils
9
+ from tapas.scripts.run_task_main import get_classifier_model, get_task_config
10
 
11
+ # 1) Load & stringify your CSV
12
  df = pd.read_csv("synthetic_profit.csv")
 
13
  df = df.astype(str)
 
 
14
 
15
+ # 2) Build the “list of lists” table
16
+ # (header row + all data rows)
17
+ table = [list(df.columns)]
18
+ table.extend(df.values.tolist())
19
+
20
+ # 3) Prepare the TAPAS converter + model
21
+ # – add_aggregation_candidates=True to surface SUM/AVG ops
22
+ # – strip_column_names=False so your exact headers stay visible
23
  config = tf_example_utils.ClassifierConversionConfig(
24
  vocab_file="tapas_sqa_base/vocab.txt",
25
  max_seq_length=512,
26
  max_column_id=512,
27
  max_row_id=512,
28
+ strip_column_names=False,
29
+ add_aggregation_candidates=True,
30
  )
31
  converter = tf_example_utils.ToClassifierTensorflowExample(config)
32
 
33
+ # 4) Load your pretrained checkpoint
34
+ # (uses the same flags as run_task_main.py --mode=predict)
35
+ task_config = get_task_config(
36
+ task="sqa",
37
+ init_checkpoint="tapas_sqa_base/model.ckpt-0",
38
+ vocab_file=config.vocab_file,
39
+ bsz=1,
40
+ max_seq_length=config.max_seq_length,
41
+ )
42
+ model, tokenizer = get_classifier_model(task_config)
43
+
44
+ # 5) Convert a single (table, query) into a TF Example
45
+ def make_tf_example(table, query):
46
  interaction = interaction_pb2.Interaction()
47
+ # a) question
48
  q = interaction.questions.add()
49
+ q.original_text = query
50
+ # b) columns
51
  for col in table[0]:
52
  interaction.table.columns.add().text = col
53
+ # c) rows
54
+ for row_vals in table[1:]:
55
+ row = interaction.table.rows.add()
56
+ for cell in row_vals:
57
+ row.cells.add().text = cell
58
+ # d) numeric annotation helps SUM/AVG
59
  number_annotation_utils.add_numeric_values(interaction)
60
+ # e) convert to example
61
+ serialized = converter.convert(interaction)
62
+ return serialized
63
 
64
+ # 6) Run TAPAS and parse its coordinate output
65
+ def predict_answer(query):
66
+ # build TF example
67
+ example = make_tf_example(table, query)
68
+ # run prediction
69
+ input_fn = tf_example_utils.input_fn_builder(
70
+ [example],
71
+ is_training=False,
72
+ drop_remainder=False,
73
+ batch_size=1,
74
+ seq_length=config.max_seq_length,
75
+ )
76
+ preds = model.predict(input_fn)
77
+ # parse answer coordinates
78
+ coords = prediction_utils.parse_coordinates(preds[0]["answer_coordinates"])
79
+ # map back to table values
80
  answers = []
81
+ for (r, c) in coords:
82
+ # table[0] is header row, so data starts at index 1
83
  answers.append(table[r+1][c])
84
+ return ", ".join(answers) if answers else "No answer found."
85
+
86
+ # 7) Gradio interface
87
+ def answer_fn(question: str) -> str:
88
+ try:
89
+ return predict_answer(question)
90
+ except Exception as e:
91
+ return f"❌ Error: {e}"
92
 
 
93
  iface = gr.Interface(
94
+ fn=answer_fn,
95
+ inputs=gr.Textbox(lines=2, label="Your question"),
96
+ outputs=gr.Textbox(label="Answer"),
97
  title="SAP Profitability Q&A (TAPAS Low-Level)",
98
  description=(
99
+ "Uses TAPAS’s Interaction + Converter APIs with aggregation candidates "
100
+ "and numeric annotations to reliably answer sum/average queries."
101
  ),
102
  allow_flagging="never",
103
  )