提交 a0478084 编写于 作者: M minqiyang

Right transformer

上级 124f45c9
......@@ -517,7 +517,7 @@ class DecoderSubLayer(Layer):
y = self._preprocess_layer(None, input, "n", 0.1)
slf_attn_output = self._multihead_attention_layer(y, None, None,
slf_attn_bias)
return slf_attn_output
return slf_attn_output, y
class TestDygraphTransformer(unittest.TestCase):
......@@ -536,7 +536,7 @@ class TestDygraphTransformer(unittest.TestCase):
dy_param_init = dict()
dy_param_updated = dict()
for i in range(batch_num):
loss = transformer(to_variable(x1), to_variable(x2))
loss, y = transformer(to_variable(x1), to_variable(x2))
loss = fluid.layers.reduce_sum(loss)
print('dy los', loss.shape)
if i == 0:
......@@ -545,6 +545,7 @@ class TestDygraphTransformer(unittest.TestCase):
loss._backward()
optimizer.minimize(loss)
dy_key_value = y._gradient()
transformer.clear_gradients()
if i == batch_num - 1:
for param in transformer.parameters():
......@@ -563,7 +564,7 @@ class TestDygraphTransformer(unittest.TestCase):
data1 = fluid.layers.data(name='X', shape=[4, 512], dtype='float32')
data2 = fluid.layers.data(
name='Y', shape=[8, 4, 4], dtype='float32')
loss = transformer(data1, data2)
loss, y = transformer(data1, data2)
loss = fluid.layers.reduce_sum(loss)
print('loss hspae', loss.shape)
......@@ -580,24 +581,33 @@ class TestDygraphTransformer(unittest.TestCase):
for i in range(len(static_param_name_list)):
static_param_init[static_param_name_list[i]] = out[i]
print(fluid.default_main_program())
for i in range(batch_num):
feed_dict = {"X": x1, "Y": x2}
fetch_list = []
fetch_list = [
"transformer/DecoderSubLayer_0/PrePostProcessLayer_0/LayerNorm_0.tmp_2@GRAD"
]
fetch_list.extend(static_param_name_list)
out = exe.run(fluid.default_main_program(),
feed=feed_dict,
fetch_list=fetch_list)
if i == batch_num - 1:
for k in range(0, len(out)):
static_key_value = out[0]
for k in range(1, len(out)):
static_param_updated[static_param_name_list[k -
0]] = out[k]
1]] = out[k]
for key, value in six.iteritems(static_param_init):
self.assertTrue(np.array_equal(value, dy_param_init[key]))
for key, value in six.iteritems(static_param_updated):
if not (value == dy_param_updated[key]).all():
print(key)
if not np.array_equal(dy_key_value, static_key_value):
print("xxx", dy_key_value, static_key_value)
print("yyy")
print(dy_key_value - static_key_value)
print(np.where(dy_key_value - static_key_value))
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册