matt-tries-dl commited on
Commit
357d6d7
·
1 Parent(s): 9084b03

more nat language stuff

Browse files
Files changed (1) hide show
  1. llama_test.ipynb +144 -7
llama_test.ipynb CHANGED
@@ -2,7 +2,7 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 1,
6
  "metadata": {},
7
  "outputs": [
8
  {
@@ -11,7 +11,7 @@
11
  "True"
12
  ]
13
  },
14
- "execution_count": 1,
15
  "metadata": {},
16
  "output_type": "execute_result"
17
  }
@@ -32,7 +32,7 @@
32
  },
33
  {
34
  "cell_type": "code",
35
- "execution_count": 2,
36
  "metadata": {},
37
  "outputs": [
38
  {
@@ -47,7 +47,7 @@
47
  {
48
  "data": {
49
  "application/vnd.jupyter.widget-view+json": {
50
- "model_id": "fc0f4312aa7c4009a912c66dd1443763",
51
  "version_major": 2,
52
  "version_minor": 0
53
  },
@@ -83,7 +83,7 @@
83
  },
84
  {
85
  "cell_type": "code",
86
- "execution_count": 3,
87
  "metadata": {},
88
  "outputs": [
89
  {
@@ -132,7 +132,7 @@
132
  },
133
  {
134
  "cell_type": "code",
135
- "execution_count": 5,
136
  "metadata": {},
137
  "outputs": [
138
  {
@@ -158,9 +158,17 @@
158
  "! tar xvjf WikiSQL/data.tar.bz2"
159
  ]
160
  },
 
 
 
 
 
 
 
 
161
  {
162
  "cell_type": "code",
163
- "execution_count": 12,
164
  "metadata": {},
165
  "outputs": [
166
  {
@@ -213,6 +221,135 @@
213
  "qs = replace_cols(str(q),cm)\n",
214
  "print(qs)"
215
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  }
217
  ],
218
  "metadata": {
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 13,
6
  "metadata": {},
7
  "outputs": [
8
  {
 
11
  "True"
12
  ]
13
  },
14
+ "execution_count": 13,
15
  "metadata": {},
16
  "output_type": "execute_result"
17
  }
 
32
  },
33
  {
34
  "cell_type": "code",
35
+ "execution_count": 14,
36
  "metadata": {},
37
  "outputs": [
38
  {
 
47
  {
48
  "data": {
49
  "application/vnd.jupyter.widget-view+json": {
50
+ "model_id": "ca1fb983d9884b91a3c0feed1e207d0e",
51
  "version_major": 2,
52
  "version_minor": 0
53
  },
 
83
  },
84
  {
85
  "cell_type": "code",
86
+ "execution_count": 15,
87
  "metadata": {},
88
  "outputs": [
89
  {
 
132
  },
133
  {
134
  "cell_type": "code",
135
+ "execution_count": 16,
136
  "metadata": {},
137
  "outputs": [
138
  {
 
158
  "! tar xvjf WikiSQL/data.tar.bz2"
159
  ]
160
  },
161
+ {
162
+ "attachments": {},
163
+ "cell_type": "markdown",
164
+ "metadata": {},
165
+ "source": [
166
+ "Figure out what the actual data set has in it."
167
+ ]
168
+ },
169
  {
170
  "cell_type": "code",
171
+ "execution_count": 17,
172
  "metadata": {},
173
  "outputs": [
174
  {
 
221
  "qs = replace_cols(str(q),cm)\n",
222
  "print(qs)"
223
  ]
224
+ },
225
+ {
226
+ "attachments": {},
227
+ "cell_type": "markdown",
228
+ "metadata": {},
229
+ "source": [
230
+ "Ok, their query class deals poorly with stringifying the constraints. Let's mock up natural language prompt and SQL response."
231
+ ]
232
+ },
233
+ {
234
+ "cell_type": "code",
235
+ "execution_count": 56,
236
+ "metadata": {},
237
+ "outputs": [
238
+ {
239
+ "name": "stdout",
240
+ "output_type": "stream",
241
+ "text": [
242
+ "\n",
243
+ "Respond to the following data request with a SQL query.\n",
244
+ "Q: Table 2-16763320-1 has columns Tournament (text),Surface (text),Week (text),Winner (text),Finalist (text),Semifinalists (text). Which finalist has Semifinalists of andre agassi (1) lleyton hewitt (14)?\n",
245
+ "A: SELECT Finalist FROM 2-16763320-1 WHERE Semifinalists = 'andre agassi (1) lleyton hewitt (14)'\n",
246
+ "\n",
247
+ "Respond to the following data request with a SQL query.\n",
248
+ "Q: Table 1-27755784-10 has columns Game (real),Date (text),Team (text),Score (text),High points (text),High rebounds (text),High assists (text),Location Attendance (text),Record (text). What is the highest game number?\n",
249
+ "A: SELECT MAX Game FROM 1-27755784-10\n",
250
+ "\n",
251
+ "Respond to the following data request with a SQL query.\n",
252
+ "Q: Table 2-17231086-5 has columns Place (text),Player (text),Country (text),Score (text),To par (text). What place is the United States in that has a score of 68-73-68=209?\n",
253
+ "A: SELECT Place FROM 2-17231086-5 WHERE Country = 'united states' AND Score = '68-73-68=209'\n",
254
+ "\n",
255
+ "Respond to the following data request with a SQL query.\n",
256
+ "Q: Table 2-1302729-1 has columns Season (real),Overall (text),Slalom (text),Giant Slalom (text),Super G (text),Downhill (text),Combined (text). What is the combined of 2 overalls and 5 slaloms?\n",
257
+ "A: SELECT Combined FROM 2-1302729-1 WHERE Overall = '2' AND Slalom = '5'\n",
258
+ "\n",
259
+ "Respond to the following data request with a SQL query.\n",
260
+ "Q: Table 2-15295737-56 has columns Nation (text),Skip (text),Third (text),Second (text),Lead (text),Alternate (text). Who is the alternate for the team for which Monika Wagner is the third?\n",
261
+ "A: SELECT Alternate FROM 2-15295737-56 WHERE Third = 'monika wagner'\n"
262
+ ]
263
+ }
264
+ ],
265
+ "source": [
266
+ "import random\n",
267
+ "\n",
268
+ "# defined by WikiSQL\n",
269
+ "\n",
270
+ "agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']\n",
271
+ "cond_ops = ['=', '>', '<', 'OP']\n",
272
+ "syms = ['SELECT', 'WHERE', 'AND', 'COL', 'TABLE', 'CAPTION', 'PAGE', 'SECTION', 'OP', 'COND', 'QUESTION', 'AGG', 'AGGOPS', 'CONDOPS']\n",
273
+ "\n",
274
+ "def fix_repr(d,cols,types,tid):\n",
275
+ " sel_index=d['sel'] \n",
276
+ " agg_index=d['agg']\n",
277
+ " conditions=d['conds']\n",
278
+ " col = cols[sel_index]\n",
279
+ " rep = 'SELECT {agg} {sel} FROM {tid}'.format(\n",
280
+ " agg=agg_ops[agg_index],\n",
281
+ " sel=col,\n",
282
+ " tid=tid\n",
283
+ " )\n",
284
+ " if conditions:\n",
285
+ " cs = []\n",
286
+ " for i, o, v in conditions:\n",
287
+ " #print(i,cols)\n",
288
+ " nm = cols[i]\n",
289
+ " op = cond_ops[o]\n",
290
+ " \n",
291
+ " if types[i] in ['text']:\n",
292
+ " val = f\"\\'{v}\\'\"\n",
293
+ " else:\n",
294
+ " val = v\n",
295
+ " cs.append(f'{nm} {op} {val}')\n",
296
+ " #print(cs)\n",
297
+ "\n",
298
+ " rep += ' WHERE ' + ' AND '.join(cs)\n",
299
+ " \n",
300
+ " return rep\n",
301
+ "\n",
302
+ "tbl_cols = {}\n",
303
+ "tbl_types = {}\n",
304
+ "tbl_str = {}\n",
305
+ "\n",
306
+ "prefix = 'Respond to the following data request with a SQL query.\\n'\n",
307
+ "\n",
308
+ "def tbl_def_to_string(id, header, types):\n",
309
+ " ht = [f'{header[i]} ({types[i]})' for i in range(len(header))]\n",
310
+ " s = f'Q: Table {id} has columns ' + ','.join(ht) + '. '\n",
311
+ " return s\n",
312
+ "\n",
313
+ "with open('data/train.tables.jsonl') as f:\n",
314
+ " for line in f:\n",
315
+ " js = json.loads(line)\n",
316
+ " id = js['id']\n",
317
+ " hdr = js['header']\n",
318
+ " ts = js['types']\n",
319
+ " tbl_str[id] = tbl_def_to_string(id,hdr,ts)\n",
320
+ " tbl_cols[id] = hdr\n",
321
+ " tbl_types[id] = ts\n",
322
+ "\n",
323
+ "\n",
324
+ "nl_q = []\n",
325
+ "sql_a = []\n",
326
+ "\n",
327
+ "with open('data/train.jsonl') as f:\n",
328
+ " for line in f:\n",
329
+ " js = json.loads(line)\n",
330
+ " id = js['table_id']\n",
331
+ " s = tbl_str[id]\n",
332
+ " qst = js['question']\n",
333
+ " nl = prefix + s + qst\n",
334
+ " nl_q.append(nl)\n",
335
+ "\n",
336
+ " sql = js['sql']\n",
337
+ " a = fix_repr(sql,tbl_cols[id],tbl_types[id],id)\n",
338
+ " a = 'A: ' + a\n",
339
+ " sql_a.append(a)\n",
340
+ "\n",
341
+ "\n",
342
+ "M = len(nl_q)\n",
343
+ "\n",
344
+ "\n",
345
+ "for i in range(5):\n",
346
+ " j = random.randint(0,M-1)\n",
347
+ " print()\n",
348
+ " print(nl_q[j])\n",
349
+ " print(sql_a[j]) \n",
350
+ " \n",
351
+ " "
352
+ ]
353
  }
354
  ],
355
  "metadata": {