提交 0f31ed71 编写于 作者: L lyuwenyu 提交者: jzhang533

add grad unittest

上级 9a73283b
......@@ -728,18 +728,25 @@ class TestModelFunction(unittest.TestCase):
inputs = [InputSpec([None, dim], 'float32', 'x')]
labels = [InputSpec([None, 1], 'int64', 'label')]
model = Model(net, inputs, labels)
model.prepare(optim, loss=CrossEntropyLoss(reduction="sum"))
loss1, = model.train_batch([data], [label], update=False)
loss2, = model.train_batch([data], [label], update=True)
np.testing.assert_almost_equal(loss1, loss2, decimal=4)
for amp_cfg in [None, 'O1']:
model = Model(net, inputs, labels)
model.prepare(
optim, loss=CrossEntropyLoss(reduction="sum"), amp_configs='O1')
loss1, = model.train_batch([data], [label], update=False)
loss2, = model.train_batch([data], [label], update=True)
np.testing.assert_almost_equal(loss1, loss2, decimal=4)
optim,
loss=CrossEntropyLoss(reduction="sum"),
amp_configs=amp_cfg)
losses, grads = [], []
for stat in [False, False, True]:
loss, = model.train_batch([data], [label], update=stat)
losses.append(loss)
grads.append([p.grad.numpy() for p in net.parameters()])
for grad1, grad2, grad3 in zip(*grads):
np.testing.assert_almost_equal(grad1 * 2, grad2, decimal=4)
np.testing.assert_almost_equal(
grad3, np.zeros_like(grad3), decimal=4)
np.testing.assert_almost_equal(losses[0], losses[1], decimal=4)
np.testing.assert_almost_equal(losses[0], losses[2], decimal=4)
class TestModelWithLRScheduler(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册