File size: 10,820 Bytes
724be6e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 |
# coding=utf-8
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for BLT."""
import os
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
from ...utils import logging
if TYPE_CHECKING:
from ...tokenization_utils_base import TextInput
logger = logging.get_logger(__name__)
# BLT tokenizer constants
SEP = " "
BOS_ID: int = 1
EOS_ID: int = 2
PAD_ID: int = -1
BOE_ID: int = 0
BPE_ID: int = 3
OFFSET: int = 4
BYTE_UNITS: int = 256
VOCAB_FILES_NAMES = {} # BLT doesn't require external vocab files
class BLTTokenizer(PreTrainedTokenizer):
"""
Construct a BLT tokenizer. Based on byte-level tokenization where each byte is treated as a token.
This tokenizer converts text to UTF-8 bytes and then maps each byte to a token ID with an offset.
It supports special tokens for beginning of sequence (BOS), end of sequence (EOS),
beginning of example (BOE), and padding (PAD).
Args:
bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<s>"`):
The beginning of sequence token.
eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"</s>"`):
The end of sequence token.
pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<pad>"`):
The padding token.
unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<unk>"`):
The unknown token. Not used in BLT but kept for compatibility.
boe_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<boe>"`):
The beginning of example token, specific to BLT.
add_bos_token (`bool`, *optional*, defaults to `True`):
Whether or not to add a `bos_token` at the start of sequences.
add_eos_token (`bool`, *optional*, defaults to `True`):
Whether or not to add an `eos_token` at the end of sequences.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
Whether or not to cleanup spaces after decoding.
spaces_between_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to add spaces between special tokens.
"""
vocab_files_names = VOCAB_FILES_NAMES
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
bos_token="<s>",
eos_token="</s>",
pad_token="<pad>",
unk_token="<unk>",
boe_token="<boe>",
add_bos_token=True,
add_eos_token=True,
clean_up_tokenization_spaces=False,
spaces_between_special_tokens=False,
**kwargs,
):
# Store BLT-specific parameters first
self.add_bos_token = add_bos_token
self.add_eos_token = add_eos_token
self.vocab_size_unit_1 = BYTE_UNITS
self.offsetting_special_char = OFFSET
# BLT token IDs (exactly like original)
self.boe_id = BOE_ID
self.bos_id = BOS_ID
self.eos_id = EOS_ID
self.pad_id = PAD_ID
self.bpe_id = BPE_ID
self.n_words = self.vocab_size_unit_1 + self.offsetting_special_char
# Convert string tokens to AddedToken objects
bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token
eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token
pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token
unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token
self.boe_token = AddedToken(boe_token, normalized=False, special=True) if isinstance(boe_token, str) else boe_token
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
unk_token=unk_token,
add_bos_token=add_bos_token,
add_eos_token=add_eos_token,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
spaces_between_special_tokens=spaces_between_special_tokens,
**kwargs,
)
@property
def vocab_size(self):
"""Returns vocab size"""
return self.vocab_size_unit_1 + self.offsetting_special_char
def get_vocab(self):
"""Returns vocab as a dict"""
# Create a mapping for byte values + offset
vocab = {}
# Add special tokens (with defensive checks)
if hasattr(self, 'bos_token'):
vocab[str(self.bos_token)] = self.bos_id
if hasattr(self, 'eos_token'):
vocab[str(self.eos_token)] = self.eos_id
if hasattr(self, 'pad_token'):
vocab[str(self.pad_token)] = self.pad_id
if hasattr(self, 'boe_token'):
vocab[str(self.boe_token)] = self.boe_id
# Add byte tokens as string representations of byte values
vocab_size_unit_1 = getattr(self, 'vocab_size_unit_1', BYTE_UNITS)
offsetting_special_char = getattr(self, 'offsetting_special_char', OFFSET)
for i in range(vocab_size_unit_1):
vocab[str(i)] = i + offsetting_special_char
# Add any additional tokens if available
if hasattr(self, 'added_tokens_encoder'):
vocab.update(self.added_tokens_encoder)
return vocab
def _tokenize(self, text: str, **kwargs) -> List[str]:
"""
Converts a string to a list of tokens. For BLT, we work directly with byte values.
Returns a list of strings that represent the byte values.
"""
# Convert text to UTF-8 bytes, just like the original
try:
bytes_data = text.encode("utf-8", errors="ignore")
except UnicodeEncodeError:
bytes_data = text.encode("utf-8", errors="ignore")
# Return string representations of byte values for the tokenizer framework
return [str(byte_val) for byte_val in bytes_data]
def _convert_token_to_id(self, token: str) -> int:
"""Converts a token (str) to an id using the vocab."""
# Handle special tokens
if token == str(self.bos_token):
return self.bos_id
elif token == str(self.eos_token):
return self.eos_id
elif token == str(self.pad_token):
return self.pad_id
elif token == str(self.boe_token):
return self.boe_id
else:
try:
# Convert byte value string to int and add offset (like original)
byte_val = int(token)
if 0 <= byte_val <= 255:
return byte_val + self.offsetting_special_char
except ValueError:
pass
# Check if it's in added tokens
return self.added_tokens_encoder.get(token, self.unk_token_id)
def _convert_id_to_token(self, index: int) -> str:
"""Converts an index (integer) to a token (str) using the vocab."""
# Handle special tokens
if index == self.bos_id:
return str(self.bos_token)
elif index == self.eos_id:
return str(self.eos_token)
elif index == self.pad_id:
return str(self.pad_token)
elif index == self.boe_id:
return str(self.boe_token)
elif index >= self.offsetting_special_char and index < self.vocab_size:
# Convert back to byte value (like original)
byte_val = index - self.offsetting_special_char
return str(byte_val)
else:
# Check added tokens
for token, token_id in self.added_tokens_encoder.items():
if token_id == index:
return token
return str(self.unk_token)
def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""Converts a sequence of tokens to a single string."""
byte_values = []
for token in tokens:
# Skip special tokens
if token in [str(self.bos_token), str(self.eos_token), str(self.pad_token), str(self.boe_token)]:
continue
try:
# Convert token back to byte value (like original decode method)
byte_val = int(token)
if 0 <= byte_val <= 255:
byte_values.append(byte_val)
except ValueError:
continue
# Convert byte values back to string (exactly like original)
try:
return bytes(byte_values).decode("utf-8", errors="ignore")
except (UnicodeDecodeError, ValueError):
return ""
def encode(self, text: str, add_bos: bool | None = None, add_eos: bool | None = None):
"""
Encode text exactly like the original BLT tokenizer.
"""
if add_bos is None:
add_bos = self.add_bos_token
if add_eos is None:
add_eos = self.add_eos_token
# Since bpe_delim=False, we use the simple byte encoding
tokens = bytes(text, encoding="utf-8", errors="ignore")
# Offsetting (exactly like original)
tokens = [int(unit) + self.offsetting_special_char for unit in tokens]
if add_bos:
tokens.insert(0, self.bos_id)
if add_eos:
tokens.append(self.eos_id)
return tokens
def decode(self, tokens: list[int], cut_at_eos: bool = False):
"""
Decode tokens exactly like the original BLT tokenizer.
"""
if cut_at_eos:
for k, t in enumerate(tokens):
if t == self.eos_id:
tokens = tokens[: k + 1]
break
return bytes(
[tok - self.offsetting_special_char for tok in tokens if tok - self.offsetting_special_char >= 0]
).decode("utf-8", errors="ignore")
def get_vocab_size(self) -> int:
"""Get vocab size like the original tokenizer."""
return self.vocab_size_unit_1 + self.offsetting_special_char
__all__ = ["BLTTokenizer"] |