未验证 提交 dbcf8797 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] Support gnn ptb_rnn in eager mode (#39993)

上级 fb0cadfd
...@@ -23,6 +23,7 @@ import paddle.fluid.core as core ...@@ -23,6 +23,7 @@ import paddle.fluid.core as core
from paddle.fluid.optimizer import AdamOptimizer from paddle.fluid.optimizer import AdamOptimizer
from test_imperative_base import new_program_scope from test_imperative_base import new_program_scope
from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.framework import _test_eager_guard
def gen_data(): def gen_data():
...@@ -60,7 +61,7 @@ class GCN(fluid.Layer): ...@@ -60,7 +61,7 @@ class GCN(fluid.Layer):
class TestDygraphGNN(unittest.TestCase): class TestDygraphGNN(unittest.TestCase):
def test_gnn_float32(self): def func_gnn_float32(self):
paddle.seed(90) paddle.seed(90)
paddle.framework.random._manual_program_seed(90) paddle.framework.random._manual_program_seed(90)
startup = fluid.Program() startup = fluid.Program()
...@@ -168,6 +169,11 @@ class TestDygraphGNN(unittest.TestCase): ...@@ -168,6 +169,11 @@ class TestDygraphGNN(unittest.TestCase):
self.assertTrue(np.allclose(static_weight, model2_gc_weight_value)) self.assertTrue(np.allclose(static_weight, model2_gc_weight_value))
sys.stderr.write('%s %s\n' % (static_loss, loss_value)) sys.stderr.write('%s %s\n' % (static_loss, loss_value))
def test_gnn_float32(self):
with _test_eager_guard():
self.func_gnn_float32()
self.func_gnn_float32()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -26,10 +26,11 @@ from test_imperative_base import new_program_scope ...@@ -26,10 +26,11 @@ from test_imperative_base import new_program_scope
from test_imperative_ptb_rnn import PtbModel from test_imperative_ptb_rnn import PtbModel
import numpy as np import numpy as np
import six import six
from paddle.fluid.framework import _test_eager_guard
class TestDygraphPtbRnnSortGradient(unittest.TestCase): class TestDygraphPtbRnnSortGradient(unittest.TestCase):
def test_ptb_rnn_sort_gradient(self): def func_ptb_rnn_sort_gradient(self):
for is_sparse in [True, False]: for is_sparse in [True, False]:
self.ptb_rnn_sort_gradient_cpu_float32(is_sparse) self.ptb_rnn_sort_gradient_cpu_float32(is_sparse)
...@@ -171,6 +172,11 @@ class TestDygraphPtbRnnSortGradient(unittest.TestCase): ...@@ -171,6 +172,11 @@ class TestDygraphPtbRnnSortGradient(unittest.TestCase):
for key, value in six.iteritems(static_param_updated): for key, value in six.iteritems(static_param_updated):
self.assertTrue(np.array_equal(value, dy_param_updated[key])) self.assertTrue(np.array_equal(value, dy_param_updated[key]))
def test_ptb_rnn_sort_gradient(self):
with _test_eager_guard():
self.func_ptb_rnn_sort_gradient()
self.func_ptb_rnn_sort_gradient()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册