Source code for mindmeld.augmentation

# -*- coding: utf-8 -*-
#
# Copyright (c) 2015 Cisco Systems, Inc. and others.  All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""This module contains the data augmentation processes for MindMeld."""

import logging
import re
import string
import random
import os
import zipfile

from abc import ABC, abstractmethod
from urllib.request import urlretrieve
from tqdm import tqdm

from ._util import get_pattern, read_path_queries, write_to_file
from .components._util import _is_module_available, _get_module_or_attr
from .components._config import ENGLISH_LANGUAGE_CODE
from .models.helpers import register_augmentor, AUGMENTATION_MAP
from .markup import load_query, dump_query
from .core import Entity, Span, QueryEntity, ProcessedQuery, _get_overlap
from .path import EMBEDDINGS_FOLDER_PATH, \
    PARAPHRASER_FILE_PATH, \
    PARAPHRASER_MODEL_PATH, \
    HUGGINGFACE_PARAPHRASER_MODEL_PATH
from .models.containers import TqdmUpTo

logger = logging.getLogger(__name__)

# pylint: disable=R0201

SUPPORTED_LANGUAGE_CODES = ["en", "es", "fr", "it", "pt", "ro"]
EOS_TOKEN = '</s>'
DEFAULT_NUM_PARAPHRASES = 10
PARAPHRASER_RETAIN_ENTITIES_URL = 'https://mindmeld-binaries.s3.amazonaws.com/paraphraser' \
                                  '/paraphrase_retain_entities.zip'


