Source code for transfer_nlp.plugins.trainers

"""
This class contains the abstraction interface to customize runners.
For the training loop, we use the engine logic from pytorch-ignite

Check experiments for examples of experiment json files

"""
import inspect
import logging
from itertools import zip_longest
from typing import Dict, List, Any

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from ignite.contrib.handlers.tqdm_logger import ProgressBar
from ignite.engine import Events
from ignite.engine.engine import Engine
from ignite.metrics import Loss, Metric, RunningAverage
from ignite.utils import convert_tensor
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OutputHandler, OptimizerParamsHandler, WeightsScalarHandler, WeightsHistHandler, \
    GradsScalarHandler
from tensorboardX import SummaryWriter


from transfer_nlp.loaders.loaders import DatasetSplits
from transfer_nlp.plugins.config import register_plugin, ExperimentConfig, PluginFactory
from transfer_nlp.plugins.regularizers import RegularizerABC

logger = logging.getLogger(__name__)


def set_seed_everywhere(seed: int, cuda: bool):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if cuda:
        torch.cuda.manual_seed_all(seed)


def _prepare_batch(batch: Dict, device=None, non_blocking: bool = False):
    """Prepare batch for training: pass to a device with options.

    """
    result = {key: convert_tensor(value, device=device, non_blocking=non_blocking) for key, value in batch.items()}
    return result


