# Copyright (c) 2025 Baidu, Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from shutil import copyfile from typing import Dict, List, Optional, Tuple, Union import sentencepiece as spm from transformers.tokenization_utils import PreTrainedTokenizer from transformers.utils import logging logger = logging.get_logger(__name__) class Ernie4_5_Tokenizer(PreTrainedTokenizer): vocab_files_names = { "vocab_file": "tokenizer.model", } # Model input names expected by the tokenizer model_input_names = ["input_ids", "position_ids", "attention_mask", "labels"] # Padding side (where to add padding tokens) padding_side = "right" def __init__( self, vocab_file, bos_token="", cls_token="", eos_token="", mask_token="", pad_token="", sep_token="", unk_token="", additional_special_tokens=None, verbose=False, **kwargs, ): """ Initialize the ERNIE tokenizer. Args: vocab_file (str): Path to the SentencePiece model file. bos_token (str, optional): Beginning of sentence token. Defaults to "". cls_token (str, optional): Classification token. Defaults to "". eos_token (str, optional): End of sentence token. Defaults to "". mask_token (str, optional): Mask token. Defaults to "". pad_token (str, optional): Padding token. Defaults to "". sep_token (str, optional): Separator token. Defaults to "". unk_token (str, optional): Unknown token. Defaults to "". additional_special_tokens (List[str], optional): Additional special tokens. Defaults to ["", ""]. verbose (bool, optional): Whether to print detailed logs or progress information during execution. **kwargs: Additional keyword arguments passed to the parent class. """ self.vocab_file = vocab_file self.sp_model = spm.SentencePieceProcessor() self.sp_model.Load(vocab_file) if additional_special_tokens is None: additional_special_tokens = ["", ""] super().__init__( bos_token=bos_token, cls_token=cls_token, eos_token=eos_token, mask_token=mask_token, pad_token=pad_token, sep_token=sep_token, unk_token=unk_token, additional_special_tokens=additional_special_tokens, verbose=verbose, **kwargs, ) self.all_spec_tok = set(self.all_special_tokens) @property def vocab_size(self): """Returns the size of the vocabulary. Returns: int: The number of tokens in the vocabulary. """ return self.sp_model.vocab_size() def get_vocab(self): """Get the vocabulary as a dictionary mapping tokens to their IDs. Returns: dict: A dictionary mapping tokens to their corresponding IDs. """ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} vocab.update(self.added_tokens_encoder) return vocab def _tokenize(self, text): """Tokenize text using SentencePiece. Args: text (str): The text to tokenize. Returns: list: A list of tokens. """ return self.sp_model.encode_as_pieces(text) def _convert_token_to_id(self, token): """Convert a token (str) to an ID using the vocabulary. Args: token (str): The token to convert. Returns: int: The corresponding token ID. """ return self.sp_model.piece_to_id(token) def _convert_id_to_token(self, id): """Convert an ID to a token (str) using the vocabulary. Args: id (int): The token ID to convert. Returns: str: The corresponding token. """ if id >= self.vocab_size: return self.unk_token else: return self.sp_model.id_to_piece(id) def convert_tokens_to_string(self, tokens): """Convert a sequence of tokens back to a single string. Args: tokens (List[str]): A list of tokens to convert. Returns: str: The reconstructed string. """ return self.sp_model.decode(tokens) def prepare_for_model(self, *args, **kwargs): if "add_special_tokens" in kwargs: kwargs.pop("add_special_tokens") return super().prepare_for_model(*args, **kwargs) def save_vocabulary( self, save_directory, filename_prefix: Optional[str] = None ) -> Tuple[str]: """ Save the vocabulary and special tokens file to a directory. Args: save_directory (str): The directory in which to save the vocabulary. filename_prefix (Optional[str]): Optional prefix for the saved filename. Returns: Tuple[str]: Paths to the files saved. Raises: ValueError: If the save_directory is not a valid directory. """ if not os.path.isdir(save_directory): logger.error(f"Vocabulary path ({save_directory}) should be a directory") return out_vocab_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + self.vocab_files_names["vocab_file"], ) if os.path.abspath(self.vocab_file) != os.path.abspath( out_vocab_file ) and os.path.isfile(self.vocab_file): copyfile(self.vocab_file, out_vocab_file) elif not os.path.isfile(self.vocab_file): with open(out_vocab_file, "wb") as fi: content_spiece_model = self.sp_model.serialized_model_proto() fi.write(content_spiece_model) return (out_vocab_file,) def _decode( self, token_ids: Union[int, list[int]], skip_special_tokens: bool = False, clean_up_tokenization_spaces: Optional[bool] = False, spaces_between_special_tokens: bool = False, **kwargs, ) -> str: self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False) filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens) # If given is a single id, prevents splitting the string in upcoming loop if isinstance(filtered_tokens, str): filtered_tokens = [filtered_tokens] sub_texts = [] current_sub_text = [] for token in filtered_tokens: if skip_special_tokens and token in self.all_spec_tok: continue else: current_sub_text.append(token) if current_sub_text: sub_texts.append(self.convert_tokens_to_string(current_sub_text)) if spaces_between_special_tokens: text = " ".join(sub_texts) else: text = "".join(sub_texts) if clean_up_tokenization_spaces: clean_text = self.clean_up_tokenization(text) return clean_text else: return text