From 7eeaae169548566bb051eeb5e9d7c200a40e2276 Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Thu, 19 Oct 2017 17:05:09 -0700 Subject: [PATCH] deconv --- paddle/operators/deconv2d_op.h | 15 +-- .../v2/framework/tests/test_deconv_op.py | 101 ++++++++++++++++++ 2 files changed, 109 insertions(+), 7 deletions(-) create mode 100644 python/paddle/v2/framework/tests/test_deconv_op.py diff --git a/paddle/operators/deconv2d_op.h b/paddle/operators/deconv2d_op.h index 9036801a6..71254c952 100644 --- a/paddle/operators/deconv2d_op.h +++ b/paddle/operators/deconv2d_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include "glog/logging.h" #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" #include "paddle/operators/math/im2col.h" @@ -117,8 +118,7 @@ class GemmDeconv2DKernel : public framework::OpKernel { // of shape (C * K_H * K_W, H * W) math::matmul(context.device_context(), filter, true, input_batch, false, T(1.0), &col_matrix, T(0.0)); - - col2im(context.device_context(), output_batch, col_matrix, strides[0], + col2im(context.device_context(), output_batch, col, strides[0], strides[1], 0, 0); } } @@ -203,8 +203,8 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel { input_grad->Slice(i, i + 1).Resize(input_matrix_shape); // im2col: dy from (C, O_H, O_W) -> (C * K_H * K_W, H * W) - im2col(context.device_context(), output_grad_batch, col_matrix, - strides[0], strides[1], paddings[0], paddings[1]); + im2col(context.device_context(), output_grad_batch, col, strides[0], + strides[1], paddings[0], paddings[1]); // gemm: dx = filter * dy // (M, C * K_H * K_W) * (C * K_H * K_W, H * W) -> (M, C, H) @@ -234,13 +234,14 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel { Tensor in_batch = input->Slice(i, i + 1).Resize(input_matrix_shape); // im2col: (C * H * W, K_H * K_W) - im2col(context.device_context(), output_grad_batch, col_matrix_f, - strides[0], strides[1], paddings[0], paddings[1]); + im2col(context.device_context(), output_grad_batch, col, strides[0], + strides[1], paddings[0], paddings[1]); // gemm: d_filter = x * y_grad^T // (M, C * H * W) * (K_H * K_W, C * H * W) -> (M, C, H) math::matmul(context.device_context(), in_batch, false, - col_matrix, true, T(1.0), &filter_grad_, T(1.0)); + col_matrix_f, true, T(1.0), &filter_grad_, + T(1.0)); } } } diff --git a/python/paddle/v2/framework/tests/test_deconv_op.py b/python/paddle/v2/framework/tests/test_deconv_op.py new file mode 100644 index 000000000..c3baea804 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_deconv_op.py @@ -0,0 +1,101 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def deconv2d_forward_naive(input_, filter_, deconv_param): + # [2, 3, 5, 5] + in_n, in_c, in_h, in_w = input_.shape + # [3, 6, 3, 3] + f_c, out_c, f_h, f_w = filter_.shape + assert in_c == f_c + + stride, pad = deconv_param['stride'], deconv_param['pad'] + out_h = (in_h - 1) * stride[0] + f_h + out_w = (in_w - 1) * stride[1] + f_w + + out = np.zeros((in_n, out_c, out_h, out_w)) + + for n in range(in_n): + for i in range(in_h): + for j in range(in_w): + input_masked = input_[n, :, i, j] # (c) + input_masked = np.reshape(input_masked, (in_c, 1, 1)) + input_masked = np.tile(input_masked, (1, f_h, f_w)) + + for k in range(out_c): + tmp_out = np.sum(input_masked * filter_[:, k, :, :], axis=0) + i1, i2 = i * stride[0], i * stride[0] + f_h + j1, j2 = j * stride[0], j * stride[0] + f_w + out[n, k, i1:i2, j1:j2] += tmp_out + + return out + + +class TestDeconv2dOp(OpTest): + def setUp(self): + # init as deconv + self.init_op_type() + + # [2, 3, 5, 5] -> kernel [3, 6, 3, 3] -> output [2, 6, 7, 7] + self.init_test_case() + + deconv2d_param = {'stride': self.stride, 'pad': self.pad} + input_ = np.random.random(self.input_size).astype("float32") + filter_ = np.random.random(self.filter_size).astype("float32") + output = deconv2d_forward_naive(input_, filter_, deconv2d_param) + # print 'deconv output py', output, output.shape + + self.inputs = {'Input': input_, 'Filter': filter_} + self.attrs = { + 'strides': self.stride, + 'paddings': self.pad, + # 'dilations': self.dilations + } + self.outputs = {'Output': output} + + def test_check_output(self): + print 'check output here' + self.check_output() + + def test_check_grad(self): + self.check_grad( + set(['Input', 'Filter']), 'Output', max_relative_error=0.05) + + def test_check_grad_no_filter(self): + self.check_grad( + ['Input'], + 'Output', + max_relative_error=0.05, + no_grad_set=set(['Filter'])) + + def test_check_grad_no_input(self): + self.check_grad( + ['Filter'], + 'Output', + max_relative_error=0.05, + no_grad_set=set(['Input'])) + + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 1] + self.dilations = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 6, 3, 3] + + def init_op_type(self): + self.op_type = "deconv2d" + + +""" +class TestCudnn(TestConv2dOp): + def init_group(self): + self.groups = 1 + + def init_op_type(self): + self.op_type = "conv_cudnn" +""" + +if __name__ == '__main__': + unittest.main() -- GitLab