[docs]class TrainingMetric(Metric): def __init__(self, metric: Metric): self.source_metric = metric self.reset() super().__init__(lambda x: x[:-1]) def reset(self): self.source_metric.reset() def update(self, output): self.source_metric.update(output) def compute(self): return self.source_metric.compute()
[docs]@register_plugin class BasicTrainer: def __init__(self, model: nn.Module, dataset_splits: DatasetSplits, loss: nn.Module, optimizer: optim.Optimizer, metrics: Dict[str, Metric], experiment_config: ExperimentConfig, device: str = None, num_epochs: int = 1, seed: int = None, cuda: bool = None, loss_accumulation_steps: int = 4, scheduler: Any = None, # no common parent class? regularizer: RegularizerABC = None, gradient_clipping: float = 1.0, output_transform=None, tensorboard_logs: str = None, embeddings_name: str = None, finetune: bool = False): self.model: nn.Module = model self.forward_param_defaults = {} model_spec = inspect.getfullargspec(model.forward) self.forward_params: List[str] = model_spec.args[1:] for fparam, pdefault in zip(reversed(model_spec.args[1:]), reversed(model_spec.defaults if model_spec.defaults else [])): self.forward_param_defaults[fparam] = pdefault self.dataset_splits: DatasetSplits = dataset_splits self.loss: nn.Module = loss self.optimizer: optim.Optimizer = optimizer self.metrics: Dict[str, Metric] = metrics self.metrics: List[Metric] = [metric for _, metric in self.metrics.items()] self.experiment_config: ExperimentConfig = experiment_config self.device: str = device self.num_epochs: int = num_epochs self.scheduler: Any = scheduler self.seed: int = seed self.cuda: bool = cuda if self.cuda is None: # If cuda not specified, just check if the cuda is available and use accordingly self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.loss_accumulation_steps: int = loss_accumulation_steps self.regularizer: RegularizerABC = regularizer self.gradient_clipping: float = gradient_clipping self.output_transform = output_transform self.tensorboard_logs = tensorboard_logs if self.tensorboard_logs: self.writer = SummaryWriter(log_dir=self.tensorboard_logs) self.embeddings_name = embeddings_name if self.output_transform: self.trainer, self.training_metrics = self.create_supervised_trainer(output_transform=self.output_transform) self.evaluator = self.create_supervised_evaluator(output_transform=self.output_transform) else: self.trainer, self.training_metrics = self.create_supervised_trainer() self.evaluator = self.create_supervised_evaluator() self.finetune = finetune self.optimizer_factory: PluginFactory = None loss_metrics = [m for m in self.metrics if isinstance(m, Loss)] if self.scheduler: if not loss_metrics: raise ValueError('A loss metric must be configured') elif len(loss_metrics) > 1: logging.warning('multiple loss metrics detected, using %s for LR scheduling', loss_metrics[0]) self.loss_metric = loss_metrics[0] self.setup(self.training_metrics) def setup(self, training_metrics): def metric_name(n) -> str: if n.endswith('Accuracy'): n = 'acc' else: n = n[:-6] if n.endswith('Metric') else n return n def print_metrics(metrics) -> str: rv = '' metric_keys = sorted(k for k in metrics) for k in metric_keys: if k == 'Accuracy': rv += f'{metric_name(k)}: {metrics[k]:.3}' else: rv += f'{metric_name(k)}: {metrics[k]:.6}' return rv if self.seed: set_seed_everywhere(self.seed, self.cuda) pbar = ProgressBar() names = [] for k, v in training_metrics.items(): name = f'r{k}' names.append(name) RunningAverage(v).attach(self.trainer, name) RunningAverage(None, output_transform=lambda x: x[-1] * self.loss_accumulation_steps).attach(self.trainer, 'rloss') names.append('rloss') pbar.attach(self.trainer, names) pbar = ProgressBar() pbar.attach(self.evaluator) # A few events handler. To add / modify the events handler, you need to extend the __init__ method of RunnerABC # Ignite provides the necessary abstractions and a furnished repository of useful tools @self.trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(trainer): self.evaluator.run(self.dataset_splits.val_data_loader()) metrics = self.evaluator.state.metrics logger.info(f"Validation Results - Epoch: {trainer.state.epoch} {print_metrics(metrics)}") if self.scheduler: self.scheduler.step(metrics[self.loss_metric.__class__.__name__]) @self.trainer.on(Events.COMPLETED) def log_test_results(trainer): self.evaluator.run(self.dataset_splits.test_data_loader()) metrics = self.evaluator.state.metrics logger.info(f"Test Results - Epoch: {trainer.state.epoch} {print_metrics(metrics)}") if self.tensorboard_logs: tb_logger = TensorboardLogger(log_dir=self.tensorboard_logs) tb_logger.attach(self.trainer, log_handler=OutputHandler(tag="training", output_transform=lambda loss: { 'loss': loss}), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(self.evaluator, log_handler=OutputHandler(tag="validation", metric_names=["LossMetric"], another_engine=self.trainer), event_name=Events.EPOCH_COMPLETED) tb_logger.attach(self.trainer, log_handler=OptimizerParamsHandler(self.optimizer), event_name=Events.ITERATION_STARTED) tb_logger.attach(self.trainer, log_handler=WeightsScalarHandler(self.model), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(self.trainer, log_handler=WeightsHistHandler(self.model), event_name=Events.EPOCH_COMPLETED) tb_logger.attach(self.trainer, log_handler=GradsScalarHandler(self.model), event_name=Events.ITERATION_COMPLETED) # This is important to close the tensorboard file logger @self.trainer.on(Events.COMPLETED) def end_tensorboard(trainer): logger.info("Training completed") tb_logger.close() if self.embeddings_name: @self.trainer.on(Events.COMPLETED) def log_embeddings(trainer): if hasattr(self.model, self.embeddings_name) and hasattr(self.dataset_splits, "vectorizer"): logger.info(f"Logging embeddings ({self.embeddings_name}) to Tensorboard!") embeddings = getattr(self.model, self.embeddings_name).weight.data metadata = [str(self.dataset_splits.vectorizer.data_vocab._id2token[token_index]).encode('utf-8') for token_index in range(embeddings.shape[0])] self.writer.add_embedding(mat=embeddings, metadata=metadata, global_step=self.trainer.state.epoch) def _forward(self, batch): model_inputs = {} for p in self.forward_params: val = batch.get(p) if val is None: if p in self.forward_param_defaults: val = self.forward_param_defaults[p] else: raise ValueError(f'missing model parameter "{p}"') model_inputs[p] = val return self.model(**model_inputs) def create_supervised_trainer(self, prepare_batch=_prepare_batch, non_blocking=False, output_transform=lambda y_pred, y_target, loss: (y_pred, y_target, loss)): if self.device: self.model.to(self.device) # Gradient accumulation trick adapted from : # https://medium.com/huggingface/training-larger-batches-practical-tips-on-1-gpu-multi-gpu-distributed-setups-ec88c3e51255 accumulation_steps = self.loss_accumulation_steps def _update(engine, batch): self.model.train() batch = prepare_batch(batch, device=self.device, non_blocking=non_blocking) y_pred = self._forward(batch) loss = self.loss(input=y_pred, target=batch['y_target']) # Add a regularisation term at train time only if self.regularizer: loss += self.regularizer.compute_penalty(model=self.model) loss /= accumulation_steps loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.gradient_clipping) if engine.state.iteration % accumulation_steps == 0: self.optimizer.step() self.optimizer.zero_grad() return output_transform(y_pred, batch['y_target'], loss.item()) engine = Engine(_update) metrics = {} for i, metric in enumerate(self.metrics): if not isinstance(metric, Loss): n = metric.__class__.__name__ tm = TrainingMetric(metric) metrics[n] = tm tm.attach(engine, n) return engine, metrics def create_supervised_evaluator(self, prepare_batch=_prepare_batch, non_blocking=False, output_transform=lambda y, y_pred: (y, y_pred)): if self.device: self.model.to(self.device) def _inference(engine, batch): self.model.eval() with torch.no_grad(): batch = prepare_batch(batch, device=self.device, non_blocking=non_blocking) y_pred = self._forward(batch) return output_transform(y_pred, batch['y_target']) engine = Engine(_inference) for i, metric in enumerate(self.metrics): metric.attach(engine, f'{str(metric.__class__.__name__)}') return engine
[docs] def freeze_and_replace_final_layer(self): """ Freeze al layers and replace the last layer with a custom Linear projection on the predicted classes Note: this method assumes that the pre-trained model ends with a `classifier` layer, that we want to learn :return: """ # freeze all layers for param in self.model.parameters(): param.requires_grad = False # Number of input features to the final classification layer number_features = self.model.classifier.in_features # If `classifier` has several layers itself, this will only remove the last on, otherwise this does not contain anything features = list(self.model.classifier.children())[:-1] logger.info(f"Keeping layers {list(self.model.classifier.children())[:-1]} from the classifier layer") logger.info(f"Append layer {torch.nn.Linear(number_features, self.model.num_labels)} to the classifier") # Create the final linear layer for classification features.append(torch.nn.Linear(number_features, self.model.num_labels)) self.model.classifier = torch.nn.Sequential(*features) self.model = self.model.to(self.device)
[docs] def train(self): """ Launch the ignite training pipeline If fine-tuning mode is granted in the config file, freeze all layers, replace classification layer by a Linear layer and reset the optimizer :return: """ if self.finetune: logger.info(f"Fine-tuning the last classification layer to the data") trainer_key = [k for k, v in self.experiment_config.experiment.items() if v is self] if trainer_key: self.optimizer_factory = self.experiment_config.factories['optimizer'] else: raise ValueError('this trainer object was not found in config') self.freeze_and_replace_final_layer() self.optimizer = self.optimizer_factory.create() self.trainer.run(self.dataset_splits.train_data_loader(), max_epochs=self.num_epochs)