File size: 5,093 Bytes
934316a
 
 
 
 
3868b69
 
 
 
4319bf0
3868b69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
---
library_name: transformers
tags: []
---

# Baichuan-M1-14B-Instruct-tokenizer

Fast transformers tokenizer for [mlx-community/Baichuan-M1-14B-Instruct-8bit](https://hf.co/mlx-community/Baichuan-M1-14B-Instruct-8bit)

Thanks a lot [@Xenova](https://huggingface.co/Xenova) for finding the final fix! 🙌

## Conversion

```py
from tokenization_baichuan import BaichuanTokenizer

original = BaichuanTokenizer.from_pretrained(".")

from transformers.convert_slow_tokenizer import SpmConverter, LlamaConverter, GemmaConverter, _get_prepend_scheme
from tokenizers import decoders, normalizers, pre_tokenizers, processors, Tokenizer, AddedToken
from tokenizers.models import BPE

class BaichuanConverter(SpmConverter):
    handle_byte_fallback = True

    def vocab(self, proto):
        vocab = [
            (self.original_tokenizer.convert_ids_to_tokens(0), 0.0),
            (self.original_tokenizer.convert_ids_to_tokens(1), 0.0),
            (self.original_tokenizer.convert_ids_to_tokens(2), 0.0),
        ]
        vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
        return vocab

    def unk_id(self, proto):
        unk_id = 0
        return unk_id

    def decoder(self, replacement, add_prefix_space):
        sequence = [
            decoders.Replace("▁", " "),
            decoders.ByteFallback(),
            decoders.Fuse(),
        ]
        return decoders.Sequence(sequence)

    def normalizer(self, proto):
        return normalizers.Replace(pattern=" ", content="▁")

    def pre_tokenizer(self, replacement, add_prefix_space):
        return None

    def post_processor(self):
        return None

    def tokenizer(self, proto):
        vocab_scores = self.vocab(proto)
        _, merges = self.SpmExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
        bpe_vocab = {word: i for i, (word, score) in enumerate(vocab_scores)}
        tokenizer = Tokenizer(
            BPE(
                bpe_vocab,
                merges,
                unk_token=proto.trainer_spec.unk_piece,
                fuse_unk=True,
                byte_fallback=self.handle_byte_fallback,
                dropout=None,
            )
        )

        # control tokens are special
        # user defined symbols are not
        # both user and control tokens are AddedTokens
        # Add user defined symbols (type == 4) from sentencepiece (https://github.com/google/sentencepiece/blob/6225e08edb2577757163b3f5dbba4c0b670ef445/src/sentencepiece_model.proto#L299C29-L299C33)
        spm_added_tokens = [
            (id, p.piece, p.type == 3 or p.piece in self.special_tokens)
            for id, p in enumerate(proto.pieces)
            if p.type in [3, 4]
        ]

        # Reproduce weird behaviour in original tokenizer
        # only add tokens that did not originally exist
        bad_added_tokens = set()
        for _, token, _ in spm_added_tokens:
            encoded = self.original_tokenizer.encode(token)
            if len(encoded) != 1:
                bad_added_tokens.add(token)

        tokenizer.add_tokens(
            [
                AddedToken(token, normalized=True, special=special)
                for id, token, special in sorted(spm_added_tokens, key=lambda x: x[0])
                if token not in bad_added_tokens
            ]
        )

        return tokenizer

converter = BaichuanConverter(original)
converted = converter.converted()

from transformers import PreTrainedTokenizerFast

t_fast = PreTrainedTokenizerFast(
    tokenizer_object=converted,
    model_input_names=original.model_input_names,
    model_max_length=32768,
    clean_up_tokenization_spaces=False,
)

test_strings = [
    " {\n",
    "  {\n",
    "x  {\n",
    "----------------------------------------------------------------------------\n",
    "\n \n",
    "\n  \n",
    '// -----------------------------------------------------------------------\n',
    '-----------------------------------------------------------------------\n',
]
for test_string in test_strings:
    print("Original:", original.encode(test_string))
    print("Fast:    ", t_fast.encode(test_string))


# Testing on xnli

from datasets import load_dataset
from tqdm import tqdm

xnli = load_dataset("xnli", "all_languages", split="validation")

def verify(lang, text):
    encoded_original = original.encode(text)
    encoded_fast = t_fast.encode(text)
    assert encoded_fast == encoded_original, f"Fast encode error: {lang} - {text}"
    decoded = original.decode(encoded_original)
    decoded_fast = t_fast.decode(encoded_fast, skip_special_tokens=True)
    assert decoded_fast == decoded, f"Fast decode error: {lang} - {text}"

for p in tqdm(xnli["premise"]):
    for lang, text in p.items():
        verify(lang, text)

# Testing on codeparrot

ds = load_dataset("codeparrot/github-code", streaming=True, trust_remote_code=True, split="train")

iterator = iter(ds)
for _ in tqdm(range(1000)):
    item = next(iterator)
    code = item["code"]
    lang = item["language"]
    verify(lang, code)

t_fast.push_to_hub("Baichuan-M1-14B-Instruct-tokenizer")
```