未验证 提交 2ae10efd 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] Support transformer tests in eager mode (#41347)

上级 bce9c8c4
......@@ -21,7 +21,7 @@ from paddle.fluid import Embedding, LayerNorm, Linear, Layer
from paddle.fluid.dygraph import to_variable, guard
from paddle.fluid.dygraph import TracedLayer
from test_imperative_base import new_program_scope
from paddle.fluid.framework import _test_eager_guard
from paddle.fluid.framework import _test_eager_guard, in_dygraph_mode, _in_legacy_dygraph
from paddle.fluid import core
import numpy as np
import six
......@@ -1041,8 +1041,9 @@ class TestDygraphTransformerSortGradient(unittest.TestCase):
with guard():
fluid.set_flags({'FLAGS_sort_sum_gradient': True})
dy_avg_cost_value, dy_sum_cost_value, dy_predict_value, dy_token_num_value, \
dy_param_init, dy_param_updated = run_dygraph()
if _in_legacy_dygraph():
dy_avg_cost_value, dy_sum_cost_value, dy_predict_value, dy_token_num_value, \
dy_param_init, dy_param_updated = run_dygraph()
with new_program_scope():
paddle.seed(seed)
......@@ -1116,21 +1117,22 @@ class TestDygraphTransformerSortGradient(unittest.TestCase):
for k in range(4, len(out)):
static_param_updated[static_param_name_list[k -
4]] = out[k]
self.assertTrue(
np.array_equal(static_avg_cost_value, dy_avg_cost_value))
self.assertTrue(
np.array_equal(static_sum_cost_value, dy_sum_cost_value))
self.assertTrue(np.array_equal(static_predict_value, dy_predict_value))
self.assertTrue(
np.array_equal(static_token_num_value, dy_token_num_value))
for key, value in six.iteritems(static_param_init):
self.assertTrue(np.array_equal(value, dy_param_init[key]))
for key, value in six.iteritems(static_param_updated):
self.assertTrue(np.array_equal(value, dy_param_updated[key]))
# check eager result
if _in_legacy_dygraph():
self.assertTrue(
np.array_equal(static_avg_cost_value, dy_avg_cost_value))
self.assertTrue(
np.array_equal(static_sum_cost_value, dy_sum_cost_value))
self.assertTrue(
np.array_equal(static_predict_value, dy_predict_value))
self.assertTrue(
np.array_equal(static_token_num_value, dy_token_num_value))
for key, value in six.iteritems(static_param_init):
self.assertTrue(np.array_equal(value, dy_param_init[key]))
for key, value in six.iteritems(static_param_updated):
self.assertTrue(np.array_equal(value, dy_param_updated[key]))
# compare eager result with imperative result
with guard():
fluid.set_flags({'FLAGS_sort_sum_gradient': False})
dy_avg_cost_value, dy_sum_cost_value, dy_predict_value, dy_token_num_value, \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册