提交 c671189d 编写于 作者: H hedaoyuan

Fix test_conv2d_op.py.

上级 3705de6d
import unittest import unittest
import numpy as np import numpy as np
from gradient_checker import GradientChecker, create_op from op_test import OpTest
from op_test_util import OpTestMeta
from paddle.v2.framework.op import Operator
class TestConv2dOp(unittest.TestCase): class TestConv2dOp(OpTest):
__metaclass__ = OpTestMeta
def setUp(self): def setUp(self):
self.type = "conv2d" self.op_type = "conv2d"
batch_size = 2 batch_size = 2
input_channels = 3 input_channels = 3
input_height = 5 input_height = 5
...@@ -58,8 +54,11 @@ class TestConv2dOp(unittest.TestCase): ...@@ -58,8 +54,11 @@ class TestConv2dOp(unittest.TestCase):
self.outputs = {'Output': output} self.outputs = {'Output': output}
self.attrs = {'strides': [1, 1], 'paddings': [0, 0]} self.attrs = {'strides': [1, 1], 'paddings': [0, 0]}
def test_check_output(self):
self.check_output()
class TestConv2dGradOp(GradientChecker): class TestConv2dGradOp(OpTest):
def setUp(self): def setUp(self):
batch_size = 2 batch_size = 2
input_channels = 3 input_channels = 3
...@@ -79,21 +78,18 @@ class TestConv2dGradOp(GradientChecker): ...@@ -79,21 +78,18 @@ class TestConv2dGradOp(GradientChecker):
(output_channels, input_channels, filter_height, (output_channels, input_channels, filter_height,
filter_width)).astype("float32") filter_width)).astype("float32")
self.op_type = 'conv2d'
self.inputs = {'Input': input, 'Filter': filter} self.inputs = {'Input': input, 'Filter': filter}
self.op = Operator( output = np.ndarray(
"conv2d", (batch_size, output_channels, output_height, output_width))
Input='Input', self.outputs = {'Output': output}
Filter='Filter', self.attrs = {'strides': [1, 1], 'paddings': [0, 0]}
Output='Output',
strides=[1, 1],
paddings=[0, 0])
def test_compare_grad(self): #def test_compare_grad(self):
self.compare_grad(self.op, self.inputs) # self.compare_grad(self.op, self.inputs)
def test_check_grad(self): def test_check_grad(self):
self.check_grad(self.op, self.inputs, self.check_grad(set(['Input', 'Filter']), 'Output')
set(['Input', 'Filter']), 'Output')
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册