diff --git a/paddle/operators/conv_op.cc b/paddle/operators/conv_op.cc index 5e264d730c4577805cebdfafb8340e43cfd85766..1250900d154c4c4236636615a58cb420b6e8803b 100644 --- a/paddle/operators/conv_op.cc +++ b/paddle/operators/conv_op.cc @@ -33,6 +33,8 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { int input_channels = in_dims[1]; int output_channels = filter_dims[0]; + PADDLE_ENFORCE(in_dims.size() == 4 || in_dims.size() == 5, + "Conv intput should be 4-D or 5-D tensor."); PADDLE_ENFORCE_EQ( in_dims.size(), filter_dims.size(), "Conv input dimension and filter dimension should be the same."); @@ -62,26 +64,30 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto, : OpProtoAndCheckerMaker(proto, op_checker) { AddInput( "Input", - "The input tensor of convolution operator. " + "(Tensor), the input tensor of convolution operator. " "The format of input tensor is NCHW. Where N is batch size, C is the " "number of channels, H and W is the height and width of image."); AddInput("Filter", - "The filter tensor of convolution operator." + "(Tensor), the filter tensor of convolution operator." "The format of the filter tensor is MCHW, where M is the number of " "output image channels, C is the number of input image channels, " "H and W is height and width of filter. " "If the groups attribute is greater than 1, C equal the number of " "input image channels divided by the groups."); AddOutput("Output", - "The output tensor of convolution operator." - "The format of output tensor is also NCHW."); - AddAttr>("strides", "strides of convolution operator.") + "(Tensor), the output tensor of convolution operator." + "The format of output tensor is also NCHW. Where N is batch size, " + "C is the " + "number of channels, H and W is the height and width of image."); + AddAttr>( + "strides", "(vector default:{1, 1}), strides of convolution operator.") .SetDefault({1, 1}); - AddAttr>("paddings", "paddings of convolution operator.") + AddAttr>( + "paddings", "(vector default:{0, 0}), paddings of convolution operator.") .SetDefault({0, 0}); AddAttr( "groups", - "group size of convolution operator. " + "(int, default:1), group size of convolution operator. " "Refer to grouped convolution in Alex Krizhevsky's paper: " "when group=2, the first half of the filters are only connected to the " "first half of the input channels, and the second half only connected " @@ -91,6 +97,21 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto, The convolution operation calculates the output based on the input, filter and strides, paddings, groups parameters. The size of each dimension of the parameters is checked in the infer-shape. +Input(Input, Filter) and output(Output) are in NCHW format. Where N is batch +size, C is the number of channels, H and W is the height and +width of feature. Parameters(ksize, strides, paddings) are two elements. +These two elements represent height and width, respectively. +The input(X) size and output(Out) size may be different. + +Example: + Input: + Input shape: (N, C_in, H_in, W_in) + Filter shape: (C_out, C_in, H_f, W_f) + Output: + Output shape: (N, C_out, H_out, W_out) + where + H_out = (H_in - filter_size[0] + 2 * paddings[0]) / strides[0] + 1; + W_out = (W_in - filter_size[1] + 2 * paddings[1]) / strides[1] + 1; )DOC"); } @@ -115,15 +136,15 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto, "The format of output tensor is also NCDHW."); AddAttr>( "strides", - "(vector, default {0,0,0}), the strides of convolution operator.") + "(vector, default:{0, 0, 0}), the strides of convolution operator.") .SetDefault({1, 1, 1}); AddAttr>( "paddings", - "(vector, default {0,0,0}), the paddings of convolution operator.") + "(vector, default:{0, 0, 0}), the paddings of convolution operator.") .SetDefault({0, 0, 0}); AddAttr( "groups", - "(int, default 1) the group size of convolution operator. " + "(int, default:1) the group size of convolution operator. " "Refer to grouped convolution in Alex Krizhevsky's paper: " "when group=2, the first half of the filters are only connected to the " "first half of the input channels, and the second half only connected " diff --git a/paddle/operators/conv_op.h b/paddle/operators/conv_op.h index 7e8f5d75bb6be75b2d9b64b5b723fe63024baa85..198e51e4ad4c45b6a7fc155a9eb1d1343e263a28 100644 --- a/paddle/operators/conv_op.h +++ b/paddle/operators/conv_op.h @@ -85,9 +85,7 @@ class GemmConv2DKernel : public framework::OpKernel { int output_height = output->dims()[2]; int output_width = output->dims()[3]; - paddle::operators::math::Im2ColFunctor< - paddle::operators::math::ColFormat::kCFO, Place, T> - im2col; + math::Im2ColFunctor im2col; // use col_shape in the im2col calculation framework::DDim col_shape = {input_channels / groups, filter_height, filter_width, output_height, output_width}; @@ -162,12 +160,8 @@ class GemmConvGrad2DKernel : public framework::OpKernel { int output_height = output_grad->dims()[2]; int output_width = output_grad->dims()[3]; - paddle::operators::math::Col2ImFunctor< - paddle::operators::math::ColFormat::kCFO, Place, T> - col2im; - paddle::operators::math::Im2ColFunctor< - paddle::operators::math::ColFormat::kCFO, Place, T> - im2col; + math::Col2ImFunctor col2im; + math::Im2ColFunctor im2col; // use col_shape in the im2col and col2im calculation framework::DDim col_shape = {input_channels / groups, filter_height, filter_width, output_height, output_width}; @@ -283,7 +277,7 @@ class GemmConv3DKernel : public framework::OpKernel { int output_height = output->dims()[3]; int output_width = output->dims()[4]; - paddle::operators::math::Vol2ColFunctor vol2col; + math::Vol2ColFunctor vol2col; // use col_shape in the vol2col calculation framework::DDim col_shape = {input_channels / groups, filter_depth, @@ -369,8 +363,8 @@ class GemmConvGrad3DKernel : public framework::OpKernel { int output_height = output_grad->dims()[3]; int output_width = output_grad->dims()[4]; - paddle::operators::math::Col2VolFunctor col2vol; - paddle::operators::math::Vol2ColFunctor vol2col; + math::Col2VolFunctor col2vol; + math::Vol2ColFunctor vol2col; // use col_shape in the vol2col and col2vol calculation framework::DDim col_shape = {input_channels / groups, filter_depth, diff --git a/python/paddle/v2/framework/tests/test_conv2d_op.py b/python/paddle/v2/framework/tests/test_conv2d_op.py index f58b96463cf78103b2acb3d80652ef0aa988ad49..6bd4bad8e2db53907a0e1e442215ca8912d2c300 100644 --- a/python/paddle/v2/framework/tests/test_conv2d_op.py +++ b/python/paddle/v2/framework/tests/test_conv2d_op.py @@ -103,6 +103,9 @@ class TestWithGroup(TestConv2dOp): self.op_type = "conv2d" +#----------------Conv2dCudnn---------------- + + class TestCudnn(TestConv2dOp): def init_group(self): self.groups = 1