Source code for ckip_transformers.nlp.util

#!/usr/bin/env python3
# -*- coding:utf-8 -*-

"""
This module implements the utilities for CKIP Transformers NLP drivers.
"""

__author__ = "Mu Yang <http://muyang.pro>"
__copyright__ = "2023 CKIP Lab"
__license__ = "GPL-3.0"


from abc import (
    ABCMeta,
    abstractmethod,
)

from typing import (
    List,
    NamedTuple,
    Optional,
    Tuple,
    Union,
)

from tqdm import tqdm

import numpy as np

import torch
from torch.utils.data import (
    DataLoader,
    TensorDataset,
)

from transformers import (
    AutoModelForTokenClassification,
    BatchEncoding,
    BertTokenizerFast,
)

################################################################################################################################


[docs]class CkipTokenClassification(metaclass=ABCMeta): """The base class for token classification task. Parameters ---------- model_name : ``str`` The pretrained model name (e.g. ``'ckiplab/bert-base-chinese-ws'``). tokenizer_name : ``str``, *optional*, defaults to **model_name** The pretrained tokenizer name (e.g. ``'bert-base-chinese'``). device : ``int`` or ``torch.device``, *optional*, defaults to -1 Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run the model on the associated CUDA device id. """ def __init__( self, model_name: str, tokenizer_name: Optional[str] = None, *, device: Union[int, torch.device] = -1, ): self.model = AutoModelForTokenClassification.from_pretrained(model_name) self.tokenizer = BertTokenizerFast.from_pretrained(tokenizer_name or model_name) # Allow passing a customized torch.device. if isinstance(device, torch.device): self.device = device else: self.device = torch.device("cpu" if device < 0 else f"cuda:{device}") self.model.to(self.device) ######################################################################################################################## @classmethod @abstractmethod def _model_names(cls): return NotImplemented # pragma: no cover def _get_model_name( self, model: str, ): try: model_name = self._model_names[model] except KeyError as exc: raise KeyError(f"Invalid model {model}") from exc return model_name ########################################################################################################################
[docs] def __call__( self, input_text: Union[List[str], List[List[str]]], *, use_delim: bool = False, delim_set: Optional[str] = ",,。::;;!!??", batch_size: int = 256, max_length: Optional[int] = None, show_progress: bool = True, pin_memory: bool = True, ): """Call the driver. Parameters ---------- input_text : ``List[str]`` or ``List[List[str]]`` The input sentences. Each sentence is a string or a list of string. use_delim : ``bool``, *optional*, defaults to False Segment sentence (internally) using ``delim_set``. delim_set : `str`, *optional*, defaults to ``',,。::;;!!??'`` Used for sentence segmentation if ``use_delim=True``. batch_size : ``int``, *optional*, defaults to 256 The size of mini-batch. max_length : ``int``, *optional* The maximum length of the sentence, must not longer then the maximum sequence length for this model (i.e. ``tokenizer.model_max_length``). show_progress : ``bool``, *optional*, defaults to True Show progress bar. pin_memory : ``bool``, *optional*, defaults to True Pin memory in order to accelerate the speed of data transfer to the GPU. This option is incompatible with multiprocessing. Disabled on CPU device. """ # Disable pin memory on CPU device if self.device.type == "cpu": pin_memory = False # Check max length model_max_length = self.tokenizer.model_max_length - 2 # Add [CLS] and [SEP] if max_length: assert max_length < model_max_length, ( "Sequence length is longer than the maximum sequence length for this model " f"({max_length} > {model_max_length})." ) else: max_length = model_max_length # Apply delimiter cut delim_index = self._find_delim( input_text=input_text, use_delim=use_delim, delim_set=delim_set, ) # Get worded input IDs if show_progress: input_text = tqdm(input_text, desc="Tokenization") input_ids_worded = [ [self.tokenizer.convert_tokens_to_ids(list(input_word)) for input_word in input_sent] for input_sent in input_text ] # Flatten input IDs (input_ids, index_map,) = self._flatten_input_ids( input_ids_worded=input_ids_worded, max_length=max_length, delim_index=delim_index, ) # Pad and segment input IDs (input_ids, attention_mask,) = self._pad_input_ids( input_ids=input_ids, ) # Convert input format encoded_input = BatchEncoding( data=dict( input_ids=input_ids, attention_mask=attention_mask, ), tensor_type="pt", ) # Create dataset dataset = TensorDataset(*encoded_input.values()) dataloader = DataLoader( dataset=dataset, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=pin_memory, ) if show_progress: dataloader = tqdm(dataloader, desc="Inference") # Call Model logits = [] with torch.no_grad(): for batch in dataloader: batch = tuple(tensor.to(self.device) for tensor in batch) (batch_logits,) = self.model(**dict(zip(encoded_input.keys(), batch)), return_dict=False) batch_logits = batch_logits.cpu().numpy()[:, 1:, :] # Remove [CLS] logits.append(batch_logits) # Call model logits = np.concatenate(logits, axis=0) return logits, index_map
@staticmethod def _find_delim( *, input_text, use_delim, delim_set, ): if not use_delim: return set() delim_index = set() delim_set = set(delim_set) for sent_idx, input_sent in enumerate(input_text): for word_idx, input_word in enumerate(input_sent): if input_word in delim_set: delim_index.add((sent_idx, word_idx)) return delim_index @staticmethod def _flatten_input_ids( *, input_ids_worded, max_length, delim_index, ): input_ids = [] index_map = [] input_ids_sent = [] index_map_sent = [] for sent_idx, input_ids_worded_sent in enumerate(input_ids_worded): for word_idx, word_ids in enumerate(input_ids_worded_sent): word_length = len(word_ids) if word_length == 0: index_map_sent.append(None) continue # Check if sentence segmentation is needed if len(input_ids_sent) + word_length > max_length: input_ids.append(input_ids_sent) input_ids_sent = [] # Insert tokens index_map_sent.append( ( len(input_ids), # line index len(input_ids_sent), # token index ) ) input_ids_sent += word_ids if (sent_idx, word_idx) in delim_index: input_ids.append(input_ids_sent) input_ids_sent = [] # End of a sentence if input_ids_sent: input_ids.append(input_ids_sent) input_ids_sent = [] index_map.append(index_map_sent) index_map_sent = [] return input_ids, index_map def _pad_input_ids( self, *, input_ids, ): max_length = max(map(len, input_ids)) padded_input_ids = [] attention_mask = [] for input_ids_sent in input_ids: token_count = len(input_ids_sent) pad_count = max_length - token_count padded_input_ids.append( [self.tokenizer.cls_token_id] + input_ids_sent + [self.tokenizer.sep_token_id] + [self.tokenizer.pad_token_id] * pad_count ) attention_mask.append([1] * (token_count + 2) + [0] * pad_count) # [CLS] & input & [SEP] # [PAD]s return padded_input_ids, attention_mask
################################################################################################################################
[docs]class NerToken(NamedTuple): """A named-entity recognition token.""" word: str #: ``str``, the token word. ner: str #: ``str``, the NER-tag. idx: Tuple[int, int] #: ``Tuple[int, int]``, the starting / ending index in the sentence.