Source code for mindmeld.models.nn_utils.classification

# -*- coding: utf-8 -*-
#
# Copyright (c) 2015 Cisco Systems, Inc. and others.  All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Base for custom modules that are developed on top of nn layers that can do
sequence or token classification
"""

import json
import logging
import os
import random
import shutil
import uuid
from abc import abstractmethod
from itertools import chain
from typing import Dict, Union, List, Tuple

import numpy as np
from sklearn.metrics import accuracy_score, f1_score
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import torch
import torch.nn as nn

from .helpers import (
    BatchData,
    get_disk_space_of_model,
    get_num_weights_of_model,
    get_default_params,
    TokenizerType,
    EmbedderType,
    ValidationMetricType,
    TRAIN_DEV_SPLIT_SEED,
    SHUFFLE_TRAINING_SEED,
    LABEL_PAD_TOKEN_IDX,
    DEFAULT_EMB_DIM,
    DEFAULT_TOKENIZER
)
from .input_encoders import InputEncoderFactory
from .._util import _get_module_or_attr
from ..containers import GloVeEmbeddingsContainer
from ...core import Bunch
from ...path import USER_CONFIG_DIR

nn_module = _get_module_or_attr("torch.nn", "Module")

logger = logging.getLogger(__name__)


[docs]class BaseClassification(nn_module): """ A base class for sequence & token classification using deep neural nets. Both the classification submodules have a common fit() method defined in this base class, which also drives the training of pytorch based deep nets. The net's computational graph is defined only when the fit() method is called. This base class also holds few common utility methods and further defines the skeleton of the children classes through abstract methods. """ def __init__(self): super().__init__() self.name = self.__class__.__name__ self.params = Bunch() self.params.update({ "name": self.name, "classification_type": self.classification_type, }) self.encoder = None self.ready = False # True when .fit() is called or loaded from a checkpoint, else False self.dirty = False # True when the model weights aren't saved to disk yet, else False self.out_dim = float('-inf') def __repr__(self): return f"<{self.name}> ready:{self.ready} dirty:{self.dirty}"
[docs] def get_default_params(self) -> Dict: return get_default_params(self.name)
[docs] def log_and_return_model_info(self, verbose: bool = False) -> str: """ Logs and returns the details of the underlying torch.nn model, such as occupying disk space when dumped, number of parameters, the device on which the model is placed, etc. Args: verbose (bool): Determines the amount of information to be logged and returned. """ msg = f"{self.name} " \ f"ready:{self.ready} dirty:{self.dirty} device:{self.params.device} " \ f"\n\tNumber of weights (trainable, all):{get_num_weights_of_model(self)} " verbose_msg = msg + ( f"\n\tDisk Size (in MB): {get_disk_space_of_model(self):.4f} " if verbose else "" ) logger.info(verbose_msg) return msg
[docs] def to_device(self, batch_data: BatchData) -> BatchData: """ Places pytorch tensors on the device configured through the params Args: batch_data (BatchData): A BatchData object consisting different tensor objects """ for k, v in batch_data.items(): if v is not None and isinstance(v, torch.Tensor): batch_data[k] = v.to(self.params.device) elif isinstance(v, list): batch_data[k] = [ vv.to(self.params.device) if isinstance(vv, torch.Tensor) else vv for vv in v ] elif isinstance(v, dict): batch_data[k] = self.to_device(batch_data[k]) return batch_data
# pylint: disable=too-many-locals
[docs] def fit(self, examples: List[str], labels: Union[List[int], List[List[int]]], **params): """ Trains the underlying neural model on the inputted data and finally retains the best scored model among all iterations. Because of possibly large sized neural models, instead of retaining a copy of best set of model weights on RAM, it is advisable to dump them in a temporary folder and upon completing the training process, load the best checkpoint weights. Args: examples (List[str]): A list of text strings that will be used for model training and validation labels (Union[List[int], List[List[int]]]): A list of labels passed in as integers corresponding to the examples. The encoded labels must have values between 0 and n_classes-1 -- one label per example in case of sequence classification and a sequence of labels per example in case of token classification """ if self.ready: msg = "The model is already fitted or is loaded from a file. Aborting re-fitting." logger.error(msg) # obtain and validate all parameters required to fit the model params = { **self.params, **self.get_default_params(), **params # overwrite keys of default params that are passed-in } params = self._validate_and_update_params(**params) # update params upon preparing encoder and embedder params = self._prepare_input_encoder(examples, **params) params = self._prepare_embedder(**params) # update number of labels in params upon identifying the unique labels try: # labels for sequence classification are of type List[int] whereas for token # classification they are List[List[int]]; this try-except tries to obtain the number # of unique label strings for the purpose of classification num_labels = len(set(labels)) except TypeError: # error will be raised in case of token classification as `set(labels)` attempts to # find a set of lists num_labels = len(set(chain.from_iterable(labels))) params.update({"num_labels": num_labels}) # split input data into train & dev splits, and get data loaders train_examples, dev_examples, train_labels, dev_labels = train_test_split( examples, labels, test_size=params["dev_split_ratio"], random_state=TRAIN_DEV_SPLIT_SEED ) # update self.params which will be used throughout the following modeling code self.params.update(params) # init the graph and move model to device, inputs are moved to device on-the-go self._init_graph() self.to(self.params.device) # dumping weights during training process into a temp folder instead of keeping in # memory to reduce memory usage temp_folder = os.path.join(USER_CONFIG_DIR, "tmp", "pytorch_models", str(uuid.uuid4())) temp_weights_save_path = os.path.join(temp_folder, "pytorch_model.bin") os.makedirs(temp_folder, exist_ok=True) # create an optimizer and attach all model params to it num_training_steps = int( len(train_examples) / self.params.batch_size / self.params.gradient_accumulation_steps * self.params.number_of_epochs ) optimizer, scheduler = self._create_optimizer_and_scheduler(num_training_steps) # set verbosity boolean _verbose = ( logger.getEffectiveLevel() == logging.INFO or logger.getEffectiveLevel() == logging.DEBUG ) self.log_and_return_model_info(_verbose) # training w/ validation random.seed(SHUFFLE_TRAINING_SEED) best_dev_score, best_dev_epoch = -np.inf, -1 msg = f"Beginning to train for {self.params.number_of_epochs} number of epochs" logger.info(msg) if self.params.number_of_epochs < 1: raise ValueError("Param 'number_of_epochs' must be a positive integer greater than 0") patience_counter = 0 for epoch in range(1, self.params.number_of_epochs + 1): # patience before terminating due to no dev score improvements if patience_counter >= self.params.patience: msg = f"Set patience of {self.params.patience} epochs reached" logger.info(msg) break # set modules to train phase, reset gradients, do forward-backward propagations self.train() optimizer.zero_grad() train_loss, train_batches = 0.0, 0.0 indices = list(range(len(train_examples))) random.shuffle(indices) train_examples = [train_examples[ii] for ii in indices] train_labels = [train_labels[ii] for ii in indices] t = tqdm(range(0, len(train_examples), self.params.batch_size), disable=not _verbose) for start_idx in t: batch_examples = train_examples[start_idx:start_idx + self.params.batch_size] batch_labels = train_labels[start_idx:start_idx + self.params.batch_size] batch_data = self.encoder.batch_encode( examples=batch_examples, padding_length=self.params.padding_length, **({'add_terminals': self.params.add_terminals} if self.params.add_terminals is not None else {}) ) batch_data.update({ "_labels": self._prepare_labels( # `_` 'cause this key is for intermediate use batch_labels, # pad to the max length amongst encoded examples max([len(_split_lengths) for _split_lengths in batch_data["split_lengths"]]) ) }) batch_data = self.forward(batch_data) loss = batch_data["loss"] # .cpu() returns copy of tensor in CPU memory train_loss += loss.cpu().detach().numpy() train_batches += 1 # find gradients loss = loss / self.params.gradient_accumulation_steps loss.backward() # optimizer and scheduler step batch_id = start_idx / self.params.batch_size if ( start_idx + self.params.batch_size >= len(train_examples) or (batch_id + 1) % self.params.gradient_accumulation_steps == 0 ): # update weights when it is the last batch in the epoch or # when specified step is reached or if self.params.max_grad_norm: # clip (accumulated) gradients if required nn.utils.clip_grad_norm_(self.parameters(), self.params.max_grad_norm) optimizer.step() scheduler.step() optimizer.zero_grad() # log progress progress_bar_msg = f"Epoch: {epoch} | Mean loss: {train_loss / train_batches:.4f}" t.set_description(progress_bar_msg, refresh=True) train_loss = train_loss / train_batches # dev evaluation predictions, targets = [], [] t = tqdm(range(0, len(dev_examples), self.params.batch_size), disable=not _verbose) for start_idx in t: batch_examples = dev_examples[start_idx:start_idx + self.params.batch_size] batch_labels_targetted = dev_labels[start_idx:start_idx + self.params.batch_size] batch_labels_predicted = self.predict(batch_examples) # validation if len(batch_labels_predicted) != len(batch_labels_targetted): msg = f"Number of predictions ({len(batch_labels_predicted)}) " \ f"not equal to number of targets ({len(batch_labels_targetted)})" logger.error(msg) raise AssertionError(msg) # flatten if required try: batch_labels_predicted = sum(batch_labels_predicted, []) batch_labels_targetted = sum(batch_labels_targetted, []) except TypeError: # raised in case of sequence classification; implies already flattened pass # discard unwanted predictions using _label_padding_idx batch_labels_predicted, batch_labels_targetted = zip(*[ (x, y) for x, y in zip(batch_labels_predicted, batch_labels_targetted) if y != self.params._label_padding_idx ]) predictions.extend(batch_labels_predicted) targets.extend(batch_labels_targetted) progress_bar_msg = f"Epoch: {epoch} | " \ f"Validation Metric: {self.params.validation_metric} " t.set_description(progress_bar_msg, refresh=True) # compute score if ValidationMetricType(self.params.validation_metric) == ValidationMetricType.ACCURACY: dev_score = accuracy_score(targets, predictions, normalize=True) elif ValidationMetricType(self.params.validation_metric) == ValidationMetricType.F1: dev_score = f1_score(targets, predictions, average='weighted') else: msg = f"Invalid 'validation_metric' ({self.params.validation_metric}) provided " \ f"in params. Allowed values are only 'accuracy' and 'f1'" raise ValueError(msg) # save model weights in a temp folder; later move it to folder passed through dump() if dev_score >= best_dev_score: torch.save(self.state_dict(), temp_weights_save_path) phrase = ( f"improved from '{best_dev_score:.4f}' to" if dev_score > best_dev_score else "remained at" ) msg = f"Model weights saved after epoch: {epoch} when dev score {phrase} " \ f"'{dev_score:.4f}'\n" logger.info(msg) # update patience counter if dev_score == best_dev_score: patience_counter += 1 else: patience_counter = 0 best_dev_score, best_dev_epoch = dev_score, epoch else: patience_counter += 1 msg = f"No weights saved after epoch: {epoch}. " \ f"The dev score last improved after epoch: {best_dev_epoch}" logger.info(msg) # load back the best model dumped in temporary path and delete the temp folder msg = f"Setting the model weights to checkpoint whose dev " \ f"{self.params.validation_metric} score is {best_dev_score:.4f}" logger.info(msg) # because we are loading to same device, no `map_location` specified self.load_state_dict(torch.load(temp_weights_save_path)) shutil.rmtree(temp_folder) self.ready = True self.dirty = True
@staticmethod def _validate_and_update_params(**params) -> Dict: """Common validation and updation of the params dict before creating encoders and layers""" # populate few required key-values (ensures the key-values are populated if not inputted) params.update({ "add_terminals": params.get("add_terminals", True), "padding_length": params.get("padding_length"), # explicitly obtained for more # transparent param dictionary "tokenizer_type": params.get("tokenizer_type", DEFAULT_TOKENIZER), "_label_padding_idx": LABEL_PAD_TOKEN_IDX, # used to discard unwanted i.e. label # padding indices in the batch predictions in the fit() method }) # validate tokenizer_type param allowed_tokenizer_types = { EmbedderType.GLOVE: [ TokenizerType.WHITESPACE_TOKENIZER, TokenizerType.WHITESPACE_AND_CHAR_DUAL_TOKENIZER, ], EmbedderType.BERT: [TokenizerType.HUGGINGFACE_PRETRAINED_TOKENIZER, ] } if params.get("embedder_type") and params.get("tokenizer_type"): embedder_type = EmbedderType(params.get("embedder_type")) tokenizer_type = TokenizerType(params.get("tokenizer_type")) if embedder_type in allowed_tokenizer_types: if tokenizer_type not in allowed_tokenizer_types[embedder_type]: msg = f"For the selected choice of embedder ({embedder_type.value}), only " \ f"the following tokenizer_type are allowed: " \ f"{[v.value for v in allowed_tokenizer_types[embedder_type]]}." raise ValueError(msg) # validate validation metric validation_metric = params.get("validation_metric") try: validation_metric = ValidationMetricType(validation_metric) except ValueError as e: msg = f"Expected validation_metric amongst " \ f"{[v.value for v in ValidationMetricType.__members__.values()]} " \ f"but found '{validation_metric}'." raise ValueError(msg) from e return params def _prepare_input_encoder(self, examples: List[str], **params) -> Dict: """Sets the input encoder and returns an updated param dict""" # create and fit encoder self.encoder = InputEncoderFactory.get_encoder_cls(params.get("tokenizer_type"))(**params) self.encoder.prepare(examples=examples) params.update({ "_num_tokens": len(self.encoder.get_vocab()), "_padding_idx": self.encoder.get_pad_token_idx(), }) return params def _prepare_embedder(self, **params) -> Dict: """Sets the embedder if required and returns an updated param dict""" # check: cannot specify any conflicting params as required by child class if self.encoder is None: raise ValueError("An encoder must be first fitted before calling _prepare_embedder()") # check: cannot specify any conflicting params as required by child class embedder_type = params.get("embedder_type") if EmbedderType(embedder_type) == EmbedderType.GLOVE: # load glove embs token_dimension = params.get("embedder_type", 300) token_pretrained_embedding_filepath = params.get("token_pretrained_embedding_filepath") glove_container = GloVeEmbeddingsContainer( token_dimension=token_dimension, token_pretrained_embedding_filepath=token_pretrained_embedding_filepath ) token2emb = glove_container.get_pretrained_word_to_embeddings_dict() glove_emb_dim = glove_container.token_dimension # validate emb_dim emb_dim = params.get("emb_dim", glove_emb_dim) if emb_dim != glove_emb_dim: msg = f"Provided 'emb_dim':{emb_dim} cannot be used with the provided " \ f"'embedder_type':{embedder_type}. Consider not specifying any 'emb_dim' " \ f"with this embedder." raise ValueError(msg) params.update({ "embedder_type": embedder_type, "emb_dim": emb_dim, # overwrite the default value "_embedding_weights": { i: token2emb[t] for t, i in self.encoder.get_vocab().items() if t in token2emb }, }) elif EmbedderType(embedder_type) == EmbedderType.BERT: # the bert model is directly loaded in _init_core() itself params.update({ "embedder_type": embedder_type, "emb_dim": self.encoder.config.hidden_size, # overwrite the default value "pretrained_model_name_or_path": params.get("pretrained_model_name_or_path"), }) if not params.get("emb_dim"): msg = f"Need a valid 'emb_dim' to initialize embedding layers. To specify a " \ f"particular dimension, either pass-in the 'emb_dim' param or provide a valid " \ f"'embedder_type' param. Continuing with a default value:{DEFAULT_EMB_DIM}." logger.error(msg) params.update({"emb_dim": DEFAULT_EMB_DIM}) return params def _create_optimizer_and_scheduler(self, num_training_steps: int) -> Tuple: """Sets an optimizer and scheduler for training torch.nn net""" del num_training_steps # load a torch optimizer optimizer = getattr(torch.optim, self.params.optimizer)( self.parameters(), lr=self.params.learning_rate ) # load a constant lr scheduler scheduler = getattr(torch.optim.lr_scheduler, "LambdaLR")(optimizer, lambda _: 1) return optimizer, scheduler def _get_dumpable_state_dict(self): """ Returns a state dict of the Pytorch module that can be dumped. Overwriting definitions can select a subset of full state dict to be dumped (e.g. like the BERT based ones) """ return self.state_dict()
[docs] def dump(self, path: str): """ Dumps underlying torch.nn model, encoder state and params Args: path (str): The path header for where the files are dumped. The following states are dumped into different files: - Pytorch model weights - Encoder state - Params (including params such as tokenizer_type and emb_dim that are used during loading to create encoder and forward graph) """ # resolve path and create associated folder if required path = os.path.abspath(os.path.splitext(path)[0]) + ".pytorch_model" os.makedirs(path, exist_ok=True) # save weights torch.save(self._get_dumpable_state_dict(), os.path.join(path, "model.bin")) # save encoder's state self.encoder.dump(path) # save all params with open(os.path.join(path, "params.json"), "w") as fp: json.dump(dict(self.params), fp, indent=4) fp.close() msg = f"{self.name} model weights are dumped successfully" logger.info(msg) self.dirty = False
[docs] @classmethod def load(cls, path: str): """ Loads states from a dumped path Args: path (str): The path header wherein dumped files are present. """ # resolve path path = os.path.abspath(os.path.splitext(path)[0]) + ".pytorch_model" # load all params with open(os.path.join(path, "params.json"), "r") as fp: all_params = json.load(fp) fp.close() # create new instance module = cls() if module.name != all_params["name"]: msg = f"The name of the loaded model ({all_params['name']}) from the path '{path}' " \ f"is different from the name of the module instantiated ({module.name})" raise AssertionError(msg) # load encoder's state module.params.update(dict(all_params)) module.encoder = InputEncoderFactory.get_encoder_cls( tokenizer_type=all_params["tokenizer_type"])(**module.params) module.encoder.load(path) # load weights module._init_graph() device = "cuda" if torch.cuda.is_available() else "cpu" if device != module.params.device: msg = f"Model was dumped when on the device:{module.params.device} " \ f"but is now being loaded on device:{device}" logger.warning(msg) module.params.device = device bin_path = os.path.join(path, "model.bin") trained_state_dict = torch.load(bin_path, map_location=torch.device(device)) module_state_dict = module.state_dict() keys_diff = module_state_dict.keys() - trained_state_dict.keys() if keys_diff: msg = f"While loading {module.__class__.__name__}, {len(keys_diff)} keys of the " \ f"total {len(module_state_dict.keys())} of the torch module are not found in " \ f"the file loaded from {bin_path} " msg += "\n- This IS fine if loading a model for which only some parameters were " \ "trained and others frozen. \n- This IS NOT fine if you expect all parameters " \ "were trained." logger.warning(msg) module.load_state_dict(trained_state_dict, strict=False) module.to(device) msg = f"{module.name} model weights are loaded successfully on to the device:{device}" logger.info(msg) module.ready = True module.dirty = False return module
@abstractmethod def _prepare_labels(self, labels: Union[List[int], List[List[int]]], max_length: int): raise NotImplementedError @abstractmethod def _init_graph(self): raise NotImplementedError
[docs] @abstractmethod def forward(self, batch_data: BatchData) -> BatchData: raise NotImplementedError
[docs] @abstractmethod def predict(self, examples: List[str]) -> Union[List[int], List[List[int]]]: """ Returns predicted class labels Args: examples (List[str]): The list of examples for which predictions are computed and returned. """ raise NotImplementedError
[docs] @abstractmethod def predict_proba(self, examples: List[str]) -> Union[List[List[int]], List[List[List[int]]]]: """ Returns predicted class probabilities Args: examples (List[str]): The list of examples for which class prediction probabilities are computed and returned. """ raise NotImplementedError
@property def classification_type(self) -> str: raise NotImplementedError