File size: 10,820 Bytes
724be6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
# coding=utf-8
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tokenization classes for BLT."""

import os
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple

from ...tokenization_utils import AddedToken, PreTrainedTokenizer
from ...utils import logging


if TYPE_CHECKING:
    from ...tokenization_utils_base import TextInput

logger = logging.get_logger(__name__)

# BLT tokenizer constants
SEP = " "
BOS_ID: int = 1
EOS_ID: int = 2
PAD_ID: int = -1
BOE_ID: int = 0
BPE_ID: int = 3
OFFSET: int = 4
BYTE_UNITS: int = 256

VOCAB_FILES_NAMES = {}  # BLT doesn't require external vocab files


class BLTTokenizer(PreTrainedTokenizer):
    """
    Construct a BLT tokenizer. Based on byte-level tokenization where each byte is treated as a token.

    This tokenizer converts text to UTF-8 bytes and then maps each byte to a token ID with an offset.
    It supports special tokens for beginning of sequence (BOS), end of sequence (EOS), 
    beginning of example (BOE), and padding (PAD).

    Args:
        bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<s>"`):
            The beginning of sequence token.
        eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"</s>"`):
            The end of sequence token.
        pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<pad>"`):
            The padding token.
        unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<unk>"`):
            The unknown token. Not used in BLT but kept for compatibility.
        boe_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<boe>"`):
            The beginning of example token, specific to BLT.
        add_bos_token (`bool`, *optional*, defaults to `True`):
            Whether or not to add a `bos_token` at the start of sequences.
        add_eos_token (`bool`, *optional*, defaults to `True`):
            Whether or not to add an `eos_token` at the end of sequences.
        clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
            Whether or not to cleanup spaces after decoding.
        spaces_between_special_tokens (`bool`, *optional*, defaults to `False`):
            Whether or not to add spaces between special tokens.
    """

    vocab_files_names = VOCAB_FILES_NAMES
    model_input_names = ["input_ids", "attention_mask"]

    def __init__(
        self,
        bos_token="<s>",
        eos_token="</s>", 
        pad_token="<pad>",
        unk_token="<unk>",
        boe_token="<boe>",
        add_bos_token=True,
        add_eos_token=True,
        clean_up_tokenization_spaces=False,
        spaces_between_special_tokens=False,
        **kwargs,
    ):
        # Store BLT-specific parameters first
        self.add_bos_token = add_bos_token
        self.add_eos_token = add_eos_token
        self.vocab_size_unit_1 = BYTE_UNITS
        self.offsetting_special_char = OFFSET
        
        # BLT token IDs (exactly like original)
        self.boe_id = BOE_ID
        self.bos_id = BOS_ID  
        self.eos_id = EOS_ID
        self.pad_id = PAD_ID
        self.bpe_id = BPE_ID
        self.n_words = self.vocab_size_unit_1 + self.offsetting_special_char

        # Convert string tokens to AddedToken objects
        bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token
        eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token
        pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token
        unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token
        self.boe_token = AddedToken(boe_token, normalized=False, special=True) if isinstance(boe_token, str) else boe_token

        super().__init__(
            bos_token=bos_token,
            eos_token=eos_token,
            pad_token=pad_token,
            unk_token=unk_token,
            add_bos_token=add_bos_token,
            add_eos_token=add_eos_token,
            clean_up_tokenization_spaces=clean_up_tokenization_spaces,
            spaces_between_special_tokens=spaces_between_special_tokens,
            **kwargs,
        )

    @property
    def vocab_size(self):
        """Returns vocab size"""
        return self.vocab_size_unit_1 + self.offsetting_special_char

    def get_vocab(self):
        """Returns vocab as a dict"""
        # Create a mapping for byte values + offset
        vocab = {}
        
        # Add special tokens (with defensive checks)
        if hasattr(self, 'bos_token'):
            vocab[str(self.bos_token)] = self.bos_id
        if hasattr(self, 'eos_token'):
            vocab[str(self.eos_token)] = self.eos_id  
        if hasattr(self, 'pad_token'):
            vocab[str(self.pad_token)] = self.pad_id
        if hasattr(self, 'boe_token'):
            vocab[str(self.boe_token)] = self.boe_id
        
        # Add byte tokens as string representations of byte values
        vocab_size_unit_1 = getattr(self, 'vocab_size_unit_1', BYTE_UNITS)
        offsetting_special_char = getattr(self, 'offsetting_special_char', OFFSET)
        for i in range(vocab_size_unit_1):
            vocab[str(i)] = i + offsetting_special_char
            
        # Add any additional tokens if available
        if hasattr(self, 'added_tokens_encoder'):
            vocab.update(self.added_tokens_encoder)
        return vocab

    def _tokenize(self, text: str, **kwargs) -> List[str]:
        """
        Converts a string to a list of tokens. For BLT, we work directly with byte values.
        Returns a list of strings that represent the byte values.
        """
        # Convert text to UTF-8 bytes, just like the original
        try:
            bytes_data = text.encode("utf-8", errors="ignore")
        except UnicodeEncodeError:
            bytes_data = text.encode("utf-8", errors="ignore")
        
        # Return string representations of byte values for the tokenizer framework
        return [str(byte_val) for byte_val in bytes_data]

    def _convert_token_to_id(self, token: str) -> int:
        """Converts a token (str) to an id using the vocab."""
        # Handle special tokens
        if token == str(self.bos_token):
            return self.bos_id
        elif token == str(self.eos_token):
            return self.eos_id
        elif token == str(self.pad_token):
            return self.pad_id
        elif token == str(self.boe_token):
            return self.boe_id
        else:
            try:
                # Convert byte value string to int and add offset (like original)
                byte_val = int(token)
                if 0 <= byte_val <= 255:
                    return byte_val + self.offsetting_special_char
            except ValueError:
                pass
            
            # Check if it's in added tokens
            return self.added_tokens_encoder.get(token, self.unk_token_id)

    def _convert_id_to_token(self, index: int) -> str:
        """Converts an index (integer) to a token (str) using the vocab."""
        # Handle special tokens
        if index == self.bos_id:
            return str(self.bos_token)
        elif index == self.eos_id:
            return str(self.eos_token)
        elif index == self.pad_id:
            return str(self.pad_token)
        elif index == self.boe_id:
            return str(self.boe_token)
        elif index >= self.offsetting_special_char and index < self.vocab_size:
            # Convert back to byte value (like original)
            byte_val = index - self.offsetting_special_char
            return str(byte_val)
        else:
            # Check added tokens
            for token, token_id in self.added_tokens_encoder.items():
                if token_id == index:
                    return token
            return str(self.unk_token)

    def convert_tokens_to_string(self, tokens: List[str]) -> str:
        """Converts a sequence of tokens to a single string."""
        byte_values = []
        
        for token in tokens:
            # Skip special tokens
            if token in [str(self.bos_token), str(self.eos_token), str(self.pad_token), str(self.boe_token)]:
                continue
            
            try:
                # Convert token back to byte value (like original decode method)
                byte_val = int(token)
                if 0 <= byte_val <= 255:
                    byte_values.append(byte_val)
            except ValueError:
                continue
                    
        # Convert byte values back to string (exactly like original)
        try:
            return bytes(byte_values).decode("utf-8", errors="ignore")
        except (UnicodeDecodeError, ValueError):
            return ""

    def encode(self, text: str, add_bos: bool | None = None, add_eos: bool | None = None):
        """
        Encode text exactly like the original BLT tokenizer.
        """
        if add_bos is None:
            add_bos = self.add_bos_token
        if add_eos is None:
            add_eos = self.add_eos_token

        # Since bpe_delim=False, we use the simple byte encoding
        tokens = bytes(text, encoding="utf-8", errors="ignore")

        # Offsetting (exactly like original)
        tokens = [int(unit) + self.offsetting_special_char for unit in tokens]

        if add_bos:
            tokens.insert(0, self.bos_id)
        if add_eos:
            tokens.append(self.eos_id)

        return tokens

    def decode(self, tokens: list[int], cut_at_eos: bool = False):
        """
        Decode tokens exactly like the original BLT tokenizer.
        """
        if cut_at_eos:
            for k, t in enumerate(tokens):
                if t == self.eos_id:
                    tokens = tokens[: k + 1]
                    break
        return bytes(
            [tok - self.offsetting_special_char for tok in tokens if tok - self.offsetting_special_char >= 0]
        ).decode("utf-8", errors="ignore")
    
    def get_vocab_size(self) -> int:
        """Get vocab size like the original tokenizer."""
        return self.vocab_size_unit_1 + self.offsetting_special_char

__all__ = ["BLTTokenizer"]