Source code for ding.torch_utils.network.rnn

import math
import torch
import torch.nn as nn
from typing import Optional, Union, List, Tuple

import ding
from ding.torch_utils.network.normalization import build_normalization
if ding.enable_hpc_rl:
    from hpc_rll.torch_utils.network.rnn import LSTM as HPCLSTM
else:
    HPCLSTM = None


def is_sequence(data):
    """
    Overview:
        Judege whether input ``data`` is instance ``list`` or ``tuple``.
    """
    return isinstance(data, list) or isinstance(data, tuple)


def sequence_mask(lengths: torch.Tensor, max_len: Optional[int] = None) -> torch.BoolTensor:
    r"""
    Overview:
        create a mask for a batch sequences with different lengths
    Arguments:
        - lengths (:obj:`torch.Tensor`): lengths in each different sequences, shape could be (n, 1) or (n)
        - max_len (:obj:`int`): the padding size, if max_len is None, the padding size is the \
            max length of sequences
    Returns:
        - masks (:obj:`torch.BoolTensor`): mask has the same device as lengths
    """
    if len(lengths.shape) == 1:
        lengths = lengths.unsqueeze(dim=1)
    bz = lengths.numel()
    if max_len is None:
        max_len = lengths.max()
    else:
        max_len = min(max_len, lengths.max())
    return torch.arange(0, max_len).type_as(lengths).repeat(bz, 1).lt(lengths).to(lengths.device)


