Source code for transfer_nlp.plugins.regularizers

import torch

from transfer_nlp.plugins.config import register_plugin


[docs]class RegularizerABC: def __call__(self, *args, **kwargs): raise NotImplementedError def __str__(self): raise NotImplemented def compute_penalty(self, model: torch.nn.Module): raise NotImplementedError
@register_plugin class L1(RegularizerABC): def __init__(self, alpha: float = 0.01) -> None: self.alpha = alpha def __call__(self, parameter: torch.Tensor) -> torch.Tensor: return self.alpha * torch.sum(torch.abs(parameter)) def __str__(self): return f"L1(alpha={self.alpha})" def compute_penalty(self, model: torch.nn.Module): """ Compute a penalty value uniformly over layers :param self: :param model: :return: """ penalty = 0 for name, parameter in model.named_parameters(): penalty += self(parameter) return penalty @register_plugin class L2(RegularizerABC): def __init__(self, alpha: float = 0.01) -> None: self.alpha = alpha def __str__(self): return f"L2(alpha={self.alpha})" def __call__(self, parameter: torch.Tensor) -> torch.Tensor: return self.alpha * torch.sum(torch.pow(parameter, 2)) def compute_penalty(self, model: torch.nn.Module): """ Compute a penalty value uniformly over layers :param self: :param model: :return: """ penalty = 0 for name, parameter in model.named_parameters(): penalty += self(parameter) return penalty