未验证 提交 c47ae621 编写于 作者: H hong 提交者: GitHub

add eager test in rnn and fc; test=develop (#40149)

上级 d50fb43e
...@@ -24,6 +24,7 @@ import paddle.fluid.core as core ...@@ -24,6 +24,7 @@ import paddle.fluid.core as core
from test_imperative_base import new_program_scope from test_imperative_base import new_program_scope
from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph import Linear from paddle.fluid.dygraph import Linear
from paddle.fluid.framework import _test_eager_guard
# Can use Amusic dataset as the DeepCF describes. # Can use Amusic dataset as the DeepCF describes.
DATA_PATH = os.environ.get('DATA_PATH', '') DATA_PATH = os.environ.get('DATA_PATH', '')
...@@ -294,9 +295,42 @@ class TestDygraphDeepCF(unittest.TestCase): ...@@ -294,9 +295,42 @@ class TestDygraphDeepCF(unittest.TestCase):
sys.stderr.write('dynamic loss: %s %s\n' % sys.stderr.write('dynamic loss: %s %s\n' %
(slice, dy_loss2)) (slice, dy_loss2))
with fluid.dygraph.guard():
with _test_eager_guard():
paddle.seed(seed)
paddle.framework.random._manual_program_seed(seed)
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
deepcf = DeepCF(num_users, num_items, matrix)
adam = fluid.optimizer.AdamOptimizer(
0.01, parameter_list=deepcf.parameters())
for e in range(NUM_EPOCHES):
sys.stderr.write('epoch %d\n' % e)
for slice in range(0, BATCH_SIZE * NUM_BATCHES, BATCH_SIZE):
if slice + BATCH_SIZE >= users_np.shape[0]:
break
prediction = deepcf(
to_variable(users_np[slice:slice + BATCH_SIZE]),
to_variable(items_np[slice:slice + BATCH_SIZE]))
loss = fluid.layers.reduce_sum(
fluid.layers.log_loss(prediction,
to_variable(
labels_np[slice:slice +
BATCH_SIZE])))
loss.backward()
adam.minimize(loss)
deepcf.clear_gradients()
eager_loss = loss.numpy()
sys.stderr.write('eager loss: %s %s\n' %
(slice, eager_loss))
self.assertEqual(static_loss, dy_loss) self.assertEqual(static_loss, dy_loss)
self.assertEqual(static_loss, dy_loss2) self.assertEqual(static_loss, dy_loss2)
self.assertEqual(static_loss, eager_loss)
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static()
unittest.main() unittest.main()
...@@ -16,9 +16,11 @@ from __future__ import print_function ...@@ -16,9 +16,11 @@ from __future__ import print_function
import unittest import unittest
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.dygraph.nn import Embedding from paddle.fluid.dygraph.nn import Embedding
import paddle.fluid.framework as framework import paddle.fluid.framework as framework
from paddle.fluid.framework import _test_eager_guard
from paddle.fluid.optimizer import SGDOptimizer from paddle.fluid.optimizer import SGDOptimizer
from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.base import to_variable
from test_imperative_base import new_program_scope from test_imperative_base import new_program_scope
...@@ -60,6 +62,25 @@ class TestRecurrentFeed(unittest.TestCase): ...@@ -60,6 +62,25 @@ class TestRecurrentFeed(unittest.TestCase):
original_in1.stop_gradient = True original_in1.stop_gradient = True
rt.clear_gradients() rt.clear_gradients()
with fluid.dygraph.guard():
with _test_eager_guard():
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
original_in1 = to_variable(original_np1)
original_in2 = to_variable(original_np2)
original_in1.stop_gradient = False
original_in2.stop_gradient = False
rt = RecurrentTest("RecurrentTest")
for i in range(3):
sum_out, out = rt(original_in1, original_in2)
original_in1 = out
eager_sum_out_value = sum_out.numpy()
sum_out.backward()
eager_dyout = out.gradient()
original_in1.stop_gradient = True
rt.clear_gradients()
with new_program_scope(): with new_program_scope():
fluid.default_startup_program().random_seed = seed fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed fluid.default_main_program().random_seed = seed
...@@ -88,8 +109,11 @@ class TestRecurrentFeed(unittest.TestCase): ...@@ -88,8 +109,11 @@ class TestRecurrentFeed(unittest.TestCase):
original_np1 = static_out_value original_np1 = static_out_value
self.assertTrue(np.array_equal(static_sum_out, sum_out_value)) self.assertTrue(np.array_equal(static_sum_out, sum_out_value))
self.assertTrue(np.array_equal(static_sum_out, eager_sum_out_value))
self.assertTrue(np.array_equal(static_dout, dyout)) self.assertTrue(np.array_equal(static_dout, dyout))
self.assertTrue(np.array_equal(static_dout, eager_dyout))
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static()
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册