未验证 提交 b07584dc 编写于 作者: J Jiabin Yang 提交者: GitHub

test=release/1.4, refine test_imperative_transformer (#16737)

上级 cb9c59bd
...@@ -116,7 +116,7 @@ class ModelHyperParams(object): ...@@ -116,7 +116,7 @@ class ModelHyperParams(object):
# to process after each sub-layer # to process after each sub-layer
postprocess_cmd = "da" # dropout + residual connection postprocess_cmd = "da" # dropout + residual connection
# random seed used in dropout for CE. # random seed used in dropout for CE.
dropout_seed = 1 dropout_seed = None
# the flag indicating whether to share embedding and softmax weights. # the flag indicating whether to share embedding and softmax weights.
# vocabularies in source and target should be same for weight sharing. # vocabularies in source and target should be same for weight sharing.
weight_sharing = True weight_sharing = True
...@@ -166,15 +166,21 @@ def create_data(is_static=False): ...@@ -166,15 +166,21 @@ def create_data(is_static=False):
] ]
else: else:
enc_inputs = [ enc_inputs = [
to_variable(src_word_np), to_variable(src_pos_np), to_variable(
to_variable(src_slf_attn_bias_np) src_word_np, name='src_word'), to_variable(
src_pos_np, name='src_pos'), to_variable(
src_slf_attn_bias_np, name='src_slf_attn_bias')
] ]
dec_inputs = [ dec_inputs = [
to_variable(trg_word_np), to_variable(trg_pos_np), to_variable(
to_variable(trg_slf_attn_bias_np), to_variable(trg_src_attn_bias_np) trg_word_np, name='trg_word'), to_variable(
trg_pos_np, name='trg_pos'), to_variable(
trg_slf_attn_bias_np, name='trg_slf_attn_bias'),
to_variable(
trg_src_attn_bias_np, name='trg_src_attn_bias')
] ]
label = to_variable(lbl_word_np) label = to_variable(lbl_word_np, name='lbl_word')
weight = to_variable(lbl_weight_np) weight = to_variable(lbl_weight_np, name='lbl_weight')
return enc_inputs, dec_inputs, label, weight return enc_inputs, dec_inputs, label, weight
...@@ -211,7 +217,7 @@ def make_all_inputs(input_fields): ...@@ -211,7 +217,7 @@ def make_all_inputs(input_fields):
# The placeholder for batch_size in compile time. Must be -1 currently to be # The placeholder for batch_size in compile time. Must be -1 currently to be
# consistent with some ops' infer-shape output in compile time, such as the # consistent with some ops' infer-shape output in compile time, such as the
# sequence_expand op used in beamsearch decoder. # sequence_expand op used in beamsearch decoder.
batch_size = 32 batch_size = -1
# The placeholder for squence length in compile time. # The placeholder for squence length in compile time.
seq_len = ModelHyperParams.max_length seq_len = ModelHyperParams.max_length
# Here list the data shapes and data types of all inputs. # Here list the data shapes and data types of all inputs.
...@@ -304,35 +310,40 @@ sync = False ...@@ -304,35 +310,40 @@ sync = False
batch_num = 5 batch_num = 5
np.random.seed = 1 np.random.seed = 90
src_word_np = np.random.randint( src_word_np = np.random.randint(
1, 1,
ModelHyperParams.src_vocab_size - 1, ModelHyperParams.src_vocab_size - 1,
size=(batch_size, seq_len, 1), size=(TrainTaskConfig.batch_size, seq_len, 1),
dtype='int64') dtype='int64')
src_pos_np = np.random.randint( src_pos_np = np.random.randint(
1, seq_len, size=(batch_size, seq_len, 1), dtype='int64') 1, seq_len, size=(TrainTaskConfig.batch_size, seq_len, 1), dtype='int64')
src_slf_attn_bias_np = np.random.randn(batch_size, ModelHyperParams.n_head, src_slf_attn_bias_np = np.random.randn(TrainTaskConfig.batch_size,
seq_len, seq_len).astype('float32') ModelHyperParams.n_head, seq_len,
seq_len).astype('float32')
trg_word_np = np.random.randint( trg_word_np = np.random.randint(
1, 1,
ModelHyperParams.src_vocab_size - 1, ModelHyperParams.src_vocab_size - 1,
size=(batch_size, seq_len, 1), size=(TrainTaskConfig.batch_size, seq_len, 1),
dtype='int64') dtype='int64')
trg_pos_np = np.random.randint( trg_pos_np = np.random.randint(
1, seq_len, size=(batch_size, seq_len, 1), dtype='int64') 1, seq_len, size=(TrainTaskConfig.batch_size, seq_len, 1), dtype='int64')
trg_slf_attn_bias_np = np.random.randn(batch_size, ModelHyperParams.n_head, trg_slf_attn_bias_np = np.random.randn(TrainTaskConfig.batch_size,
seq_len, seq_len).astype('float32') ModelHyperParams.n_head, seq_len,
trg_src_attn_bias_np = np.random.randn(batch_size, ModelHyperParams.n_head, seq_len).astype('float32')
seq_len, seq_len).astype('float32') trg_src_attn_bias_np = np.random.randn(TrainTaskConfig.batch_size,
ModelHyperParams.n_head, seq_len,
seq_len).astype('float32')
lbl_word_np = np.random.randint( lbl_word_np = np.random.randint(
1, 1,
ModelHyperParams.src_vocab_size - 1, ModelHyperParams.src_vocab_size - 1,
size=(batch_size * seq_len, 1), size=(TrainTaskConfig.batch_size * seq_len, 1),
dtype='int64') dtype='int64')
lbl_weight_np = np.random.randn(batch_size * seq_len, 1).astype('float32')
lbl_weight_np = np.random.randn(TrainTaskConfig.batch_size * seq_len,
1).astype('float32')
pos_inp1 = position_encoding_init(ModelHyperParams.max_length, pos_inp1 = position_encoding_init(ModelHyperParams.max_length,
ModelHyperParams.d_model) ModelHyperParams.d_model)
...@@ -447,7 +458,7 @@ class MultiHeadAttentionLayer(Layer): ...@@ -447,7 +458,7 @@ class MultiHeadAttentionLayer(Layer):
x=v, shape=[0, 0, self._n_head, self._d_value], inplace=False) x=v, shape=[0, 0, self._n_head, self._d_value], inplace=False)
transpose_v = fluid.layers.transpose(x=reshaped_v, perm=[0, 2, 1, 3]) transpose_v = fluid.layers.transpose(x=reshaped_v, perm=[0, 2, 1, 3])
#scale dot product attention # scale dot product attention
product = fluid.layers.matmul( product = fluid.layers.matmul(
x=transpose_q, x=transpose_q,
y=transpose_k, y=transpose_k,
...@@ -971,6 +982,7 @@ class TestDygraphTransformer(unittest.TestCase): ...@@ -971,6 +982,7 @@ class TestDygraphTransformer(unittest.TestCase):
enc_inputs, dec_inputs, label, weights = create_data() enc_inputs, dec_inputs, label, weights = create_data()
dy_sum_cost, dy_avg_cost, dy_predict, dy_token_num = transformer( dy_sum_cost, dy_avg_cost, dy_predict, dy_token_num = transformer(
enc_inputs, dec_inputs, label, weights) enc_inputs, dec_inputs, label, weights)
if i == 0: if i == 0:
for param in transformer.parameters(): for param in transformer.parameters():
dy_param_init[param.name] = param._numpy() dy_param_init[param.name] = param._numpy()
...@@ -978,6 +990,7 @@ class TestDygraphTransformer(unittest.TestCase): ...@@ -978,6 +990,7 @@ class TestDygraphTransformer(unittest.TestCase):
dy_avg_cost._backward() dy_avg_cost._backward()
optimizer.minimize(dy_avg_cost) optimizer.minimize(dy_avg_cost)
transformer.clear_gradients() transformer.clear_gradients()
if i == batch_num - 1: if i == batch_num - 1:
for param in transformer.parameters(): for param in transformer.parameters():
dy_param_updated[param.name] = param._numpy() dy_param_updated[param.name] = param._numpy()
...@@ -1024,7 +1037,6 @@ class TestDygraphTransformer(unittest.TestCase): ...@@ -1024,7 +1037,6 @@ class TestDygraphTransformer(unittest.TestCase):
static_param_name_list = list() static_param_name_list = list()
static_sum_cost, static_avg_cost, static_predict, static_token_num = transformer( static_sum_cost, static_avg_cost, static_predict, static_token_num = transformer(
enc_inputs, dec_inputs, label, weights) enc_inputs, dec_inputs, label, weights)
optimizer.minimize(static_avg_cost) optimizer.minimize(static_avg_cost)
for param in transformer.parameters(): for param in transformer.parameters():
static_param_name_list.append(param.name) static_param_name_list.append(param.name)
...@@ -1042,8 +1054,8 @@ class TestDygraphTransformer(unittest.TestCase): ...@@ -1042,8 +1054,8 @@ class TestDygraphTransformer(unittest.TestCase):
static_sum_cost, static_avg_cost, static_predict, static_sum_cost, static_avg_cost, static_predict,
static_token_num static_token_num
] ]
fetch_list.extend(static_param_name_list)
fetch_list.extend(static_param_name_list)
out = exe.run(fluid.default_main_program(), out = exe.run(fluid.default_main_program(),
feed=feed_dict, feed=feed_dict,
fetch_list=fetch_list) fetch_list=fetch_list)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册