未验证 提交 dbcf8797 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] Support gnn ptb_rnn in eager mode (#39993)

上级 fb0cadfd
......@@ -23,6 +23,7 @@ import paddle.fluid.core as core
from paddle.fluid.optimizer import AdamOptimizer
from test_imperative_base import new_program_scope
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.framework import _test_eager_guard
def gen_data():
......@@ -60,7 +61,7 @@ class GCN(fluid.Layer):
class TestDygraphGNN(unittest.TestCase):
def test_gnn_float32(self):
def func_gnn_float32(self):
paddle.seed(90)
paddle.framework.random._manual_program_seed(90)
startup = fluid.Program()
......@@ -168,6 +169,11 @@ class TestDygraphGNN(unittest.TestCase):
self.assertTrue(np.allclose(static_weight, model2_gc_weight_value))
sys.stderr.write('%s %s\n' % (static_loss, loss_value))
def test_gnn_float32(self):
with _test_eager_guard():
self.func_gnn_float32()
self.func_gnn_float32()
if __name__ == '__main__':
unittest.main()
......@@ -26,10 +26,11 @@ from test_imperative_base import new_program_scope
from test_imperative_ptb_rnn import PtbModel
import numpy as np
import six
from paddle.fluid.framework import _test_eager_guard
class TestDygraphPtbRnnSortGradient(unittest.TestCase):
def test_ptb_rnn_sort_gradient(self):
def func_ptb_rnn_sort_gradient(self):
for is_sparse in [True, False]:
self.ptb_rnn_sort_gradient_cpu_float32(is_sparse)
......@@ -171,6 +172,11 @@ class TestDygraphPtbRnnSortGradient(unittest.TestCase):
for key, value in six.iteritems(static_param_updated):
self.assertTrue(np.array_equal(value, dy_param_updated[key]))
def test_ptb_rnn_sort_gradient(self):
with _test_eager_guard():
self.func_ptb_rnn_sort_gradient()
self.func_ptb_rnn_sort_gradient()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册