提交 1a55f7d3 编写于 作者: M minqiyang

Change from width-first backward to deep-first backward process

test=develop
上级 a0478084
......@@ -129,14 +129,14 @@ class Autograd {
std::map<OpBase*, int> dep_counts = ComputeDepCounts(var->PreOp());
while (!ready.empty()) {
OpBase* ready_op = ready.front();
ready.pop_front();
OpBase* ready_op = ready.back();
ready.pop_back();
std::map<std::string, std::vector<VarBase*>> input_grads =
ready_op->ApplyGrad();
for (auto it : input_grads) {
const std::vector<VarBase*>& ingrads = it.second;
for (size_t i = 0; i < ingrads.size(); ++i) {
for (int64_t i = ingrads.size() - 1; i >= 0; --i) {
if (!ingrads[i]) continue;
if (ready_op->input_vars_[it.first][i]->IsStopGradient()) {
continue;
......
......@@ -106,7 +106,7 @@ class ModelHyperParams(object):
# number of head used in multi-head attention.
n_head = 8
# number of sub-layers to be stacked in the encoder and decoder.
n_layer = 1
n_layer = 6
# dropout rates of different modules.
prepostprocess_dropout = 0.1
attention_dropout = 0.1
......@@ -303,7 +303,7 @@ use_py_reader = False
sync = False
# how many batches we use
batch_num = 1
batch_num = 2
np.random.seed = 1
src_word_np = np.random.randint(
......@@ -359,6 +359,59 @@ pos_inp2 = position_encoding_init(ModelHyperParams.max_length,
ModelHyperParams.d_model)
class PrePostProcessLayer(Layer):
def __init__(self, name_scope, process_cmd, shape_len=None):
super(PrePostProcessLayer, self).__init__(name_scope)
for cmd in process_cmd:
if cmd == "n":
self._layer_norm = LayerNorm(
name_scope=self.full_name(),
begin_norm_axis=shape_len - 1,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(1.)),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(0.)))
def forward(self, prev_out, out, process_cmd, dropout_rate=0.):
for cmd in process_cmd:
if cmd == "a": # add residual connection
out = out + prev_out if prev_out else out
elif cmd == "n": # add layer normalization
out = self._layer_norm(out)
elif cmd == "d": # add dropout
if dropout_rate:
out = fluid.layers.dropout(
out,
dropout_prob=dropout_rate,
seed=ModelHyperParams.dropout_seed,
is_test=False)
return out
class PositionwiseFeedForwardLayer(Layer):
def __init__(self, name_scope, d_inner_hid, d_hid, dropout_rate):
super(PositionwiseFeedForwardLayer, self).__init__(name_scope)
self._i2h = FC(name_scope=self.full_name(),
size=d_inner_hid,
num_flatten_dims=2,
act="relu")
self._h2o = FC(name_scope=self.full_name(),
size=d_hid,
num_flatten_dims=2)
self._dropout_rate = dropout_rate
def forward(self, x):
hidden = self._i2h(x)
if self._dropout_rate:
hidden = fluid.layers.dropout(
hidden,
dropout_prob=self._dropout_rate,
seed=ModelHyperParams.dropout_seed,
is_test=False)
out = self._h2o(hidden)
return out
class MultiHeadAttentionLayer(Layer):
def __init__(self,
name_scope,
......@@ -393,22 +446,11 @@ class MultiHeadAttentionLayer(Layer):
bias_attr=False,
num_flatten_dims=2)
def _mm(self, input):
input_shape = input.shape
param_shape = [
reduce(lambda a, b: a * b, input_shape[self._num_flatten_dims:], 1)
] + [self._size]
self.x = self.create_parameter(
attr=None, shape=param_shape, dtype=self._dtype, is_bias=False)
def forward(self, queries, keys, values, attn_bias):
# compute q ,k ,v
keys = queries if keys is None else keys
values = keys if values is None else values
# q = queries
# k = keys
# v = values
q = self._q_fc(queries)
k = self._k_fc(keys)
v = self._v_fc(values)
......@@ -453,38 +495,181 @@ class MultiHeadAttentionLayer(Layer):
inplace=False)
# fc to output
print(final_out.shape)
proj_out = self._proj_fc(final_out)
return proj_out
class PrePostProcessLayer(Layer):
def __init__(self, name_scope, process_cmd, shape_len=None):
super(PrePostProcessLayer, self).__init__(name_scope)
for cmd in process_cmd:
if cmd == "n":
self._layer_norm = LayerNorm(
class EncoderSubLayer(Layer):
def __init__(self,
name_scope,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd="n",
postprocess_cmd="da"):
super(EncoderSubLayer, self).__init__(name_scope)
self._preprocess_cmd = preprocess_cmd
self._postprocess_cmd = postprocess_cmd
self._prepostprocess_dropout = prepostprocess_dropout
self._preprocess_layer = PrePostProcessLayer(self.full_name(),
self._preprocess_cmd, 3)
self._multihead_attention_layer = MultiHeadAttentionLayer(
self.full_name(), d_key, d_value, d_model, n_head,
attention_dropout)
self._postprocess_layer = PrePostProcessLayer(
self.full_name(), self._postprocess_cmd, None)
self._preprocess_layer2 = PrePostProcessLayer(self.full_name(),
self._preprocess_cmd, 3)
self._positionwise_feed_forward = PositionwiseFeedForwardLayer(
self.full_name(), d_inner_hid, d_model, relu_dropout)
self._postprocess_layer2 = PrePostProcessLayer(
self.full_name(), self._postprocess_cmd, None)
def forward(self, enc_input, attn_bias):
pre_process_multihead = self._preprocess_layer(
None, enc_input, self._preprocess_cmd, self._prepostprocess_dropout)
attn_output = self._multihead_attention_layer(pre_process_multihead,
None, None, attn_bias)
attn_output = self._postprocess_layer(enc_input, attn_output,
self._postprocess_cmd,
self._prepostprocess_dropout)
pre_process2_output = self._preprocess_layer2(
None, attn_output, self._preprocess_cmd,
self._prepostprocess_dropout)
ffd_output = self._positionwise_feed_forward(pre_process2_output)
return self._postprocess_layer2(attn_output, ffd_output,
self._postprocess_cmd,
self._prepostprocess_dropout)
class EncoderLayer(Layer):
def __init__(self,
name_scope,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd="n",
postprocess_cmd="da"):
super(EncoderLayer, self).__init__(name_scope)
self._preprocess_cmd = preprocess_cmd
self._encoder_sublayers = list()
self._prepostprocess_dropout = prepostprocess_dropout
self._n_layer = n_layer
self._preprocess_layer = PrePostProcessLayer(self.full_name(),
self._preprocess_cmd, 3)
for i in range(n_layer):
self._encoder_sublayers.append(
self.add_sublayer(
'esl_%d' % i,
EncoderSubLayer(
self.full_name(), n_head, d_key, d_value, d_model,
d_inner_hid, prepostprocess_dropout, attention_dropout,
relu_dropout, preprocess_cmd, postprocess_cmd)))
def forward(self, enc_input, attn_bias):
for i in range(self._n_layer):
enc_output = self._encoder_sublayers[i](enc_input, attn_bias)
enc_input = enc_output
return self._preprocess_layer(None, enc_output, self._preprocess_cmd,
self._prepostprocess_dropout)
class PrepareEncoderDecoderLayer(Layer):
def __init__(self,
name_scope,
src_vocab_size,
src_emb_dim,
src_max_len,
dropout_rate,
word_emb_param_name=None,
pos_enc_param_name=None):
super(PrepareEncoderDecoderLayer, self).__init__(name_scope)
self._src_max_len = src_max_len
self._src_emb_dim = src_emb_dim
self._src_vocab_size = src_vocab_size
self._dropout_rate = dropout_rate
self._input_emb = Embedding(
name_scope=self.full_name(),
begin_norm_axis=shape_len - 1,
size=[src_vocab_size, src_emb_dim],
padding_idx=0,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(1.)),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(0.)))
name=word_emb_param_name,
initializer=fluid.initializer.Normal(0., src_emb_dim**-0.5)))
def forward(self, prev_out, out, process_cmd, dropout_rate=0.):
for cmd in process_cmd:
if cmd == "a": # add residual connection
out = out + prev_out if prev_out else out
elif cmd == "n": # add layer normalization
out = self._layer_norm(out)
elif cmd == "d": # add dropout
if dropout_rate:
out = fluid.layers.dropout(
out,
dropout_prob=dropout_rate,
if pos_enc_param_name is pos_enc_param_names[0]:
pos_inp = pos_inp1
else:
pos_inp = pos_inp2
self._pos_emb = Embedding(
name_scope=self.full_name(),
size=[self._src_max_len, src_emb_dim],
param_attr=fluid.ParamAttr(
name=pos_enc_param_name,
initializer=fluid.initializer.NumpyArrayInitializer(pos_inp),
trainable=False))
# use in dygraph_mode to fit different length batch
# self._pos_emb._w = to_variable(
# position_encoding_init(self._src_max_len, self._src_emb_dim))
def forward(self, src_word, src_pos):
src_word_emb = self._input_emb(src_word)
src_word_emb = fluid.layers.scale(
x=src_word_emb, scale=self._src_emb_dim**0.5)
# # TODO change this to fit dynamic length input
src_pos_emb = self._pos_emb(src_pos)
src_pos_emb.stop_gradient = True
enc_input = src_word_emb + src_pos_emb
return fluid.layers.dropout(
enc_input,
dropout_prob=self._dropout_rate,
seed=ModelHyperParams.dropout_seed,
is_test=False)
return out
is_test=False) if self._dropout_rate else enc_input
class WrapEncoderLayer(Layer):
def __init__(self, name_cope, src_vocab_size, max_length, n_layer, n_head,
d_key, d_value, d_model, d_inner_hid, prepostprocess_dropout,
attention_dropout, relu_dropout, preprocess_cmd,
postprocess_cmd, weight_sharing):
"""
The wrapper assembles together all needed layers for the encoder.
"""
super(WrapEncoderLayer, self).__init__(name_cope)
self._prepare_encoder_layer = PrepareEncoderDecoderLayer(
self.full_name(),
src_vocab_size,
d_model,
max_length,
prepostprocess_dropout,
word_emb_param_name=word_emb_param_names[0],
pos_enc_param_name=pos_enc_param_names[0])
self._encoder = EncoderLayer(
self.full_name(), n_layer, n_head, d_key, d_value, d_model,
d_inner_hid, prepostprocess_dropout, attention_dropout,
relu_dropout, preprocess_cmd, postprocess_cmd)
def forward(self, enc_inputs):
src_word, src_pos, src_slf_attn_bias = enc_inputs
enc_input = self._prepare_encoder_layer(src_word, src_pos)
enc_output = self._encoder(enc_input, src_slf_attn_bias)
return enc_output
class DecoderSubLayer(Layer):
......@@ -494,12 +679,19 @@ class DecoderSubLayer(Layer):
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
cache=None,
preprocess_cmd="n",
gather_idx=None):
super(DecoderSubLayer, self).__init__(name_scope)
self._preprocess_layer = PrePostProcessLayer(self.full_name(),
self._postprocess_cmd = postprocess_cmd
self._preprocess_cmd = preprocess_cmd
self._prepostprcess_dropout = prepostprocess_dropout
self._pre_process_layer = PrePostProcessLayer(self.full_name(),
preprocess_cmd, 3)
self._multihead_attention_layer = MultiHeadAttentionLayer(
self.full_name(),
......@@ -510,42 +702,300 @@ class DecoderSubLayer(Layer):
attention_dropout,
cache=cache,
gather_idx=gather_idx)
self._post_process_layer = PrePostProcessLayer(self.full_name(),
postprocess_cmd, None)
self._pre_process_layer2 = PrePostProcessLayer(self.full_name(),
preprocess_cmd, 3)
self._multihead_attention_layer2 = MultiHeadAttentionLayer(
self.full_name(),
d_key,
d_value,
d_model,
n_head,
attention_dropout,
cache=cache,
gather_idx=gather_idx,
static_kv=True)
self._post_process_layer2 = PrePostProcessLayer(self.full_name(),
postprocess_cmd, None)
self._pre_process_layer3 = PrePostProcessLayer(self.full_name(),
preprocess_cmd, 3)
self._positionwise_feed_forward_layer = PositionwiseFeedForwardLayer(
self.full_name(), d_inner_hid, d_model, relu_dropout)
self._post_process_layer3 = PrePostProcessLayer(self.full_name(),
postprocess_cmd, None)
def forward(self, dec_input, enc_output, slf_attn_bias, dec_enc_attn_bias):
pre_process_rlt = self._pre_process_layer(
None, dec_input, self._preprocess_cmd, self._prepostprcess_dropout)
slf_attn_output = self._multihead_attention_layer(pre_process_rlt, None,
None, slf_attn_bias)
slf_attn_output_pp = self._post_process_layer(
dec_input, slf_attn_output, self._postprocess_cmd,
self._prepostprcess_dropout)
pre_process_rlt2 = self._pre_process_layer2(None, slf_attn_output_pp,
self._preprocess_cmd,
self._prepostprcess_dropout)
enc_attn_output_pp = self._multihead_attention_layer2(
pre_process_rlt2, enc_output, enc_output, dec_enc_attn_bias)
enc_attn_output = self._post_process_layer2(
slf_attn_output, enc_attn_output_pp, self._postprocess_cmd,
self._prepostprcess_dropout)
pre_process_rlt3 = self._pre_process_layer3(None, enc_attn_output,
self._preprocess_cmd,
self._prepostprcess_dropout)
ffd_output = self._positionwise_feed_forward_layer(pre_process_rlt3)
dec_output = self._post_process_layer3(enc_attn_output, ffd_output,
self._postprocess_cmd,
self._prepostprcess_dropout)
return dec_output
class DecoderLayer(Layer):
def __init__(self,
name_scope,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
caches=None,
gather_idx=None):
super(DecoderLayer, self).__init__(name_scope)
self._pre_process_layer = PrePostProcessLayer(self.full_name(),
preprocess_cmd, 3)
self._decoder_sub_layers = list()
self._n_layer = n_layer
self._preprocess_cmd = preprocess_cmd
self._prepostprocess_dropout = prepostprocess_dropout
for i in range(n_layer):
self._decoder_sub_layers.append(
self.add_sublayer(
'dsl_%d' % i,
DecoderSubLayer(
self.full_name(),
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
cache=None if caches is None else caches[i],
gather_idx=gather_idx)))
def forward(self, dec_input, enc_output, dec_slf_attn_bias,
dec_enc_attn_bias):
for i in range(self._n_layer):
tmp_dec_output = self._decoder_sub_layers[i](
dec_input, enc_output, dec_slf_attn_bias, dec_enc_attn_bias)
dec_input = tmp_dec_output
dec_output = self._pre_process_layer(None, tmp_dec_output,
self._preprocess_cmd,
self._prepostprocess_dropout)
return dec_output
class WrapDecoderLayer(Layer):
def __init__(self,
name_scope,
trg_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
caches=None,
gather_idx=None):
"""
The wrapper assembles together all needed layers for the encoder.
"""
super(WrapDecoderLayer, self).__init__(name_scope)
self._prepare_decoder_layer = PrepareEncoderDecoderLayer(
self.full_name(),
trg_vocab_size,
d_model,
max_length,
prepostprocess_dropout,
word_emb_param_name=word_emb_param_names[1],
pos_enc_param_name=pos_enc_param_names[1])
self._decoder_layer = DecoderLayer(
self.full_name(),
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
caches=caches,
gather_idx=gather_idx)
self._weight_sharing = weight_sharing
if not weight_sharing:
self._fc = FC(self.full_name(),
size=trg_vocab_size,
bias_attr=False)
def forward(self, dec_inputs=None, enc_output=None):
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias = dec_inputs
dec_input = self._prepare_decoder_layer(trg_word, trg_pos)
dec_output = self._decoder_layer(dec_input, enc_output,
trg_slf_attn_bias, trg_src_attn_bias)
dec_output_reshape = fluid.layers.reshape(
dec_output, shape=[-1, dec_output.shape[-1]], inplace=False)
def forward(self, input, slf_attn_bias):
print(input.shape)
print(slf_attn_bias.shape)
y = self._preprocess_layer(None, input, "n", 0.1)
slf_attn_output = self._multihead_attention_layer(y, None, None,
slf_attn_bias)
return slf_attn_output, y
if self._weight_sharing:
predict = fluid.layers.matmul(
x=dec_output_reshape,
y=self._prepare_decoder_layer._input_emb._w,
transpose_y=True)
else:
predict = self._fc(dec_output_reshape)
if dec_inputs is None:
# Return probs for independent decoder program.
predict_out = fluid.layers.softmax(predict)
return predict_out
return predict
class TransFormer(Layer):
def __init__(self,
name_scope,
src_vocab_size,
trg_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
label_smooth_eps,
use_py_reader=False,
is_test=False):
super(TransFormer, self).__init__(name_scope)
self._label_smooth_eps = label_smooth_eps
self._trg_vocab_size = trg_vocab_size
if weight_sharing:
assert src_vocab_size == trg_vocab_size, (
"Vocabularies in source and target should be same for weight sharing."
)
self._wrap_encoder_layer = WrapEncoderLayer(
self.full_name(), src_vocab_size, max_length, n_layer, n_head,
d_key, d_value, d_model, d_inner_hid, prepostprocess_dropout,
attention_dropout, relu_dropout, preprocess_cmd, postprocess_cmd,
weight_sharing)
self._wrap_decoder_layer = WrapDecoderLayer(
self.full_name(), trg_vocab_size, max_length, n_layer, n_head,
d_key, d_value, d_model, d_inner_hid, prepostprocess_dropout,
attention_dropout, relu_dropout, preprocess_cmd, postprocess_cmd,
weight_sharing)
if weight_sharing:
self._wrap_decoder_layer._prepare_decoder_layer._input_emb._w = self._wrap_encoder_layer._prepare_encoder_layer._input_emb._w
def forward(self, enc_inputs, dec_inputs, label, weights):
enc_output = self._wrap_encoder_layer(enc_inputs)
predict = self._wrap_decoder_layer(dec_inputs, enc_output)
if self._label_smooth_eps:
label_out = fluid.layers.label_smooth(
label=fluid.layers.one_hot(
input=label, depth=self._trg_vocab_size),
epsilon=self._label_smooth_eps)
cost = fluid.layers.softmax_with_cross_entropy(
logits=predict,
label=label_out,
soft_label=True if self._label_smooth_eps else False)
weighted_cost = cost * weights
sum_cost = fluid.layers.reduce_sum(weighted_cost)
token_num = fluid.layers.reduce_sum(weights)
token_num.stop_gradient = True
avg_cost = sum_cost / token_num
return sum_cost, avg_cost, predict, token_num
class TestDygraphTransformer(unittest.TestCase):
def test_transformer_float32(self):
seed = 90
x1 = np.ones([32, 4, 512]).astype('float32')
x2 = np.ones([32, 8, 4, 4]).astype('float32')
with guard(place=fluid.CPUPlace()):
with guard():
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
transformer = DecoderSubLayer(
'transformer', ModelHyperParams.n_head, ModelHyperParams.d_key,
ModelHyperParams.d_value, ModelHyperParams.d_model,
ModelHyperParams.attention_dropout)
transformer = TransFormer(
'transformer',
ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size,
ModelHyperParams.max_length + 1,
ModelHyperParams.n_layer,
ModelHyperParams.n_head,
ModelHyperParams.d_key,
ModelHyperParams.d_value,
ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid,
ModelHyperParams.prepostprocess_dropout,
ModelHyperParams.attention_dropout,
ModelHyperParams.relu_dropout,
ModelHyperParams.preprocess_cmd,
ModelHyperParams.postprocess_cmd,
ModelHyperParams.weight_sharing,
TrainTaskConfig.label_smooth_eps,
use_py_reader=use_py_reader,
is_test=False)
if sync:
lr_decay = fluid.layers.learning_rate_scheduler.noam_decay(
ModelHyperParams.d_model, TrainTaskConfig.warmup_steps)
with fluid.default_main_program()._lr_schedule_guard():
learning_rate = lr_decay * TrainTaskConfig.learning_rate
optimizer = fluid.optimizer.Adam(
learning_rate=learning_rate,
beta1=TrainTaskConfig.beta1,
beta2=TrainTaskConfig.beta2,
epsilon=TrainTaskConfig.eps)
else:
optimizer = fluid.optimizer.SGD(learning_rate=0.003)
dy_param_init = dict()
dy_param_updated = dict()
for i in range(batch_num):
loss, y = transformer(to_variable(x1), to_variable(x2))
loss = fluid.layers.reduce_sum(loss)
print('dy los', loss.shape)
enc_inputs, dec_inputs, label, weights = create_data()
dy_sum_cost, dy_avg_cost, dy_predict, dy_token_num = transformer(
enc_inputs, dec_inputs, label, weights)
if i == 0:
for param in transformer.parameters():
dy_param_init[param.name] = param._numpy()
loss._backward()
optimizer.minimize(loss)
dy_key_value = y._gradient()
dy_avg_cost._backward()
optimizer.minimize(dy_avg_cost)
transformer.clear_gradients()
if i == batch_num - 1:
for param in transformer.parameters():
......@@ -554,60 +1004,92 @@ class TestDygraphTransformer(unittest.TestCase):
with new_program_scope():
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
transformer = DecoderSubLayer(
'transformer', ModelHyperParams.n_head, ModelHyperParams.d_key,
ModelHyperParams.d_value, ModelHyperParams.d_model,
ModelHyperParams.attention_dropout)
exe = fluid.Executor(fluid.CPUPlace())
transformer = TransFormer(
'transformer',
ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size,
ModelHyperParams.max_length + 1,
ModelHyperParams.n_layer,
ModelHyperParams.n_head,
ModelHyperParams.d_key,
ModelHyperParams.d_value,
ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid,
ModelHyperParams.prepostprocess_dropout,
ModelHyperParams.attention_dropout,
ModelHyperParams.relu_dropout,
ModelHyperParams.preprocess_cmd,
ModelHyperParams.postprocess_cmd,
ModelHyperParams.weight_sharing,
TrainTaskConfig.label_smooth_eps,
use_py_reader=use_py_reader,
is_test=False)
exe = fluid.Executor(fluid.CPUPlace(
) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0))
optimizer = fluid.optimizer.SGD(learning_rate=0.003)
data1 = fluid.layers.data(name='X', shape=[4, 512], dtype='float32')
data2 = fluid.layers.data(
name='Y', shape=[8, 4, 4], dtype='float32')
loss, y = transformer(data1, data2)
loss = fluid.layers.reduce_sum(loss)
print('loss hspae', loss.shape)
optimizer.minimize(loss)
data_input_names = encoder_data_input_fields + decoder_data_input_fields[:
-1] + label_data_input_fields
all_inputs = make_all_inputs(data_input_names)
enc_inputs_len = len(encoder_data_input_fields)
dec_inputs_len = len(decoder_data_input_fields[:-1])
enc_inputs = all_inputs[0:enc_inputs_len]
dec_inputs = all_inputs[enc_inputs_len:enc_inputs_len +
dec_inputs_len]
label = all_inputs[-2]
weights = all_inputs[-1]
static_param_updated = dict()
static_param_init = dict()
static_param_name_list = list()
static_sum_cost, static_avg_cost, static_predict, static_token_num = transformer(
enc_inputs, dec_inputs, label, weights)
static_param_init = {}
static_param_name_list = []
static_param_updated = {}
optimizer.minimize(static_avg_cost)
for param in transformer.parameters():
static_param_name_list.append(param.name)
out = exe.run(fluid.default_startup_program(),
fetch_list=static_param_name_list)
for i in range(len(static_param_name_list)):
static_param_init[static_param_name_list[i]] = out[i]
print(fluid.default_main_program())
static_sum_cost_value = None
static_avg_cost_value = None
static_predict_value = None
static_token_num_value = None
for i in range(batch_num):
feed_dict = {"X": x1, "Y": x2}
feed_dict = create_feed_dict_list(create_data(True))
fetch_list = [
"transformer/DecoderSubLayer_0/PrePostProcessLayer_0/LayerNorm_0.tmp_2@GRAD"
static_sum_cost, static_avg_cost, static_predict,
static_token_num
]
fetch_list.extend(static_param_name_list)
out = exe.run(fluid.default_main_program(),
feed=feed_dict,
fetch_list=fetch_list)
static_sum_cost_value = out[0]
static_avg_cost_value = out[1]
static_predict_value = out[2]
static_token_num_value = out[3]
if i == batch_num - 1:
static_key_value = out[0]
for k in range(1, len(out)):
for k in range(4, len(out)):
static_param_updated[static_param_name_list[k -
1]] = out[k]
4]] = out[k]
self.assertTrue(
np.allclose(static_avg_cost_value, dy_avg_cost._numpy()))
self.assertTrue(
np.allclose(static_sum_cost_value, dy_sum_cost._numpy()))
self.assertTrue(
np.allclose(
static_predict_value, dy_predict._numpy(), atol=1e-5))
self.assertTrue(
np.allclose(static_token_num_value, dy_token_num._numpy()))
for key, value in six.iteritems(static_param_init):
self.assertTrue(np.array_equal(value, dy_param_init[key]))
self.assertTrue(np.allclose(value, dy_param_init[key]))
for key, value in six.iteritems(static_param_updated):
if not (value == dy_param_updated[key]).all():
print(key)
if not np.array_equal(dy_key_value, static_key_value):
print("xxx", dy_key_value, static_key_value)
print("yyy")
print(dy_key_value - static_key_value)
print(np.where(dy_key_value - static_key_value))
self.assertTrue(
np.allclose(
value, dy_param_updated[key], atol=1e-4))
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册