from typing import Any, Optional, Tuple, Union import torch import transformers class DistilBertTransferLearningModel(torch.nn.Module): def __init__( self, pretrained_model: str = "distilbert-base-uncased", layers: list[Tuple[str, Optional[list[Any]]]] = [ ('linear', ['in', 'out']), ('softmax'), ], dim_out: int = 2, use_local_file: bool = False, device: str = 'cpu', state_dict: Optional[Union[str, dict]] = None, ): super(DistilBertTransferLearningModel, self).__init__() self.tokenizer = transformers.AutoTokenizer.from_pretrained( pretrained_model, local_files_only=use_local_file ) self.base_model = transformers.AutoModel.from_pretrained( pretrained_model, local_files_only=use_local_file ) clf_layers = [] for layer in layers: layer_type = layer[0] if isinstance(layer, tuple) else layer if layer_type == 'linear': layer_in, layer_out = [ ( self.base_model.config.hidden_size if x == 'in' else dim_out if x == 'out' else x ) for x in layer[1] ] clf_layers.append(torch.nn.Linear(layer_in, layer_out)) elif layer_type == 'softmax': clf_layers.append(torch.nn.Softmax(dim=-1)) self.clf = torch.nn.Sequential(*clf_layers) if state_dict is not None: if isinstance(state_dict, str) and state_dict.endswith('.pt'): if device == 'cpu': state_dict = torch.load(state_dict, map_location='cpu') else: state_dict = torch.load(state_dict) self.load_state_dict(state_dict) def forward(self, ids: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: y = self.base_model(ids, attention_mask=mask, return_dict=False)[0][:, 0] y = self.clf(y) return y def predict(self, text: str, device: str) -> torch.Tensor: encoded = self.tokenizer.encode_plus( text, add_special_tokens=True, return_token_type_ids=False, return_attention_mask=True, max_length=512, padding='max_length', truncation=True, return_tensors='pt', ) with torch.no_grad(): ids = encoded['input_ids'].to(device) mask = encoded['attention_mask'].to(device) output = self.forward(ids, mask) return output.to(device)