Pclanglais commited on
Commit
5a43ec1
·
verified ·
1 Parent(s): a924afe

Create inference_transcript_ner.py

Browse files
Files changed (1) hide show
  1. inference_transcript_ner.py +119 -0
inference_transcript_ner.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import pandas as pd
3
+ from tqdm.auto import tqdm
4
+ from transformers import pipeline
5
+ from transformers import AutoTokenizer
6
+
7
+ model_checkpoint = "Pclanglais/French-TV-transcript-NER"
8
+ token_classifier = pipeline(
9
+ "token-classification", model=model_checkpoint, aggregation_strategy="simple"
10
+ )
11
+
12
+ tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
13
+
14
+ def split_text(text, max_tokens=500):
15
+ # Split the text by newline characters
16
+ parts = text.split("\n")
17
+ chunks = []
18
+ current_chunk = ""
19
+
20
+ for part in parts:
21
+ # Add part to current chunk
22
+ if current_chunk:
23
+ temp_chunk = current_chunk + "\n" + part
24
+ else:
25
+ temp_chunk = part
26
+
27
+ # Tokenize the temporary chunk
28
+ num_tokens = len(tokenizer.tokenize(temp_chunk))
29
+
30
+ if num_tokens <= max_tokens:
31
+ current_chunk = temp_chunk
32
+ else:
33
+ if current_chunk:
34
+ chunks.append(current_chunk)
35
+ current_chunk = part
36
+
37
+ if current_chunk:
38
+ chunks.append(current_chunk)
39
+
40
+ # If no newlines were found and still exceeding max_tokens, split further
41
+ if len(chunks) == 1 and len(tokenizer.tokenize(chunks[0])) > max_tokens:
42
+ long_text = chunks[0]
43
+ chunks = []
44
+ while len(tokenizer.tokenize(long_text)) > max_tokens:
45
+ split_point = len(long_text) // 2
46
+ while split_point < len(long_text) and not re.match(r'\s', long_text[split_point]):
47
+ split_point += 1
48
+ # Ensure split_point does not go out of range
49
+ if split_point >= len(long_text):
50
+ split_point = len(long_text) - 1
51
+ chunks.append(long_text[:split_point].strip())
52
+ long_text = long_text[split_point:].strip()
53
+ if long_text:
54
+ chunks.append(long_text)
55
+
56
+ return chunks
57
+
58
+
59
+ complete_data = pd.read_parquet("../ocr/ocr_corrected_yacast.parquet")
60
+
61
+ print(complete_data)
62
+
63
+ classified_list = []
64
+
65
+ list_prompt = []
66
+ list_page = []
67
+ list_file = []
68
+ list_id = []
69
+ text_id = 1
70
+ for index, row in complete_data.iterrows():
71
+ prompt, current_file = str(row["corrected_text"]), row["identifier"]
72
+ prompt = re.sub("\n", " ¶ ", prompt)
73
+
74
+ # Tokenize the prompt and check if it exceeds 500 tokens
75
+ num_tokens = len(tokenizer.tokenize(prompt))
76
+
77
+ if num_tokens > 500:
78
+ # Split the prompt into chunks
79
+ chunks = split_text(prompt, max_tokens=500)
80
+ for chunk in chunks:
81
+ list_file.append(current_file)
82
+ list_prompt.append(chunk)
83
+ list_id.append(text_id)
84
+ else:
85
+ list_file.append(current_file)
86
+ list_prompt.append(prompt)
87
+ list_id.append(text_id)
88
+
89
+ text_id = text_id + 1
90
+
91
+ full_classification = []
92
+ batch_size = 4
93
+ for out in tqdm(token_classifier(list_prompt, batch_size=batch_size), total=len(list_prompt)/batch_size):
94
+ full_classification.append(out)
95
+
96
+ id_row = 0
97
+ for classification in full_classification:
98
+ try:
99
+ df = pd.DataFrame(classification)
100
+
101
+ df["identifier"] = list_file[id_row]
102
+ df["text_id"] = list_id[id_row]
103
+
104
+ df['word'] = df['word'].replace(' ¶ ', ' \n ', regex=True)
105
+
106
+ print(df)
107
+
108
+ classified_list.append(df)
109
+
110
+ except:
111
+ pass
112
+ id_row = id_row + 1
113
+
114
+ classified_list = pd.concat(classified_list)
115
+
116
+ # Display the DataFrame
117
+ print(classified_list)
118
+
119
+ classified_list.to_csv("result_transcripts.tsv", sep = "\t")