[docs]class UnsupportedLanguageError(Exception): pass
[docs]class AugmentorFactory: """Creates an Augmentor object. Attributes: config (dict): A model configuration. language (str): Language for data augmentation. resource_loader (object): Resource Loader object for the application. """ def __init__(self, config, language, resource_loader): self.config = config self.language = language self.resource_loader = resource_loader
[docs] def create_augmentor(self): """Creates an augmentor instance using the provided configuration Returns: Augmentor: An Augmentor class Raises: ValueError: When model configuration is invalid or required key is missing """ if "augmentor_class" not in self.config: raise KeyError( "Missing required argument in AUGMENTATION_CONFIG: 'augmentor_class'" ) # Validate configuration input batch_size = self.config.get("batch_size", 8) paths = self.config.get( "paths", [ { "domains": ".*", "intents": ".*", "files": ".*", } ], ) path_suffix = self.config.get("path_suffix", "-augment.txt") retain_entities = self.config.get("retain_entities", False) register_all_augmentors() try: return AUGMENTATION_MAP[self.config["augmentor_class"]]( batch_size=batch_size, language=self.language, retain_entities=retain_entities, paths=paths, path_suffix=path_suffix, resource_loader=self.resource_loader, ) except KeyError as e: msg = "Invalid model configuration: Unknown model type {!r}" raise ValueError(msg.format(self.config["augmentor_class"])) from e
[docs]class Augmentor(ABC): """ Abstract Augmentor class. """ def __init__(self, language, paths, path_suffix, resource_loader): """Initializes an augmentor. Args: language (str): The language code for paraphrasing paths (list): Path rules for fetching relevant files to Paraphrase. path_suffix (str): Suffix to be added to new augmented files. resource_loader (object): Resource Loader object for the application. """ self.language_code = language self.files_to_augment = paths self.path_suffix = path_suffix self._resource_loader = resource_loader self._check_dependencies() self._check_language_support() def _check_dependencies(self): """Checks module dependencies.""" if not _is_module_available("torch"): raise ModuleNotFoundError( "Library not found: 'torch'. Run 'pip install mindmeld[augment]' to install." ) if not _is_module_available("transformers"): raise ModuleNotFoundError( "Library not found: 'transformers'. Run 'pip install mindmeld[augment]' to install." ) def _check_language_support(self): """Checks if language is currently supported for augmentation.""" if self.language_code not in SUPPORTED_LANGUAGE_CODES: raise UnsupportedLanguageError( f"'{self.language_code}' is not supported yet. " "English (en), French (fr), and Italian (it), Portuguese (pt), Romanian (ro) " " and Spanish (es) are currently supported." )
[docs] def augment(self, **kwargs): """Augments queries given initial queries in application.""" filtered_paths = self._get_files(path_rules=self.files_to_augment) for path in tqdm(filtered_paths): queries = self._get_processed_queries_to_paraphrase(path) # To-Do: Use generator to write files incrementally. augmented_queries = self.augment_queries(queries, **kwargs) write_to_file(path, augmented_queries, suffix=self.path_suffix)
[docs] @abstractmethod def augment_queries(self, queries): """Generates augmented data given application queries. Args: queries (list): List of queries. Return: augmented_queries (list): List of augmented queries. """ raise NotImplementedError("Subclasses must implement this method")
@abstractmethod def _prepare_inputs(self, queries): """Prepare data to be fed to the models as input Args: queries (list(str)): List of queries to be paraphrased Returns: formatted queries (list(str)) """ raise NotImplementedError("Subclasses must implement this method") def _validate_generated_query(self, query): """Validates whether augmented query has atleast one alphanumeric character Args: query (str): Generated query to be validated. """ pattern = re.compile("^.*[a-zA-Z0-9].*$") return pattern.search(query) and True def _get_processed_queries_to_paraphrase(self, path): """Returns a list of processed queries for a given file path Args: path (str): Path to text file with queries Return: Processed queries (list(ProcessedQuery)) """ queries = read_path_queries(path) processed_queries = [] for query in queries: processed_query = load_query(query, query_factory=self._resource_loader.query_factory) processed_queries.append(processed_query) return processed_queries def _get_files(self, path_rules=None): """Fetches relevant files given the path rules specified in the config. Args: path_rules (list): Path rules for fetching relevant files. Return: filtered_paths (list): List of file paths to be augmeted. """ all_file_paths = self._resource_loader.get_all_file_paths() if not path_rules: logger.warning( """'paths' field is not configured or misconfigured in the `config.py`. Can't find files to augment.""" ) return [] filtered_paths = [] for rule in path_rules: pattern = get_pattern(rule) compiled_pattern = re.compile(pattern) filtered_paths.extend( self._resource_loader.filter_file_paths( compiled_pattern=compiled_pattern, file_paths=all_file_paths ) ) return filtered_paths
[docs]class EnglishParaphraser(Augmentor): """Paraphraser class for generating English paraphrases.""" def __init__(self, batch_size, language, retain_entities, paths, path_suffix, resource_loader): """Initializes an English paraphraser. Args: batch_size (int): Batch size for batch processing. language (str): The language code for paraphrasing. paths (list): Path rules for fetching relevant files to Paraphrase. path_suffix (str): Suffix to be added to new augmented files. resource_loader (object): Resource Loader object for the application. """ if language != ENGLISH_LANGUAGE_CODE: raise UnsupportedLanguageError( f"'{language}' is not supported by the English Augmentor class" ) super().__init__( language=language, paths=paths, path_suffix=path_suffix, resource_loader=resource_loader, ) PegasusTokenizer = _get_module_or_attr("transformers", "PegasusTokenizer") PegasusForConditionalGeneration = _get_module_or_attr( "transformers", "PegasusForConditionalGeneration" ) self.retain_entities = retain_entities if self.retain_entities: if not os.path.exists(PARAPHRASER_MODEL_PATH): self._download_model() model_name = PARAPHRASER_MODEL_PATH else: model_name = HUGGINGFACE_PARAPHRASER_MODEL_PATH self.torch_device = ( "cuda" if _get_module_or_attr("torch.cuda", "is_available")() else "cpu" ) self.tokenizer = PegasusTokenizer.from_pretrained(model_name) self.model = PegasusForConditionalGeneration.from_pretrained( model_name).to(self.torch_device) self.model.eval() # Update default params with user model config self.batch_size = batch_size self.default_paraphraser_model_params = { "max_length": 60, "num_beams": DEFAULT_NUM_PARAPHRASES, "num_return_sequences": DEFAULT_NUM_PARAPHRASES, "temperature": 1.5, } self.default_tokenizer_params = { "truncation": True, "padding": "longest", "max_length": 60, } def _download_model(self): logger.info("Downloading paraphrase model from %s", PARAPHRASER_RETAIN_ENTITIES_URL) # Make the folder that will contain the model folder if not os.path.exists(EMBEDDINGS_FOLDER_PATH): os.makedirs(EMBEDDINGS_FOLDER_PATH) with TqdmUpTo(unit="B", unit_scale=True, miniters=1, desc='') as t: try: urlretrieve(PARAPHRASER_RETAIN_ENTITIES_URL, PARAPHRASER_FILE_PATH, reporthook=t.update_to) except ConnectionError as e: logger.error("Model download failed with error: %s", e) return try: with zipfile.ZipFile(PARAPHRASER_FILE_PATH, 'r') as zip_ref: zip_ref.extractall(EMBEDDINGS_FOLDER_PATH) os.remove(PARAPHRASER_FILE_PATH) except zipfile.BadZipfile: logger.error("Unable to extract zip file. Try downloading the model again.") def _prepare_inputs(self, processed_queries): """Processes input as expected by the two different English models Example: The default model requires just <unannotated text> The retain_entities model requires <unannotated text> <EOS> <entity values> Args: processed queries (list(ProcessedQuery)): List of ProcessedQuery objects from the app Return: model_inputs (list(str)): List of queries to be paraphrased in the format required by the models processed_queries (list(ProcessedQuery)): List of ProcessedQuery with entity annotations """ model_inputs = [] for processed_query in processed_queries: processed_query_text = processed_query.query.text.strip() if self.retain_entities: text = [processed_query_text.lower(), EOS_TOKEN] for entity in processed_query.entities: text.append(entity.text.lower()) model_inputs.append(' '.join(text)) else: model_inputs.append(processed_query_text) return model_inputs def _replace_with_random_gaz_entity(self, paraphrase_text, entity_matches): """Replaces values of annotated entities with randomly sampled ones from gazetteers Args: paraphrase_text (str): The paraphrased unannotated text entity_matches (List((Entity,Span))): List of (Entity, Span) values found in the paraphrase_text Return: processed paraphrases (ProcessedQuery): ProcessedQuery of the paraphrase_text """ new_paraphrase_text = [] # Start replacing entities in ascending order of span starts entity_matches.sort(key=lambda x: x[1].start, reverse=False) running_start = 0 previous_end = 0 replaced_spans_entities = [] for (entity, span) in entity_matches: # Calculate new start based on previously replaced entity length # For the first entity in the query, previous_end will be 0 running_start += (span.start - previous_end) # Append text seen between entities new_paraphrase_text.append(paraphrase_text[previous_end:span.start]) gaz = None # If not a system entity and gazetteer is available, load it if not Entity.is_system_entity(entity.type): gaz = self._resource_loader.get_gazetteer(entity.type)['entities'] if gaz: # Create new Entity and Span based on random gaz entry for entity type random_gaz_entity_text = random.sample(gaz, 1)[0] new_span = Span(start=running_start, end=running_start + len(random_gaz_entity_text) - 1) new_entity = Entity(text=random_gaz_entity_text, entity_type=entity.type, role=entity.role, value=None) else: new_entity = entity new_span = Span(start=running_start, end=running_start + len(entity.text) - 1) running_start += len(new_entity.text) replaced_spans_entities.append([new_entity, new_span]) new_paraphrase_text.append(new_entity.text) previous_end = span.end + 1 # Append any leftover text new_paraphrase_text.append(paraphrase_text[previous_end:]) processed_query = self._resource_loader.query_factory.create_query( ''.join(new_paraphrase_text)) final_entities = [QueryEntity.from_query(query=processed_query, span=span, entity=entity) for (entity, span) in replaced_spans_entities] return ProcessedQuery(query=processed_query, entities=tuple(final_entities)) def _annotate_entities(self, paraphrases, processed_queries): """Annotates entities in the generated paraphrases with the entities in the original query Args: paraphrases (list(str)): List of unannotated paraphrases of queries processed_queries (list(ProcessedQuery)): List of their corresponding original ProcessedQuery Return: paraphrases (list(str)): List of paraphrased queries. """ valid_paraphrases = [] for i, processed_query in enumerate(processed_queries): # sort entities so we annotate the longest one first entities = sorted(list(processed_query.entities), key=lambda x: len(x.text), reverse=True) # fetch paraphrases for the query from the batch queries = paraphrases[(i * DEFAULT_NUM_PARAPHRASES): (i * DEFAULT_NUM_PARAPHRASES) + DEFAULT_NUM_PARAPHRASES] for query in queries: if not query: continue all_matches = [] for entity in entities: found_matches = re.finditer(entity.text.lower(), query) for match in found_matches: matched_span = Span(start=match.start(0), end=match.end(0)-1) matched_entity = Entity(text=match.group(0), entity_type=entity.entity.type, role=entity.entity.role, value=None) # check if found entity has no overlaps with previously matched entities no_overlaps = [not _get_overlap(m[1], matched_span) for m in all_matches] if all(no_overlaps): all_matches.append((matched_entity, matched_span)) # We are taking a call here to only return paraphrases that contain all entities # that were present in the original query if len(all_matches) == len(entities): processed_paraphrase = self._replace_with_random_gaz_entity(query, all_matches) # Dump the processed paraphrase queries in the mindmeld markdown format valid_paraphrases.append(dump_query(processed_paraphrase)) return valid_paraphrases @staticmethod def _normalize_paraphrases(queries): # This function removes punctuations since these generative models # have a tendency to repeat them. # Since most classifiers use normalized text, this should not be an issue. without_puncts = [s.lower().translate(str.maketrans(string.punctuation, " " * len(string.punctuation))) for s in queries] queries = [' '.join(s.split()) for s in without_puncts if s] return queries def _generate_paraphrases(self, processed_queries): """Generates paraphrase responses for given query. Args: queries (list(str)): List of application queries. Return: paraphrases (list(str)): List of paraphrased queries. """ all_generated_queries = [] for pos in range(0, len(processed_queries), self.batch_size): processed_input_queries = processed_queries[pos:pos + self.batch_size] tokenizer_input = self._prepare_inputs(processed_input_queries) batch = self.tokenizer.prepare_seq2seq_batch( tokenizer_input, **self.default_tokenizer_params, return_tensors="pt", ).to(self.torch_device) with _get_module_or_attr("torch", "no_grad")(): generated = self.model.generate( **batch, **self.default_paraphraser_model_params, ) decoded_queries = self.tokenizer.batch_decode(generated, skip_special_tokens=True) if self.retain_entities: decoded_queries = self._normalize_paraphrases(decoded_queries) decoded_queries = self._annotate_entities(decoded_queries, processed_input_queries) all_generated_queries.extend(decoded_queries) return all_generated_queries
[docs] def augment_queries(self, queries, **kwargs): augmented_queries = list( set( p.lower() for p in self._generate_paraphrases(queries, **kwargs) if self._validate_generated_query(p) ) ) return augmented_queries
[docs]class MultiLingualParaphraser(Augmentor): """Paraphraser class for generating paraphrases based on language code of the app (currently supports: French, Italian, Portuguese, Romanian and Spanish). """ def __init__(self, batch_size, language, retain_entities, paths, path_suffix, resource_loader): """Initializes a multi-lingual paraphraser. Args: batch_size (int): Batch size for batch processing. language (str): The language code for paraphrasing. paths (list): Path rules for fetching relevant files to Paraphrase. path_suffix (str): Suffix to be added to new augmented files. resource_loader (object): Resource Loader object for the application. """ if language not in ["es", "fr", "it", "pt", "ro"]: raise UnsupportedLanguageError( f"'{language}' is not supported by the MultiLingual Augmentor class" ) super().__init__( language=language, paths=paths, path_suffix=path_suffix, resource_loader=resource_loader, ) self.torch_device = ( "cuda" if _get_module_or_attr("torch.cuda", "is_available")() else "cpu" ) self.retain_entities = retain_entities MarianTokenizer = _get_module_or_attr("transformers", "MarianTokenizer") MarianMTModel = _get_module_or_attr("transformers", "MarianMTModel") en_model_name = "Helsinki-NLP/opus-mt-ROMANCE-en" self.en_tokenizer = MarianTokenizer.from_pretrained(en_model_name) self.en_model = MarianMTModel.from_pretrained(en_model_name) self.en_model.to(self.torch_device) self.en_model.eval() target_model_name = "Helsinki-NLP/opus-mt-en-ROMANCE" self.target_tokenizer = MarianTokenizer.from_pretrained(target_model_name) self.target_model = MarianMTModel.from_pretrained(target_model_name).to( self.torch_device ) self.target_model.eval() # Update default params with user model config self.batch_size = batch_size self.default_forward_params = { "max_length": 60, "num_beams": 5, "num_return_sequences": 5, "temperature": 1.0, "top_k": 0, } self.default_reverse_params = { "max_length": 60, "num_beams": 3, "num_return_sequences": 3, "temperature": 1.0, "top_k": 0, } def _translate(self, *, queries, model, tokenizer, **kwargs): """The core translation step for forward and reverse translation. Args: template (lambda func): Structure input text to model. queries (list(str)): List of input queries. model: Machine translation model (en-ROMANCE or ROMANCE-en). tokenizer: Language tokenizer for input query text. """ all_translated_queries = [] for pos in range(0, len(queries), self.batch_size): encoded = tokenizer.prepare_seq2seq_batch( queries[pos : pos + self.batch_size], return_tensors="pt" ).to(self.torch_device) for key in encoded: encoded[key] = encoded[key].to(self.torch_device) with _get_module_or_attr("torch", "no_grad")(): translated = model.generate(**encoded, **kwargs) translated_queries = tokenizer.batch_decode( translated, skip_special_tokens=True ) all_translated_queries.extend(translated_queries) return all_translated_queries def _prepare_inputs(self, processed_queries): """Removes any markdown formatting in the query Args: queries (list(str)): List of queries to be paraphrased Returns: unannotated queries (list(str)) """ unannotated_queries = [processed_query.query.text.strip() for processed_query in processed_queries] return unannotated_queries
[docs] def augment_queries(self, processed_queries): translated_queries = self._translate( queries=self._prepare_inputs(processed_queries), model=self.en_model, tokenizer=self.en_tokenizer, **self.default_forward_params, ) def template(text): return f">>{self.language_code}<< {text}" translated_queries = [template(query) for query in set(translated_queries)] reverse_translated_queries = self._translate( queries=translated_queries, model=self.target_model, tokenizer=self.target_tokenizer, **self.default_reverse_params, ) augmented_queries = list( set( p.lower() for p in reverse_translated_queries if self._validate_generated_query(p) ) ) return augmented_queries
[docs]def register_all_augmentors(): register_augmentor("EnglishParaphraser", EnglishParaphraser) register_augmentor("MultiLingualParaphraser", MultiLingualParaphraser)