提交 0f31ed71 编写于 作者: L lyuwenyu 提交者: jzhang533

add grad unittest

上级 9a73283b
...@@ -728,18 +728,25 @@ class TestModelFunction(unittest.TestCase): ...@@ -728,18 +728,25 @@ class TestModelFunction(unittest.TestCase):
inputs = [InputSpec([None, dim], 'float32', 'x')] inputs = [InputSpec([None, dim], 'float32', 'x')]
labels = [InputSpec([None, 1], 'int64', 'label')] labels = [InputSpec([None, 1], 'int64', 'label')]
model = Model(net, inputs, labels) for amp_cfg in [None, 'O1']:
model.prepare(optim, loss=CrossEntropyLoss(reduction="sum")) model = Model(net, inputs, labels)
loss1, = model.train_batch([data], [label], update=False) model.prepare(
loss2, = model.train_batch([data], [label], update=True) optim,
np.testing.assert_almost_equal(loss1, loss2, decimal=4) loss=CrossEntropyLoss(reduction="sum"),
amp_configs=amp_cfg)
model = Model(net, inputs, labels) losses, grads = [], []
model.prepare( for stat in [False, False, True]:
optim, loss=CrossEntropyLoss(reduction="sum"), amp_configs='O1') loss, = model.train_batch([data], [label], update=stat)
loss1, = model.train_batch([data], [label], update=False) losses.append(loss)
loss2, = model.train_batch([data], [label], update=True) grads.append([p.grad.numpy() for p in net.parameters()])
np.testing.assert_almost_equal(loss1, loss2, decimal=4)
for grad1, grad2, grad3 in zip(*grads):
np.testing.assert_almost_equal(grad1 * 2, grad2, decimal=4)
np.testing.assert_almost_equal(
grad3, np.zeros_like(grad3), decimal=4)
np.testing.assert_almost_equal(losses[0], losses[1], decimal=4)
np.testing.assert_almost_equal(losses[0], losses[2], decimal=4)
class TestModelWithLRScheduler(unittest.TestCase): class TestModelWithLRScheduler(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册