提交 17248153 编写于 作者: C chengduoZH

fix code format and doc

上级 9ee8a0d0
......@@ -33,6 +33,8 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
int input_channels = in_dims[1];
int output_channels = filter_dims[0];
PADDLE_ENFORCE(in_dims.size() == 4 || in_dims.size() == 5,
"Conv intput should be 4-D or 5-D tensor.");
PADDLE_ENFORCE_EQ(
in_dims.size(), filter_dims.size(),
"Conv input dimension and filter dimension should be the same.");
......@@ -62,26 +64,30 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto,
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"Input",
"The input tensor of convolution operator. "
"(Tensor), the input tensor of convolution operator. "
"The format of input tensor is NCHW. Where N is batch size, C is the "
"number of channels, H and W is the height and width of image.");
AddInput("Filter",
"The filter tensor of convolution operator."
"(Tensor), the filter tensor of convolution operator."
"The format of the filter tensor is MCHW, where M is the number of "
"output image channels, C is the number of input image channels, "
"H and W is height and width of filter. "
"If the groups attribute is greater than 1, C equal the number of "
"input image channels divided by the groups.");
AddOutput("Output",
"The output tensor of convolution operator."
"The format of output tensor is also NCHW.");
AddAttr<std::vector<int>>("strides", "strides of convolution operator.")
"(Tensor), the output tensor of convolution operator."
"The format of output tensor is also NCHW. Where N is batch size, "
"C is the "
"number of channels, H and W is the height and width of image.");
AddAttr<std::vector<int>>(
"strides", "(vector default:{1, 1}), strides of convolution operator.")
.SetDefault({1, 1});
AddAttr<std::vector<int>>("paddings", "paddings of convolution operator.")
AddAttr<std::vector<int>>(
"paddings", "(vector default:{0, 0}), paddings of convolution operator.")
.SetDefault({0, 0});
AddAttr<int>(
"groups",
"group size of convolution operator. "
"(int, default:1), group size of convolution operator. "
"Refer to grouped convolution in Alex Krizhevsky's paper: "
"when group=2, the first half of the filters are only connected to the "
"first half of the input channels, and the second half only connected "
......@@ -91,6 +97,21 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto,
The convolution operation calculates the output based on the input, filter
and strides, paddings, groups parameters. The size of each dimension of the
parameters is checked in the infer-shape.
Input(Input, Filter) and output(Output) are in NCHW format. Where N is batch
size, C is the number of channels, H and W is the height and
width of feature. Parameters(ksize, strides, paddings) are two elements.
These two elements represent height and width, respectively.
The input(X) size and output(Out) size may be different.
Example:
Input:
Input shape: (N, C_in, H_in, W_in)
Filter shape: (C_out, C_in, H_f, W_f)
Output:
Output shape: (N, C_out, H_out, W_out)
where
H_out = (H_in - filter_size[0] + 2 * paddings[0]) / strides[0] + 1;
W_out = (W_in - filter_size[1] + 2 * paddings[1]) / strides[1] + 1;
)DOC");
}
......@@ -115,15 +136,15 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto,
"The format of output tensor is also NCDHW.");
AddAttr<std::vector<int>>(
"strides",
"(vector, default {0,0,0}), the strides of convolution operator.")
"(vector, default:{0, 0, 0}), the strides of convolution operator.")
.SetDefault({1, 1, 1});
AddAttr<std::vector<int>>(
"paddings",
"(vector, default {0,0,0}), the paddings of convolution operator.")
"(vector, default:{0, 0, 0}), the paddings of convolution operator.")
.SetDefault({0, 0, 0});
AddAttr<int>(
"groups",
"(int, default 1) the group size of convolution operator. "
"(int, default:1) the group size of convolution operator. "
"Refer to grouped convolution in Alex Krizhevsky's paper: "
"when group=2, the first half of the filters are only connected to the "
"first half of the input channels, and the second half only connected "
......
......@@ -85,9 +85,7 @@ class GemmConv2DKernel : public framework::OpKernel<T> {
int output_height = output->dims()[2];
int output_width = output->dims()[3];
paddle::operators::math::Im2ColFunctor<
paddle::operators::math::ColFormat::kCFO, Place, T>
im2col;
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
// use col_shape in the im2col calculation
framework::DDim col_shape = {input_channels / groups, filter_height,
filter_width, output_height, output_width};
......@@ -162,12 +160,8 @@ class GemmConvGrad2DKernel : public framework::OpKernel<T> {
int output_height = output_grad->dims()[2];
int output_width = output_grad->dims()[3];
paddle::operators::math::Col2ImFunctor<
paddle::operators::math::ColFormat::kCFO, Place, T>
col2im;
paddle::operators::math::Im2ColFunctor<
paddle::operators::math::ColFormat::kCFO, Place, T>
im2col;
math::Col2ImFunctor<math::ColFormat::kCFO, Place, T> col2im;
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
// use col_shape in the im2col and col2im calculation
framework::DDim col_shape = {input_channels / groups, filter_height,
filter_width, output_height, output_width};
......@@ -283,7 +277,7 @@ class GemmConv3DKernel : public framework::OpKernel<T> {
int output_height = output->dims()[3];
int output_width = output->dims()[4];
paddle::operators::math::Vol2ColFunctor<Place, T> vol2col;
math::Vol2ColFunctor<Place, T> vol2col;
// use col_shape in the vol2col calculation
framework::DDim col_shape = {input_channels / groups,
filter_depth,
......@@ -369,8 +363,8 @@ class GemmConvGrad3DKernel : public framework::OpKernel<T> {
int output_height = output_grad->dims()[3];
int output_width = output_grad->dims()[4];
paddle::operators::math::Col2VolFunctor<Place, T> col2vol;
paddle::operators::math::Vol2ColFunctor<Place, T> vol2col;
math::Col2VolFunctor<Place, T> col2vol;
math::Vol2ColFunctor<Place, T> vol2col;
// use col_shape in the vol2col and col2vol calculation
framework::DDim col_shape = {input_channels / groups,
filter_depth,
......
......@@ -103,6 +103,9 @@ class TestWithGroup(TestConv2dOp):
self.op_type = "conv2d"
#----------------Conv2dCudnn----------------
class TestCudnn(TestConv2dOp):
def init_group(self):
self.groups = 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册