未验证 提交 023d8771 编写于 作者: W Weilong Wu 提交者: GitHub

Update ResNet test cases (#40953)

上级 5c5a2a83
...@@ -20,7 +20,7 @@ import unittest ...@@ -20,7 +20,7 @@ import unittest
from unittest import TestCase from unittest import TestCase
import numpy as np import numpy as np
import paddle.compat as cpt import paddle.compat as cpt
from paddle.fluid.framework import _test_eager_guard, _in_legacy_dygraph from paddle.fluid.framework import _test_eager_guard, _in_legacy_dygraph, _in_eager_without_dygraph_check
import paddle.fluid.core as core import paddle.fluid.core as core
...@@ -568,13 +568,32 @@ class TestRaiseNoDoubleGradOp(TestCase): ...@@ -568,13 +568,32 @@ class TestRaiseNoDoubleGradOp(TestCase):
self.assertRaises(RuntimeError, self.raise_no_grad_op) self.assertRaises(RuntimeError, self.raise_no_grad_op)
class TestDoubleGradResNetBase(TestCase): class TestDoubleGradResNet(TestCase):
def setUp(self):
paddle.seed(123)
paddle.framework.random._manual_program_seed(123)
self.data = np.random.rand(1, 3, 224, 224).astype(np.float32)
@dygraph_guard @dygraph_guard
def check_resnet(self): def test_resnet_resnet50(self):
data = np.random.rand(1, 3, 224, 224).astype(np.float32) with _test_eager_guard():
data = paddle.to_tensor(data) model = resnet50(pretrained=False)
egr_data = paddle.to_tensor(self.data)
egr_data.stop_gradient = False
egr_out = model(egr_data)
egr_preds = paddle.argmax(egr_out, axis=1)
egr_label_onehot = paddle.nn.functional.one_hot(
paddle.to_tensor(egr_preds), num_classes=egr_out.shape[1])
egr_target = paddle.sum(egr_out * egr_label_onehot, axis=1)
egr_g = paddle.grad(outputs=egr_target, inputs=egr_out)[0]
egr_g_numpy = egr_g.numpy()
self.assertEqual(list(egr_g_numpy.shape), list(egr_out.shape))
model = resnet50(pretrained=False)
data = paddle.to_tensor(self.data)
data.stop_gradient = False data.stop_gradient = False
out = self.model(data) out = model(data)
preds = paddle.argmax(out, axis=1) preds = paddle.argmax(out, axis=1)
label_onehot = paddle.nn.functional.one_hot( label_onehot = paddle.nn.functional.one_hot(
paddle.to_tensor(preds), num_classes=out.shape[1]) paddle.to_tensor(preds), num_classes=out.shape[1])
...@@ -584,21 +603,40 @@ class TestDoubleGradResNetBase(TestCase): ...@@ -584,21 +603,40 @@ class TestDoubleGradResNetBase(TestCase):
g_numpy = g.numpy() g_numpy = g.numpy()
self.assertEqual(list(g_numpy.shape), list(out.shape)) self.assertEqual(list(g_numpy.shape), list(out.shape))
self.assertTrue(np.array_equal(egr_out, out))
self.assertTrue(np.array_equal(egr_g_numpy, g_numpy))
class TestDoubleGradResNet50(TestDoubleGradResNetBase): @dygraph_guard
def setUp(self): def test_resnet_resnet101(self):
self.model = resnet50(pretrained=False) with _test_eager_guard():
model = resnet101(pretrained=False)
def test_main(self): egr_data = paddle.to_tensor(self.data)
self.check_resnet() egr_data.stop_gradient = False
egr_out = model(egr_data)
egr_preds = paddle.argmax(egr_out, axis=1)
egr_label_onehot = paddle.nn.functional.one_hot(
paddle.to_tensor(egr_preds), num_classes=egr_out.shape[1])
egr_target = paddle.sum(egr_out * egr_label_onehot, axis=1)
egr_g = paddle.grad(outputs=egr_target, inputs=egr_out)[0]
egr_g_numpy = egr_g.numpy()
self.assertEqual(list(egr_g_numpy.shape), list(egr_out.shape))
model = resnet101(pretrained=False)
data = paddle.to_tensor(self.data)
data.stop_gradient = False
out = model(data)
preds = paddle.argmax(out, axis=1)
label_onehot = paddle.nn.functional.one_hot(
paddle.to_tensor(preds), num_classes=out.shape[1])
target = paddle.sum(out * label_onehot, axis=1)
class TestDoubleGradResNet101(TestDoubleGradResNetBase): g = paddle.grad(outputs=target, inputs=out)[0]
def setUp(self): g_numpy = g.numpy()
self.model = resnet101(pretrained=False) self.assertEqual(list(g_numpy.shape), list(out.shape))
def test_main(self): self.assertTrue(np.array_equal(egr_out, out))
self.check_resnet() self.assertTrue(np.array_equal(egr_g_numpy, g_numpy))
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册