未验证 提交 2ae10efd 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] Support transformer tests in eager mode (#41347)

上级 bce9c8c4
...@@ -21,7 +21,7 @@ from paddle.fluid import Embedding, LayerNorm, Linear, Layer ...@@ -21,7 +21,7 @@ from paddle.fluid import Embedding, LayerNorm, Linear, Layer
from paddle.fluid.dygraph import to_variable, guard from paddle.fluid.dygraph import to_variable, guard
from paddle.fluid.dygraph import TracedLayer from paddle.fluid.dygraph import TracedLayer
from test_imperative_base import new_program_scope from test_imperative_base import new_program_scope
from paddle.fluid.framework import _test_eager_guard from paddle.fluid.framework import _test_eager_guard, in_dygraph_mode, _in_legacy_dygraph
from paddle.fluid import core from paddle.fluid import core
import numpy as np import numpy as np
import six import six
...@@ -1041,8 +1041,9 @@ class TestDygraphTransformerSortGradient(unittest.TestCase): ...@@ -1041,8 +1041,9 @@ class TestDygraphTransformerSortGradient(unittest.TestCase):
with guard(): with guard():
fluid.set_flags({'FLAGS_sort_sum_gradient': True}) fluid.set_flags({'FLAGS_sort_sum_gradient': True})
dy_avg_cost_value, dy_sum_cost_value, dy_predict_value, dy_token_num_value, \ if _in_legacy_dygraph():
dy_param_init, dy_param_updated = run_dygraph() dy_avg_cost_value, dy_sum_cost_value, dy_predict_value, dy_token_num_value, \
dy_param_init, dy_param_updated = run_dygraph()
with new_program_scope(): with new_program_scope():
paddle.seed(seed) paddle.seed(seed)
...@@ -1116,21 +1117,22 @@ class TestDygraphTransformerSortGradient(unittest.TestCase): ...@@ -1116,21 +1117,22 @@ class TestDygraphTransformerSortGradient(unittest.TestCase):
for k in range(4, len(out)): for k in range(4, len(out)):
static_param_updated[static_param_name_list[k - static_param_updated[static_param_name_list[k -
4]] = out[k] 4]] = out[k]
if _in_legacy_dygraph():
self.assertTrue( self.assertTrue(
np.array_equal(static_avg_cost_value, dy_avg_cost_value)) np.array_equal(static_avg_cost_value, dy_avg_cost_value))
self.assertTrue( self.assertTrue(
np.array_equal(static_sum_cost_value, dy_sum_cost_value)) np.array_equal(static_sum_cost_value, dy_sum_cost_value))
self.assertTrue(np.array_equal(static_predict_value, dy_predict_value)) self.assertTrue(
self.assertTrue( np.array_equal(static_predict_value, dy_predict_value))
np.array_equal(static_token_num_value, dy_token_num_value)) self.assertTrue(
np.array_equal(static_token_num_value, dy_token_num_value))
for key, value in six.iteritems(static_param_init):
self.assertTrue(np.array_equal(value, dy_param_init[key])) for key, value in six.iteritems(static_param_init):
for key, value in six.iteritems(static_param_updated): self.assertTrue(np.array_equal(value, dy_param_init[key]))
self.assertTrue(np.array_equal(value, dy_param_updated[key])) for key, value in six.iteritems(static_param_updated):
self.assertTrue(np.array_equal(value, dy_param_updated[key]))
# check eager result
# compare eager result with imperative result
with guard(): with guard():
fluid.set_flags({'FLAGS_sort_sum_gradient': False}) fluid.set_flags({'FLAGS_sort_sum_gradient': False})
dy_avg_cost_value, dy_sum_cost_value, dy_predict_value, dy_token_num_value, \ dy_avg_cost_value, dy_sum_cost_value, dy_predict_value, dy_token_num_value, \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册