update tokenization decode
Browse files- tokenization_ernie4_5.py +38 -22
tokenization_ernie4_5.py
CHANGED
@@ -14,9 +14,8 @@
|
|
14 |
|
15 |
import os
|
16 |
from shutil import copyfile
|
17 |
-
from typing import List, Optional, Tuple
|
18 |
import sentencepiece as spm
|
19 |
-
|
20 |
from transformers.tokenization_utils import PreTrainedTokenizer
|
21 |
from transformers.utils import logging
|
22 |
|
@@ -84,6 +83,7 @@ class Ernie4_5_Tokenizer(PreTrainedTokenizer):
|
|
84 |
verbose=verbose,
|
85 |
**kwargs,
|
86 |
)
|
|
|
87 |
|
88 |
@property
|
89 |
def vocab_size(self):
|
@@ -149,17 +149,7 @@ class Ernie4_5_Tokenizer(PreTrainedTokenizer):
|
|
149 |
Returns:
|
150 |
str: The reconstructed string.
|
151 |
"""
|
152 |
-
|
153 |
-
out_string = ""
|
154 |
-
for token in tokens:
|
155 |
-
# make sure that special tokens are not decoded using sentencepiece model
|
156 |
-
if token in self.all_special_tokens:
|
157 |
-
out_string += self.sp_model.decode(current_sub_tokens) + token
|
158 |
-
current_sub_tokens = []
|
159 |
-
else:
|
160 |
-
current_sub_tokens.append(token)
|
161 |
-
out_string += self.sp_model.decode(current_sub_tokens)
|
162 |
-
return out_string
|
163 |
|
164 |
def prepare_for_model(self, *args, **kwargs):
|
165 |
if "add_special_tokens" in kwargs:
|
@@ -202,13 +192,39 @@ class Ernie4_5_Tokenizer(PreTrainedTokenizer):
|
|
202 |
|
203 |
return (out_vocab_file,)
|
204 |
|
205 |
-
def _decode(
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
|
|
|
14 |
|
15 |
import os
|
16 |
from shutil import copyfile
|
17 |
+
from typing import Dict, List, Optional, Tuple, Union
|
18 |
import sentencepiece as spm
|
|
|
19 |
from transformers.tokenization_utils import PreTrainedTokenizer
|
20 |
from transformers.utils import logging
|
21 |
|
|
|
83 |
verbose=verbose,
|
84 |
**kwargs,
|
85 |
)
|
86 |
+
self.all_spec_tok = set(self.all_special_tokens)
|
87 |
|
88 |
@property
|
89 |
def vocab_size(self):
|
|
|
149 |
Returns:
|
150 |
str: The reconstructed string.
|
151 |
"""
|
152 |
+
return self.sp_model.decode(tokens)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
|
154 |
def prepare_for_model(self, *args, **kwargs):
|
155 |
if "add_special_tokens" in kwargs:
|
|
|
192 |
|
193 |
return (out_vocab_file,)
|
194 |
|
195 |
+
def _decode(
|
196 |
+
self,
|
197 |
+
token_ids: Union[int, list[int]],
|
198 |
+
skip_special_tokens: bool = False,
|
199 |
+
clean_up_tokenization_spaces: Optional[bool] = False,
|
200 |
+
spaces_between_special_tokens: bool = False,
|
201 |
+
**kwargs,
|
202 |
+
) -> str:
|
203 |
+
self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)
|
204 |
+
|
205 |
+
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
|
206 |
+
# If given is a single id, prevents splitting the string in upcoming loop
|
207 |
+
if isinstance(filtered_tokens, str):
|
208 |
+
filtered_tokens = [filtered_tokens]
|
209 |
+
|
210 |
+
sub_texts = []
|
211 |
+
current_sub_text = []
|
212 |
+
for token in filtered_tokens:
|
213 |
+
if skip_special_tokens and token in self.all_spec_tok:
|
214 |
+
continue
|
215 |
+
else:
|
216 |
+
current_sub_text.append(token)
|
217 |
+
if current_sub_text:
|
218 |
+
sub_texts.append(self.convert_tokens_to_string(current_sub_text))
|
219 |
+
|
220 |
+
if spaces_between_special_tokens:
|
221 |
+
text = " ".join(sub_texts)
|
222 |
+
else:
|
223 |
+
text = "".join(sub_texts)
|
224 |
+
|
225 |
+
if clean_up_tokenization_spaces:
|
226 |
+
clean_text = self.clean_up_tokenization(text)
|
227 |
+
return clean_text
|
228 |
+
else:
|
229 |
+
return text
|
230 |
|