Source code for lightwood.encoder.text.pretrained

"""
"""
import time
import torch
from torch.utils.data import DataLoader
import os
import pandas as pd
from lightwood.encoder.text.helpers.pretrained_helpers import TextEmbed
from lightwood.helpers.device import get_devices
from lightwood.encoder.base import BaseEncoder
from lightwood.helpers.log import log
from lightwood.helpers.torch import LightwoodAutocast
from lightwood.api import dtype
from transformers import (
    DistilBertModel,
    DistilBertForSequenceClassification,
    DistilBertTokenizerFast,
    AdamW,
    get_linear_schedule_with_warmup,
)
from lightwood.helpers.general import is_none
from typing import Iterable


[docs]class PretrainedLangEncoder(BaseEncoder): is_trainable_encoder: bool = True """ Creates a contextualized embedding to represent input text via the [CLS] token vector from DistilBERT (transformers). (Sanh et al. 2019 - https://arxiv.org/abs/1910.01108). In certain text tasks, this model can use a transformer to automatically fine-tune on a class of interest (providing there is a 2 column dataset, where the input column is text). """ # noqa def __init__( self, stop_after: float, is_target: bool = False, batch_size: int = 10, max_position_embeddings: int = None, frozen: bool = False, epochs: int = 1, output_type: str = None, embed_mode: bool = True, ): """ :param is_target: Whether this encoder represents the target. NOT functional for text generation yet. :param batch_size: size of batch while fine-tuning :param max_position_embeddings: max sequence length of input text :param custom_train: If True, trains model on target procided :param frozen: If True, freezes transformer layers during training. :param epochs: number of epochs to train model with :param output_type: Data dtype of the target; if categorical/binary, the option to return logits is possible. :param embed_mode: If True, assumes the output of the encode() step is the CLS embedding (this can be trained or not). If False, returns the logits of the tuned task. """ # noqa super().__init__(is_target) self.output_type = output_type self.name = "distilbert text encoder" self._max_len = max_position_embeddings self._frozen = frozen self._batch_size = batch_size self._epochs = epochs # Model setup self._model = None self.model_type = None # TODO: Other LMs; Distilbert is a good balance of speed/performance self._classifier_model_class = DistilBertForSequenceClassification self._embeddings_model_class = DistilBertModel self._pretrained_model_name = "distilbert-base-uncased" self._tokenizer = DistilBertTokenizerFast.from_pretrained(self._pretrained_model_name) self.device, _ = get_devices() self.stop_after = stop_after self.embed_mode = embed_mode self.uses_target = True self.output_size = None if self.embed_mode: log.info("Embedding mode on. [CLS] embedding dim output of encode()") else: log.info("Embedding mode off. Logits are output of encode()")
[docs] def prepare( self, train_priming_data: Iterable[str], dev_priming_data: Iterable[str], encoded_target_values: torch.Tensor, ): """ Fine-tunes a transformer on the priming data. CURRENTLY WIP; train + dev are placeholders for a validation-based approach. Train + Dev are concatenated together and a transformer is then fine tuned with weight-decay applied on the transformer parameters. The option to freeze the underlying transformer and only train a linear layer exists if `frozen=True`. This trains faster, with the exception that the performance is often lower than fine-tuning on internal benchmarks. :param train_priming_data: Text data in the train set :param dev_priming_data: Text data in the dev set (not currently supported; can be empty) :param encoded_target_values: Encoded target labels in Nrows x N_output_dimension """ # noqa if self.is_prepared: raise Exception("Encoder is already prepared.") os.environ['TOKENIZERS_PARALLELISM'] = 'true' # TODO -> we shouldn't be concatenating these together if len(dev_priming_data) > 0: priming_data = pd.concat([train_priming_data, dev_priming_data]).values else: priming_data = train_priming_data.tolist() # Replaces empty strings with '' priming_data = [x if x is not None else "" for x in priming_data] # If classification, then fine-tune if (self.output_type in (dtype.categorical, dtype.binary)): log.info("Training model.") # Prepare priming data into tokenized form + attention masks text = self._tokenizer(priming_data, truncation=True, padding=True) log.info("\tOutput trained is categorical") # Label encode the OHE/binary output for classification labels = encoded_target_values.argmax(dim=1) # Construct the model self._model = self._classifier_model_class.from_pretrained( self._pretrained_model_name, num_labels=len(encoded_target_values[0]), # max classes to test ).to(self.device) # Construct the dataset for training xinp = TextEmbed(text, labels) dataset = DataLoader(xinp, batch_size=self._batch_size, shuffle=True) # Set max length of input string; affects input to the model if self._max_len is None: self._max_len = self._model.config.max_position_embeddings if self._frozen: log.info("\tFrozen Model + Training Classifier Layers") """ Freeze the base transformer model and train a linear layer on top """ # Freeze all the transformer parameters for param in self._model.base_model.parameters(): param.requires_grad = False optimizer_grouped_parameters = self._model.parameters() else: log.info("\tFine-tuning model") """ Fine-tuning parameters with weight decay """ no_decay = [ "bias", "LayerNorm.weight", ] # decay on all terms EXCLUDING bias/layernorms optimizer_grouped_parameters = [ { "params": [ p for n, p in self._model.named_parameters() if not any(nd in n for nd in no_decay) ], "weight_decay": 0.01, }, { "params": [ p for n, p in self._model.named_parameters() if any(nd in n for nd in no_decay) ], "weight_decay": 0.0, }, ] optimizer = AdamW(optimizer_grouped_parameters, lr=1e-5) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=0, # default value for GLUE num_training_steps=len(dataset) * self._epochs, ) # Train model; declare optimizer earlier if desired. self._tune_model( dataset, optim=optimizer, scheduler=scheduler, n_epochs=self._epochs ) else: log.info("Target is not classification; Embeddings Generator only") self.model_type = "embeddings_generator" self._model = self._embeddings_model_class.from_pretrained( self._pretrained_model_name ).to(self.device) # TODO: Not a great flag # Currently, if the task is not classification, you must have # an embedding generator only. if self.embed_mode is False: log.info("Embedding mode must be ON for non-classification targets.") self.embed_mode = True self.is_prepared = True encoded = self.encode(priming_data[0:1]) self.output_size = len(encoded[0])
def _tune_model(self, dataset, optim, scheduler, n_epochs=1): """ Given a model, train for n_epochs. Specifically intended for tuning; it does NOT use loss/ stopping criterion. model - torch.nn model; dataset - torch.DataLoader; dataset to train device - torch.device; cuda/cpu log - lightwood.logger.log; log.info output optim - transformers.optimization.AdamW; optimizer scheduler - scheduling params n_epochs - number of epochs to train """ # noqa self._model.train() if optim is None: log.info("No opt. provided, setting all params with AdamW.") optim = AdamW(self._model.parameters(), lr=5e-5) else: log.info("Optimizer provided") if scheduler is None: log.info("No scheduler provided.") else: log.info("Scheduler provided.") started = time.time() for epoch in range(n_epochs): total_loss = 0 for batch in dataset: optim.zero_grad() with LightwoodAutocast(): inpids = batch["input_ids"].to(self.device) attn = batch["attention_mask"].to(self.device) labels = batch["labels"].to(self.device) outputs = self._model(inpids, attention_mask=attn, labels=labels) loss = outputs[0] total_loss += loss.item() loss.backward() optim.step() if scheduler is not None: scheduler.step() if time.time() - started > self.stop_after: break if time.time() - started > self.stop_after: break self._train_callback(epoch, total_loss / len(dataset)) def _train_callback(self, epoch, loss): log.info(f"{self.name} at epoch {epoch+1} and loss {loss}!")
[docs] def encode(self, column_data: Iterable[str]) -> torch.Tensor: """ Converts each text example in a column into encoded state. This can be either a vector embedding of the [CLS] token (represents the full text input) OR the logits prediction of the output. The transformer model is of form: transformer base + pre-classifier linear layer + classifier layer The embedding returned is of the [CLS] token after the pre-classifier layer; from internal testing, we found the latent space most highly separated across classes. If the encoder represents the logits in classification, returns a soft-maxed output of the class vector. :param column_data: List of text data as strings :returns: Embedded vector N_rows x Nembed_dim OR logits vector N_rows x N_classes depending on if `embed_mode` is True or not. """ # noqa if self.is_prepared is False: raise Exception("You need to first prepare the encoder.") # Set model to testing/eval mode. self._model.eval() encoded_representation = [] with torch.no_grad(): # Set the weights; this is GPT-2 for text in column_data: # Omit NaNs if is_none(text): text = "" # Tokenize the text with the built-in tokenizer. inp = self._tokenizer.encode( text, truncation=True, return_tensors="pt" ).to(self.device) if self.embed_mode: # Embedding mode ON; return [CLS] output = self._model.base_model(inp).last_hidden_state[:, 0] # If the model has a pre-classifier layer, use this embedding. if hasattr(self._model, "pre_classifier"): output = self._model.pre_classifier(output) else: # Embedding mode off; return classes output = self._model(inp).logits encoded_representation.append(output.detach()) return torch.stack(encoded_representation).squeeze(1).to('cpu')
[docs] def decode(self, encoded_values_tensor, max_length=100): """ Text generation via decoding is not supported. """ # noqa raise Exception("Decoder not implemented.")
[docs] def to(self, device, available_devices): """ Converts encoder models to device specified (CPU/GPU) Transformers are LARGE models, please run on GPU for fastest implementation. """ # noqa for v in vars(self): attr = getattr(self, v) if isinstance(attr, torch.nn.Module): attr.to(device) return self