|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Tokenization classes for BLT.""" |
|
|
|
import os |
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple |
|
|
|
from ...tokenization_utils import AddedToken, PreTrainedTokenizer |
|
from ...utils import logging |
|
|
|
|
|
if TYPE_CHECKING: |
|
from ...tokenization_utils_base import TextInput |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
SEP = " " |
|
BOS_ID: int = 1 |
|
EOS_ID: int = 2 |
|
PAD_ID: int = -1 |
|
BOE_ID: int = 0 |
|
BPE_ID: int = 3 |
|
OFFSET: int = 4 |
|
BYTE_UNITS: int = 256 |
|
|
|
VOCAB_FILES_NAMES = {} |
|
|
|
|
|
class BLTTokenizer(PreTrainedTokenizer): |
|
""" |
|
Construct a BLT tokenizer. Based on byte-level tokenization where each byte is treated as a token. |
|
|
|
This tokenizer converts text to UTF-8 bytes and then maps each byte to a token ID with an offset. |
|
It supports special tokens for beginning of sequence (BOS), end of sequence (EOS), |
|
beginning of example (BOE), and padding (PAD). |
|
|
|
Args: |
|
bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<s>"`): |
|
The beginning of sequence token. |
|
eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"</s>"`): |
|
The end of sequence token. |
|
pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<pad>"`): |
|
The padding token. |
|
unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<unk>"`): |
|
The unknown token. Not used in BLT but kept for compatibility. |
|
boe_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<boe>"`): |
|
The beginning of example token, specific to BLT. |
|
add_bos_token (`bool`, *optional*, defaults to `True`): |
|
Whether or not to add a `bos_token` at the start of sequences. |
|
add_eos_token (`bool`, *optional*, defaults to `True`): |
|
Whether or not to add an `eos_token` at the end of sequences. |
|
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): |
|
Whether or not to cleanup spaces after decoding. |
|
spaces_between_special_tokens (`bool`, *optional*, defaults to `False`): |
|
Whether or not to add spaces between special tokens. |
|
""" |
|
|
|
vocab_files_names = VOCAB_FILES_NAMES |
|
model_input_names = ["input_ids", "attention_mask"] |
|
|
|
def __init__( |
|
self, |
|
bos_token="<s>", |
|
eos_token="</s>", |
|
pad_token="<pad>", |
|
unk_token="<unk>", |
|
boe_token="<boe>", |
|
add_bos_token=True, |
|
add_eos_token=True, |
|
clean_up_tokenization_spaces=False, |
|
spaces_between_special_tokens=False, |
|
**kwargs, |
|
): |
|
|
|
self.add_bos_token = add_bos_token |
|
self.add_eos_token = add_eos_token |
|
self.vocab_size_unit_1 = BYTE_UNITS |
|
self.offsetting_special_char = OFFSET |
|
|
|
|
|
self.boe_id = BOE_ID |
|
self.bos_id = BOS_ID |
|
self.eos_id = EOS_ID |
|
self.pad_id = PAD_ID |
|
self.bpe_id = BPE_ID |
|
self.n_words = self.vocab_size_unit_1 + self.offsetting_special_char |
|
|
|
|
|
bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token |
|
eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token |
|
pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token |
|
unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token |
|
self.boe_token = AddedToken(boe_token, normalized=False, special=True) if isinstance(boe_token, str) else boe_token |
|
|
|
super().__init__( |
|
bos_token=bos_token, |
|
eos_token=eos_token, |
|
pad_token=pad_token, |
|
unk_token=unk_token, |
|
add_bos_token=add_bos_token, |
|
add_eos_token=add_eos_token, |
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces, |
|
spaces_between_special_tokens=spaces_between_special_tokens, |
|
**kwargs, |
|
) |
|
|
|
@property |
|
def vocab_size(self): |
|
"""Returns vocab size""" |
|
return self.vocab_size_unit_1 + self.offsetting_special_char |
|
|
|
def get_vocab(self): |
|
"""Returns vocab as a dict""" |
|
|
|
vocab = {} |
|
|
|
|
|
if hasattr(self, 'bos_token'): |
|
vocab[str(self.bos_token)] = self.bos_id |
|
if hasattr(self, 'eos_token'): |
|
vocab[str(self.eos_token)] = self.eos_id |
|
if hasattr(self, 'pad_token'): |
|
vocab[str(self.pad_token)] = self.pad_id |
|
if hasattr(self, 'boe_token'): |
|
vocab[str(self.boe_token)] = self.boe_id |
|
|
|
|
|
vocab_size_unit_1 = getattr(self, 'vocab_size_unit_1', BYTE_UNITS) |
|
offsetting_special_char = getattr(self, 'offsetting_special_char', OFFSET) |
|
for i in range(vocab_size_unit_1): |
|
vocab[str(i)] = i + offsetting_special_char |
|
|
|
|
|
if hasattr(self, 'added_tokens_encoder'): |
|
vocab.update(self.added_tokens_encoder) |
|
return vocab |
|
|
|
def _tokenize(self, text: str, **kwargs) -> List[str]: |
|
""" |
|
Converts a string to a list of tokens. For BLT, we work directly with byte values. |
|
Returns a list of strings that represent the byte values. |
|
""" |
|
|
|
try: |
|
bytes_data = text.encode("utf-8", errors="ignore") |
|
except UnicodeEncodeError: |
|
bytes_data = text.encode("utf-8", errors="ignore") |
|
|
|
|
|
return [str(byte_val) for byte_val in bytes_data] |
|
|
|
def _convert_token_to_id(self, token: str) -> int: |
|
"""Converts a token (str) to an id using the vocab.""" |
|
|
|
if token == str(self.bos_token): |
|
return self.bos_id |
|
elif token == str(self.eos_token): |
|
return self.eos_id |
|
elif token == str(self.pad_token): |
|
return self.pad_id |
|
elif token == str(self.boe_token): |
|
return self.boe_id |
|
else: |
|
try: |
|
|
|
byte_val = int(token) |
|
if 0 <= byte_val <= 255: |
|
return byte_val + self.offsetting_special_char |
|
except ValueError: |
|
pass |
|
|
|
|
|
return self.added_tokens_encoder.get(token, self.unk_token_id) |
|
|
|
def _convert_id_to_token(self, index: int) -> str: |
|
"""Converts an index (integer) to a token (str) using the vocab.""" |
|
|
|
if index == self.bos_id: |
|
return str(self.bos_token) |
|
elif index == self.eos_id: |
|
return str(self.eos_token) |
|
elif index == self.pad_id: |
|
return str(self.pad_token) |
|
elif index == self.boe_id: |
|
return str(self.boe_token) |
|
elif index >= self.offsetting_special_char and index < self.vocab_size: |
|
|
|
byte_val = index - self.offsetting_special_char |
|
return str(byte_val) |
|
else: |
|
|
|
for token, token_id in self.added_tokens_encoder.items(): |
|
if token_id == index: |
|
return token |
|
return str(self.unk_token) |
|
|
|
def convert_tokens_to_string(self, tokens: List[str]) -> str: |
|
"""Converts a sequence of tokens to a single string.""" |
|
byte_values = [] |
|
|
|
for token in tokens: |
|
|
|
if token in [str(self.bos_token), str(self.eos_token), str(self.pad_token), str(self.boe_token)]: |
|
continue |
|
|
|
try: |
|
|
|
byte_val = int(token) |
|
if 0 <= byte_val <= 255: |
|
byte_values.append(byte_val) |
|
except ValueError: |
|
continue |
|
|
|
|
|
try: |
|
return bytes(byte_values).decode("utf-8", errors="ignore") |
|
except (UnicodeDecodeError, ValueError): |
|
return "" |
|
|
|
def encode(self, text: str, add_bos: bool | None = None, add_eos: bool | None = None): |
|
""" |
|
Encode text exactly like the original BLT tokenizer. |
|
""" |
|
if add_bos is None: |
|
add_bos = self.add_bos_token |
|
if add_eos is None: |
|
add_eos = self.add_eos_token |
|
|
|
|
|
tokens = bytes(text, encoding="utf-8", errors="ignore") |
|
|
|
|
|
tokens = [int(unit) + self.offsetting_special_char for unit in tokens] |
|
|
|
if add_bos: |
|
tokens.insert(0, self.bos_id) |
|
if add_eos: |
|
tokens.append(self.eos_id) |
|
|
|
return tokens |
|
|
|
def decode(self, tokens: list[int], cut_at_eos: bool = False): |
|
""" |
|
Decode tokens exactly like the original BLT tokenizer. |
|
""" |
|
if cut_at_eos: |
|
for k, t in enumerate(tokens): |
|
if t == self.eos_id: |
|
tokens = tokens[: k + 1] |
|
break |
|
return bytes( |
|
[tok - self.offsetting_special_char for tok in tokens if tok - self.offsetting_special_char >= 0] |
|
).decode("utf-8", errors="ignore") |
|
|
|
def get_vocab_size(self) -> int: |
|
"""Get vocab size like the original tokenizer.""" |
|
return self.vocab_size_unit_1 + self.offsetting_special_char |
|
|
|
__all__ = ["BLTTokenizer"] |