未验证 提交 f605f167 编写于 作者: C chengduo 提交者: GitHub

Merge pull request #6279 from chengduoZH/feature/add_dilation_for_conv_trans

Add dilation for conv_trans_op
......@@ -21,8 +21,6 @@ class CudnnConv2DTransposeOpMaker : public Conv2DTransposeOpMaker {
public:
CudnnConv2DTransposeOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: Conv2DTransposeOpMaker(proto, op_checker) {
AddAttr<std::vector<int>>("dilations", "dilations of convolution operator.")
.SetDefault({1, 1});
AddAttr<int>("workspace_size_MB",
"workspace size for cudnn, in MB, "
"workspace is a section of GPU memory which will be "
......@@ -37,8 +35,6 @@ class CudnnConv3DTransposeOpMaker : public Conv3DTransposeOpMaker {
public:
CudnnConv3DTransposeOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: Conv3DTransposeOpMaker(proto, op_checker) {
AddAttr<std::vector<int>>("dilations", "dilations of convolution operator.")
.SetDefault({1, 1, 1});
AddAttr<int>("workspace_size_MB",
"workspace size for cudnn, in MB, "
"workspace is a section of GPU memory which will be "
......
......@@ -29,6 +29,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
auto filter_dims = ctx->GetInputDim("Filter");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
std::vector<int> dilations = ctx->Attrs().Get<std::vector<int>>("dilations");
PADDLE_ENFORCE(in_dims.size() == 4 || in_dims.size() == 5,
"ConvTransposeOp intput should be 4-D or 5-D tensor.");
......@@ -41,14 +42,18 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_EQ(paddings.size(), strides.size(),
"ConvTransposeOp paddings dimension and strides "
"dimension should be the same.");
PADDLE_ENFORCE_EQ(paddings.size(), dilations.size(),
"ConvTransposeOp paddings dimension and dilations "
"dimension should be the same.");
PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[0],
"In ConvTransposeOp, The input channel should be the same "
"as the number of filters.");
std::vector<int64_t> output_shape({in_dims[0], filter_dims[1]});
for (size_t i = 0; i < strides.size(); ++i) {
auto filter_extent = dilations[i] * (filter_dims[i + 2] - 1) + 1;
output_shape.push_back((in_dims[i + 2] - 1) * strides[i] - 2 * paddings[i] +
filter_dims[i + 2]);
filter_extent);
}
ctx->SetOutputDim("Output", framework::make_ddim(output_shape));
}
......@@ -73,6 +78,12 @@ Conv2DTransposeOpMaker::Conv2DTransposeOpMaker(OpProto* proto,
AddOutput("Output",
"(Tensor) The output tensor of convolution transpose operator. "
"The format of output tensor is also NCHW.");
AddAttr<std::vector<int>>("dilations",
"(vector<int> default:{1, 1}), the "
"dilations(h_dilation, w_dilation) of convolution "
"transpose operator.")
.SetDefault({1, 1});
AddAttr<std::vector<int>>(
"strides",
"(vector<int> default:{1, 1}), the strides(h_stride, w_stride) of "
......@@ -87,7 +98,7 @@ Conv2DTransposeOpMaker::Conv2DTransposeOpMaker(OpProto* proto,
Convolution2D Transpose Operator.
The convolution transpose operation calculates the output based on the input, filter
and strides, paddings, groups parameters. The size of each dimension of the
and dilations, strides, paddings, groups parameters. The size of each dimension of the
parameters is checked in the infer-shape.
Input(Input) and output(Output) are in NCHW format. Where N is batchsize, C is the
number of channels, H is the height of the feature, and W is the width of the feature.
......@@ -136,6 +147,13 @@ Conv3DTransposeOpMaker::Conv3DTransposeOpMaker(OpProto* proto,
"Where N is batch size, C is "
"the number of channels, D is the depth of the feature, H is the "
"height of the feature, and W is the width of the feature.");
AddAttr<std::vector<int>>(
"dilations",
"(vector<int> default:{1, 1, 1}), the "
"dilations(d_dilation,h_dilation, w_dilation) of convolution "
"transpose operator.")
.SetDefault({1, 1, 1});
AddAttr<std::vector<int>>("strides",
"(vector<int> default:{1, 1, 1}), the "
"strides{d_stride, h_stride, w_stride} of "
......@@ -149,7 +167,7 @@ Conv3DTransposeOpMaker::Conv3DTransposeOpMaker(OpProto* proto,
Convolution3D Transpose Operator.
The convolution transpose operation calculates the output based on the input, filter
and strides, paddings, groups parameters. The size of each dimension of the
and dilations, strides, paddings, groups parameters. The size of each dimension of the
parameters is checked in the infer-shape.
Input(Input) and output(Output) are in NCDHW format. Where N is batch size, C is the
number of channels, D is the depth of the feature, H is the height of the feature,
......
......@@ -61,6 +61,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
// groups will alway be disabled in conv2dtranspose.
const int batch_size = static_cast<int>(input->dims()[0]);
......@@ -113,7 +114,6 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
math::Col2ImFunctor<math::ColFormat::kCFO, DeviceContext, T> col2im;
math::Col2VolFunctor<DeviceContext, T> col2vol;
std::vector<int> dilations({1, 1, 1});
// convolution transpose: gemm + col2im or col2vol (similar to conv-backward
// on input)
......@@ -165,6 +165,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
const int batch_size = static_cast<int>(input->dims()[0]);
......@@ -219,7 +220,6 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
math::Im2ColFunctor<math::ColFormat::kCFO, DeviceContext, T> im2col;
math::Vol2ColFunctor<DeviceContext, T> vol2col;
std::vector<int> dilations({1, 1, 1});
if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace());
......
......@@ -704,6 +704,7 @@ def conv2d_transpose(input,
filter_size=None,
padding=None,
stride=None,
dilation=None,
param_attr=None):
"""
The transpose of conv2d layer.
......@@ -727,6 +728,9 @@ def conv2d_transpose(input,
stride(int|tuple): The stride size. If stride is a tuple, it must
contain two integers, (stride_H, stride_W). Otherwise, the
stride_H = stride_W = stride.
dilation(int|tuple): The dilation size. If dilation is a tuple, it must
contain two integers, (dilation_H, dilation_W). Otherwise, the
dilation_H = dilation_W = dilation.
param_attr: Parameter Attribute.
main_program(Program): the main program
startup_program(Program): the startup program
......@@ -747,10 +751,15 @@ def conv2d_transpose(input,
op_attr['paddings'] = padding
if isinstance(stride, int):
op_attr['strides'] = stride
op_attr['strides'] = [stride, stride]
elif stride is not None:
op_attr['strides'] = stride
if isinstance(dilation, int):
op_attr['dilations'] = [dilation, dilation]
elif dilation is not None:
op_attr['dilations'] = dilation
if filter_size is None:
if output_size is None:
raise ValueError("output_size must be set when filter_size is None")
......@@ -759,14 +768,17 @@ def conv2d_transpose(input,
padding = op_attr.get('paddings', [0, 0])
stride = op_attr.get('strides', [1, 1])
dilation = op_attr.get('dilations', [1, 1])
h_in = input.shape[2]
w_in = input.shape[3]
filter_size_h = output_size[0] - \
(h_in - 1) * stride[0] + 2 * padding[0]
filter_size_w = output_size[1] - \
(w_in - 1) * stride[1] + 2 * padding[1]
filter_size_h = (output_size[0] - (h_in - 1) * stride[0] + 2 *
padding[0] - 1) / dilation[0] + 1
filter_size_w = (output_size[1] - (w_in - 1) * stride[1] + 2 *
padding[1] - 1) / dilation[1] + 1
filter_size = [filter_size_h, filter_size_w]
elif isinstance(filter_size, int):
filter_size = [filter_size, filter_size]
......
......@@ -3,14 +3,17 @@ import numpy as np
from op_test import OpTest
def conv2dtranspose_forward_naive(input_, filter_, conv2dtranspose_param):
def conv2dtranspose_forward_naive(input_, filter_, attrs):
in_n, in_c, in_h, in_w = input_.shape
f_c, out_c, f_h, f_w = filter_.shape
assert in_c == f_c
stride, pad = conv2dtranspose_param['stride'], conv2dtranspose_param['pad']
out_h = (in_h - 1) * stride[0] + f_h
out_w = (in_w - 1) * stride[1] + f_w
stride, pad, dilations = attrs['strides'], attrs['paddings'], attrs[
'dilations']
d_bolck_h = dilations[0] * (f_h - 1) + 1
d_bolck_w = dilations[1] * (f_w - 1) + 1
out_h = (in_h - 1) * stride[0] + d_bolck_h
out_w = (in_w - 1) * stride[1] + d_bolck_w
out = np.zeros((in_n, out_c, out_h, out_w))
......@@ -23,9 +26,9 @@ def conv2dtranspose_forward_naive(input_, filter_, conv2dtranspose_param):
for k in range(out_c):
tmp_out = np.sum(input_masked * filter_[:, k, :, :], axis=0)
i1, i2 = i * stride[0], i * stride[0] + f_h
j1, j2 = j * stride[0], j * stride[0] + f_w
out[n, k, i1:i2, j1:j2] += tmp_out
i1, i2 = i * stride[0], i * stride[0] + d_bolck_h
j1, j2 = j * stride[0], j * stride[0] + d_bolck_h
out[n, k, i1:i2:dilations[0], j1:j2:dilations[1]] += tmp_out
out = out[:, :, pad[0]:out_h - pad[0], pad[1]:out_w - pad[1]]
return out
......@@ -37,11 +40,8 @@ class TestConv2dTransposeOp(OpTest):
self.init_op_type()
self.init_test_case()
conv2dtranspose_param = {'stride': self.stride, 'pad': self.pad}
input_ = np.random.random(self.input_size).astype("float32")
filter_ = np.random.random(self.filter_size).astype("float32")
output = conv2dtranspose_forward_naive(
input_, filter_, conv2dtranspose_param).astype('float32')
self.inputs = {'Input': input_, 'Filter': filter_}
self.attrs = {
......@@ -49,6 +49,10 @@ class TestConv2dTransposeOp(OpTest):
'paddings': self.pad,
'dilations': self.dilations
}
output = conv2dtranspose_forward_naive(input_, filter_,
self.attrs).astype('float32')
self.outputs = {'Output': output}
def test_check_output(self):
......@@ -104,11 +108,60 @@ class TestWithStride(TestConv2dTransposeOp):
self.filter_size = [f_c, 6, 3, 3]
class TestWithDilation(TestConv2dTransposeOp):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.dilations = [2, 2]
self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3]
# ------------ test_cudnn ------------
class TestCudnn(TestConv2dTransposeOp):
def init_op_type(self):
self.op_type = "conv2d_transpose_cudnn"
class TestCudnnWithPad(TestWithPad):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.dilations = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3]
def init_op_type(self):
self.op_type = "conv2d_transpose_cudnn"
class TestCudnnWithStride(TestWithStride):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [2, 2]
self.dilations = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3]
def init_op_type(self):
self.op_type = "conv2d_transpose_cudnn"
# #cudnn v5 does not support dilation conv.
# class TestCudnnWithDilation(TestWithDilation):
# def init_test_case(self):
# self.pad = [1, 1]
# self.stride = [2, 2]
# self.dilations = [2, 2]
# self.input_size = [2, 3, 5, 5] # NCHW
# f_c = self.input_size[1]
# self.filter_size = [f_c, 6, 3, 3]
#
# def init_op_type(self):
# self.op_type = "conv2d_transpose_cudnn"
if __name__ == '__main__':
unittest.main()
......@@ -3,15 +3,20 @@ import numpy as np
from op_test import OpTest
def conv3dtranspose_forward_naive(input_, filter_, conv3dtranspose_param):
def conv3dtranspose_forward_naive(input_, filter_, attrs):
in_n, in_c, in_d, in_h, in_w = input_.shape
f_c, out_c, f_d, f_h, f_w = filter_.shape
assert in_c == f_c
stride, pad = conv3dtranspose_param['stride'], conv3dtranspose_param['pad']
out_d = (in_d - 1) * stride[0] + f_d
out_h = (in_h - 1) * stride[1] + f_h
out_w = (in_w - 1) * stride[2] + f_w
stride, pad, dilations = attrs['strides'], attrs['paddings'], attrs[
'dilations']
d_bolck_d = dilations[0] * (f_d - 1) + 1
d_bolck_h = dilations[1] * (f_h - 1) + 1
d_bolck_w = dilations[2] * (f_w - 1) + 1
out_d = (in_d - 1) * stride[0] + d_bolck_d
out_h = (in_h - 1) * stride[1] + d_bolck_h
out_w = (in_w - 1) * stride[2] + d_bolck_w
out = np.zeros((in_n, out_c, out_d, out_h, out_w))
for n in range(in_n):
......@@ -25,10 +30,11 @@ def conv3dtranspose_forward_naive(input_, filter_, conv3dtranspose_param):
for k in range(out_c):
tmp_out = np.sum(input_masked * filter_[:, k, :, :, :],
axis=0)
d1, d2 = d * stride[0], d * stride[0] + f_d
i1, i2 = i * stride[1], i * stride[1] + f_h
j1, j2 = j * stride[2], j * stride[2] + f_w
out[n, k, d1:d2, i1:i2, j1:j2] += tmp_out
d1, d2 = d * stride[0], d * stride[0] + d_bolck_d
i1, i2 = i * stride[1], i * stride[1] + d_bolck_h
j1, j2 = j * stride[2], j * stride[2] + d_bolck_w
out[n, k, d1:d2:dilations[0], i1:i2:dilations[1], j1:j2:
dilations[2]] += tmp_out
out = out[:, :, pad[0]:out_d - pad[0], pad[1]:out_h - pad[1], pad[2]:out_w -
pad[2]]
......@@ -41,18 +47,19 @@ class TestConv3dTransposeOp(OpTest):
self.init_op_type()
self.init_test_case()
conv3dtranspose_param = {'stride': self.stride, 'pad': self.pad}
input_ = np.random.random(self.input_size).astype("float32")
filter_ = np.random.random(self.filter_size).astype("float32")
output = conv3dtranspose_forward_naive(
input_, filter_, conv3dtranspose_param).astype("float32")
self.inputs = {'Input': input_, 'Filter': filter_}
self.attrs = {
'strides': self.stride,
'paddings': self.pad,
# 'dilations': self.dilations
'dilations': self.dilations
}
output = conv3dtranspose_forward_naive(input_, filter_,
self.attrs).astype("float32")
self.outputs = {'Output': output}
def test_check_output(self):
......@@ -108,11 +115,60 @@ class TestWithStride(TestConv3dTransposeOp):
self.filter_size = [f_c, 6, 3, 3, 3]
class TestWithDilation(TestConv3dTransposeOp):
def init_test_case(self):
self.pad = [1, 1, 1]
self.stride = [1, 1, 1]
self.dilations = [2, 2, 2]
self.input_size = [2, 3, 5, 5, 5] # NCDHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3, 3]
# ------------ test_cudnn ------------
class TestCudnn(TestConv3dTransposeOp):
def init_op_type(self):
self.op_type = "conv3d_transpose_cudnn"
class TestCudnnWithPad(TestWithPad):
def init_test_case(self):
self.pad = [1, 1, 1]
self.stride = [1, 1, 1]
self.dilations = [1, 1, 1]
self.input_size = [2, 3, 5, 5, 5] # NCDHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3, 3]
def init_op_type(self):
self.op_type = "conv3d_transpose_cudnn"
class TestCudnnWithStride(TestWithStride):
def init_test_case(self):
self.pad = [1, 1, 1]
self.stride = [2, 2, 2]
self.dilations = [1, 1, 1]
self.input_size = [2, 3, 5, 5, 5] # NCDHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3, 3]
def init_op_type(self):
self.op_type = "conv3d_transpose_cudnn"
# #cudnn v5 does not support dilation conv.
# class TestCudnnWithDilation(TestWithDilation):
# def init_test_case(self):
# self.pad = [1, 1, 1]
# self.stride = [2, 2, 2]
# self.dilations = [2, 2, 2]
# self.input_size = [2, 3, 5, 5, 5] # NCDHW
# f_c = self.input_size[1]
# self.filter_size = [f_c, 6, 3, 3, 3]
#
# def init_op_type(self):
# self.op_type = "conv3d_transpose_cudnn"
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册