From 0f31ed7128fbf1f8476f38f3819425f41f345410 Mon Sep 17 00:00:00 2001 From: lyuwenyu Date: Tue, 22 Jun 2021 20:31:42 +0800 Subject: [PATCH] add grad unittest --- python/paddle/tests/test_model.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/python/paddle/tests/test_model.py b/python/paddle/tests/test_model.py index 904d5732d2a..a970489b92a 100644 --- a/python/paddle/tests/test_model.py +++ b/python/paddle/tests/test_model.py @@ -728,18 +728,25 @@ class TestModelFunction(unittest.TestCase): inputs = [InputSpec([None, dim], 'float32', 'x')] labels = [InputSpec([None, 1], 'int64', 'label')] - model = Model(net, inputs, labels) - model.prepare(optim, loss=CrossEntropyLoss(reduction="sum")) - loss1, = model.train_batch([data], [label], update=False) - loss2, = model.train_batch([data], [label], update=True) - np.testing.assert_almost_equal(loss1, loss2, decimal=4) - - model = Model(net, inputs, labels) - model.prepare( - optim, loss=CrossEntropyLoss(reduction="sum"), amp_configs='O1') - loss1, = model.train_batch([data], [label], update=False) - loss2, = model.train_batch([data], [label], update=True) - np.testing.assert_almost_equal(loss1, loss2, decimal=4) + for amp_cfg in [None, 'O1']: + model = Model(net, inputs, labels) + model.prepare( + optim, + loss=CrossEntropyLoss(reduction="sum"), + amp_configs=amp_cfg) + losses, grads = [], [] + for stat in [False, False, True]: + loss, = model.train_batch([data], [label], update=stat) + losses.append(loss) + grads.append([p.grad.numpy() for p in net.parameters()]) + + for grad1, grad2, grad3 in zip(*grads): + np.testing.assert_almost_equal(grad1 * 2, grad2, decimal=4) + np.testing.assert_almost_equal( + grad3, np.zeros_like(grad3), decimal=4) + + np.testing.assert_almost_equal(losses[0], losses[1], decimal=4) + np.testing.assert_almost_equal(losses[0], losses[2], decimal=4) class TestModelWithLRScheduler(unittest.TestCase): -- GitLab