diff --git a/python/paddle/fluid/tests/unittests/test_imperative_deepcf.py b/python/paddle/fluid/tests/unittests/test_imperative_deepcf.py index 04a0e5e4cd10f7ece370e879986056d508c894ff..3e222e3c658ecd105811f3694a25d20f1826bcda 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_deepcf.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_deepcf.py @@ -24,6 +24,7 @@ import paddle.fluid.core as core from test_imperative_base import new_program_scope from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph import Linear +from paddle.fluid.framework import _test_eager_guard # Can use Amusic dataset as the DeepCF describes. DATA_PATH = os.environ.get('DATA_PATH', '') @@ -294,9 +295,42 @@ class TestDygraphDeepCF(unittest.TestCase): sys.stderr.write('dynamic loss: %s %s\n' % (slice, dy_loss2)) + with fluid.dygraph.guard(): + with _test_eager_guard(): + paddle.seed(seed) + paddle.framework.random._manual_program_seed(seed) + fluid.default_startup_program().random_seed = seed + fluid.default_main_program().random_seed = seed + + deepcf = DeepCF(num_users, num_items, matrix) + adam = fluid.optimizer.AdamOptimizer( + 0.01, parameter_list=deepcf.parameters()) + + for e in range(NUM_EPOCHES): + sys.stderr.write('epoch %d\n' % e) + for slice in range(0, BATCH_SIZE * NUM_BATCHES, BATCH_SIZE): + if slice + BATCH_SIZE >= users_np.shape[0]: + break + prediction = deepcf( + to_variable(users_np[slice:slice + BATCH_SIZE]), + to_variable(items_np[slice:slice + BATCH_SIZE])) + loss = fluid.layers.reduce_sum( + fluid.layers.log_loss(prediction, + to_variable( + labels_np[slice:slice + + BATCH_SIZE]))) + loss.backward() + adam.minimize(loss) + deepcf.clear_gradients() + eager_loss = loss.numpy() + sys.stderr.write('eager loss: %s %s\n' % + (slice, eager_loss)) + self.assertEqual(static_loss, dy_loss) self.assertEqual(static_loss, dy_loss2) + self.assertEqual(static_loss, eager_loss) if __name__ == '__main__': + paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_recurrent_usage.py b/python/paddle/fluid/tests/unittests/test_imperative_recurrent_usage.py index d0b3adc490945377a4dd05f9c414cd5c35c7fae5..f12ca0a93ffd9441761c2da866c2c811a30c6e68 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_recurrent_usage.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_recurrent_usage.py @@ -16,9 +16,11 @@ from __future__ import print_function import unittest import paddle.fluid as fluid +import paddle import paddle.fluid.core as core from paddle.fluid.dygraph.nn import Embedding import paddle.fluid.framework as framework +from paddle.fluid.framework import _test_eager_guard from paddle.fluid.optimizer import SGDOptimizer from paddle.fluid.dygraph.base import to_variable from test_imperative_base import new_program_scope @@ -60,6 +62,25 @@ class TestRecurrentFeed(unittest.TestCase): original_in1.stop_gradient = True rt.clear_gradients() + with fluid.dygraph.guard(): + with _test_eager_guard(): + fluid.default_startup_program().random_seed = seed + fluid.default_main_program().random_seed = seed + original_in1 = to_variable(original_np1) + original_in2 = to_variable(original_np2) + original_in1.stop_gradient = False + original_in2.stop_gradient = False + rt = RecurrentTest("RecurrentTest") + + for i in range(3): + sum_out, out = rt(original_in1, original_in2) + original_in1 = out + eager_sum_out_value = sum_out.numpy() + sum_out.backward() + eager_dyout = out.gradient() + original_in1.stop_gradient = True + rt.clear_gradients() + with new_program_scope(): fluid.default_startup_program().random_seed = seed fluid.default_main_program().random_seed = seed @@ -88,8 +109,11 @@ class TestRecurrentFeed(unittest.TestCase): original_np1 = static_out_value self.assertTrue(np.array_equal(static_sum_out, sum_out_value)) + self.assertTrue(np.array_equal(static_sum_out, eager_sum_out_value)) self.assertTrue(np.array_equal(static_dout, dyout)) + self.assertTrue(np.array_equal(static_dout, eager_dyout)) if __name__ == '__main__': + paddle.enable_static() unittest.main()