[docs]class LSTMForwardWrapper(object): r""" Overview: A class which provides methods to use before and after `forward`, in order to wrap the LSTM `forward` method. Interfaces: _before_forward, _after_forward """
[docs] def _before_forward(self, inputs: torch.Tensor, prev_state: Union[torch.Tensor, list]) -> torch.Tensor: r""" Overview: Preprocess the inputs and previous states Arguments: - inputs (:obj:`torch.Tensor`): input vector of cell, tensor of size [seq_len, batch_size, input_size] - prev_state (:obj:`Union[torch.Tensor, list]`): None or tensor of size \ [num_directions*num_layers, batch_size, hidden_size]. \ If None then prv_state will be initialized to all zeros. Returns: - prev_state (:obj:`torch.Tensor`): batch previous state in lstm """ assert hasattr(self, 'num_layers') assert hasattr(self, 'hidden_size') seq_len, batch_size = inputs.shape[:2] if prev_state is None: num_directions = 1 zeros = torch.zeros( num_directions * self.num_layers, batch_size, self.hidden_size, dtype=inputs.dtype, device=inputs.device ) prev_state = (zeros, zeros) elif is_sequence(prev_state): if len(prev_state) == 2 and isinstance(prev_state[0], torch.Tensor): pass else: if len(prev_state) != batch_size: raise RuntimeError( "prev_state number is not equal to batch_size: {}/{}".format(len(prev_state), batch_size) ) num_directions = 1 zeros = torch.zeros( num_directions * self.num_layers, 1, self.hidden_size, dtype=inputs.dtype, device=inputs.device ) state = [] for prev in prev_state: if prev is None: state.append([zeros, zeros]) else: state.append(prev) state = list(zip(*state)) prev_state = [torch.cat(t, dim=1) for t in state] else: raise TypeError("not support prev_state type: {}".format(type(prev_state))) return prev_state
[docs] def _after_forward(self, next_state: List[Tuple[torch.Tensor]], list_next_state: bool = False) -> Union[torch.Tensor, list]: r""" Overview: Post-process the next_state, return list or tensor type next_states Arguments: - next_state (:obj:`List[Tuple[torch.Tensor]]`): List of tuple which contains the next (h, c) - list_next_state (:obj:`bool`): whether return next_state with list format, default set to False Returns: - next_state(:obj:`Union[torch.Tensor, list]`): the formatted next_state """ if list_next_state: h, c = [torch.stack(t, dim=0) for t in zip(*next_state)] batch_size = h.shape[1] next_state = [torch.chunk(h, batch_size, dim=1), torch.chunk(c, batch_size, dim=1)] next_state = list(zip(*next_state)) else: next_state = [torch.stack(t, dim=0) for t in zip(*next_state)] return next_state
[docs]class LSTM(nn.Module, LSTMForwardWrapper): r""" Overview: Implimentation of LSTM cell Interface: forward .. note:: s For begainners, you can refer to <https://zhuanlan.zhihu.com/p/32085405> to learn the basics about lstm """ def __init__( self, input_size: int, hidden_size: int, num_layers: int, norm_type: Optional[str] = None, dropout: float = 0. ) -> None: r""" Overview: Initializate the LSTM cell Arguments: - input_size (:obj:`int`): size of the input vector - hidden_size (:obj:`int`): size of the hidden state vector - num_layers (:obj:`int`): number of lstm layers - norm_type (:obj:`Optional[str]`): type of the normaliztion, (default: None) - dropout (:obj:float): dropout rate, default set to .0 """ super(LSTM, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.num_layers = num_layers norm_func = build_normalization(norm_type) self.norm = nn.ModuleList([norm_func(hidden_size * 4) for _ in range(2 * num_layers)]) self.wx = nn.ParameterList() self.wh = nn.ParameterList() dims = [input_size] + [hidden_size] * num_layers for l in range(num_layers): self.wx.append(nn.Parameter(torch.zeros(dims[l], dims[l + 1] * 4))) self.wh.append(nn.Parameter(torch.zeros(hidden_size, hidden_size * 4))) self.bias = nn.Parameter(torch.zeros(num_layers, hidden_size * 4)) self.use_dropout = dropout > 0. if self.use_dropout: self.dropout = nn.Dropout(dropout) self._init() def _init(self): gain = math.sqrt(1. / self.hidden_size) for l in range(self.num_layers): torch.nn.init.uniform_(self.wx[l], -gain, gain) torch.nn.init.uniform_(self.wh[l], -gain, gain) if self.bias is not None: torch.nn.init.uniform_(self.bias[l], -gain, gain)
[docs] def forward(self, inputs: torch.Tensor, prev_state: torch.Tensor, list_next_state: bool = True) -> Tuple[torch.Tensor, Union[torch.Tensor, list]]: r""" Overview: Take the previous state and the input and calculate the output and the nextstate Arguments: - inputs (:obj:`torch.Tensor`): input vector of cell, tensor of size [seq_len, batch_size, input_size] - prev_state (:obj:`torch.Tensor`): None or tensor of size \ [num_directions*num_layers, batch_size, hidden_size] - list_next_state (:obj:`bool`): whether return next_state with list format, default set to False Returns: - x (:obj:`torch.Tensor`): output from lstm - next_state (:obj:`Union[torch.Tensor, list]`): hidden state from lstm """ seq_len, batch_size = inputs.shape[:2] prev_state = self._before_forward(inputs, prev_state) H, C = prev_state x = inputs next_state = [] for l in range(self.num_layers): h, c = H[l], C[l] new_x = [] for s in range(seq_len): gate = self.norm[l * 2](torch.matmul(x[s], self.wx[l]) ) + self.norm[l * 2 + 1](torch.matmul(h, self.wh[l])) if self.bias is not None: gate += self.bias[l] gate = list(torch.chunk(gate, 4, dim=1)) i, f, o, u = gate i = torch.sigmoid(i) f = torch.sigmoid(f) o = torch.sigmoid(o) u = torch.tanh(u) c = f * c + i * u h = o * torch.tanh(c) new_x.append(h) next_state.append((h, c)) x = torch.stack(new_x, dim=0) if self.use_dropout and l != self.num_layers - 1: x = self.dropout(x) next_state = self._after_forward(next_state, list_next_state) return x, next_state
[docs]class PytorchLSTM(nn.LSTM, LSTMForwardWrapper): r""" Overview: Wrap the PyTorch nn.LSTM, format the input and output Interface: forward .. note:: you can reference the <https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html#torch.nn.LSTM> """
[docs] def forward(self, inputs: torch.Tensor, prev_state: torch.Tensor, list_next_state: bool = True) -> Tuple[torch.Tensor, Union[torch.Tensor, list]]: r""" Overview: Wrapped nn.LSTM.forward Arguments: - inputs (:obj:`torch.Tensor`): input vector of cell, tensor of size \ [seq_len, batch_size, input_size] - prev_state (:obj:`torch.Tensor`): None or tensor of size \ [num_directions*num_layers, batch_size, hidden_size] - list_next_state (:obj:`bool`): whether return next_state with list format, default set to False Returns: - output (:obj:`torch.Tensor`): output from lstm - next_state (:obj:`Union[torch.Tensor, list]`): hidden state from lstm """ prev_state = self._before_forward(inputs, prev_state) output, next_state = nn.LSTM.forward(self, inputs, prev_state) next_state = self._after_forward(next_state, list_next_state) return output, next_state
def _after_forward(self, next_state, list_next_state=False): if list_next_state: h, c = next_state batch_size = h.shape[1] next_state = [torch.chunk(h, batch_size, dim=1), torch.chunk(c, batch_size, dim=1)] return list(zip(*next_state)) else: return next_state
class GRU(nn.GRUCell): r""" Overview: Wrap the nn.GRU , format the input and output Interface: forward .. note:: you can reference the <https://pytorch.org/docs/stable/generated/torch.nn.GRU.html#torch.nn.GRU> """ def __init__(self, input_size, hidden_size, num_layers): super(GRU, self).__init__(input_size, hidden_size) self.hidden_size = hidden_size self.num_layers = num_layers def _before_forward(self, inputs, prev_state): r""" Overview: preprocess the inputs and previous states Arguments: - inputs (:obj:`tensor`): input vector of cell, tensor of size [seq_len, batch_size, input_size] - prev_state (:obj:`tensor` or :obj:`list`): None or tensor of size [num_directions*num_layers, batch_size, hidden_size], if None then prv_state will be initialized to all zeros. Returns: - prev_state (:obj:`tensor`): batch previous state in GRU """ assert hasattr(self, 'num_layers') assert hasattr(self, 'hidden_size') seq_len, batch_size = inputs.shape[:2] if prev_state is None: num_directions = 1 zeros = torch.zeros( num_directions * self.num_layers, batch_size, self.hidden_size, dtype=inputs.dtype, device=inputs.device ) prev_state = (zeros, zeros) elif is_sequence(prev_state): if len(prev_state) == 2 and isinstance(prev_state[0], torch.Tensor): pass else: if len(prev_state) != batch_size: raise RuntimeError( "prev_state number is not equal to batch_size: {}/{}".format(len(prev_state), batch_size) ) num_directions = 1 zeros = torch.zeros( num_directions * self.num_layers, 1, self.hidden_size, dtype=inputs.dtype, device=inputs.device ) state = [] for prev in prev_state: if prev is None: state.append([zeros, zeros]) else: state.append(prev) state = list(zip(*state)) prev_state = [torch.cat(t, dim=1) for t in state] else: raise TypeError("not support prev_state type: {}".format(type(prev_state))) return prev_state def forward(self, inputs, prev_state, list_next_state=True): r""" Overview: wrapped nn.GRU.forward Arguments: - inputs (:obj:`tensor`): input vector of cell, tensor of size [seq_len, batch_size, input_size] - prev_state (:obj:`tensor`): None or tensor of size [num_directions*num_layers, batch_size, hidden_size] - list_next_state (:obj:`bool`): whether return next_state with list format, default set to False Returns: - output (:obj:`tensor`): output from GRU - next_state (:obj:`tensor` or :obj:`list`): hidden state from GRU """ prev_state = self._before_forward(inputs, prev_state)[0] next_state = nn.GRUCell.forward(self, inputs.squeeze(0), prev_state.squeeze(0)) next_state.unsqueeze_(0) x = next_state next_state = self._after_forward([next_state, next_state.clone()], list_next_state) # for compatibility return x, next_state def _after_forward(self, next_state, list_next_state=False): r""" Overview: process hidden state after GRU, make it list or remains tensor Arguments: - nex_state (:obj:`tensor`): hidden state from GRU - list_nex_state (:obj:`bool`): whether return next_state with list format, default set to False Returns: - next_state (:obj:`tensor` or :obj:`list`): hidden state from GRU """ if list_next_state: h, c = next_state batch_size = h.shape[1] next_state = [torch.chunk(h, batch_size, dim=1), torch.chunk(c, batch_size, dim=1)] return list(zip(*next_state)) else: return next_state def get_lstm( lstm_type: str, input_size: int, hidden_size: int, num_layers: int = 1, norm_type: str = 'LN', dropout: float = 0., seq_len: Optional[int] = None, batch_size: Optional[int] = None ) -> Union[LSTM, PytorchLSTM]: r""" Overview: Build and return the corresponding LSTM cell Arguments: - lstm_type (:obj:`str`): version of rnn cell, now support ['normal', 'pytorch', 'hpc', 'gru'] - input_size (:obj:`int`): size of the input vector - hidden_size (:obj:`int`): size of the hidden state vector - num_layers (:obj:`int`): number of lstm layers - norm_type (:obj:`str`): type of the normaliztion, (default: None) - dropout (:obj:float): dropout rate, default set to .0 - seq_len (:obj:`Optional[int]`): seq len, default set to None - batch_size (:obj:`Optional[int]`): batch_size len, default set to None Returns: - lstm (:obj:`Union[LSTM, PytorchLSTM]`): the corresponding lstm cell """ assert lstm_type in ['normal', 'pytorch', 'hpc', 'gru'] if lstm_type == 'normal': return LSTM(input_size, hidden_size, num_layers, norm_type, dropout=dropout) elif lstm_type == 'pytorch': return PytorchLSTM(input_size, hidden_size, num_layers, dropout=dropout) elif lstm_type == 'hpc': return HPCLSTM(seq_len, batch_size, input_size, hidden_size, num_layers, norm_type, dropout).cuda() elif lstm_type == 'gru': assert num_layers == 1 return GRU(input_size, hidden_size, num_layers)