diff --git a/hapi/tests/test_text.py b/hapi/tests/test_text.py index eca5fda0a4927641a1de7aa114dd6cf7e47ab304..6f0d014f53485d64db289560f8f06bb995beda11 100644 --- a/hapi/tests/test_text.py +++ b/hapi/tests/test_text.py @@ -25,7 +25,6 @@ from paddle.fluid.dygraph import Embedding, Linear, Layer from paddle.fluid.layers import BeamSearchDecoder import hapi.text as text from hapi.model import Model, Input, set_device -# from hapi.text.text import BasicLSTMCell, BasicGRUCell, RNN, DynamicDecode, MultiHeadAttention, TransformerEncoder, TransformerCell from hapi.text.text import * @@ -515,15 +514,142 @@ class TestTransformerBeamSearchDecoder(ModuleApiTest): class TestSequenceTagging(ModuleApiTest): def setUp(self): - shape = (2, 4, 128) + self.inputs = [ + np.random.randint(0, 100, (2, 8)).astype("int64"), + np.random.randint(1, 8, (2)).astype("int64"), + np.random.randint(0, 5, (2, 8)).astype("int64") + ] + self.outputs = None + self.attrs = {"vocab_size": 100, "num_labels": 5} + self.param_states = {} + + @staticmethod + def model_init(self, + vocab_size, + num_labels, + word_emb_dim=128, + grnn_hidden_dim=128, + emb_learning_rate=0.1, + crf_learning_rate=0.1, + bigru_num=2, + init_bound=0.1): + self.tagger = SequenceTagging(vocab_size, num_labels, word_emb_dim, + grnn_hidden_dim, emb_learning_rate, + crf_learning_rate, bigru_num, init_bound) + + @staticmethod + def model_forward(self, word, lengths, target=None): + return self.tagger(word, lengths, target) + + def make_inputs(self): + inputs = [ + Input( + [None, None], "int64", name="word"), + Input( + [None], "int64", name="lengths"), + Input( + [None, None], "int64", name="target"), + ] + return inputs + + def test_check_output(self): + self.check_output() + + +class TestSequenceTaggingInfer(TestSequenceTagging): + def setUp(self): + super(TestSequenceTaggingInfer, self).setUp() + self.inputs = self.inputs[:2] # remove target + + def make_inputs(self): + inputs = super(TestSequenceTaggingInfer, + self).make_inputs()[:2] # remove target + return inputs + + +class TestLSTM(ModuleApiTest): + def setUp(self): + shape = (2, 4, 16) self.inputs = [np.random.random(shape).astype("float32")] self.outputs = None - self.attrs = {"input_size": 128, "hidden_size": 128} + self.attrs = {"input_size": 16, "hidden_size": 16, "num_layers": 2} self.param_states = {} @staticmethod - def model_init(self, input_size, hidden_size): - self.module = SequenceTagging(input_size, hidden_size) + def model_init(self, input_size, hidden_size, num_layers): + self.lstm = LSTM(input_size, hidden_size, num_layers=num_layers) + + @staticmethod + def model_forward(self, inputs): + return self.lstm(inputs)[0] + + def make_inputs(self): + inputs = [ + Input( + [None, None, self.inputs[-1].shape[-1]], + "float32", + name="input"), + ] + return inputs + + def test_check_output(self): + self.check_output() + + +class TestBiLSTM(ModuleApiTest): + def setUp(self): + shape = (2, 4, 16) + self.inputs = [np.random.random(shape).astype("float32")] + self.outputs = None + self.attrs = {"input_size": 16, "hidden_size": 16, "num_layers": 2} + self.param_states = {} + + @staticmethod + def model_init(self, + input_size, + hidden_size, + num_layers, + merge_mode="concat", + merge_each_layer=False): + self.bilstm = BidirectionalLSTM( + input_size, + hidden_size, + num_layers=num_layers, + merge_mode=merge_mode, + merge_each_layer=merge_each_layer) + + @staticmethod + def model_forward(self, inputs): + return self.bilstm(inputs)[0] + + def make_inputs(self): + inputs = [ + Input( + [None, None, self.inputs[-1].shape[-1]], + "float32", + name="input"), + ] + return inputs + + def test_check_output_merge0(self): + self.check_output() + + def test_check_output_merge1(self): + self.attrs["merge_each_layer"] = True + self.check_output() + + +class TestGRU(ModuleApiTest): + def setUp(self): + shape = (2, 4, 64) + self.inputs = [np.random.random(shape).astype("float32")] + self.outputs = None + self.attrs = {"input_size": 64, "hidden_size": 128, "num_layers": 2} + self.param_states = {} + + @staticmethod + def model_init(self, input_size, hidden_size, num_layers): + self.gru = GRU(input_size, hidden_size, num_layers=num_layers) @staticmethod def model_forward(self, inputs): @@ -542,5 +668,48 @@ class TestSequenceTagging(ModuleApiTest): self.check_output() +class TestBiGRU(ModuleApiTest): + def setUp(self): + shape = (2, 4, 64) + self.inputs = [np.random.random(shape).astype("float32")] + self.outputs = None + self.attrs = {"input_size": 64, "hidden_size": 128, "num_layers": 2} + self.param_states = {} + + @staticmethod + def model_init(self, + input_size, + hidden_size, + num_layers, + merge_mode="concat", + merge_each_layer=False): + self.bigru = BidirectionalGRU( + input_size, + hidden_size, + num_layers=num_layers, + merge_mode=merge_mode, + merge_each_layer=merge_each_layer) + + @staticmethod + def model_forward(self, inputs): + return self.bigru(inputs)[0] + + def make_inputs(self): + inputs = [ + Input( + [None, None, self.inputs[-1].shape[-1]], + "float32", + name="input"), + ] + return inputs + + def test_check_output_merge0(self): + self.check_output() + + def test_check_output_merge1(self): + self.attrs["merge_each_layer"] = True + self.check_output() + + if __name__ == '__main__': unittest.main() diff --git a/hapi/text/text.py b/hapi/text/text.py index 83327000c99e7091f0f7dc9df0e58349aa43b75b..b5a0cf57321eb230dd8ef3367b3f238ca0d50f64 100644 --- a/hapi/text/text.py +++ b/hapi/text/text.py @@ -49,7 +49,9 @@ __all__ = [ 'BeamSearchDecoder', 'MultiHeadAttention', 'FFN', 'TransformerEncoderLayer', 'TransformerEncoder', 'TransformerDecoderLayer', 'TransformerDecoder', 'TransformerCell', 'TransformerBeamSearchDecoder', - 'LinearChainCRF', 'CRFDecoding', 'SequenceTagging', 'GRUEncoder' + 'LinearChainCRF', 'CRFDecoding', 'SequenceTagging', 'GRUEncoder', + 'StackedLSTMCell', 'LSTM', 'BidirectionalLSTM', 'StackedGRUCell', 'GRU', + 'BidirectionalGRU' ] @@ -241,7 +243,7 @@ class BasicLSTMCell(RNNCell): # TODO(guosheng): find better way to resolve constants in __init__ self._forget_bias = layers.create_global_var( shape=[1], dtype=dtype, value=forget_bias, persistable=True) - self._forget_bias.stop_gradient = False + self._forget_bias.stop_gradient = True self._dtype = dtype self._input_size = input_size @@ -468,9 +470,11 @@ class BasicLSTMCell(RNNCell): new_cell = layers.elementwise_add( layers.elementwise_mul( pre_cell, - layers.sigmoid(layers.elementwise_add(f, self._forget_bias))), - layers.elementwise_mul(layers.sigmoid(i), layers.tanh(j))) - new_hidden = layers.tanh(new_cell) * layers.sigmoid(o) + self._gate_activation( + layers.elementwise_add(f, self._forget_bias))), + layers.elementwise_mul( + self._gate_activation(i), self._activation(j))) + new_hidden = self._activation(new_cell) * self._gate_activation(o) return new_hidden, [new_hidden, new_cell] @@ -1029,7 +1033,7 @@ class TransformerCell(Layer): if self.output_fn is not None: outputs = self.output_fn(outputs) if len(outputs.shape) == 3: - # squeeze to adapt to BeamSearchDecoder which use 2D logits + # squeeze to adapt to BeamSearchDecoder which use 2D logits outputs = layers.squeeze(outputs, [1]) new_states = [{"k": cache["k"], "v": cache["v"]} for cache in states] return outputs, new_states @@ -1037,9 +1041,9 @@ class TransformerCell(Layer): @property def state_shape(self): return [{ - "k": [self.n_head, 0, self.d_key], - "v": [self.n_head, 0, self.d_value], - } for i in range(len(self.n_layer))] + "k": [self.decoder.n_head, 0, self.decoder.d_key], + "v": [self.decoder.n_head, 0, self.decoder.d_value], + } for i in range(len(self.decoder.n_layer))] class TransformerBeamSearchDecoder(layers.BeamSearchDecoder): @@ -1787,10 +1791,8 @@ class GRUEncoder(Layer): grnn_hidden_dim, init_bound, num_layers=1, - h_0=None, is_bidirection=False): super(GRUEncoder, self).__init__() - self.h_0 = h_0 self.num_layers = num_layers self.is_bidirection = is_bidirection self.gru_list = [] @@ -1827,7 +1829,7 @@ class GRUEncoder(Layer): is_reverse=True, time_major=False))) - def forward(self, input_feature): + def forward(self, input_feature, h0=None): for i in range(self.num_layers): pre_gru, pre_state = self.gru_list[i](input_feature) if self.is_bidirection: @@ -1839,18 +1841,16 @@ class GRUEncoder(Layer): return out -class SequenceTagging(fluid.dygraph.Layer): +class SequenceTagging(Layer): def __init__(self, vocab_size, num_labels, - batch_size, word_emb_dim=128, grnn_hidden_dim=128, emb_learning_rate=0.1, crf_learning_rate=0.1, bigru_num=2, - init_bound=0.1, - length=None): + init_bound=0.1): super(SequenceTagging, self).__init__() """ define the sequence tagging network structure @@ -1868,7 +1868,6 @@ class SequenceTagging(fluid.dygraph.Layer): self.emb_lr = emb_learning_rate self.crf_lr = crf_learning_rate self.bigru_num = bigru_num - self.batch_size = batch_size self.init_bound = 0.1 self.word_embedding = Embedding( @@ -1880,20 +1879,11 @@ class SequenceTagging(fluid.dygraph.Layer): initializer=fluid.initializer.Uniform( low=-self.init_bound, high=self.init_bound))) - h_0 = fluid.layers.create_global_var( - shape=[self.batch_size, self.grnn_hidden_dim], - value=0.0, - dtype='float32', - persistable=True, - force_cpu=True, - name='h_0') - self.gru_encoder = GRUEncoder( input_dim=self.grnn_hidden_dim, grnn_hidden_dim=self.grnn_hidden_dim, init_bound=self.init_bound, num_layers=self.bigru_num, - h_0=h_0, is_bidirection=True) self.fc = Linear( @@ -1936,3 +1926,426 @@ class SequenceTagging(fluid.dygraph.Layer): self.linear_chain_crf.weight = self.crf_decoding.weight crf_decode = self.crf_decoding(input=emission, length=lengths) return crf_decode, lengths + + +class StackedRNNCell(RNNCell): + def __init__(self, cells): + self.cells = [] + for i, cell in enumerate(cells): + self.cells.append(self.add_sublayer("cell_%d" % i, cell)) + + def forward(self, inputs, states): + pass + + @staticmethod + def stack_param_attr(param_attr, n): + if isinstance(param_attr, (list, tuple)): + assert len(param_attr) == n, ( + "length of param_attr should be %d when it is a list/tuple" % + n) + param_attrs = [ + fluid.ParamAttr._to_attr(attr) for attr in param_attr + ] + else: + param_attrs = [] + attr = fluid.ParamAttr._to_attr(param_attr) + for i in range(n): + attr_i = copy.deepcopy(attr) + if attr.name: + attr_i.name = attr_i.name + "_" + str(i) + param_attrs.append(attr_i) + return param_attrs + + @property + def state_shape(self): + return [cell.state_shape for cell in self.cells] + + +class StackedLSTMCell(RNNCell): + """ + """ + + def __init__(self, + input_size, + hidden_size, + gate_activation=None, + activation=None, + forget_bias=1.0, + num_layers=1, + dropout=0.0, + param_attr=None, + bias_attr=None, + dtype="float32"): + super(StackedLSTMCell, self).__init__() + self.dropout = utils.convert_to_list(dropout, num_layers, "dropout", + float) + param_attrs = StackedRNNCell.stack_param_attr(param_attr, num_layers) + bias_attrs = StackedRNNCell.stack_param_attr(bias_attr, num_layers) + + self.cells = [] + for i in range(num_layers): + if forget_bias is True: + bias_attrs[ + i].initializer = fluid.initializer.NumpyArrayInitializer( + np.concatenate( + np.zeros(2 * hidden_size), + np.ones(hidden_size), np.zeros(hidden_size)) + .astype(dtype)) + forget_bias = 0.0 + self.cells.append( + self.add_sublayer( + "lstm_%d" % i, + BasicLSTMCell( + input_size=input_size if i == 0 else hidden_size, + hidden_size=hidden_size, + gate_activation=gate_activation, + activation=activation, + forget_bias=forget_bias, + param_attr=param_attrs[i], + bias_attr=bias_attrs[i], + dtype=dtype))) + + def forward(self, step_input, states): + new_states = [] + for i, cell in enumerate(self.cells): + out, new_state = cell(step_input, states[i]) + step_input = layers.dropout( + out, + self.dropout[i], + dropout_implementation='upscale_in_train') if self.dropout[ + i] > 0 else out + new_states.append(new_state) + return step_input, new_states + + @property + def state_shape(self): + return [cell.state_shape for cell in self.cells] + + +class LSTM(Layer): + def __init__(self, + input_size, + hidden_size, + gate_activation=None, + activation=None, + forget_bias=1.0, + num_layers=1, + dropout=0.0, + is_reverse=False, + time_major=False, + param_attr=None, + bias_attr=None, + dtype='float32'): + super(LSTM, self).__init__() + lstm_cell = StackedLSTMCell(input_size, hidden_size, gate_activation, + activation, forget_bias, num_layers, + dropout, param_attr, bias_attr, dtype) + self.lstm = RNN(lstm_cell, is_reverse, time_major) + + def forward(self, inputs, initial_states=None, sequence_length=None): + return self.lstm(inputs, initial_states, sequence_length) + + +class BidirectionalRNN(Layer): + def __init__(self, + cell_fw, + cell_bw, + merge_mode='concat', + time_major=False, + cell_cls=None, + **kwargs): + super(BidirectionalRNN, self).__init__() + self.rnn_fw = RNN(cell_fw, is_reverse=False, time_major=time_major) + self.rnn_bw = RNN(cell_bw, is_reverse=True, time_major=time_major) + if merge_mode == 'concat': + self.merge_func = lambda x, y: layers.concat([x, y], -1) + elif merge_mode == 'sum': + self.merge_func = lambda x, y: layers.elementwise_add(x, y) + elif merge_mode == 'ave': + self.merge_func = lambda x, y: layers.scale( + layers.elementwise_add(x, y), 0.5) + elif merge_mode == 'mul': + self.merge_func = lambda x, y: layers.elementwise_mul(x, y) + elif merge_mode == 'zip': + self.merge_func = lambda x, y: (x, y) + elif merge_mode is None: + self.merge_func = None + else: + raise ValueError('Unsupported value for `merge_mode`: %s' % + merge_mode) + + def forward(self, inputs, initial_states=None, sequence_length=None): + if isinstance(initial_states, (list, tuple)): + assert len( + initial_states + ) == 2, "length of initial_states should be 2 when it is a list/tuple" + else: + initial_states = [initial_states, initial_states] + outputs_fw, states_fw = self.rnn_fw(inputs, initial_states[0], + sequence_length) + outputs_bw, states_bw = self.rnn_bw(inputs, initial_states[1], + sequence_length) + outputs = map_structure( + self.merge_func, outputs_fw, + outputs_bw) if self.merge_func else (outputs_fw, outputs_bw) + return outputs, (states_fw, states_bw) + + @staticmethod + def bidirect_param_attr(param_attr): + if isinstance(param_attr, (list, tuple)): + assert len( + param_attr + ) == 2, "length of param_attr should be 2 when it is a list/tuple" + param_attrs = param_attr + else: + param_attrs = [] + attr = fluid.ParamAttr._to_attr(param_attr) + attr_fw = copy.deepcopy(attr) + if attr.name: + attr_fw.name = attr_fw.name + "_fw" + param_attrs.append(attr_fw) + attr_bw = copy.deepcopy(attr) + if attr.name: + attr_bw.name = attr_bw.name + "_bw" + param_attrs.append(attr_bw) + return param_attrs + + +class BidirectionalLSTM(Layer): + def __init__(self, + input_size, + hidden_size, + gate_activation=None, + activation=None, + forget_bias=1.0, + num_layers=1, + dropout=0.0, + merge_mode='concat', + merge_each_layer=False, + time_major=False, + param_attr=None, + bias_attr=None, + dtype='float32'): + super(BidirectionalLSTM, self).__init__() + self.num_layers = num_layers + self.merge_mode = merge_mode + self.merge_each_layer = merge_each_layer + param_attrs = BidirectionalRNN.bidirect_param_attr(param_attr) + bias_attrs = BidirectionalRNN.bidirect_param_attr(bias_attr) + if not merge_each_layer: + cell_fw = StackedLSTMCell(input_size, hidden_size, gate_activation, + activation, forget_bias, num_layers, + dropout, param_attrs[0], bias_attrs[0], + dtype) + cell_bw = StackedLSTMCell(input_size, hidden_size, gate_activation, + activation, forget_bias, num_layers, + dropout, param_attrs[1], bias_attrs[1], + dtype) + self.lstm = BidirectionalRNN( + cell_fw, cell_bw, merge_mode=merge_mode, time_major=time_major) + else: + fw_param_attrs = StackedRNNCell.stack_param_attr(param_attrs[0], + num_layers) + bw_param_attrs = StackedRNNCell.stack_param_attr(param_attrs[1], + num_layers) + fw_bias_attrs = StackedRNNCell.stack_param_attr(bias_attrs[0], + num_layers) + bw_bias_attrs = StackedRNNCell.stack_param_attr(bias_attrs[1], + num_layers) + + # maybe design cell including both forward and backward later + self.lstm = [] + for i in range(num_layers): + cell_fw = StackedLSTMCell( + input_size if i == 0 else (hidden_size * 2 + if merge_mode == 'concat' else + hidden_size), hidden_size, + gate_activation, activation, forget_bias, 1, dropout, + fw_param_attrs[i], fw_bias_attrs[i], dtype) + cell_bw = StackedLSTMCell( + input_size if i == 0 else (hidden_size * 2 + if merge_mode == 'concat' else + hidden_size), hidden_size, + gate_activation, activation, forget_bias, 1, dropout, + bw_param_attrs[i], bw_bias_attrs[i], dtype) + self.lstm.append( + self.add_sublayer( + "lstm_%d" % i, + BidirectionalRNN( + cell_fw, + cell_bw, + merge_mode=merge_mode, + time_major=time_major))) + + def forward(self, inputs, initial_states=None, sequence_length=None): + if not self.merge_each_layer: + return self.lstm(inputs, initial_states, sequence_length) + else: + if isinstance(initial_states, (list, tuple)): + assert len(initial_states) == self.num_layers, ( + "length of initial_states should be %d when it is a list/tuple" + % self.num_layers) + else: + initial_states = [initial_states] * self.num_layers + stacked_states = [] + for i in range(self.num_layers): + outputs, states = self.lstm[i](inputs, initial_states[i], + sequence_length) + inputs = outputs + stacked_states.append(states) + return outputs, stacked_states + + +class StackedGRUCell(RNNCell): + """ + """ + + def __init__(self, + input_size, + hidden_size, + gate_activation=None, + activation=None, + num_layers=1, + dropout=0.0, + param_attr=None, + bias_attr=None, + dtype="float32"): + super(StackedGRUCell, self).__init__() + self.dropout = utils.convert_to_list(dropout, num_layers, "dropout", + float) + param_attrs = StackedRNNCell.stack_param_attr(param_attr, num_layers) + bias_attrs = StackedRNNCell.stack_param_attr(bias_attr, num_layers) + + self.cells = [] + for i in range(num_layers): + self.cells.append( + self.add_sublayer( + "gru_%d" % i, + BasicGRUCell( + input_size=input_size if i == 0 else hidden_size, + hidden_size=hidden_size, + gate_activation=gate_activation, + activation=activation, + param_attr=param_attrs[i], + bias_attr=bias_attrs[i], + dtype=dtype))) + + def forward(self, step_input, states): + new_states = [] + for i, cell in enumerate(self.cells): + out, new_state = cell(step_input, states[i]) + step_input = layers.dropout( + out, + self.dropout[i], + dropout_implementation='upscale_in_train') if self.dropout[ + i] > 0 else out + new_states.append(new_state) + return step_input, new_states + + @property + def state_shape(self): + return [cell.state_shape for cell in self.cells] + + +class GRU(Layer): + def __init__(self, + input_size, + hidden_size, + gate_activation=None, + activation=None, + num_layers=1, + dropout=0.0, + is_reverse=False, + time_major=False, + param_attr=None, + bias_attr=None, + dtype='float32'): + super(GRU, self).__init__() + gru_cell = StackedGRUCell(input_size, hidden_size, gate_activation, + activation, num_layers, dropout, param_attr, + bias_attr, dtype) + self.gru = RNN(gru_cell, is_reverse, time_major) + + def forward(self, inputs, initial_states=None, sequence_length=None): + return self.gru(inputs, initial_states, sequence_length) + + +class BidirectionalGRU(Layer): + def __init__(self, + input_size, + hidden_size, + gate_activation=None, + activation=None, + forget_bias=1.0, + num_layers=1, + dropout=0.0, + merge_mode='concat', + merge_each_layer=False, + time_major=False, + param_attr=None, + bias_attr=None, + dtype='float32'): + super(BidirectionalGRU, self).__init__() + self.num_layers = num_layers + self.merge_mode = merge_mode + self.merge_each_layer = merge_each_layer + param_attrs = BidirectionalRNN.bidirect_param_attr(param_attr) + bias_attrs = BidirectionalRNN.bidirect_param_attr(bias_attr) + if not merge_each_layer: + cell_fw = StackedGRUCell(input_size, hidden_size, gate_activation, + activation, num_layers, dropout, + param_attrs[0], bias_attrs[0], dtype) + cell_bw = StackedGRUCell(input_size, hidden_size, gate_activation, + activation, num_layers, dropout, + param_attrs[1], bias_attrs[1], dtype) + self.gru = BidirectionalRNN( + cell_fw, cell_bw, merge_mode=merge_mode, time_major=time_major) + else: + fw_param_attrs = StackedRNNCell.stack_param_attr(param_attrs[0], + num_layers) + bw_param_attrs = StackedRNNCell.stack_param_attr(param_attrs[1], + num_layers) + fw_bias_attrs = StackedRNNCell.stack_param_attr(bias_attrs[0], + num_layers) + bw_bias_attrs = StackedRNNCell.stack_param_attr(bias_attrs[1], + num_layers) + + # maybe design cell including both forward and backward later + self.gru = [] + for i in range(num_layers): + cell_fw = StackedGRUCell(input_size if i == 0 else ( + hidden_size * 2 if merge_mode == 'concat' else + hidden_size), hidden_size, gate_activation, activation, 1, + dropout, fw_param_attrs[i], + fw_bias_attrs[i], dtype) + cell_bw = StackedGRUCell(input_size if i == 0 else ( + hidden_size * 2 if merge_mode == 'concat' else + hidden_size), hidden_size, gate_activation, activation, 1, + dropout, bw_param_attrs[i], + bw_bias_attrs[i], dtype) + self.gru.append( + self.add_sublayer( + "gru_%d" % i, + BidirectionalRNN( + cell_fw, + cell_bw, + merge_mode=merge_mode, + time_major=time_major))) + + def forward(self, inputs, initial_states=None, sequence_length=None): + if not self.merge_each_layer: + return self.gru(inputs, initial_states, sequence_length) + else: + if isinstance(initial_states, (list, tuple)): + assert len(initial_states) == self.num_layers, ( + "length of initial_states should be %d when it is a list/tuple" + % self.num_layers) + else: + initial_states = [initial_states] * self.num_layers + stacked_states = [] + for i in range(self.num_layers): + outputs, states = self.gru[i](inputs, initial_states[i], + sequence_length) + inputs = outputs + stacked_states.append(states) + return outputs, stacked_states