Source code for ding.torch_utils.network.activation

import torch
import torch.nn as nn
import torch.nn.functional as F


[docs]class GLU(nn.Module): r""" Overview: Gating Linear Unit. This class does a thing like this: .. code::python # Inputs: input, context, output_size # The gate value is a learnt function of the input. gate = sigmoid(linear(input.size)(context)) # Gate the input and return an output of desired size. gated_input = gate * input output = linear(output_size)(gated_input) return output Interfaces: forward .. tip:: This module also supports 2D convolution, in which case, the input and context must have the same shape. """ def __init__(self, input_dim: int, output_dim: int, context_dim: int, input_type: str = 'fc') -> None: r""" Overview: Init GLU Arguments: - input_dim (:obj:`int`): the input dimension - output_dim (:obj:`int`): the output dimension - context_dim (:obj:`int`): the context dimension - input_type (:obj:`str`): the type of input, now support ['fc', 'conv2d'] """ super(GLU, self).__init__() assert (input_type in ['fc', 'conv2d']) if input_type == 'fc': self.layer1 = nn.Linear(context_dim, input_dim) self.layer2 = nn.Linear(input_dim, output_dim) elif input_type == 'conv2d': self.layer1 = nn.Conv2d(context_dim, input_dim, 1, 1, 0) self.layer2 = nn.Conv2d(input_dim, output_dim, 1, 1, 0)
[docs] def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor: r""" Overview: Return GLU computed tensor Arguments: - x (:obj:`torch.Tensor`) : the input tensor - context (:obj:`torch.Tensor`) : the context tensor Returns: - x (:obj:`torch.Tensor`): the computed tensor """ gate = self.layer1(context) gate = torch.sigmoid(gate) x = gate * x x = self.layer2(x) return x
class Swish(nn.Module): def __init__(self): super(Swish, self).__init__() def forward(self, x): x = x * torch.sigmoid(x) return x def build_activation(activation: str, inplace: bool = None) -> nn.Module: r""" Overview: Return the activation module according to the given type. Arguments: - actvation (:obj:`str`): the type of activation module, now supports ['relu', 'glu', 'prelu'] - inplace (:obj:`bool`): can optionally do the operation in-place in relu. Default ``None`` Returns: - act_func (:obj:`nn.module`): the corresponding activation module """ if inplace is not None: assert activation == 'relu', 'inplace argument is not compatible with {}'.format(activation) else: inplace = False act_func = {'relu': nn.ReLU(inplace=inplace), 'glu': GLU, 'prelu': nn.PReLU(), 'swish': Swish()} if activation in act_func.keys(): return act_func[activation] else: raise KeyError("invalid key for activation: {}".format(activation))