提交 dd926498 编写于 作者: Y Yibing Liu

adapt to the new test framework

上级 31cbb343
import unittest import unittest
import numpy as np import numpy as np
from gradient_checker import GradientChecker, Operator from op_test import OpTest
from op_test_util import OpTestMeta
class TestReshapeOp(unittest.TestCase): class TestReshapeOp(OpTest):
__metaclass__ = OpTestMeta
def setUp(self): def setUp(self):
self.type = "reshape" self.op_type = "reshape"
self.inputs = {'X': np.random.random((37, 51)).astype("float32"), } self.inputs = {'X': np.random.random((10, 20)).astype("float32")}
self.attrs = {'shape': [51 * 37]} self.attrs = {'shape': [10 * 20]}
self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])} self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])}
def test_check_output(self):
self.check_output()
class TestReshapeGradOp(GradientChecker): def test_check_grad(self):
def setUp(self): self.check_grad(["X"], "Out")
self.op = Operator("reshape", X='X', Out='Out', shape=[5, 40])
self.inputs = {"X": np.random.random((10, 20)).astype("float32")}
def test_normal(self):
self.check_grad(self.op, self.inputs, ["X"], "Out")
def test_dev_compare(self):
self.compare_grad(self.op, self.inputs)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册