提交 dd926498 编写于 作者: Y Yibing Liu

adapt to the new test framework

上级 31cbb343
import unittest
import numpy as np
from gradient_checker import GradientChecker, Operator
from op_test_util import OpTestMeta
from op_test import OpTest
class TestReshapeOp(unittest.TestCase):
__metaclass__ = OpTestMeta
class TestReshapeOp(OpTest):
def setUp(self):
self.type = "reshape"
self.inputs = {'X': np.random.random((37, 51)).astype("float32"), }
self.attrs = {'shape': [51 * 37]}
self.op_type = "reshape"
self.inputs = {'X': np.random.random((10, 20)).astype("float32")}
self.attrs = {'shape': [10 * 20]}
self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])}
def test_check_output(self):
self.check_output()
class TestReshapeGradOp(GradientChecker):
def setUp(self):
self.op = Operator("reshape", X='X', Out='Out', shape=[5, 40])
self.inputs = {"X": np.random.random((10, 20)).astype("float32")}
def test_normal(self):
self.check_grad(self.op, self.inputs, ["X"], "Out")
def test_dev_compare(self):
self.compare_grad(self.op, self.inputs)
def test_check_grad(self):
self.check_grad(["X"], "Out")
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册