From db601f70ccfbadf8b9bbdcb6deb3477c8b3d4466 Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Wed, 17 Jun 2020 17:06:32 +0800 Subject: [PATCH] [Dy2Stat] Add test for dygraph seq2seq model. (#25054) * The arg of append() can be not Tensor temporarily. * Add Seq2Seq as ProgramTranslator Unit Test. * set dtype of vocab_size_tensor to int64 to pass Windows-CI. --- .../dygraph_to_static/list_transformer.py | 38 +- .../seq2seq_dygraph_model.py | 450 ++++++++++++++++++ .../dygraph_to_static/seq2seq_utils.py | 135 ++++++ .../dygraph_to_static/test_seq2seq.py | 172 +++++++ 4 files changed, 777 insertions(+), 18 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/dygraph_to_static/seq2seq_dygraph_model.py create mode 100644 python/paddle/fluid/tests/unittests/dygraph_to_static/seq2seq_utils.py create mode 100644 python/paddle/fluid/tests/unittests/dygraph_to_static/test_seq2seq.py diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py index de9acabe24..03b0de9907 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py @@ -213,27 +213,29 @@ class ListTransformer(gast.NodeTransformer): if value_name not in self.list_name_to_updated: return False - # 3. The arg of append() is one `Tensor` + # 3. The number of arg of append() is one # Only one argument is supported in Python list.append() if len(node.args) != 1: return False - arg = node.args[0] - if isinstance(arg, gast.Name): - # TODO: `arg.id` may be not in scope_var_type_dict if `arg.id` is the arg of decorated function - # Need a better way to confirm whether `arg.id` is a Tensor. - try: - var_type_set = self.scope_var_type_dict[arg.id] - except KeyError: - return False - - if NodeVarType.NUMPY_NDARRAY in var_type_set: - return False - if NodeVarType.TENSOR not in var_type_set and NodeVarType.PADDLE_RETURN_TYPES not in var_type_set: - return False - # else: - # Todo: Consider that `arg` may be a gast.Call about Paddle Api. - # eg: list_a.append(fluid.layers.reshape(x)) - # return True + + # TODO(liym27): The arg of append() should be Tensor. But because the type of arg is often wrong with static analysis, + # the arg is not required to be Tensor here. + # 4. The arg of append() is Tensor + # arg = node.args[0] + # if isinstance(arg, gast.Name): + # # TODO: `arg.id` may be not in scope_var_type_dict if `arg.id` is the arg of decorated function + # # Need a better way to confirm whether `arg.id` is a Tensor. + # try: + # var_type_set = self.scope_var_type_dict[arg.id] + # except KeyError: + # return False + # if NodeVarType.NUMPY_NDARRAY in var_type_set: + # return False + # if NodeVarType.TENSOR not in var_type_set and NodeVarType.PADDLE_RETURN_TYPES not in var_type_set: + # return False + # # TODO: Consider that `arg` may be a gast.Call about Paddle Api. eg: list_a.append(fluid.layers.reshape(x)) + # # else: + # # return True self.list_name_to_updated[value_name.strip()] = True return True diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/seq2seq_dygraph_model.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/seq2seq_dygraph_model.py new file mode 100644 index 0000000000..809e4d51a7 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/seq2seq_dygraph_model.py @@ -0,0 +1,450 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import paddle.fluid as fluid +from paddle.fluid import ParamAttr +from paddle.fluid import layers +from paddle.fluid.dygraph import Layer +from paddle.fluid.dygraph.base import to_variable +from paddle.fluid.dygraph.jit import declarative +from paddle.fluid.dygraph.nn import Embedding +from seq2seq_utils import Seq2SeqModelHyperParams as args + +INF = 1. * 1e5 +alpha = 0.6 +uniform_initializer = lambda x: fluid.initializer.UniformInitializer(low=-x, high=x) +zero_constant = fluid.initializer.Constant(0.0) + + +class BasicLSTMUnit(Layer): + def __init__(self, + hidden_size, + input_size, + param_attr=None, + bias_attr=None, + gate_activation=None, + activation=None, + forget_bias=1.0, + dtype='float32'): + super(BasicLSTMUnit, self).__init__(dtype) + + self._hiden_size = hidden_size + self._param_attr = param_attr + self._bias_attr = bias_attr + self._gate_activation = gate_activation or layers.sigmoid + self._activation = activation or layers.tanh + self._forget_bias = forget_bias + self._dtype = dtype + self._input_size = input_size + + self._weight = self.create_parameter( + attr=self._param_attr, + shape=[self._input_size + self._hiden_size, 4 * self._hiden_size], + dtype=self._dtype) + + self._bias = self.create_parameter( + attr=self._bias_attr, + shape=[4 * self._hiden_size], + dtype=self._dtype, + is_bias=True) + + def forward(self, input, pre_hidden, pre_cell): + concat_input_hidden = layers.concat([input, pre_hidden], 1) + gate_input = layers.matmul(x=concat_input_hidden, y=self._weight) + + gate_input = layers.elementwise_add(gate_input, self._bias) + i, j, f, o = layers.split(gate_input, num_or_sections=4, dim=-1) + new_cell = layers.elementwise_add( + layers.elementwise_mul(pre_cell, + layers.sigmoid(f + self._forget_bias)), + layers.elementwise_mul(layers.sigmoid(i), layers.tanh(j))) + + new_hidden = layers.tanh(new_cell) * layers.sigmoid(o) + + return new_hidden, new_cell + + +class BaseModel(fluid.dygraph.Layer): + def __init__(self, + hidden_size, + src_vocab_size, + tar_vocab_size, + batch_size, + num_layers=1, + init_scale=0.1, + dropout=None, + beam_size=1, + beam_start_token=1, + beam_end_token=2, + beam_max_step_num=2, + mode='train'): + super(BaseModel, self).__init__() + self.hidden_size = hidden_size + self.src_vocab_size = src_vocab_size + self.tar_vocab_size = tar_vocab_size + self.batch_size = batch_size + self.num_layers = num_layers + self.init_scale = init_scale + self.dropout = dropout + self.beam_size = beam_size + self.beam_start_token = beam_start_token + self.beam_end_token = beam_end_token + self.beam_max_step_num = beam_max_step_num + self.mode = mode + self.kinf = 1e9 + + param_attr = ParamAttr(initializer=uniform_initializer(self.init_scale)) + bias_attr = ParamAttr(initializer=zero_constant) + forget_bias = 1.0 + + self.src_embeder = Embedding( + size=[self.src_vocab_size, self.hidden_size], + param_attr=fluid.ParamAttr( + initializer=uniform_initializer(init_scale))) + + self.tar_embeder = Embedding( + size=[self.tar_vocab_size, self.hidden_size], + is_sparse=False, + param_attr=fluid.ParamAttr( + initializer=uniform_initializer(init_scale))) + + self.enc_units = [] + for i in range(num_layers): + self.enc_units.append( + self.add_sublayer( + "enc_units_%d" % i, + BasicLSTMUnit( + hidden_size=self.hidden_size, + input_size=self.hidden_size, + param_attr=param_attr, + bias_attr=bias_attr, + forget_bias=forget_bias))) + + self.dec_units = [] + for i in range(num_layers): + self.dec_units.append( + self.add_sublayer( + "dec_units_%d" % i, + BasicLSTMUnit( + hidden_size=self.hidden_size, + input_size=self.hidden_size, + param_attr=param_attr, + bias_attr=bias_attr, + forget_bias=forget_bias))) + + self.fc = fluid.dygraph.nn.Linear( + self.hidden_size, + self.tar_vocab_size, + param_attr=param_attr, + bias_attr=False) + + def _transpose_batch_time(self, x): + return fluid.layers.transpose(x, [1, 0] + list(range(2, len(x.shape)))) + + def _merge_batch_beams(self, x): + return fluid.layers.reshape(x, shape=(-1, x.shape[2])) + + def _split_batch_beams(self, x): + return fluid.layers.reshape(x, shape=(-1, self.beam_size, x.shape[1])) + + def _expand_to_beam_size(self, x): + x = fluid.layers.unsqueeze(x, [1]) + expand_times = [1] * len(x.shape) + expand_times[1] = self.beam_size + x = fluid.layers.expand(x, expand_times) + return x + + def _real_state(self, state, new_state, step_mask): + new_state = fluid.layers.elementwise_mul(new_state, step_mask, axis=0) - \ + fluid.layers.elementwise_mul(state, (step_mask - 1), axis=0) + return new_state + + def _gather(self, x, indices, batch_pos): + topk_coordinates = fluid.layers.stack([batch_pos, indices], axis=2) + return fluid.layers.gather_nd(x, topk_coordinates) + + @declarative + def forward(self, inputs): + src, tar, label, src_sequence_length, tar_sequence_length = inputs + if src.shape[0] < self.batch_size: + self.batch_size = src.shape[0] + + src_emb = self.src_embeder(self._transpose_batch_time(src)) + + # NOTE: modify model code about `enc_hidden` and `enc_cell` to transforme dygraph code successfully. + # Because nested list can't be transformed now. + enc_hidden_0 = to_variable( + np.zeros( + (self.batch_size, self.hidden_size), dtype='float32')) + enc_cell_0 = to_variable( + np.zeros( + (self.batch_size, self.hidden_size), dtype='float32')) + zero = fluid.layers.zeros(shape=[1], dtype="int64") + enc_hidden = fluid.layers.create_array(dtype="float32") + enc_cell = fluid.layers.create_array(dtype="float32") + for i in range(self.num_layers): + index = zero + i + enc_hidden = fluid.layers.array_write( + enc_hidden_0, index, array=enc_hidden) + enc_cell = fluid.layers.array_write( + enc_cell_0, index, array=enc_cell) + + max_seq_len = src_emb.shape[0] + + enc_len_mask = fluid.layers.sequence_mask( + src_sequence_length, maxlen=max_seq_len, dtype="float32") + enc_len_mask = fluid.layers.transpose(enc_len_mask, [1, 0]) + + # TODO: Because diff exits if call while_loop in static graph. + # In while block, a Variable created in parent block participates in the calculation of gradient, + # the gradient is wrong because each step scope always returns the same value generated by last step. + # NOTE: Replace max_seq_len(Tensor src_emb.shape[0]) with args.max_seq_len(int) to avoid this bug temporarily. + for k in range(args.max_seq_len): + enc_step_input = src_emb[k] + step_mask = enc_len_mask[k] + new_enc_hidden, new_enc_cell = [], [] + for i in range(self.num_layers): + enc_new_hidden, enc_new_cell = self.enc_units[i]( + enc_step_input, enc_hidden[i], enc_cell[i]) + if self.dropout != None and self.dropout > 0.0: + enc_step_input = fluid.layers.dropout( + enc_new_hidden, + dropout_prob=self.dropout, + dropout_implementation='upscale_in_train') + else: + enc_step_input = enc_new_hidden + + new_enc_hidden.append( + self._real_state(enc_hidden[i], enc_new_hidden, step_mask)) + new_enc_cell.append( + self._real_state(enc_cell[i], enc_new_cell, step_mask)) + + enc_hidden, enc_cell = new_enc_hidden, new_enc_cell + + dec_hidden, dec_cell = enc_hidden, enc_cell + tar_emb = self.tar_embeder(self._transpose_batch_time(tar)) + max_seq_len = tar_emb.shape[0] + dec_output = [] + for step_idx in range(max_seq_len): + j = step_idx + 0 + step_input = tar_emb[j] + new_dec_hidden, new_dec_cell = [], [] + for i in range(self.num_layers): + new_hidden, new_cell = self.dec_units[i]( + step_input, dec_hidden[i], dec_cell[i]) + new_dec_hidden.append(new_hidden) + new_dec_cell.append(new_cell) + if self.dropout != None and self.dropout > 0.0: + step_input = fluid.layers.dropout( + new_hidden, + dropout_prob=self.dropout, + dropout_implementation='upscale_in_train') + else: + step_input = new_hidden + dec_output.append(step_input) + + dec_output = fluid.layers.stack(dec_output) + dec_output = self.fc(self._transpose_batch_time(dec_output)) + loss = fluid.layers.softmax_with_cross_entropy( + logits=dec_output, label=label, soft_label=False) + loss = fluid.layers.squeeze(loss, axes=[2]) + max_tar_seq_len = fluid.layers.shape(tar)[1] + tar_mask = fluid.layers.sequence_mask( + tar_sequence_length, maxlen=max_tar_seq_len, dtype='float32') + loss = loss * tar_mask + loss = fluid.layers.reduce_mean(loss, dim=[0]) + loss = fluid.layers.reduce_sum(loss) + + return loss + + @declarative + def beam_search(self, inputs): + src, tar, label, src_sequence_length, tar_sequence_length = inputs + if src.shape[0] < self.batch_size: + self.batch_size = src.shape[0] + + src_emb = self.src_embeder(self._transpose_batch_time(src)) + enc_hidden_0 = to_variable( + np.zeros( + (self.batch_size, self.hidden_size), dtype='float32')) + enc_cell_0 = to_variable( + np.zeros( + (self.batch_size, self.hidden_size), dtype='float32')) + zero = fluid.layers.zeros(shape=[1], dtype="int64") + enc_hidden = fluid.layers.create_array(dtype="float32") + enc_cell = fluid.layers.create_array(dtype="float32") + for j in range(self.num_layers): + index = zero + j + enc_hidden = fluid.layers.array_write( + enc_hidden_0, index, array=enc_hidden) + enc_cell = fluid.layers.array_write( + enc_cell_0, index, array=enc_cell) + + max_seq_len = src_emb.shape[0] + + enc_len_mask = fluid.layers.sequence_mask( + src_sequence_length, maxlen=max_seq_len, dtype="float32") + enc_len_mask = fluid.layers.transpose(enc_len_mask, [1, 0]) + + for k in range(args.max_seq_len): + enc_step_input = src_emb[k] + step_mask = enc_len_mask[k] + + new_enc_hidden, new_enc_cell = [], [] + + for i in range(self.num_layers): + enc_new_hidden, enc_new_cell = self.enc_units[i]( + enc_step_input, enc_hidden[i], enc_cell[i]) + if self.dropout != None and self.dropout > 0.0: + enc_step_input = fluid.layers.dropout( + enc_new_hidden, + dropout_prob=self.dropout, + dropout_implementation='upscale_in_train') + else: + enc_step_input = enc_new_hidden + + new_enc_hidden.append( + self._real_state(enc_hidden[i], enc_new_hidden, step_mask)) + new_enc_cell.append( + self._real_state(enc_cell[i], enc_new_cell, step_mask)) + + enc_hidden, enc_cell = new_enc_hidden, new_enc_cell + + # beam search + batch_beam_shape = (self.batch_size, self.beam_size) + vocab_size_tensor = to_variable( + np.full((1), self.tar_vocab_size).astype("int64")) + start_token_tensor = to_variable( + np.full( + batch_beam_shape, self.beam_start_token, dtype='int64')) + end_token_tensor = to_variable( + np.full( + batch_beam_shape, self.beam_end_token, dtype='int64')) + step_input = self.tar_embeder(start_token_tensor) + beam_finished = to_variable( + np.full( + batch_beam_shape, 0, dtype='float32')) + beam_state_log_probs = to_variable( + np.array( + [[0.] + [-self.kinf] * (self.beam_size - 1)], dtype="float32")) + beam_state_log_probs = fluid.layers.expand(beam_state_log_probs, + [self.batch_size, 1]) + dec_hidden, dec_cell = enc_hidden, enc_cell + dec_hidden = [self._expand_to_beam_size(ele) for ele in dec_hidden] + dec_cell = [self._expand_to_beam_size(ele) for ele in dec_cell] + + batch_pos = fluid.layers.expand( + fluid.layers.unsqueeze( + to_variable(np.arange( + 0, self.batch_size, 1, dtype="int64")), [1]), + [1, self.beam_size]) + predicted_ids = [] + parent_ids = [] + + for step_idx in range(self.beam_max_step_num): + if fluid.layers.reduce_sum(1 - beam_finished).numpy()[0] == 0: + break + step_input = self._merge_batch_beams(step_input) + new_dec_hidden, new_dec_cell = [], [] + state = 0 + dec_hidden = [ + self._merge_batch_beams(state) for state in dec_hidden + ] + dec_cell = [self._merge_batch_beams(state) for state in dec_cell] + + for i in range(self.num_layers): + new_hidden, new_cell = self.dec_units[i]( + step_input, dec_hidden[i], dec_cell[i]) + new_dec_hidden.append(new_hidden) + new_dec_cell.append(new_cell) + if self.dropout != None and self.dropout > 0.0: + step_input = fluid.layers.dropout( + new_hidden, + dropout_prob=self.dropout, + dropout_implementation='upscale_in_train') + else: + step_input = new_hidden + cell_outputs = self._split_batch_beams(step_input) + cell_outputs = self.fc(cell_outputs) + + step_log_probs = fluid.layers.log( + fluid.layers.softmax(cell_outputs)) + noend_array = [-self.kinf] * self.tar_vocab_size + noend_array[self.beam_end_token] = 0 + noend_mask_tensor = to_variable( + np.array( + noend_array, dtype='float32')) + + step_log_probs = fluid.layers.elementwise_mul( + fluid.layers.expand(fluid.layers.unsqueeze(beam_finished, [2]), [1, 1, self.tar_vocab_size]), + noend_mask_tensor, axis=-1) - \ + fluid.layers.elementwise_mul(step_log_probs, (beam_finished - 1), axis=0) + log_probs = fluid.layers.elementwise_add( + x=step_log_probs, y=beam_state_log_probs, axis=0) + scores = fluid.layers.reshape( + log_probs, [-1, self.beam_size * self.tar_vocab_size]) + topk_scores, topk_indices = fluid.layers.topk( + input=scores, k=self.beam_size) + + beam_indices = fluid.layers.elementwise_floordiv(topk_indices, + vocab_size_tensor) + token_indices = fluid.layers.elementwise_mod(topk_indices, + vocab_size_tensor) + next_log_probs = self._gather(scores, topk_indices, batch_pos) + + x = 0 + new_dec_hidden = [ + self._split_batch_beams(state) for state in new_dec_hidden + ] + new_dec_cell = [ + self._split_batch_beams(state) for state in new_dec_cell + ] + new_dec_hidden = [ + self._gather(x, beam_indices, batch_pos) for x in new_dec_hidden + ] + new_dec_cell = [ + self._gather(x, beam_indices, batch_pos) for x in new_dec_cell + ] + + new_dec_hidden = [ + self._gather(x, beam_indices, batch_pos) for x in new_dec_hidden + ] + new_dec_cell = [ + self._gather(x, beam_indices, batch_pos) for x in new_dec_cell + ] + next_finished = self._gather(beam_finished, beam_indices, batch_pos) + next_finished = fluid.layers.cast(next_finished, "bool") + next_finished = fluid.layers.logical_or( + next_finished, + fluid.layers.equal(token_indices, end_token_tensor)) + next_finished = fluid.layers.cast(next_finished, "float32") + + dec_hidden, dec_cell = new_dec_hidden, new_dec_cell + beam_finished = next_finished + beam_state_log_probs = next_log_probs + step_input = self.tar_embeder(token_indices) + predicted_ids.append(token_indices) + parent_ids.append(beam_indices) + + predicted_ids = fluid.layers.stack(predicted_ids) + parent_ids = fluid.layers.stack(parent_ids) + predicted_ids = fluid.layers.gather_tree(predicted_ids, parent_ids) + predicted_ids = self._transpose_batch_time(predicted_ids) + return predicted_ids diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/seq2seq_utils.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/seq2seq_utils.py new file mode 100644 index 0000000000..7f9766c535 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/seq2seq_utils.py @@ -0,0 +1,135 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +SEED = 2020 + + +def build_fake_sentence(seed): + random = np.random.RandomState(seed) + sentence_len = random.randint(5, 15) + token_ids = [random.randint(0, 1000) for _ in range(sentence_len - 1)] + return token_ids + + +def get_data_iter(batch_size, mode='train', cache_num=20): + + self_random = np.random.RandomState(SEED) + + def to_pad_np(data, source=False): + max_len = 0 + bs = min(batch_size, len(data)) + for ele in data: + if len(ele) > max_len: + max_len = len(ele) + + ids = np.ones((bs, max_len), dtype='int64') * 2 + mask = np.zeros((bs), dtype='int32') + + for i, ele in enumerate(data): + ids[i, :len(ele)] = ele + if not source: + mask[i] = len(ele) - 1 + else: + mask[i] = len(ele) + + return ids, mask + + b_src = [] + + if mode != "train": + cache_num = 1 + data_len = 1000 + for j in range(data_len): + if len(b_src) == batch_size * cache_num: + if mode == 'infer': + new_cache = b_src + else: + new_cache = sorted(b_src, key=lambda k: len(k[0])) + + for i in range(cache_num): + batch_data = new_cache[i * batch_size:(i + 1) * batch_size] + src_cache = [w[0] for w in batch_data] + tar_cache = [w[1] for w in batch_data] + src_ids, src_mask = to_pad_np(src_cache, source=True) + tar_ids, tar_mask = to_pad_np(tar_cache) + yield (src_ids, src_mask, tar_ids, tar_mask) + + b_src = [] + src_seed = self_random.randint(0, data_len) + tar_seed = self_random.randint(0, data_len) + src_data = build_fake_sentence(src_seed) + tar_data = build_fake_sentence(tar_seed) + b_src.append((src_data, tar_data)) + + if len(b_src) == batch_size * cache_num or mode == 'infer': + if mode == 'infer': + new_cache = b_src + else: + new_cache = sorted(b_src, key=lambda k: len(k[0])) + + for i in range(cache_num): + batch_end = min(len(new_cache), (i + 1) * batch_size) + batch_data = new_cache[i * batch_size:batch_end] + src_cache = [w[0] for w in batch_data] + tar_cache = [w[1] for w in batch_data] + src_ids, src_mask = to_pad_np(src_cache, source=True) + tar_ids, tar_mask = to_pad_np(tar_cache) + yield (src_ids, src_mask, tar_ids, tar_mask) + + +class Seq2SeqModelHyperParams(object): + # Whether use attention model + attention = False + + # learning rate for optimizer + learning_rate = 0.01 + + # layers number of encoder and decoder + num_layers = 2 + + # hidden size of encoder and decoder + hidden_size = 8 + + src_vocab_size = 1000 + tar_vocab_size = 1000 + batch_size = 8 + max_epoch = 12 + + # max length for source and target sentence + max_len = 30 + + # drop probability + dropout = 0.0 + + # init scale for parameter + init_scale = 0.1 + + # max grad norm for global norm clip + max_grad_norm = 5.0 + + # model path for model to save + model_path = "dy2stat/model/seq2seq" + + # reload model to inference + reload_model = "model/epoch_0.pdparams" + + beam_size = 10 + + max_seq_len = 3 diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_seq2seq.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_seq2seq.py new file mode 100644 index 0000000000..c44b5375d2 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_seq2seq.py @@ -0,0 +1,172 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +import unittest + +import numpy as np +import paddle.fluid as fluid +from paddle.fluid.clip import GradientClipByGlobalNorm +from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator + +from seq2seq_dygraph_model import BaseModel +from seq2seq_utils import Seq2SeqModelHyperParams as args +from seq2seq_utils import get_data_iter +place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace( +) +program_translator = ProgramTranslator() +STEP_NUM = 10 +PRINT_STEP = 2 + + +def prepare_input(batch): + src_ids, src_mask, tar_ids, tar_mask = batch + src_ids = src_ids.reshape((src_ids.shape[0], src_ids.shape[1])) + in_tar = tar_ids[:, :-1] + label_tar = tar_ids[:, 1:] + + in_tar = in_tar.reshape((in_tar.shape[0], in_tar.shape[1])) + label_tar = label_tar.reshape((label_tar.shape[0], label_tar.shape[1], 1)) + inputs = [src_ids, in_tar, label_tar, src_mask, tar_mask] + return inputs, np.sum(tar_mask) + + +def train(): + with fluid.dygraph.guard(place): + fluid.default_startup_program().random_seed = 2020 + fluid.default_main_program().random_seed = 2020 + + model = BaseModel( + args.hidden_size, + args.src_vocab_size, + args.tar_vocab_size, + args.batch_size, + num_layers=args.num_layers, + init_scale=args.init_scale, + dropout=args.dropout) + + gloabl_norm_clip = GradientClipByGlobalNorm(args.max_grad_norm) + optimizer = fluid.optimizer.SGD(args.learning_rate, + parameter_list=model.parameters(), + grad_clip=gloabl_norm_clip) + + model.train() + train_data_iter = get_data_iter(args.batch_size) + + batch_times = [] + for batch_id, batch in enumerate(train_data_iter): + total_loss = 0 + word_count = 0.0 + batch_start_time = time.time() + input_data_feed, word_num = prepare_input(batch) + input_data_feed = [ + fluid.dygraph.to_variable(np_inp) for np_inp in input_data_feed + ] + word_count += word_num + loss = model(input_data_feed) + loss.backward() + optimizer.minimize(loss) + model.clear_gradients() + total_loss += loss * args.batch_size + batch_end_time = time.time() + batch_time = batch_end_time - batch_start_time + batch_times.append(batch_time) + if batch_id % PRINT_STEP == 0: + print( + "Batch:[%d]; Time: %.5f s; loss: %.5f; total_loss: %.5f; word num: %.5f; ppl: %.5f" + % (batch_id, batch_time, loss.numpy(), total_loss.numpy(), + word_count, np.exp(total_loss.numpy() / word_count))) + if batch_id + 1 >= STEP_NUM: + break + model_dir = os.path.join(args.model_path) + if not os.path.exists(model_dir): + os.makedirs(model_dir) + fluid.save_dygraph(model.state_dict(), model_dir) + return loss.numpy() + + +def infer(): + with fluid.dygraph.guard(place): + model = BaseModel( + args.hidden_size, + args.src_vocab_size, + args.tar_vocab_size, + args.batch_size, + beam_size=args.beam_size, + num_layers=args.num_layers, + init_scale=args.init_scale, + dropout=0.0, + mode='beam_search') + state_dict, _ = fluid.dygraph.load_dygraph(args.model_path) + model.set_dict(state_dict) + model.eval() + train_data_iter = get_data_iter(args.batch_size, mode='infer') + batch_times = [] + for batch_id, batch in enumerate(train_data_iter): + batch_start_time = time.time() + input_data_feed, word_num = prepare_input(batch) + input_data_feed = [ + fluid.dygraph.to_variable(np_inp) for np_inp in input_data_feed + ] + outputs = model.beam_search(input_data_feed) + batch_end_time = time.time() + batch_time = batch_end_time - batch_start_time + batch_times.append(batch_time) + if batch_id > STEP_NUM: + break + + return outputs.numpy() + + +class TestSeq2seq(unittest.TestCase): + def run_dygraph(self, mode="train"): + program_translator.enable(False) + if mode == "train": + return train() + else: + return infer() + + def run_static(self, mode="train"): + program_translator.enable(True) + if mode == "train": + return train() + else: + return infer() + + def _test_train(self): + dygraph_loss = self.run_dygraph(mode="train") + static_loss = self.run_static(mode="train") + result = np.allclose(dygraph_loss, static_loss) + self.assertTrue( + result, + msg="\ndygraph_loss = {} \nstatic_loss = {}".format(dygraph_loss, + static_loss)) + + def _test_predict(self): + pred_dygraph = self.run_dygraph(mode="test") + pred_static = self.run_static(mode="test") + result = np.allclose(pred_static, pred_dygraph) + self.assertTrue( + result, + msg="\npred_dygraph = {} \npred_static = {}".format(pred_dygraph, + pred_static)) + + def test_check_result(self): + self._test_train() + self._test_predict() + + +if __name__ == '__main__': + unittest.main() -- GitLab