diff --git a/dygraph/transformer/model.py b/dygraph/transformer/model.py index c17d8ba38abed0a464fadbc5e305a06e8356d95c..55d3aca75924e071708ec56d9ed9b31961e30a01 100644 --- a/dygraph/transformer/model.py +++ b/dygraph/transformer/model.py @@ -18,7 +18,7 @@ import numpy as np import paddle.fluid as fluid import paddle.fluid.layers as layers -from paddle.fluid.dygraph import Embedding, LayerNorm, FC, to_variable, Layer, guard +from paddle.fluid.dygraph import Embedding, LayerNorm, Linear, to_variable, Layer, guard from paddle.fluid.dygraph.learning_rate_scheduler import LearningRateDecay from config import word_emb_param_names, pos_enc_param_names @@ -71,13 +71,12 @@ class PrePostProcessLayer(Layer): """ PrePostProcessLayer """ - def __init__(self, name_scope, process_cmd, shape_len=None): - super(PrePostProcessLayer, self).__init__(name_scope) + def __init__(self, process_cmd, normalized_shape=None): + super(PrePostProcessLayer, self).__init__() for cmd in process_cmd: if cmd == "n": self._layer_norm = LayerNorm( - name_scope=self.full_name(), - begin_norm_axis=shape_len - 1, + normalized_shape = normalized_shape, param_attr=fluid.ParamAttr( initializer=fluid.initializer.Constant(1.)), bias_attr=fluid.ParamAttr( @@ -109,15 +108,13 @@ class PositionwiseFeedForwardLayer(Layer): """ PositionwiseFeedForwardLayer """ - def __init__(self, name_scope, d_inner_hid, d_hid, dropout_rate): - super(PositionwiseFeedForwardLayer, self).__init__(name_scope) - self._i2h = FC(name_scope=self.full_name(), - size=d_inner_hid, - num_flatten_dims=2, + def __init__(self, input_hid, d_inner_hid, d_hid, dropout_rate): + super(PositionwiseFeedForwardLayer, self).__init__() + self._i2h = Linear( input_dim= input_hid, + output_dim=d_inner_hid, act="relu") - self._h2o = FC(name_scope=self.full_name(), - size=d_hid, - num_flatten_dims=2) + self._h2o = Linear( input_dim = d_inner_hid, + output_dim=d_hid) self._dropout_rate = dropout_rate def forward(self, x): @@ -140,7 +137,6 @@ class MultiHeadAttentionLayer(Layer): MultiHeadAttentionLayer """ def __init__(self, - name_scope, d_key, d_value, d_model, @@ -149,28 +145,24 @@ class MultiHeadAttentionLayer(Layer): cache=None, gather_idx=None, static_kv=False): - super(MultiHeadAttentionLayer, self).__init__(name_scope) + super(MultiHeadAttentionLayer, self).__init__() self._n_head = n_head self._d_key = d_key self._d_value = d_value self._d_model = d_model self._dropout_rate = dropout_rate - self._q_fc = FC(name_scope=self.full_name(), - size=d_key * n_head, - bias_attr=False, - num_flatten_dims=2) - self._k_fc = FC(name_scope=self.full_name(), - size=d_key * n_head, - bias_attr=False, - num_flatten_dims=2) - self._v_fc = FC(name_scope=self.full_name(), - size=d_value * n_head, - bias_attr=False, - num_flatten_dims=2) - self._proj_fc = FC(name_scope=self.full_name(), - size=self._d_model, - bias_attr=False, - num_flatten_dims=2) + self._q_fc = Linear( input_dim = d_model, + output_dim=d_key * n_head, + bias_attr=False ) + self._k_fc = Linear( input_dim = d_model, + output_dim=d_key * n_head, + bias_attr=False ) + self._v_fc = Linear( input_dim = d_model, + output_dim=d_value * n_head, + bias_attr=False ) + self._proj_fc = Linear( input_dim = d_model, + output_dim=self._d_model, + bias_attr=False ) def forward(self, queries, @@ -194,18 +186,18 @@ class MultiHeadAttentionLayer(Layer): q = self._q_fc(queries) k = self._k_fc(keys) v = self._v_fc(values) - + # split head reshaped_q = layers.reshape(x=q, - shape=[0, 0, self._n_head, self._d_key], + shape=[ q.shape[0], q.shape[1], self._n_head, self._d_key], inplace=False) transpose_q = layers.transpose(x=reshaped_q, perm=[0, 2, 1, 3]) reshaped_k = layers.reshape(x=k, - shape=[0, 0, self._n_head, self._d_key], + shape=[ k.shape[0], k.shape[1], self._n_head, self._d_key], inplace=False) transpose_k = layers.transpose(x=reshaped_k, perm=[0, 2, 1, 3]) reshaped_v = layers.reshape(x=v, - shape=[0, 0, self._n_head, self._d_value], + shape=[ v.shape[0], v.shape[1], self._n_head, self._d_value], inplace=False) transpose_v = layers.transpose(x=reshaped_v, perm=[0, 2, 1, 3]) @@ -250,7 +242,6 @@ class EncoderSubLayer(Layer): EncoderSubLayer """ def __init__(self, - name_scope, n_head, d_key, d_value, @@ -262,24 +253,21 @@ class EncoderSubLayer(Layer): preprocess_cmd="n", postprocess_cmd="da"): - super(EncoderSubLayer, self).__init__(name_scope) + super(EncoderSubLayer, self).__init__() self._preprocess_cmd = preprocess_cmd self._postprocess_cmd = postprocess_cmd self._prepostprocess_dropout = prepostprocess_dropout - self._preprocess_layer = PrePostProcessLayer(self.full_name(), - self._preprocess_cmd, 3) + self._preprocess_layer = PrePostProcessLayer(self._preprocess_cmd, [d_model]) self._multihead_attention_layer = MultiHeadAttentionLayer( - self.full_name(), d_key, d_value, d_model, n_head, + d_key, d_value, d_model, n_head, attention_dropout) - self._postprocess_layer = PrePostProcessLayer(self.full_name(), - self._postprocess_cmd, + self._postprocess_layer = PrePostProcessLayer(self._postprocess_cmd, None) - self._preprocess_layer2 = PrePostProcessLayer(self.full_name(), - self._preprocess_cmd, 3) + self._preprocess_layer2 = PrePostProcessLayer(self._preprocess_cmd, [d_model]) self._positionwise_feed_forward = PositionwiseFeedForwardLayer( - self.full_name(), d_inner_hid, d_model, relu_dropout) - self._postprocess_layer2 = PrePostProcessLayer(self.full_name(), + d_model, d_inner_hid, d_model, relu_dropout) + self._postprocess_layer2 = PrePostProcessLayer( self._postprocess_cmd, None) @@ -311,7 +299,6 @@ class EncoderLayer(Layer): encoder """ def __init__(self, - name_scope, n_layer, n_head, d_key, @@ -324,18 +311,18 @@ class EncoderLayer(Layer): preprocess_cmd="n", postprocess_cmd="da"): - super(EncoderLayer, self).__init__(name_scope) + super(EncoderLayer, self).__init__() self._preprocess_cmd = preprocess_cmd self._encoder_sublayers = list() self._prepostprocess_dropout = prepostprocess_dropout self._n_layer = n_layer - self._preprocess_layer = PrePostProcessLayer(self.full_name(), - self._preprocess_cmd, 3) + self._preprocess_layer = PrePostProcessLayer( + self._preprocess_cmd, [d_model]) for i in range(n_layer): self._encoder_sublayers.append( self.add_sublayer( 'esl_%d' % i, - EncoderSubLayer(self.full_name(), n_head, d_key, d_value, + EncoderSubLayer(n_head, d_key, d_value, d_model, d_inner_hid, prepostprocess_dropout, attention_dropout, relu_dropout, preprocess_cmd, @@ -361,20 +348,18 @@ class PrepareEncoderDecoderLayer(Layer): PrepareEncoderDecoderLayer """ def __init__(self, - name_scope, src_vocab_size, src_emb_dim, src_max_len, dropout_rate, word_emb_param_name=None, pos_enc_param_name=None): - super(PrepareEncoderDecoderLayer, self).__init__(name_scope) + super(PrepareEncoderDecoderLayer, self).__init__() self._src_max_len = src_max_len self._src_emb_dim = src_emb_dim self._src_vocab_size = src_vocab_size self._dropout_rate = dropout_rate - self._input_emb = Embedding(name_scope=self.full_name(), - size=[src_vocab_size, src_emb_dim], + self._input_emb = Embedding(size=[src_vocab_size, src_emb_dim], padding_idx=0, param_attr=fluid.ParamAttr( name=word_emb_param_name, @@ -383,7 +368,6 @@ class PrepareEncoderDecoderLayer(Layer): pos_inp = position_encoding_init(src_max_len, src_emb_dim) self._pos_emb = Embedding( - name_scope=self.full_name(), size=[self._src_max_len, src_emb_dim], param_attr=fluid.ParamAttr( name=pos_enc_param_name, @@ -411,6 +395,7 @@ class PrepareEncoderDecoderLayer(Layer): src_pos_emb = self._pos_emb(src_pos) src_pos_emb.stop_gradient = True enc_input = src_word_emb + src_pos_emb + enc_input = layers.reshape( enc_input, shape=[ enc_input.shape[0], enc_input.shape[1], -1]) return layers.dropout( enc_input, dropout_prob=self._dropout_rate, is_test=False) if self._dropout_rate else enc_input @@ -420,24 +405,23 @@ class WrapEncoderLayer(Layer): """ encoderlayer """ - def __init__(self, name_cope, src_vocab_size, max_length, n_layer, n_head, + def __init__(self, src_vocab_size, max_length, n_layer, n_head, d_key, d_value, d_model, d_inner_hid, prepostprocess_dropout, attention_dropout, relu_dropout, preprocess_cmd, postprocess_cmd, weight_sharing): """ The wrapper assembles together all needed layers for the encoder. """ - super(WrapEncoderLayer, self).__init__(name_cope) + super(WrapEncoderLayer, self).__init__() self._prepare_encoder_layer = PrepareEncoderDecoderLayer( - self.full_name(), src_vocab_size, d_model, max_length, prepostprocess_dropout, word_emb_param_name=word_emb_param_names[0], pos_enc_param_name=pos_enc_param_names[0]) - self._encoder = EncoderLayer(self.full_name(), n_layer, n_head, d_key, + self._encoder = EncoderLayer(n_layer, n_head, d_key, d_value, d_model, d_inner_hid, prepostprocess_dropout, attention_dropout, relu_dropout, preprocess_cmd, @@ -455,32 +439,32 @@ class DecoderSubLayer(Layer): """ decoder """ - def __init__(self, name_scope, n_head, d_key, d_value, d_model, d_inner_hid, + def __init__(self, n_head, d_key, d_value, d_model, d_inner_hid, prepostprocess_dropout, attention_dropout, relu_dropout, preprocess_cmd, postprocess_cmd): - super(DecoderSubLayer, self).__init__(name_scope) + super(DecoderSubLayer, self).__init__() self._postprocess_cmd = postprocess_cmd self._preprocess_cmd = preprocess_cmd self._prepostprcess_dropout = prepostprocess_dropout - self._pre_process_layer = PrePostProcessLayer(self.full_name(), - preprocess_cmd, 3) + self._pre_process_layer = PrePostProcessLayer( + preprocess_cmd, [d_model]) self._multihead_attention_layer = MultiHeadAttentionLayer( - self.full_name(), d_key, d_value, d_model, n_head, + d_key, d_value, d_model, n_head, attention_dropout) - self._post_process_layer = PrePostProcessLayer(self.full_name(), + self._post_process_layer = PrePostProcessLayer( postprocess_cmd, None) - self._pre_process_layer2 = PrePostProcessLayer(self.full_name(), - preprocess_cmd, 3) + self._pre_process_layer2 = PrePostProcessLayer( + preprocess_cmd, [d_model]) self._multihead_attention_layer2 = MultiHeadAttentionLayer( - self.full_name(), d_key, d_value, d_model, n_head, + d_key, d_value, d_model, n_head, attention_dropout) - self._post_process_layer2 = PrePostProcessLayer(self.full_name(), - postprocess_cmd, None) - self._pre_process_layer3 = PrePostProcessLayer(self.full_name(), - preprocess_cmd, 3) + self._post_process_layer2 = PrePostProcessLayer( + postprocess_cmd, [d_model]) + self._pre_process_layer3 = PrePostProcessLayer( + preprocess_cmd, [d_model]) self._positionwise_feed_forward_layer = PositionwiseFeedForwardLayer( - self.full_name(), d_inner_hid, d_model, relu_dropout) - self._post_process_layer3 = PrePostProcessLayer(self.full_name(), + d_model, d_inner_hid, d_model, relu_dropout) + self._post_process_layer3 = PrePostProcessLayer( postprocess_cmd, None) def forward(self, @@ -529,12 +513,11 @@ class DecoderLayer(Layer): """ decoder """ - def __init__(self, name_scope, n_layer, n_head, d_key, d_value, d_model, + def __init__(self, n_layer, n_head, d_key, d_value, d_model, d_inner_hid, prepostprocess_dropout, attention_dropout, relu_dropout, preprocess_cmd, postprocess_cmd): - super(DecoderLayer, self).__init__(name_scope) - self._pre_process_layer = PrePostProcessLayer(self.full_name(), - preprocess_cmd, 3) + super(DecoderLayer, self).__init__() + self._pre_process_layer = PrePostProcessLayer(preprocess_cmd, [d_model]) self._decoder_sub_layers = list() self._n_layer = n_layer self._preprocess_cmd = preprocess_cmd @@ -543,7 +526,7 @@ class DecoderLayer(Layer): self._decoder_sub_layers.append( self.add_sublayer( 'dsl_%d' % i, - DecoderSubLayer(self.full_name(), n_head, d_key, d_value, + DecoderSubLayer( n_head, d_key, d_value, d_model, d_inner_hid, prepostprocess_dropout, attention_dropout, relu_dropout, preprocess_cmd, @@ -581,7 +564,6 @@ class WrapDecoderLayer(Layer): decoder """ def __init__(self, - name_scope, trg_vocab_size, max_length, n_layer, @@ -600,25 +582,24 @@ class WrapDecoderLayer(Layer): """ The wrapper assembles together all needed layers for the encoder. """ - super(WrapDecoderLayer, self).__init__(name_scope) + super(WrapDecoderLayer, self).__init__() self._prepare_decoder_layer = PrepareEncoderDecoderLayer( - self.full_name(), trg_vocab_size, d_model, max_length, prepostprocess_dropout, word_emb_param_name=word_emb_param_names[1], pos_enc_param_name=pos_enc_param_names[1]) - self._decoder_layer = DecoderLayer(self.full_name(), n_layer, n_head, + self._decoder_layer = DecoderLayer(n_layer, n_head, d_key, d_value, d_model, d_inner_hid, prepostprocess_dropout, attention_dropout, relu_dropout, preprocess_cmd, postprocess_cmd) self._weight_sharing = weight_sharing if not weight_sharing: - self._fc = FC(self.full_name(), - size=trg_vocab_size, + self._fc = Linear(input_dim = d_model, + output_dim=trg_vocab_size, bias_attr=False) def forward(self, dec_inputs, enc_output, caches=None, gather_idx=None): @@ -657,7 +638,6 @@ class TransFormer(Layer): model """ def __init__(self, - name_scope, src_vocab_size, trg_vocab_size, max_length, @@ -674,7 +654,7 @@ class TransFormer(Layer): postprocess_cmd, weight_sharing, label_smooth_eps=0.0): - super(TransFormer, self).__init__(name_scope) + super(TransFormer, self).__init__() self._label_smooth_eps = label_smooth_eps self._trg_vocab_size = trg_vocab_size if weight_sharing: @@ -682,12 +662,12 @@ class TransFormer(Layer): "Vocabularies in source and target should be same for weight sharing." ) self._wrap_encoder_layer = WrapEncoderLayer( - self.full_name(), src_vocab_size, max_length, n_layer, n_head, + src_vocab_size, max_length, n_layer, n_head, d_key, d_value, d_model, d_inner_hid, prepostprocess_dropout, attention_dropout, relu_dropout, preprocess_cmd, postprocess_cmd, weight_sharing) self._wrap_decoder_layer = WrapDecoderLayer( - self.full_name(), trg_vocab_size, max_length, n_layer, n_head, + trg_vocab_size, max_length, n_layer, n_head, d_key, d_value, d_model, d_inner_hid, prepostprocess_dropout, attention_dropout, relu_dropout, preprocess_cmd, postprocess_cmd, weight_sharing) @@ -869,18 +849,25 @@ class TransFormer(Layer): topk_scores, topk_ids = layers.topk(flat_curr_scores, k=beam_size * 2) + print( "topk ids", topk_ids) topk_log_probs = topk_scores * length_penalty topk_beam_index = topk_ids // self._trg_vocab_size topk_ids = topk_ids % self._trg_vocab_size + + print( "topk ids2", topk_ids) + # use gather as gather_nd, TODO: use gather_nd topk_seq = gather_2d_by_gather(alive_seq, topk_beam_index, beam_size, batch_size) + + print( "topk ids", topk_ids ) + reshape_temp = layers.reshape(topk_ids, topk_ids.shape + [1]) topk_seq = layers.concat( [topk_seq, - layers.reshape(topk_ids, topk_ids.shape + [1])], + reshape_temp], axis=2) states = update_states(states, topk_beam_index, beam_size) eos = layers.fill_constant(shape=topk_ids.shape, diff --git a/dygraph/transformer/predict.py b/dygraph/transformer/predict.py index c4da56ee194b6ee9f900a357bee7fc4457a0462d..5b64bc6dd903f55447281bd29f4d1542f2c8a167 100644 --- a/dygraph/transformer/predict.py +++ b/dygraph/transformer/predict.py @@ -62,9 +62,9 @@ def prepare_infer_input(insts, src_pad_idx, bos_idx, n_head): trg_word = np.asarray([[bos_idx]] * len(insts), dtype="int64") trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :], [1, 1, 1, 1]).astype("float32") - trg_word = trg_word.reshape(-1, 1, 1) - src_word = src_word.reshape(-1, src_max_len, 1) - src_pos = src_pos.reshape(-1, src_max_len, 1) + trg_word = trg_word.reshape(-1, 1, 1 ) + src_word = src_word.reshape(-1, src_max_len, 1 ) + src_pos = src_pos.reshape(-1, src_max_len,1 ) data_inputs = [ src_word, src_pos, src_slf_attn_bias, trg_word, trg_src_attn_bias @@ -101,7 +101,7 @@ def infer(args): if args.use_data_parallel else fluid.CUDAPlace(0) with fluid.dygraph.guard(place): transformer = TransFormer( - 'transformer', ModelHyperParams.src_vocab_size, + ModelHyperParams.src_vocab_size, ModelHyperParams.trg_vocab_size, ModelHyperParams.max_length + 1, ModelHyperParams.n_layer, ModelHyperParams.n_head, ModelHyperParams.d_key, ModelHyperParams.d_value, @@ -129,7 +129,8 @@ def infer(args): enc_inputs, dec_inputs = prepare_infer_input( batch, ModelHyperParams.eos_idx, ModelHyperParams.bos_idx, ModelHyperParams.n_head) - + + print( "enc inputs", enc_inputs[0].shape ) finished_seq, finished_scores = transformer.beam_search( enc_inputs, dec_inputs, diff --git a/dygraph/transformer/train.py b/dygraph/transformer/train.py index 58ee77e36b4396a21a4c7a9e36dff57a9adc7925..e043bca19edcee16f1f6bdf162bc5eaf569f9968 100644 --- a/dygraph/transformer/train.py +++ b/dygraph/transformer/train.py @@ -110,7 +110,7 @@ def train(args): # define model transformer = TransFormer( - 'transformer', ModelHyperParams.src_vocab_size, + ModelHyperParams.src_vocab_size, ModelHyperParams.trg_vocab_size, ModelHyperParams.max_length + 1, ModelHyperParams.n_layer, ModelHyperParams.n_head, ModelHyperParams.d_key, ModelHyperParams.d_value, @@ -123,6 +123,7 @@ def train(args): optimizer = fluid.optimizer.Adam(learning_rate=NoamDecay( ModelHyperParams.d_model, TrainTaskConfig.warmup_steps, TrainTaskConfig.learning_rate), + parameter_list = transformer.parameters(), beta1=TrainTaskConfig.beta1, beta2=TrainTaskConfig.beta2, epsilon=TrainTaskConfig.eps)