import gc
import logging
import os
import random
from collections import Counter
from copy import copy
from itertools import chain
from random import randint
from tempfile import NamedTemporaryFile
import numpy as np
import torch
import torch.nn as nn
from sklearn.feature_extraction import DictVectorizer, FeatureHasher
from sklearn.metrics import f1_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from torch import optim
from torch.utils.data import Dataset, DataLoader
from torchcrf import CRF
from ...exceptions import MindMeldError
logger = logging.getLogger(__name__)
TEST_BATCH_SIZE = 512
[docs]class TaggerDataset(Dataset):
"""PyTorch Dataset class used to handle tagger inputs, labels and mask"""
def __init__(self, inputs, seq_lens, labels=None):
self.inputs = inputs
self.labels = labels
self.seq_lens = seq_lens
self.max_seq_length = max(seq_lens)
def __len__(self):
return len(self.seq_lens)
def __getitem__(self, index):
mask_list = [1] * self.seq_lens[index] + [0] * (self.max_seq_length - self.seq_lens[index])
mask = torch.as_tensor(mask_list, dtype=torch.bool)
if self.labels:
return self.inputs[index], mask, self.labels[index]
return self.inputs[index], mask
[docs]def diag_concat_coo_tensors(tensors):
"""Concatenates sparse PyTorch COO tensors diagonally so that they can processed in batches.
Args:
tensors (tuple of torch.Tensor): Tuple of sparse COO tensors to diagonally concatenate.
Returns:
stacked_tensor (torch.Tensor): A single sparse COO tensor that acts as a single batch.
"""
assert len(tensors) > 0
logger.debug("Concatenating %s tensors into a diagonal representation.", len(tensors))
rows = []
cols = []
values = []
sparse_sizes = [0, 0]
nnz = 0
for tensor in tensors:
tensor = tensor.coalesce()
row, col = tensor.indices()[0], tensor.indices()[1]
if row is not None:
rows.append(row + sparse_sizes[0])
cols.append(col + sparse_sizes[1])
value = tensor.values()
if value is not None:
values.append(value)
sparse_sizes[0] += tensor.shape[0]
sparse_sizes[1] += tensor.shape[1]
nnz += tensor._nnz()
row = None
if len(rows) == len(tensors):
row = torch.cat(rows, dim=0)
col = torch.cat(cols, dim=0)
value = None
if len(values) == len(tensors):
value = torch.cat(values, dim=0)
return torch.sparse_coo_tensor(indices=torch.stack([row, col]), values=value, size=sparse_sizes).coalesce()
[docs]def collate_tensors_and_masks(sequence):
"""Custom collate function that ensures proper batching of sparse tensors, labels and masks.
Args:
sequence (list of tuples): Each tuple contains one input tensor, one mask tensor and one label tensor.
Returns:
Batched representation of input, label and mask sequences.
"""
if len(sequence[0]) == 3:
sparse_mats, masks, labels = zip(*sequence)
return diag_concat_coo_tensors(sparse_mats), torch.stack(masks), torch.stack(labels)
if len(sequence[0]) == 2:
sparse_mats, masks = zip(*sequence)
return diag_concat_coo_tensors(sparse_mats), torch.stack(masks)
[docs]class Encoder:
"""Encoder class that is responsible for the feature extraction and label encoding for the PyTorch model."""
def __init__(self, feature_extractor="hash", num_feats=50000):
if feature_extractor == "dict":
self.feat_extractor = DictVectorizer(dtype=np.float32)
else:
self.feat_extractor = FeatureHasher(n_features=num_feats, dtype=np.float32)
self.label_encoder = LabelEncoder()
self.num_classes = None
self.classes = None
self.num_feats = num_feats
[docs] def get_feats_and_classes(self):
return self.num_feats, self.num_classes
[docs] def get_tensor_data(self, feat_dicts, labels=None, fit=False):
"""Gets the feature dicts and labels transformed into padded PyTorch sparse tensor data.
Args:
feat_dicts (list of list of dicts): Generally a list of feature vectors, one for each training example
y (list of lists): A list of classification labels
fit (bool): Flag to whether fit the Feature Extractor or Label Encoder.
Returns:
encoded_tensor_inputs (list of torch.Tensor): list of Sparse COO tensor representation of
encoded padded input sequence.
seq_lens (list of ints): List of actual length of each sequence.
encoded_tensor_labels (list of torch.Tensor): list of tensors representations of encoded
padded label sequence.
"""
if fit:
if isinstance(self.feat_extractor, DictVectorizer):
flattened_feat_dicts = list(chain.from_iterable(feat_dicts))
self.feat_extractor.fit(flattened_feat_dicts)
self.num_feats = len(self.feat_extractor.get_feature_names_out())
if labels is not None:
flattened_labels = list(chain.from_iterable(labels))
self.label_encoder.fit(flattened_labels)
self.classes, self.num_classes = self.label_encoder.classes_, len(self.label_encoder.classes_)
# number of tokens in each example
seq_lens = [len(x) for x in feat_dicts]
encoded_tensor_inputs = self.get_padded_transformed_tensors(feat_dicts, seq_lens, is_label=False)
encoded_tensor_labels = self.get_padded_transformed_tensors(labels, seq_lens, is_label=True)
return encoded_tensor_inputs, seq_lens, encoded_tensor_labels
[docs] def encode_padded_label(self, current_seq_len, max_seq_len, y):
"""Pads the label sequences to the max sequence length and returns the
torch tensor representation.
Args:
current_seq_len (int): Number of tokens in the current example sequence.
max_seq_len (int): Max number of tokens in an example sequence in the current dataset.
y (list of dicts): List of labels, one for each token in the example sequence.
Returns:
label_tensor (torch.Tensor): PyTorch tensor representation of padded label sequence
"""
transformed_label = self.label_encoder.transform(y)
transformed_label = np.pad(transformed_label, pad_width=(0, max_seq_len - current_seq_len),
constant_values=(self.num_classes - 1))
label_tensor = torch.as_tensor(transformed_label, dtype=torch.long)
return label_tensor
[docs]def compute_l1_params(w):
return torch.abs(w).sum()
[docs]def compute_l2_params(w):
return torch.square(w).sum()
# pylint: disable=too-many-instance-attributes
[docs]class CRFModel(nn.Module):
"""PyTorch Model Class for Conditional Random Fields"""
def __init__(self):
super().__init__()
self.optim = None
self.scheduler = None
self._encoder = None
self.W = None
self.b = None
self.crf_layer = None
self.num_classes = None
self.feat_type = None
self.feat_num = None
self.stratify_train_val_split = None
self.drop_input = None
self.batch_size = None
self.patience = None
self.number_of_epochs = None
self.dev_split_ratio = None
self.optimizer = None
self.l1_weight = None
self.l2_weight = None
self.random_state = None
[docs] def get_encoder(self):
return self._encoder
[docs] def set_encoder(self, encoder):
self._encoder = encoder
[docs] def set_random_states(self):
"""Sets the random seeds across all libraries used for deterministic output."""
torch.manual_seed(self.random_state)
random.seed(self.random_state + 1)
np.random.seed(self.random_state + 2)
[docs] def save_best_weights_path(self, path):
"""Saves the best weights of the model to a path in the .generated folder.
Args:
path (str): Path to save the best model weights.
"""
torch.save(self.state_dict(), path)
[docs] def load_best_weights_path(self, path):
"""Saves the best weights of the model to a path in the .generated folder.
Args:
path (str): Path to save the best model weights.
"""
if os.path.exists(path):
self.load_state_dict(torch.load(path))
else:
raise MindMeldError("CRF weights not saved. Please re-train model from scratch.")
[docs] def validate_params(self, kwargs):
"""Validate the argument values saved into the CRF model. """
for key in kwargs:
msg = (
"Unexpected param `{param}`, dropping it from model config.".format(
param=key
)
)
logger.warning(msg)
if self.optimizer not in ["sgd", "adam", "lbfgs"]:
raise MindMeldError(
f"Optimizer type {self.optimizer_type} not supported. Supported options are ['sgd', 'adam', 'lbfgs']")
if self.feat_type not in ["hash", "dict"]:
raise MindMeldError(f"Feature type {self.feat_type} not supported. Supported options are ['hash', 'dict']")
if not 0 < self.dev_split_ratio < 1:
raise MindMeldError("Train-dev split should be a value between 0 and 1.")
if not 0 <= self.drop_input < 1:
raise MindMeldError("Drop Input should be a value between 0 (inclusive) and 1.")
if not isinstance(self.patience, int):
raise MindMeldError("Patience should be an integer value.")
if not isinstance(self.number_of_epochs, int):
raise MindMeldError("Number of epochs should be am integer value.")
[docs] def build_params(self, num_features, num_classes):
"""Sets the parameters for the layers in the PyTorch CRF model. Naming convention is kept
consistent with the CRFSuite implementation.
Args:
num_features (int): Number of features to use in a FeatureHasher feature extractor.
num_classes (int): Number of classes in the tagging model.
"""
self.W = nn.Parameter(torch.nn.init.xavier_normal_(torch.empty(size=(num_features, num_classes))),
requires_grad=True)
self.b = nn.Parameter(torch.nn.init.constant_(torch.empty(size=(num_classes,)), val=0.01),
requires_grad=True)
self.crf_layer = CRF(num_classes, batch_first=True)
self.num_classes = num_classes
[docs] def forward(self, inputs, targets, mask, drop_input=0.0):
"""The forward pass of the PyTorch CRF model. Returns the predictions or loss depending on whether
labels are passed or not.
Args:
inputs (torch.Tensor): Batch of input tensors to pass through the model.
targets (torch.Tensor or None): Batch of label tensors.
mask (torch.Tensor) : Batch of mask tensors to account for padded inputs.
drop_input (float): Percentage of features to drop from the input.
Returns:
loss (torch.Tensor or list): Loss from training or predictions for input sequence.
"""
if drop_input:
dp_mask = (torch.FloatTensor(inputs.values().size()).uniform_() > drop_input)
inputs.values()[:] = inputs.values() * dp_mask
dense_w = torch.tile(self.W, dims=(mask.shape[0], 1))
out_1 = torch.addmm(self.b, inputs, dense_w)
crf_input = out_1.reshape((mask.shape[0], -1, self.num_classes))
if targets is None:
return self.crf_layer.decode(crf_input, mask=mask)
loss = - self.crf_layer(crf_input, targets, mask=mask, reduction='mean')
return loss
[docs] def compute_regularized_loss(self, l1):
model_parameters = []
for parameter in self.parameters():
model_parameters.append(parameter.view(-1))
if l1:
reg_loss = self.l1_weight * compute_l1_params(torch.cat(model_parameters))
else:
reg_loss = self.l2_weight * compute_l2_params(torch.cat(model_parameters))
return reg_loss
def _compute_log_alpha(self, emissions, mask, run_backwards):
"""Function used to calculate the alpha and beta probabilities of each token/tag probability.
Implementation is borrowed from https://github.com/kmkurn/pytorch-crf/pull/37.
Args:
emissions (torch.Tensor): Emission probabilities of batched input sequence.
mask (torch.Tensor): Batch of mask tensors to account for padded inputs.
run_backwards (bool): Flag to decide whether to compute alpha or beta probabilities.
Returns:
log_prob (torch.Tensor): alpha or beta log probabilities of input batch.
"""
# emissions: (seq_length, batch_size, num_tags)
# mask: (seq_length, batch_size)
assert emissions.dim() == 3 and mask.dim() == 2
assert emissions.size()[:2] == mask.size()
assert emissions.size(2) == self.crf_layer.num_tags
assert all(mask[0].data)
seq_length = emissions.size(0)
mask = mask.float()
broadcast_transitions = self.crf_layer.transitions.unsqueeze(0) # (1, num_tags, num_tags)
emissions_broadcast = emissions.unsqueeze(2)
seq_iterator = range(1, seq_length)
if run_backwards:
# running backwards, so transpose
broadcast_transitions = broadcast_transitions.transpose(1, 2) # (1, num_tags, num_tags)
emissions_broadcast = emissions_broadcast.transpose(2, 3)
# the starting probability is end_transitions if running backwards
log_prob = [self.crf_layer.end_transitions.expand(emissions.size(1), -1)]
# iterate over the sequence backwards
seq_iterator = reversed(seq_iterator)
else:
# Start transition score and first emission
log_prob = [emissions[0] + self.crf_layer.start_transitions.view(1, -1)]
for i in seq_iterator:
# Broadcast log_prob over all possible next tags
broadcast_log_prob = log_prob[-1].unsqueeze(2) # (batch_size, num_tags, 1)
# Sum current log probability, transition, and emission scores
score = broadcast_log_prob + broadcast_transitions + emissions_broadcast[
i] # (batch_size, num_tags, num_tags)
# Sum over all possible current tags, but we're in log prob space, so a sum
# becomes a log-sum-exp
score = torch.logsumexp(score, dim=1)
# Set log_prob to the score if this timestep is valid (mask == 1), otherwise
# copy the prior value
log_prob.append(score * mask[i].unsqueeze(1) +
log_prob[-1] * (1. - mask[i]).unsqueeze(1))
if run_backwards:
log_prob.reverse()
return torch.stack(log_prob)
[docs] def compute_marginal_probabilities(self, inputs, mask):
"""Function used to calculate the marginal probabilities of each token per tag.
Implementation is borrowed from https://github.com/kmkurn/pytorch-crf/pull/37.
Args:
inputs (torch.Tensor): Batch of padded input tensors.
mask (torch.Tensor): Batch of mask tensors to account for padded inputs.
Returns:
marginal probabilities for every tag for each token for every sequence.
"""
# SWITCHING FOR BATCH FIRST DEFAULT
dense_W = torch.tile(self.W, dims=(mask.shape[0], 1))
out_1 = torch.addmm(self.b, inputs, dense_W)
emissions = out_1.reshape((mask.shape[0], -1, self.num_classes))
emissions = emissions.transpose(0, 1)
mask = mask.transpose(0, 1)
alpha = self._compute_log_alpha(emissions, mask, run_backwards=False)
beta = self._compute_log_alpha(emissions, mask, run_backwards=True)
z = torch.logsumexp(alpha[alpha.size(0) - 1] + self.crf_layer.end_transitions, dim=1)
prob = alpha + beta - z.view(1, -1, 1)
return torch.exp(prob).transpose(0, 1)
# pylint: disable=too-many-arguments
[docs] def set_params(self, feat_type="hash", feat_num=50000, stratify_train_val_split=True, drop_input=0.2, batch_size=8,
number_of_epochs=100, patience=3, dev_split_ratio=0.2, optimizer="sgd", l1_weight=0, l2_weight=0,
random_state=None, **kwargs):
"""Set the parameters for the PyTorch CRF model and also validates the parameters.
Args:
feat_type (str): The type of feature extractor. Supported options are 'dict' and 'hash'.
feat_num (int): The number of features to be used by the FeatureHasher. Is not supported with the DictVectorizer
stratify_train_val_split (bool): Flag to check whether inputs should be stratified during train-dev split.
drop_input (float): The percentage at which to apply a dropout to the input features.
batch_size (int): Training batch size for the model.
number_of_epochs (int): The number of epochs (passes over the training data) to train the model for.
patience (int): Number of epochs to wait for before stopping training if dev score does not improve.
dev_split_ratio (float): Percentage of training data to be used for validation.
optimizer (str): Type of optimizer used for the model. Supported options are 'sgd' and 'adam'.
random_state (int): Integer value to set random seeds for deterministic output.
l1_weight (float): Regularization weight for L1-penalty
l2_weight (float): Regularization weight for L2-penalty
"""
self.feat_type = feat_type # ["hash", "dict"]
self.feat_num = feat_num
self.stratify_train_val_split = stratify_train_val_split
self.drop_input = drop_input
self.batch_size = batch_size
self.patience = patience
self.number_of_epochs = number_of_epochs
self.dev_split_ratio = dev_split_ratio
self.optimizer = optimizer # ["sgd", "adam"]
self.l1_weight = l1_weight
self.l2_weight = l2_weight
self.random_state = random_state or randint(1, 10000001)
self.validate_params(kwargs)
logger.debug("Random state for torch-crf is %s", self.random_state)
if self.feat_type == "dict":
logger.warning(
"WARNING: Number of features is compatible with only `hash` feature type. This value is ignored with `dict` setting")
[docs] def get_params(self):
"""
Get the parameters for the PyTorch CRF model.
"""
return {
"feat_type": self.feat_type,
"feat_num": self.feat_num,
"stratify_train_val_split": self.stratify_train_val_split,
"drop_input": self.drop_input,
"batch_size": self.batch_size,
"patience": self.patience,
"number_of_epochs": self.number_of_epochs,
"dev_split_ratio": self.dev_split_ratio,
"optimizer": self.optimizer,
"l1_weight": self.l1_weight,
"l2_weight": self.l2_weight,
"random_state": self.random_state
}
[docs] def get_dataloader(self, X, y, is_train):
"""Creates and returns the PyTorch dataloader instance for the training/test data.
Args:
X (list of list of dicts): Generally a list of feature vectors, one for each training example
y (list of lists or None): A list of classification labels (encoded by the label_encoder, NOT MindMeld
entity objects)
is_train (bool): Whether the dataloader returned is going to be used for training.
Returns:
torch_dataloader (torch.utils.data.dataloader.DataLoader): returns PyTorch dataloader object that can be
used to iterate across the data.
"""
if self.optimizer == "lbfgs" and is_train:
self.batch_size = len(X)
tensor_inputs, input_seq_lens, tensor_labels = self._encoder.get_tensor_data(X, y, fit=is_train)
tensor_dataset = TaggerDataset(tensor_inputs, input_seq_lens, tensor_labels)
torch_dataloader = DataLoader(tensor_dataset, batch_size=self.batch_size if is_train else TEST_BATCH_SIZE,
shuffle=is_train, collate_fn=collate_tensors_and_masks)
return torch_dataloader
[docs] def fit(self, X, y):
"""Trains the entire PyTorch CRF model.
Args:
X (list of list of dicts): Generally a list of feature vectors, one for each training example
y (list of lists): A list of classification labels (encoded by the label_encoder, NOT MindMeld
entity objects)
"""
self.set_random_states()
self._encoder = Encoder(feature_extractor=self.feat_type, num_feats=self.feat_num)
stratify_tuples = None
if self.stratify_train_val_split:
X, y, stratify_tuples = stratify_input(X, y)
# TODO: Rewrite our own train_test_split function to handle FileBackedList and avoid duplicating unique labels
train_X, dev_X, train_y, dev_y = train_test_split(X, y, test_size=self.dev_split_ratio,
stratify=stratify_tuples, random_state=self.random_state)
train_dataloader = self.get_dataloader(train_X, train_y, is_train=True)
dev_dataloader = self.get_dataloader(dev_X, dev_y, is_train=False)
# desperate attempt to save some memory
del X, y, train_X, train_y, dev_X, dev_y, stratify_tuples
gc.collect()
self.build_params(*self._encoder.get_feats_and_classes())
if self.optimizer == "sgd":
self.optim = optim.SGD(self.parameters(), lr=0.01, momentum=0.9, nesterov=True,
weight_decay=self.l2_weight)
if self.optimizer == "adam":
self.optim = optim.Adam(self.parameters(), lr=0.001, weight_decay=self.l2_weight)
if self.optimizer == "lbfgs":
self.optim = optim.LBFGS(self.parameters(), lr=1, max_iter=100, history_size=6,
line_search_fn="strong_wolfe")
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optim, mode='max',
patience=max(self.patience - 2, 1),
factor=0.5)
with NamedTemporaryFile(suffix=".pt", prefix="best_crf_wts") as tmp_file:
self.training_loop(train_dataloader, dev_dataloader, tmp_file.name)
self.load_state_dict(torch.load(tmp_file.name))
[docs] def training_loop(self, train_dataloader, dev_dataloader, tmp_save_path):
"""Contains the training loop process where we train the model for specified number of epochs.
Args:
train_dataloader (torch.utils.data.dataloader.DataLoader): Dataloader for training data
dev_dataloader (torch.utils.data.dataloader.DataLoader): Dataloader for validation data
"""
best_dev_score, best_dev_epoch = -np.inf, -1
_patience_counter = 0
for epoch in range(self.number_of_epochs):
if _patience_counter >= self.patience:
break
self.train_one_epoch(train_dataloader)
dev_f1_score = self.run_predictions(dev_dataloader, calc_f1=True)
logger.info("Epoch %s finished. Dev F1: %s", epoch, dev_f1_score)
self.scheduler.step(dev_f1_score)
if dev_f1_score <= best_dev_score:
_patience_counter += 1
else:
_patience_counter = 0
best_dev_score, best_dev_epoch = dev_f1_score, epoch
torch.save(self.state_dict(), tmp_save_path)
logger.debug("Model weights saved for best dev epoch %s.", best_dev_epoch)
[docs] def train_one_epoch(self, train_dataloader):
"""Contains the training code for one epoch.
Args:
train_dataloader (torch.utils.data.dataloader.DataLoader): Dataloader for training data
"""
self.train()
train_loss = 0
for batch_idx, (inputs, mask, labels) in enumerate(train_dataloader):
def closure():
nonlocal train_loss
self.optim.zero_grad()
# pylint: disable=cell-var-from-loop
loss = self.forward(inputs, labels, mask, drop_input=self.drop_input)
if self.l2_weight > 0 and self.optimizer == "lbfgs":
loss += self.compute_regularized_loss(l1=False)
if self.l1_weight > 0:
loss += self.compute_regularized_loss(l1=True)
train_loss += loss.item()
loss.backward()
return loss
if self.optimizer == "lbfgs":
self.optim.step(closure)
else:
closure()
self.optim.step()
if batch_idx % 20 == 0:
logger.debug("Batch: %s Mean Loss: %s", batch_idx,
(train_loss / (batch_idx + 1)))
[docs] def run_predictions(self, dataloader, calc_f1=False):
"""Get predictions for the data by running a inference pass of the model.
Args:
dataloader (torch.utils.data.dataloader.DataLoader): Dataloader for test/validation data
calc_f1 (bool): Flag to return dev f1 score or return predictions for each token.
Returns:
Dev F1 score or predictions for each token in a sequence.
"""
self.eval()
predictions = []
targets = []
with torch.no_grad():
for inputs, *mask_and_labels in dataloader:
if calc_f1:
mask, labels = mask_and_labels
targets.extend(torch.masked_select(labels, mask).tolist())
else:
mask = mask_and_labels.pop()
preds = self.forward(inputs, None, mask)
predictions.extend([x for lst in preds for x in lst] if calc_f1 else preds)
if calc_f1:
dev_score = f1_score(targets, predictions, average='weighted')
return dev_score
else:
return predictions
[docs] def predict_marginals(self, X):
"""Get marginal probabilites for each tag per token for each sequence.
Args:
X (list of list of dicts): Feature vectors for data to predict marginal probabilities on.
Returns:
marginals_dict (list of list of dicts): Returns the probability of every tag for each token in a sequence.
"""
dataloader = self.get_dataloader(X, None, is_train=False)
marginals_dict = []
self.eval()
with torch.no_grad():
for inputs, mask in dataloader:
probs = self.compute_marginal_probabilities(inputs, mask).tolist()
mask = mask.tolist()
# This is basically to create a nested list-dict structure in which we have the probability values
# for each token for each sequence.
for seq, mask_seq in zip(probs, mask):
one_seq_list = []
for (token_probs, valid_token) in zip(seq, mask_seq):
if valid_token:
one_seq_list.append(dict(zip(self._encoder.classes, token_probs)))
marginals_dict.append(one_seq_list)
return marginals_dict
[docs] def predict(self, X):
"""Gets predicted labels for the data.
Args:
X (list of list of dicts): Feature vectors for data to predict labels on.
Returns:
preds (list of lists): Predictions for each token in each sequence.
"""
dataloader = self.get_dataloader(X, None, is_train=False)
preds = self.run_predictions(dataloader, calc_f1=False)
return [self._encoder.label_encoder.inverse_transform(x).tolist() for x in preds]