提交 b3ab3ce0 编写于 作者: Z zchen0211

deconv -> conv transpose

上级 64c5ecbe
......@@ -12,8 +12,7 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/deconv2d_op.h"
#include "paddle/operators/conv2d_op.h"
#include "paddle/operators/conv2dtranspose_op.h"
namespace paddle {
namespace operators {
......@@ -54,18 +53,18 @@ Conv2DTransposeOpMaker::Conv2DTransposeOpMaker(
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"Input",
"The input tensor of convolution transpose operator. "
"(Tensor) The input tensor of convolution transpose operator. "
"The format of input tensor is NCHW. Where N is batch size, C is the "
"number of input channels, H and W is the height and width of image.");
AddInput("Filter",
"The filter tensor of convolution transpose operator."
"(Tensor) The filter tensor of convolution transpose operator."
"The format of the filter tensor is CMHW, where C is the number of "
"output image channels, M is the number of input image channels, "
"H and W is height and width of filter. "
"We enforce groups number == 1 and padding == 0 in "
"convolution transpose Scenario.");
AddOutput("Output",
"The output tensor of convolution transpose operator."
"(Tensor) The output tensor of convolution transpose operator."
"The format of output tensor is also NCHW.");
AddAttr<std::vector<int>>("strides",
"strides of convolution transpose operator.")
......
......@@ -12,7 +12,7 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/deconv2d_op.h"
#include "paddle/operators/conv2dtranspose_op.h"
namespace ops = paddle::operators;
......
......@@ -14,7 +14,6 @@ limitations under the License. */
#pragma once
#include "glog/logging.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/im2col.h"
......@@ -62,7 +61,8 @@ class GemmConv2DTransposeKernel : public framework::OpKernel<T> {
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
// no paddings and groups allowed in deconv
// TODO(Zhuoyuan): Paddings can be added in future.
// groups will alway be disabled in conv2dtranspose.
const int batch_size = input->dims()[0];
const int m = input->dims()[1];
......@@ -91,7 +91,8 @@ class GemmConv2DTransposeKernel : public framework::OpKernel<T> {
// col_matrix shares the same piece of data with col,
// but will be reshaped into a two-dimensional matrix shape
// to call the matrix multiplication interface.
Tensor col_matrix = col;
Tensor col_matrix;
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
DDim output_shape = {c, o_h, o_w};
......@@ -100,7 +101,7 @@ class GemmConv2DTransposeKernel : public framework::OpKernel<T> {
DDim filter_matrix_shape = {m, c * k_h * k_w};
filter.Resize(filter_matrix_shape);
// deconvolution: gemm + col2im (similar to conv-backward on input)
// convolution transpose: gemm + col2im (similar to conv-backward on input)
output->mutable_data<T>(context.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*output);
......@@ -142,7 +143,7 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel<T> {
context.Output<Tensor>(framework::GradVarName("Filter"));
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
// Actually, no paddings and groups allowed in deconv.
// Actually, no paddings and groups allowed in conv transpose.
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
const int batch_size = input->dims()[0];
......@@ -180,11 +181,12 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel<T> {
DDim filter_matrix_shape = {m, c * k_h * k_w};
filter.Resize(filter_matrix_shape);
// deconvolution grad on input:
// convolution transpose grad on input:
// im2col + gemm (similar to conv-forward)
// input need to compute gradient
if (input_grad) {
Tensor col_matrix = col;
Tensor col_matrix;
col_matrix.ShareDataWith(col);
DDim col_matrix_shape = {c * k_h * k_w, h * w};
col_matrix.Resize(col_matrix_shape);
......@@ -216,7 +218,8 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel<T> {
// filter gradient required
if (filter_grad) {
Tensor col_matrix_f = col;
Tensor col_matrix_f;
col_matrix_f.ShareDataWith(col);
DDim col_matrix_shape_f = {c * h * w, k_h * k_w};
col_matrix_f.Resize(col_matrix_shape_f);
......
......@@ -3,14 +3,14 @@ import numpy as np
from op_test import OpTest
def deconv2d_forward_naive(input_, filter_, deconv_param):
def conv2dtranspose_forward_naive(input_, filter_, conv2dtranspose_param):
# [2, 3, 5, 5]
in_n, in_c, in_h, in_w = input_.shape
# [3, 6, 3, 3]
f_c, out_c, f_h, f_w = filter_.shape
assert in_c == f_c
stride, pad = deconv_param['stride'], deconv_param['pad']
stride, pad = conv2dtranspose_param['stride'], conv2dtranspose_param['pad']
out_h = (in_h - 1) * stride[0] + f_h
out_w = (in_w - 1) * stride[1] + f_w
......@@ -32,18 +32,19 @@ def deconv2d_forward_naive(input_, filter_, deconv_param):
return out
class TestDeconv2dOp(OpTest):
class TestConv2dTransposeOp(OpTest):
def setUp(self):
# init as deconv
# init as conv transpose
self.init_op_type()
# [2, 3, 5, 5] -> kernel [3, 6, 3, 3] -> output [2, 6, 7, 7]
self.init_test_case()
deconv2d_param = {'stride': self.stride, 'pad': self.pad}
conv2dtranspose_param = {'stride': self.stride, 'pad': self.pad}
input_ = np.random.random(self.input_size).astype("float32")
filter_ = np.random.random(self.filter_size).astype("float32")
output = deconv2d_forward_naive(input_, filter_, deconv2d_param)
output = conv2dtranspose_forward_naive(input_, filter_,
conv2dtranspose_param)
# print 'deconv output py', output, output.shape
self.inputs = {'Input': input_, 'Filter': filter_}
......@@ -85,7 +86,7 @@ class TestDeconv2dOp(OpTest):
self.filter_size = [f_c, 6, 3, 3]
def init_op_type(self):
self.op_type = "deconv2d"
self.op_type = "conv2dtranspose"
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册