diff --git a/paddle/operators/conv_op.cc b/paddle/operators/conv_op.cc index a848b9b49cd2fe8ccf2f2dfeb4e99c26f4462f4d..e1a11a38b3e56c3b91b016b7bdbf266052174160 100644 --- a/paddle/operators/conv_op.cc +++ b/paddle/operators/conv_op.cc @@ -13,6 +13,7 @@ limitations under the License. */ #include "paddle/operators/conv_op.h" +#include namespace paddle { namespace operators { @@ -53,7 +54,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { "The number of output channels should be divided by groups."); std::vector output_shape({in_dims[0], filter_dims[0]}); - for (size_t i = 0; i < paddings.size(); ++i) { + for (size_t i = 0; i < strides.size(); ++i) { PADDLE_ENFORCE(in_dims[i + 2] + 2 * paddings[i] - (dilations[i] * (filter_dims[i + 2] - 1) + 1) > 0, @@ -61,8 +62,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { "dilations, the output size is less than 0, please check " "again."); output_shape.push_back(OutputSize(in_dims[i + 2], filter_dims[i + 2], - dilations[i], paddings[i], paddings[i], - strides[i])); + dilations[i], paddings[i], strides[i])); } ctx->SetOutputDim("Output", framework::make_ddim(output_shape)); } @@ -86,9 +86,15 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto, AddOutput("Output", "(Tensor) The output tensor of convolution operator. " "The format of output tensor is also NCHW."); - AddAttr>("strides", "strides of convolution operator.") + AddAttr>("strides", + "(vector default:{1, 1}), the " + "strides(h_stride, w_stride) of " + "convolution operator.") .SetDefault({1, 1}); - AddAttr>("paddings", "paddings of convolution operator.") + AddAttr>("paddings", + "(vector default:{0, 0}), the " + "paddings(h_pad, w_pad) of " + "convolution operator.") .SetDefault({0, 0}); AddAttr( "groups", @@ -99,9 +105,10 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto, "is only connected to the second half of the input channels.") .SetDefault(1); AddAttr>("dilations", - "(vector default:{1, 1}), the dilations of " + "(vector default:{1, 1}), the " + "dilations(h_dilation, w_dilation) of " "convolution operator.") - .SetDefault(std::vector{1, 1}); + .SetDefault({1, 1}); AddComment(R"DOC( Convolution Operator. @@ -147,13 +154,15 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto, AddOutput("Output", "(Tensor) The output tensor of convolution operator." "The format of output tensor is also NCDHW."); - AddAttr>( - "strides", - "(vector, default:{0, 0, 0}), the strides of convolution operator.") + AddAttr>("strides", + "(vector, default:{1, 1, 1}), the " + "strides(d_stride, h_stride, w_stride) of " + "convolution operator.") .SetDefault({1, 1, 1}); - AddAttr>( - "paddings", - "(vector, default:{0, 0, 0}), the paddings of convolution operator.") + AddAttr>("paddings", + "(vector, default:{0, 0, 0}), the " + "paddings(d_pad, h_pad, w_pad) of convolution " + "operator.") .SetDefault({0, 0, 0}); AddAttr( "groups", @@ -164,10 +173,11 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto, "is only connected to the second half of the input channels.") .SetDefault(1); AddAttr>("dilations", - "(vector default:{1, 1, 1}), the dilations of " + "(vector default:{1, 1, 1}), the " + "dilations(d_dilation, h_dilation, w_dilation) of " "convolution operator. Currently, conv3d doesn't " "support dilation.") - .SetDefault(std::vector{1, 1, 1}); + .SetDefault({1, 1, 1}); AddComment(R"DOC( Convolution3D Operator. diff --git a/paddle/operators/conv_op.h b/paddle/operators/conv_op.h index af2c8fb163eca06522b2be38c81a51312c2d7281..fac5f1d0e25fe205f89fc7eeb9fadfd8431517d5 100644 --- a/paddle/operators/conv_op.h +++ b/paddle/operators/conv_op.h @@ -28,24 +28,22 @@ using Tensor = framework::Tensor; // Base convolution operator definations for other conv // like operators to reuse the implementation. inline int OutputSize(int input_size, int filter_size, int dilation, - int padding_up, int padding_down, int stride) { - int output_size = (input_size + padding_up + padding_down - - (dilation * (filter_size - 1) + 1)) / - stride + - 1; + int padding, int stride) { + const int dkernel = dilation * (filter_size - 1) + 1; + const int output_size = (input_size + 2 * padding - dkernel) / stride + 1; return output_size; } -inline bool NotExpand(std::vector& filter_dim, - std::vector& strides, std::vector& paddings, - std::vector& dilations) { +inline bool IsExpand(std::vector& filter_dim, + std::vector& strides, std::vector& paddings, + std::vector& dilations) { bool filter_1 = true, strides_1 = true, padding_0 = true, dilation_1 = true; for (size_t j = 0; j < strides.size(); ++j) { - filter_1 &= (static_cast(filter_dim[j]) == 1); - strides_1 &= (strides[j] == 1); - padding_0 &= (paddings[j] == 0); - dilation_1 &= (dilations[j] == 1); + filter_1 = filter_1 && (static_cast(filter_dim[j]) == 1); + strides_1 = strides_1 && (strides[j] == 1); + padding_0 = padding_0 && (paddings[j] == 0); + dilation_1 = dilation_1 && (dilations[j] == 1); } - return filter_1 && strides_1 && padding_0 && dilation_1; + return !(filter_1 && strides_1 && padding_0 && dilation_1); } // Define Op classes in .h file so that other conv @@ -65,14 +63,12 @@ class Conv3DOpMaker : public framework::OpProtoAndCheckerMaker { class ConvOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override; }; class ConvOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override; }; @@ -88,9 +84,9 @@ class GemmConvKernel : public framework::OpKernel { Tensor* output = context.Output("Output"); output->mutable_data(context.GetPlace()); + int groups = context.Attr("groups"); std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); - int groups = context.Attr("groups"); std::vector dilations = context.Attr>("dilations"); const int batch_size = static_cast(input->dims()[0]); @@ -122,13 +118,13 @@ class GemmConvKernel : public framework::OpKernel { framework::DDim col_matrix_shape = framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1); - bool not_expand = NotExpand(filter_shape_vec, strides, paddings, dilations); + bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations); Tensor col; // col_matrix shares the same piece of data with col, // but will be reshaped into a two-dimensional matrix shape // to call the matrix multiplication interface. Tensor col_matrix; - if (!not_expand) { + if (is_expand) { col.mutable_data(col_shape, context.GetPlace()); col_matrix.ShareDataWith(col); col_matrix.Resize(col_matrix_shape); @@ -149,51 +145,37 @@ class GemmConvKernel : public framework::OpKernel { int in_step = static_cast(input->dims()[1]) / groups; int out_step = static_cast(output->dims()[1]) / groups; - if (!not_expand) { - for (int i = 0; i < batch_size; i++) { - Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); - Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); - for (int g = 0; g < groups; g++) { - Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); + math::Vol2ColFunctor vol2col; + math::Im2ColFunctor im2col; - if (filter_shape_vec.size() == 2) { - // im2col - math::Im2ColFunctor im2col; - im2col(context.device_context(), in_slice, col, dilations[0], - dilations[1], strides[0], strides[1], paddings[0], - paddings[0], paddings[1], paddings[1]); - } else if (filter_shape_vec.size() == 3) { - // vol2col - math::Vol2ColFunctor vol2col; - vol2col(context.device_context(), in_slice, col, dilations[0], - dilations[1], dilations[2], strides[0], strides[1], - strides[2], paddings[0], paddings[1], paddings[2]); - } + for (int i = 0; i < batch_size; i++) { + Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); - // gemm - Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(context.device_context(), filter_slice, false, - col_matrix, false, T(1.0), &out_slice, T(0.0)); - } - } - } else { - for (int i = 0; i < batch_size; i++) { - Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); - Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); - for (int g = 0; g < groups; g++) { - Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); + for (int g = 0; g < groups; g++) { + Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); + if (!is_expand) { col.ShareDataWith(in_slice); col_matrix.ShareDataWith(col); col_matrix.Resize(col_matrix_shape); - - // gemm - Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(context.device_context(), filter_slice, false, - col_matrix, false, T(1.0), &out_slice, T(0.0)); + } else if (filter_shape_vec.size() == 2) { + // im2col + im2col(context.device_context(), in_slice, dilations, strides, + std::vector{paddings[0], paddings[1], paddings[0], + paddings[1]}, + &col); + } else if (filter_shape_vec.size() == 3) { + // vol2col + vol2col(context.device_context(), in_slice, dilations, strides, + paddings, &col); } + + // gemm + Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); + math::matmul(context.device_context(), filter_slice, false, + col_matrix, false, T(1.0), &out_slice, T(0.0)); } } } @@ -217,9 +199,9 @@ class GemmConvGradKernel : public framework::OpKernel { if (!input_grad && !filter_grad) return; + int groups = context.Attr("groups"); std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); - int groups = context.Attr("groups"); std::vector dilations = context.Attr>("dilations"); const int batch_size = static_cast(input->dims()[0]); @@ -270,13 +252,13 @@ class GemmConvGradKernel : public framework::OpKernel { int in_step = static_cast(input->dims()[1]) / groups; int out_step = static_cast(output_grad->dims()[1]) / groups; - bool not_expand = NotExpand(filter_shape_vec, strides, paddings, dilations); + bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations); Tensor col; // col_matrix shares the same piece of data with col, // but will be reshaped into a two-dimensional matrix shape // to call the matrix multiplication interface. Tensor col_matrix; - if (!not_expand) { + if (is_expand) { col.mutable_data(col_shape, context.GetPlace()); col_matrix.ShareDataWith(col); col_matrix.Resize(col_matrix_shape); @@ -288,61 +270,38 @@ class GemmConvGradKernel : public framework::OpKernel { input_grad->mutable_data(context.GetPlace()); set_zero(context.device_context(), input_grad, static_cast(0)); - if (!not_expand) { - for (int i = 0; i < batch_size; i++) { - Tensor out_grad_batch = - output_grad->Slice(i, i + 1).Resize(output_matrix_shape); - Tensor in_grad_batch = - input_grad->Slice(i, i + 1).Resize(input_shape); - for (int g = 0; g < groups; g++) { - // gemm - Tensor out_grad_slice = - out_grad_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor filter_slice = - filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(context.device_context(), filter_slice, true, - out_grad_slice, false, T(1.0), &col_matrix, - T(0.0)); - Tensor in_grad_slice = - in_grad_batch.Slice(g * in_step, (g + 1) * in_step); - - if (filter_shape_vec.size() == 2) { - math::Col2ImFunctor col2im; - col2im(context.device_context(), in_grad_slice, col, dilations[0], - dilations[1], strides[0], strides[1], paddings[0], - paddings[0], paddings[1], paddings[1]); - - } else if (filter_shape_vec.size() == 3) { - math::Col2VolFunctor col2vol; - col2vol(context.device_context(), in_grad_slice, col, - dilations[0], dilations[1], dilations[2], strides[0], - strides[1], strides[2], paddings[0], paddings[1], - paddings[2]); - } - } - } - } else { - for (int i = 0; i < batch_size; i++) { - Tensor out_grad_batch = - output_grad->Slice(i, i + 1).Resize(output_matrix_shape); - Tensor in_grad_batch = - input_grad->Slice(i, i + 1).Resize(input_shape); - for (int g = 0; g < groups; g++) { - // gemm - Tensor out_grad_slice = - out_grad_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor filter_slice = - filter.Slice(g * out_step, (g + 1) * out_step); - - Tensor in_grad_slice = - in_grad_batch.Slice(g * in_step, (g + 1) * in_step); + math::Col2VolFunctor col2vol; + math::Col2ImFunctor col2im; + for (int i = 0; i < batch_size; i++) { + Tensor out_grad_batch = + output_grad->Slice(i, i + 1).Resize(output_matrix_shape); + Tensor in_grad_batch = input_grad->Slice(i, i + 1).Resize(input_shape); + for (int g = 0; g < groups; g++) { + // gemm + Tensor out_grad_slice = + out_grad_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); + + Tensor in_grad_slice = + in_grad_batch.Slice(g * in_step, (g + 1) * in_step); + + if (!is_expand) { col_matrix.ShareDataWith(in_grad_slice); col_matrix.Resize(col_matrix_shape); - - math::matmul(context.device_context(), filter_slice, true, - out_grad_slice, false, T(1.0), &col_matrix, - T(0.0)); + } + math::matmul(context.device_context(), filter_slice, true, + out_grad_slice, false, T(1.0), &col_matrix, + T(0.0)); + + if (is_expand && filter_shape_vec.size() == 2) { + col2im(context.device_context(), col, dilations, strides, + std::vector{paddings[0], paddings[1], paddings[0], + paddings[1]}, + &in_grad_slice); + } else if (is_expand && filter_shape_vec.size() == 3) { + col2vol(context.device_context(), col, dilations, strides, paddings, + &in_grad_slice); } } } @@ -353,60 +312,38 @@ class GemmConvGradKernel : public framework::OpKernel { Tensor filter_grad_ = *filter_grad; filter_grad_.Resize(filter_matrix_shape); set_zero(context.device_context(), filter_grad, static_cast(0)); + math::Im2ColFunctor im2col; + math::Vol2ColFunctor vol2col; + for (int i = 0; i < batch_size; i++) { + Tensor out_grad_batch = + output_grad->Slice(i, i + 1).Resize(output_matrix_shape); + Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + for (int g = 0; g < groups; g++) { + // im2col + Tensor out_grad_slice = + out_grad_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); - if (!not_expand) { - for (int i = 0; i < batch_size; i++) { - Tensor out_grad_batch = - output_grad->Slice(i, i + 1).Resize(output_matrix_shape); - Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); - for (int g = 0; g < groups; g++) { - // im2col - Tensor out_grad_slice = - out_grad_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); - - if (filter_shape_vec.size() == 2) { - math::Im2ColFunctor im2col; - im2col(context.device_context(), in_slice, col, dilations[0], - dilations[1], strides[0], strides[1], paddings[0], - paddings[0], paddings[1], paddings[1]); - } else if (filter_shape_vec.size() == 3) { - math::Vol2ColFunctor vol2col; - vol2col(context.device_context(), in_slice, col, dilations[0], - dilations[1], dilations[2], strides[0], strides[1], - strides[2], paddings[0], paddings[1], paddings[2]); - } - - // gemm - Tensor filter_grad_slice = - filter_grad_.Slice(g * out_step, (g + 1) * out_step); - math::matmul(context.device_context(), out_grad_slice, - false, col_matrix, true, T(1.0), - &filter_grad_slice, T(1.0)); - } - } - } else { - for (int i = 0; i < batch_size; i++) { - Tensor out_grad_batch = - output_grad->Slice(i, i + 1).Resize(output_matrix_shape); - Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); - for (int g = 0; g < groups; g++) { - // im2col - Tensor out_grad_slice = - out_grad_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); - + if (!is_expand) { col.ShareDataWith(in_slice); col_matrix.ShareDataWith(col); col_matrix.Resize(col_matrix_shape); - - // gemm - Tensor filter_grad_slice = - filter_grad_.Slice(g * out_step, (g + 1) * out_step); - math::matmul(context.device_context(), out_grad_slice, - false, col_matrix, true, T(1.0), - &filter_grad_slice, T(1.0)); + } else if (filter_shape_vec.size() == 2) { + im2col(context.device_context(), in_slice, dilations, strides, + std::vector{paddings[0], paddings[1], paddings[0], + paddings[1]}, + &col); + } else if (filter_shape_vec.size() == 3) { + vol2col(context.device_context(), in_slice, dilations, strides, + paddings, &col); } + + // gemm + Tensor filter_grad_slice = + filter_grad_.Slice(g * out_step, (g + 1) * out_step); + math::matmul(context.device_context(), out_grad_slice, + false, col_matrix, true, T(1.0), + &filter_grad_slice, T(1.0)); } } } diff --git a/paddle/operators/conv_transpose_op.cc b/paddle/operators/conv_transpose_op.cc index 50081779a5ea3c81884007d4e4b7832dc4ea2bdd..6f47a6d6a0f089eb44d1d62ff4d631cb86c73603 100644 --- a/paddle/operators/conv_transpose_op.cc +++ b/paddle/operators/conv_transpose_op.cc @@ -51,7 +51,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const { "as the number of filters."); std::vector output_shape({in_dims[0], filter_dims[1]}); - for (size_t i = 0; i < paddings.size(); ++i) { + for (size_t i = 0; i < strides.size(); ++i) { output_shape.push_back((in_dims[i + 2] - 1) * strides[i] + filter_dims[i + 2]); } @@ -77,13 +77,14 @@ Conv2DTransposeOpMaker::Conv2DTransposeOpMaker( AddOutput("Output", "(Tensor) The output tensor of convolution transpose operator. " "The format of output tensor is also NCHW."); - AddAttr>( - "strides", - "(vector defalut:{1, 1}), strides of convolution transpose operator.") + AddAttr>("strides", + "(vector defalut:{1, 1}), strides of " + "convolution transpose operator.") .SetDefault({1, 1}); AddAttr>( "paddings", - "(vector defalut:{0, 0}), paddings of convolution transpose operator.") + "(vector defalut:{0, 0}), paddings(h_pad, w_pad) of convolution " + "transpose operator.") .SetDefault({0, 0}); AddComment(R"DOC( Convolution2D Transpose Operator. @@ -132,13 +133,13 @@ Conv3DTransposeOpMaker::Conv3DTransposeOpMaker( "Where N is batch size, C is " "the number of channels, D is the depth of the feature, H is the " "height of the feature, and W is the width of the feature."); - AddAttr>( - "strides", - "(vector defalut:{1, 1, 1}), strides of convolution transpose operator.") + AddAttr>("strides", + "(vector defalut:{1, 1, 1}), strides of " + "convolution transpose operator.") .SetDefault({1, 1, 1}); - AddAttr>( - "paddings", - "(vector defalut:{0, 0, 0}), paddings of convolution transpose operator.") + AddAttr>("paddings", + "(vector defalut:{0, 0, 0}), paddings(d_pad, " + "h_pad, w_pad) of convolution transpose operator.") .SetDefault({0, 0, 0}); AddComment(R"DOC( Convolution3D Transpose Operator. diff --git a/paddle/operators/conv_transpose_op.h b/paddle/operators/conv_transpose_op.h index 18ca6b20e0349e9e3da788e8b20f0a2fed41a172..4b2bd60437da8f58054d8cdd5e6ba1fdac05f0d5 100644 --- a/paddle/operators/conv_transpose_op.h +++ b/paddle/operators/conv_transpose_op.h @@ -43,16 +43,12 @@ class Conv3DTransposeOpMaker : public framework::OpProtoAndCheckerMaker { class ConvTransposeOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - protected: void InferShape(framework::InferShapeContext* ctx) const override; }; class ConvTransposeOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - protected: void InferShape(framework::InferShapeContext* ctx) const override; }; @@ -66,13 +62,11 @@ class GemmConvTransposeKernel : public framework::OpKernel { Tensor* output = context.Output("Output"); std::vector strides = context.Attr>("strides"); + // Actually, no paddings and groups allowed in conv transpose. + std::vector paddings = context.Attr>("paddings"); // TODO(Zhuoyuan): Paddings can be added in future. // groups will alway be disabled in conv2dtranspose. - int dilaiton_d = 1; - int dilation_h = 1; - int dilation_w = 1; - const int batch_size = static_cast(input->dims()[0]); // input_shape_vec: {h, w} or {d, h, w} @@ -124,6 +118,10 @@ class GemmConvTransposeKernel : public framework::OpKernel { math::SetConstant set_zero; set_zero(context.device_context(), output, static_cast(0)); + math::Col2ImFunctor col2im; + math::Col2VolFunctor col2vol; + std::vector dilations({1, 1, 1}); + // convolution transpose: gemm + col2im or col2vol (similar to conv-backward // on input) for (int i = 0; i < batch_size; i++) { @@ -142,17 +140,16 @@ class GemmConvTransposeKernel : public framework::OpKernel { if (filter_shape_vec.size() == 2) { // col2im: col_matrix -> dy // from (c * k_h * k_w, h * w) to (c, o_h, o_w) - math::Col2ImFunctor col2im; - - col2im(context.device_context(), output_batch, col, dilation_h, - dilation_w, strides[0], strides[1], 0, 0, 0, 0); + col2im(context.device_context(), col, + std::vector{dilations[0], dilations[1]}, strides, + std::vector{paddings[0], paddings[1], paddings[0], + paddings[1]}, + &output_batch); } else if (filter_shape_vec.size() == 3) { // col2vol: col_matrix -> dy // from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w) - math::Col2VolFunctor col2vol; - col2vol(context.device_context(), output_batch, col, dilaiton_d, - dilation_h, dilation_w, strides[0], strides[1], strides[2], 0, - 0, 0); + col2vol(context.device_context(), col, dilations, strides, + std::vector{0, 0, 0}, &output_batch); } } } @@ -179,10 +176,6 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { // Actually, no paddings and groups allowed in conv transpose. std::vector paddings = context.Attr>("paddings"); - int dilaiton_d = 1; - int dilation_h = 1; - int dilation_w = 1; - const int batch_size = static_cast(input->dims()[0]); // input_shape_vec: {h, w} or {d, h, w} @@ -237,6 +230,10 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { Tensor filter_grad_; math::SetConstant set_zero; + math::Im2ColFunctor im2col; + math::Vol2ColFunctor vol2col; + std::vector dilations({1, 1, 1}); + if (input_grad) { input_grad->mutable_data(context.GetPlace()); set_zero(context.device_context(), input_grad, static_cast(0)); @@ -256,17 +253,16 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { if (filter_shape_vec.size() == 2) { // im2col: dy -> col matrix // from (c, o_h, o_w) to (c * k_h * k_w, h * w) - math::Im2ColFunctor im2col; - im2col(context.device_context(), output_grad_batch, col, dilation_h, - dilation_w, strides[0], strides[1], paddings[0], paddings[0], - paddings[1], paddings[1]); + im2col(context.device_context(), output_grad_batch, + std::vector{dilations[0], dilations[1]}, strides, + std::vector{paddings[0], paddings[1], paddings[0], + paddings[1]}, + &col); } else if (filter_shape_vec.size() == 3) { // vol2col: dy -> col_matrix // from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w) - math::Vol2ColFunctor vol2col; - vol2col(context.device_context(), output_grad_batch, col, dilaiton_d, - dilation_h, dilation_w, strides[0], strides[1], strides[2], - paddings[0], paddings[1], paddings[2]); + vol2col(context.device_context(), output_grad_batch, dilations, + strides, paddings, &col); } if (input_grad) { diff --git a/paddle/operators/math/context_project.h b/paddle/operators/math/context_project.h index c67d84528fdd397d0ef29bb85a335c77eb7cb65f..d9f952c387f2cfd9169ec5bbde54a525d6c438f4 100644 --- a/paddle/operators/math/context_project.h +++ b/paddle/operators/math/context_project.h @@ -95,8 +95,9 @@ class ContextProjectFunctor { math::Im2ColFunctor im2col_ocf; - int dilation_h = 1; - int dilation_w = 1; + std::vector dilation({1, 1}); + std::vector padding({up_pad, 0, down_pad, 0}); + std::vector stride({context_stride, 1}); int input_row_begin, input_row_end; int sequence_height, sequence_width; @@ -126,10 +127,7 @@ class ContextProjectFunctor { {1, input_row_end - input_row_begin, sequence_width}); // input_channels, input_height, input_width in_t.Resize(framework::make_ddim(input_shape)); - - im2col_ocf(context, in_t, out_t, dilation_h, dilation_w, - /*stride_height*/ context_stride, /*stride_width*/ 1, up_pad, - down_pad, 0, 0); + im2col_ocf(context, in_t, dilation, stride, padding, &out_t); out_t.Resize({sequence_height, context_length * sequence_width}); } } @@ -207,8 +205,9 @@ class ContextProjectGradFunctor { math::Col2ImFunctor col2im_ocf; - int dilation_h = 1; - int dilation_w = 1; + std::vector dilation({1, 1}); + std::vector padding({up_pad, 0, down_pad, 0}); + std::vector stride({context_stride, 1}); int input_row_begin, input_row_end; int sequence_height, sequence_width; @@ -240,9 +239,7 @@ class ContextProjectGradFunctor { sequence_width}); // input_channels, input_height, input_width in_t.Resize(framework::make_ddim(input_shape)); - col2im_ocf(context, in_t, out_t, dilation_h, dilation_w, - /*stride_height*/ context_stride, /*stride_width*/ 1, - up_pad, down_pad, 0, 0); + col2im_ocf(context, out_t, dilation, stride, padding, &in_t); out_t.Resize({sequence_height, context_length * sequence_width}); } } diff --git a/paddle/operators/math/im2col.cc b/paddle/operators/math/im2col.cc index 2af55fa71f86acf0c08c741ccb9f4dbfb38baf70..c10c44c52076c8ee56eee3a0d82c31df70a1c9c7 100644 --- a/paddle/operators/math/im2col.cc +++ b/paddle/operators/math/im2col.cc @@ -28,40 +28,39 @@ class Im2ColFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& im, framework::Tensor& col, - int dilation_h, int dilation_w, int stride_height, - int stride_width, int padding_up, int padding_down, - int padding_left, int padding_right) { + const framework::Tensor& im, const std::vector& dilation, + const std::vector& stride, + const std::vector& padding, framework::Tensor* col) { PADDLE_ENFORCE(im.dims().size() == 3); - PADDLE_ENFORCE(col.dims().size() == 5); + PADDLE_ENFORCE(col->dims().size() == 5); int im_channels = im.dims()[0]; int im_height = im.dims()[1]; int im_width = im.dims()[2]; - int filter_height = col.dims()[1]; - int filter_width = col.dims()[2]; - int col_height = col.dims()[3]; - int col_width = col.dims()[4]; + int filter_height = col->dims()[1]; + int filter_width = col->dims()[2]; + int col_height = col->dims()[3]; + int col_width = col->dims()[4]; - PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - - ((dilation_h * (filter_height - 1) + 1))) / - stride_height + + PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] - + ((dilation[0] * (filter_height - 1) + 1))) / + stride[0] + 1, col_height, "Output_height and padding(padding_up, padding_down) are " "inconsistent."); - PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - - ((dilation_w * (filter_width - 1) + 1))) / - stride_width + + PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] - + ((dilation[1] * (filter_width - 1) + 1))) / + stride[1] + 1, col_width, - "col_width and padding(padding_left, padding_right) are " + "Output_height and padding(padding_up, padding_down) are " "inconsistent."); int channels_col = im_channels * filter_height * filter_width; const T* im_data = im.data(); - T* col_data = col.data(); + T* col_data = col->data(); for (int c = 0; c < channels_col; ++c) { int w_offset = c % filter_width; @@ -69,10 +68,8 @@ class Im2ColFunctor class Col2ImFunctor { public: - void operator()(const platform::DeviceContext& context, framework::Tensor& im, - const framework::Tensor& col, int dilation_h, int dilation_w, - int stride_height, int stride_width, int padding_up, - int padding_down, int padding_left, int padding_right) { - PADDLE_ENFORCE(im.dims().size() == 3); + void operator()(const platform::DeviceContext& context, + const framework::Tensor& col, + const std::vector& dilation, + const std::vector& stride, + const std::vector& padding, framework::Tensor* im) { + PADDLE_ENFORCE(im->dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); - int im_channels = im.dims()[0]; - int im_height = im.dims()[1]; - int im_width = im.dims()[2]; + int im_channels = im->dims()[0]; + int im_height = im->dims()[1]; + int im_width = im->dims()[2]; int filter_height = col.dims()[1]; int filter_width = col.dims()[2]; int col_height = col.dims()[3]; int col_width = col.dims()[4]; - PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - - ((dilation_h * (filter_height - 1) + 1))) / - stride_height + + PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] - + ((dilation[0] * (filter_height - 1) + 1))) / + stride[0] + 1, col_height, "Output_height and padding(padding_up, padding_down) are " "inconsistent."); - PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - - ((dilation_w * (filter_width - 1) + 1))) / - stride_width + + PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] - + ((dilation[1] * (filter_width - 1) + 1))) / + stride[1] + 1, col_width, - "col_width and padding(padding_left, padding_right) are " + "Output_height and padding(padding_up, padding_down) are " "inconsistent."); int channels_col = im_channels * filter_height * filter_width; - T* im_data = im.data(); + T* im_data = im->data(); const T* col_data = col.data(); for (int c = 0; c < channels_col; ++c) { @@ -135,10 +133,8 @@ class Col2ImFunctor= 0 && (im_row_idx) < im_height && (im_col_idx) >= 0 && (im_col_idx) < im_width) { @@ -171,35 +167,32 @@ class Im2ColFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& im, framework::Tensor& col, - int dilation_h, int dilation_w, int stride_height, - int stride_width, int padding_up, int padding_down, - int padding_left, int padding_right) { + const framework::Tensor& im, const std::vector& dilation, + const std::vector& stride, + const std::vector& padding, framework::Tensor* col) { PADDLE_ENFORCE(im.dims().size() == 3); - PADDLE_ENFORCE(col.dims().size() == 5); + PADDLE_ENFORCE(col->dims().size() == 5); int im_channels = im.dims()[0]; int im_height = im.dims()[1]; int im_width = im.dims()[2]; - int filter_height = col.dims()[3]; - int filter_width = col.dims()[4]; - int col_height = col.dims()[0]; - int col_width = col.dims()[1]; + int filter_height = col->dims()[3]; + int filter_width = col->dims()[4]; + int col_height = col->dims()[0]; + int col_width = col->dims()[1]; - PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - filter_height) / - stride_height + - 1, - col_height, - "Output_height and padding(padding_up, padding_down) are " - "inconsistent."); - PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - filter_width) / - stride_width + - 1, - col_width, - "col_width and padding(padding_left, padding_right) are " - "inconsistent."); + PADDLE_ENFORCE_EQ( + (im_height + padding[0] + padding[2] - filter_height) / stride[0] + 1, + col_height, + "Output_height and padding(padding_up, padding_down) are " + "inconsistent."); + PADDLE_ENFORCE_EQ( + (im_width + padding[1] + padding[3] - filter_width) / stride[1] + 1, + col_width, + "col_width and padding(padding_left, padding_right) are " + "inconsistent."); const T* im_data = im.data(); - T* col_data = col.data(); + T* col_data = col->data(); for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) { for (int col_col_idx = 0; col_col_idx < col_width; ++col_col_idx) { @@ -209,9 +202,9 @@ class Im2ColFunctor class Col2ImFunctor { public: - void operator()(const platform::DeviceContext& context, framework::Tensor& im, - const framework::Tensor& col, int dilation_h, int dilation_w, - int stride_height, int stride_width, int padding_up, - int padding_down, int padding_left, int padding_right) { - PADDLE_ENFORCE(im.dims().size() == 3); + void operator()(const platform::DeviceContext& context, + const framework::Tensor& col, + const std::vector& dilation, + const std::vector& stride, + const std::vector& padding, framework::Tensor* im) { + PADDLE_ENFORCE(im->dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); - int im_channels = im.dims()[0]; - int im_height = im.dims()[1]; - int im_width = im.dims()[2]; + int im_channels = im->dims()[0]; + int im_height = im->dims()[1]; + int im_width = im->dims()[2]; int filter_height = col.dims()[3]; int filter_width = col.dims()[4]; int col_height = col.dims()[0]; int col_width = col.dims()[1]; - PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - filter_height) / - stride_height + - 1, - col_height, - "Output_height and padding(padding_up, padding_down) are " - "inconsistent."); - PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - filter_width) / - stride_width + - 1, - col_width, - "col_width and padding(padding_left, padding_right) are " - "inconsistent."); + PADDLE_ENFORCE_EQ( + (im_height + padding[0] + padding[2] - filter_height) / stride[0] + 1, + col_height, + "Output_height and padding(padding_up, padding_down) are " + "inconsistent."); + PADDLE_ENFORCE_EQ( + (im_width + padding[1] + padding[3] - filter_width) / stride[1] + 1, + col_width, + "col_width and padding(padding_left, padding_right) are " + "inconsistent."); - T* im_data = im.data(); + T* im_data = im->data(); const T* col_data = col.data(); for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) { @@ -282,9 +274,9 @@ class Col2ImFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& im, framework::Tensor& col, - int dilation_h, int dilation_w, int stride_height, - int stride_width, int padding_up, int padding_down, - int padding_left, int padding_right) { + const framework::Tensor& im, const std::vector& dilation, + const std::vector& stride, + const std::vector& padding, framework::Tensor* col) { PADDLE_ENFORCE(im.dims().size() == 3); - PADDLE_ENFORCE(col.dims().size() == 5); + PADDLE_ENFORCE(col->dims().size() == 5); int im_channels = im.dims()[0]; int im_height = im.dims()[1]; int im_width = im.dims()[2]; - int filter_height = col.dims()[1]; - int filter_width = col.dims()[2]; - int col_height = col.dims()[3]; - int col_width = col.dims()[4]; - - PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - - (dilation_h * (filter_height - 1) + 1)) / - stride_height + + int filter_height = col->dims()[1]; + int filter_width = col->dims()[2]; + int col_height = col->dims()[3]; + int col_width = col->dims()[4]; + + PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] - + (dilation[0] * (filter_height - 1) + 1)) / + stride[0] + 1, col_height, "Output_height and padding(padding_up, padding_down) are " "inconsistent."); - PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - - (dilation_w * (filter_width - 1) + 1)) / - stride_width + + PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] - + (dilation[1] * (filter_width - 1) + 1)) / + stride[1] + 1, col_width, "col_width and padding(padding_left, padding_right) are " @@ -100,9 +99,9 @@ class Im2ColFunctor<<(context) .stream()>>>( - im.data(), num_outputs, im_height, im_width, dilation_h, dilation_w, - filter_height, filter_width, stride_height, stride_width, padding_up, - padding_left, col_height, col_width, col.data()); + im.data(), num_outputs, im_height, im_width, dilation[0], + dilation[1], filter_height, filter_width, stride[0], stride[1], + padding[0], padding[1], col_height, col_width, col->data()); } }; @@ -163,31 +162,32 @@ template class Col2ImFunctor { public: - void operator()(const platform::DeviceContext& context, framework::Tensor& im, - const framework::Tensor& col, int dilation_h, int dilation_w, - int stride_height, int stride_width, int padding_up, - int padding_down, int padding_left, int padding_right) { - PADDLE_ENFORCE(im.dims().size() == 3); + void operator()(const platform::DeviceContext& context, + const framework::Tensor& col, + const std::vector& dilation, + const std::vector& stride, + const std::vector& padding, framework::Tensor* im) { + PADDLE_ENFORCE(im->dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); - int im_channels = im.dims()[0]; - int im_height = im.dims()[1]; - int im_width = im.dims()[2]; + int im_channels = im->dims()[0]; + int im_height = im->dims()[1]; + int im_width = im->dims()[2]; int filter_height = col.dims()[1]; int filter_width = col.dims()[2]; int col_height = col.dims()[3]; int col_width = col.dims()[4]; - PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - - (dilation_h * (filter_height - 1) + 1)) / - stride_height + + PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] - + (dilation[0] * (filter_height - 1) + 1)) / + stride[0] + 1, col_height, "Output_height and padding(padding_up, padding_down) are " "inconsistent."); - PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - - (dilation_w * (filter_width - 1) + 1)) / - stride_width + + PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] - + (dilation[1] * (filter_width - 1) + 1)) / + stride[1] + 1, col_width, "col_width and padding(padding_left, padding_right) are " @@ -206,9 +206,9 @@ class Col2ImFunctor<<(context) .stream()>>>( - num_kernels, col.data(), im_height, im_width, dilation_h, dilation_w, - filter_height, filter_width, stride_height, stride_width, padding_up, - padding_left, col_height, col_width, im.data()); + num_kernels, col.data(), im_height, im_width, dilation[0], + dilation[1], filter_height, filter_width, stride[0], stride[1], + padding[0], padding[2], col_height, col_width, im->data()); } }; @@ -222,11 +222,11 @@ template class Col2ImFunctor; template -__global__ void im2colOCF(const T* im_data, T* col_data, int im_channels, - int im_height, int im_width, int filter_height, - int filter_width, int stride_height, int stride_width, +__global__ void im2colOCF(const T* im_data, int im_channels, int im_height, + int im_width, int filter_height, int filter_width, + int stride_height, int stride_width, int padding_height, int padding_width, int col_height, - int col_width) { + int col_width, T* col_data) { int swid = blockIdx.x; int shid = blockIdx.y; for (int channelid = threadIdx.z; channelid < im_channels; @@ -263,30 +263,29 @@ class Im2ColFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& im, framework::Tensor& col, - int dilation_h, int dilation_w, int stride_height, - int stride_width, int padding_up, int padding_down, - int padding_left, int padding_right) { + const framework::Tensor& im, const std::vector& dilation, + const std::vector& stride, + const std::vector& padding, framework::Tensor* col) { PADDLE_ENFORCE(im.dims().size() == 3); - PADDLE_ENFORCE(col.dims().size() == 5); + PADDLE_ENFORCE(col->dims().size() == 5); int im_channels = im.dims()[0]; int im_height = im.dims()[1]; int im_width = im.dims()[2]; - int filter_height = col.dims()[3]; - int filter_width = col.dims()[4]; - int col_height = col.dims()[0]; - int col_width = col.dims()[1]; - - PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - - (dilation_h * (filter_height - 1) + 1)) / - stride_height + + int filter_height = col->dims()[3]; + int filter_width = col->dims()[4]; + int col_height = col->dims()[0]; + int col_width = col->dims()[1]; + + PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] - + (dilation[0] * (filter_height - 1) + 1)) / + stride[0] + 1, col_height, "Output_height and padding(padding_up, padding_down) are " "inconsistent."); - PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - - (dilation_w * (filter_width - 1) + 1)) / - stride_width + + PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] - + (dilation[1] * (filter_width - 1) + 1)) / + stride[1] + 1, col_width, "col_width and padding(padding_left, padding_right) are " @@ -314,18 +313,18 @@ class Im2ColFunctor<<(context) .stream()>>>( - im.data(), col.data(), im_channels, im_height, im_width, - filter_height, filter_width, stride_height, stride_width, padding_up, - padding_left, col_height, col_width); + im.data(), im_channels, im_height, im_width, filter_height, + filter_width, stride[0], stride[1], padding[0], padding[1], col_height, + col_width, col->data()); } }; template -__global__ void col2imOCF(T* im_data, const T* col_data, int im_channels, - int im_height, int im_width, int filter_height, - int filter_width, int stride_height, int stride_width, +__global__ void col2imOCF(const T* col_data, int im_channels, int im_height, + int im_width, int filter_height, int filter_width, + int stride_height, int stride_width, int padding_height, int padding_width, int col_height, - int col_width) { + int col_width, T* im_data) { int swid = blockIdx.x; int shid = blockIdx.y; for (int channelid = threadIdx.z; channelid < im_channels; @@ -361,30 +360,31 @@ template class Col2ImFunctor { public: - void operator()(const platform::DeviceContext& context, framework::Tensor& im, - const framework::Tensor& col, int dilation_h, int dilation_w, - int stride_height, int stride_width, int padding_up, - int padding_down, int padding_left, int padding_right) { - PADDLE_ENFORCE(im.dims().size() == 3); + void operator()(const platform::DeviceContext& context, + const framework::Tensor& col, + const std::vector& dilation, + const std::vector& stride, + const std::vector& padding, framework::Tensor* im) { + PADDLE_ENFORCE(im->dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); - int im_channels = im.dims()[0]; - int im_height = im.dims()[1]; - int im_width = im.dims()[2]; + int im_channels = im->dims()[0]; + int im_height = im->dims()[1]; + int im_width = im->dims()[2]; int filter_height = col.dims()[3]; int filter_width = col.dims()[4]; int col_height = col.dims()[0]; int col_width = col.dims()[1]; - PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - - (dilation_h * (filter_height - 1) + 1)) / - stride_height + + PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] - + (dilation[0] * (filter_height - 1) + 1)) / + stride[0] + 1, col_height, "Output_height and padding(padding_up, padding_down) are " "inconsistent."); - PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - - (dilation_w * (filter_width - 1) + 1)) / - stride_width + + PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] - + (dilation[1] * (filter_width - 1) + 1)) / + stride[1] + 1, col_width, "col_width and padding(padding_left, padding_right) are " @@ -412,9 +412,9 @@ class Col2ImFunctor<<(context) .stream()>>>( - im.data(), col.data(), im_channels, im_height, im_width, - filter_height, filter_width, stride_height, stride_width, padding_up, - padding_left, col_height, col_width); + col.data(), im_channels, im_height, im_width, filter_height, + filter_width, stride[0], stride[1], padding[0], padding[1], col_height, + col_width, im->data()); } }; diff --git a/paddle/operators/math/im2col.h b/paddle/operators/math/im2col.h index d1c9595a328d349809daa32e3a3c2bf36b87af3d..deb60051beef56437cf75f0fa2cef90bbc0a209a 100644 --- a/paddle/operators/math/im2col.h +++ b/paddle/operators/math/im2col.h @@ -35,6 +35,15 @@ enum class ColFormat { kCFO = 0, kOCF = 1 }; * \param colData Column data. * \param colShape The shape of colData. * + * \param dilations dilation data. + * \param 2-dimension [dilation_height, dilation_width]. + * + * \param strides stride data. + * \param 2-dimension [stride_height, stride_width]. + * + * \param paddings padding data. + * \param 4-dimension [up_pad, left_pad, down_pad, right_pad]. + * * If the template argument Format is kCFO, the shape of colData is: * [input_channels, filter_height, filter_width, output_height, output_width] * So, it is easy to reshape into a convolution matrix for convolution @@ -73,19 +82,19 @@ template class Im2ColFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& im, framework::Tensor& col, - int dilation_h, int dilation_w, int stride_height, - int stride_width, int padding_up, int padding_down, - int padding_left, int padding_right); + const framework::Tensor& im, const std::vector& dilation, + const std::vector& stride, + const std::vector& padding, framework::Tensor* col); }; template class Col2ImFunctor { public: - void operator()(const platform::DeviceContext& context, framework::Tensor& im, - const framework::Tensor& col, int dilation_h, int dilation_w, - int stride_height, int stride_width, int padding_up, - int padding_down, int padding_left, int padding_right); + void operator()(const platform::DeviceContext& context, + const framework::Tensor& col, + const std::vector& dilation, + const std::vector& stride, + const std::vector& padding, framework::Tensor* im); }; } // namespace math diff --git a/paddle/operators/math/im2col_test.cc b/paddle/operators/math/im2col_test.cc index 3385fe8721cb41dbb294763677dd54e9218b229a..10c28da72ba9d3b94bb59c5cf00e7f5a2f28fd06 100644 --- a/paddle/operators/math/im2col_test.cc +++ b/paddle/operators/math/im2col_test.cc @@ -45,12 +45,14 @@ void testIm2col() { int input_height = 2; int input_width = 3; int filter_size = 2; - int stride = 1; - int padding = 0; - int dilation_h = 1; - int dilation_w = 1; - int output_height = (input_height - filter_size + 2 * padding) / stride + 1; - int output_width = (input_width - filter_size + 2 * padding) / stride + 1; + std::vector stride({1, 1}); // stride_y, stride_x + std::vector padding( + {0, 0, 0, 0}); // up_pad, left_pad, down_pad, right_pad + std::vector dilation({1, 1}); // dilation_y, dilation_x + int output_height = + (input_height - filter_size + padding[0] + padding[1]) / stride[0] + 1; + int output_width = + (input_width - filter_size + padding[2] + padding[3]) / stride[1] + 1; float* input_ptr = input_tmp.mutable_data( {1, input_height, input_width}, paddle::platform::CPUPlace()); float arr[6] = {0, 1, 2, 3, 4, 5}; @@ -87,10 +89,8 @@ void testIm2col() { paddle::operators::math::ColFormat::kOCF, Place, float> im2col_ocf; - im2col(*context, input, output_cfo, dilation_h, dilation_w, stride, stride, - padding, padding, padding, padding); - im2col_ocf(*context, input, output_ocf, dilation_h, dilation_w, stride, - stride, padding, padding, padding, padding); + im2col(*context, input, dilation, stride, padding, &output_cfo); + im2col_ocf(*context, input, dilation, stride, padding, &output_ocf); float out_cfo_data[] = {0, 1, 1, 2, 3, 4, 4, 5}; float out_ocf_data[] = {0, 1, 3, 4, 1, 2, 4, 5}; @@ -133,8 +133,7 @@ void testIm2col() { input.CopyFrom(input_tmp, *place, *context); } - col2im(*context, input, output_cfo, dilation_h, dilation_w, stride, stride, - padding, padding, padding, padding); + col2im(*context, output_cfo, dilation, stride, padding, &input); float* in_ptr; if (paddle::platform::is_cpu_place(*place)) { @@ -155,8 +154,7 @@ void testIm2col() { input.CopyFrom(input_tmp, *place, *context); } - col2im_ocf(*context, input, output_ocf, dilation_h, dilation_w, stride, - stride, padding, padding, padding, padding); + col2im_ocf(*context, output_ocf, dilation, stride, padding, &input); if (paddle::platform::is_cpu_place(*place)) { in_ptr = input.data(); diff --git a/paddle/operators/math/vol2col.cc b/paddle/operators/math/vol2col.cc index bd509a94f3fb176402b3dead54c3c0c933f24307..99eb7fd46de42400a915d86706580d15b08a74a2 100644 --- a/paddle/operators/math/vol2col.cc +++ b/paddle/operators/math/vol2col.cc @@ -28,51 +28,51 @@ template class Vol2ColFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& vol, framework::Tensor& col, - int dilation_d, int dilation_h, int dilation_w, - int stride_depth, int stride_height, int stride_width, - int padding_depth, int padding_height, - int padding_width) const { + const framework::Tensor& vol, + const std::vector& dilations, + const std::vector& strides, + const std::vector& paddings, + framework::Tensor* col) const { PADDLE_ENFORCE(vol.dims().size() == 4); - PADDLE_ENFORCE(col.dims().size() == 7); + PADDLE_ENFORCE(col->dims().size() == 7); int input_channels = vol.dims()[0]; int input_depth = vol.dims()[1]; int input_height = vol.dims()[2]; int input_width = vol.dims()[3]; - int filter_depth = col.dims()[1]; - int filter_height = col.dims()[2]; - int filter_width = col.dims()[3]; - int output_depth = col.dims()[4]; - int output_height = col.dims()[5]; - int output_width = col.dims()[6]; + int filter_depth = col->dims()[1]; + int filter_height = col->dims()[2]; + int filter_width = col->dims()[3]; + int output_depth = col->dims()[4]; + int output_height = col->dims()[5]; + int output_width = col->dims()[6]; int channels_col = input_channels * filter_depth * filter_height * filter_width; - PADDLE_ENFORCE_EQ((input_depth + 2 * padding_depth - - ((dilation_d * (filter_depth - 1) + 1))) / - stride_depth + + PADDLE_ENFORCE_EQ((input_depth + 2 * paddings[0] - + ((dilations[0] * (filter_depth - 1) + 1))) / + strides[0] + 1, output_depth, "input_depth and output_depth are " - "Mismatching."); - PADDLE_ENFORCE_EQ((input_height + 2 * padding_height - - ((dilation_h * (filter_height - 1) + 1))) / - stride_height + + "mismatching."); + PADDLE_ENFORCE_EQ((input_height + 2 * paddings[1] - + ((dilations[1] * (filter_height - 1) + 1))) / + strides[1] + 1, output_height, "input_height and output_height are " - "Mismatching."); - PADDLE_ENFORCE_EQ((input_width + 2 * padding_width - - ((dilation_w * (filter_width - 1) + 1))) / - stride_width + + "mismatching."); + PADDLE_ENFORCE_EQ((input_width + 2 * paddings[2] - + ((dilations[2] * (filter_width - 1) + 1))) / + strides[2] + 1, output_width, "input_width and output_width are " - "Mismatching."); + "mismatching."); const T* vol_data = vol.data(); - T* col_data = col.data(); + T* col_data = col->data(); for (int c = 0; c < channels_col; ++c) { int w_offset = c % filter_width; @@ -80,13 +80,11 @@ class Vol2ColFunctor { int d_offset = (c / filter_width / filter_height) % filter_depth; int c_in = c / filter_width / filter_height / filter_depth; for (int d = 0; d < output_depth; ++d) { - int d_pad = d * stride_depth - padding_depth + d_offset * dilation_d; + int d_pad = d * strides[0] - paddings[0] + d_offset * dilations[0]; for (int h = 0; h < output_height; ++h) { - int h_pad = - h * stride_height - padding_height + h_offset * dilation_h; + int h_pad = h * strides[1] - paddings[1] + h_offset * dilations[1]; for (int w = 0; w < output_width; ++w) { - int w_pad = - w * stride_width - padding_width + w_offset * dilation_w; + int w_pad = w * strides[2] - paddings[2] + w_offset * dilations[2]; int col_idx = ((c * output_depth + d) * output_height + h) * output_width + w; @@ -116,18 +114,18 @@ template class Col2VolFunctor { public: void operator()(const platform::DeviceContext& context, - framework::Tensor& vol, const framework::Tensor& col, - int dilation_d, int dilation_h, int dilation_w, - int stride_depth, int stride_height, int stride_width, - int padding_depth, int padding_height, - int padding_width) const { - PADDLE_ENFORCE(vol.dims().size() == 4); + const framework::Tensor& col, + const std::vector& dilations, + const std::vector& strides, + const std::vector& paddings, + framework::Tensor* vol) const { + PADDLE_ENFORCE(vol->dims().size() == 4); PADDLE_ENFORCE(col.dims().size() == 7); - int input_channels = vol.dims()[0]; - int input_depth = vol.dims()[1]; - int input_height = vol.dims()[2]; - int input_width = vol.dims()[3]; + int input_channels = vol->dims()[0]; + int input_depth = vol->dims()[1]; + int input_height = vol->dims()[2]; + int input_width = vol->dims()[3]; int filter_depth = col.dims()[1]; int filter_height = col.dims()[2]; int filter_width = col.dims()[3]; @@ -137,28 +135,28 @@ class Col2VolFunctor { int channels_col = input_channels * filter_depth * filter_height * filter_width; - PADDLE_ENFORCE_EQ((input_depth + 2 * padding_depth - - ((dilation_d * (filter_depth - 1) + 1))) / - stride_depth + + PADDLE_ENFORCE_EQ((input_depth + 2 * paddings[0] - + ((dilations[0] * (filter_depth - 1) + 1))) / + strides[0] + 1, output_depth, "input_depth and output_depth are " - "Mismatching."); - PADDLE_ENFORCE_EQ((input_height + 2 * padding_height - - ((dilation_h * (filter_height - 1) + 1))) / - stride_height + + "mismatching."); + PADDLE_ENFORCE_EQ((input_height + 2 * paddings[1] - + ((dilations[1] * (filter_height - 1) + 1))) / + strides[1] + 1, output_height, "input_height and output_height are " - "Mismatching."); - PADDLE_ENFORCE_EQ((input_width + 2 * padding_width - - ((dilation_w * (filter_width - 1) + 1))) / - stride_width + + "mismatching."); + PADDLE_ENFORCE_EQ((input_width + 2 * paddings[2] - + ((dilations[2] * (filter_width - 1) + 1))) / + strides[2] + 1, output_width, "input_width and output_width are " - "Mismatching."); - T* vol_data = vol.data(); + "mismatching."); + T* vol_data = vol->data(); const T* col_data = col.data(); for (int c = 0; c < channels_col; ++c) { @@ -167,13 +165,11 @@ class Col2VolFunctor { int d_offset = (c / filter_width / filter_height) % filter_depth; int cIm = c / filter_width / filter_height / filter_depth; for (int d = 0; d < output_depth; ++d) { - int d_pad = d * stride_depth - padding_depth + d_offset * dilation_d; + int d_pad = d * strides[0] - paddings[0] + d_offset * dilations[0]; for (int h = 0; h < output_height; ++h) { - int h_pad = - h * stride_height - padding_height + h_offset * dilation_h; + int h_pad = h * strides[1] - paddings[1] + h_offset * dilations[1]; for (int w = 0; w < output_width; ++w) { - int w_pad = - w * stride_width - padding_width + w_offset * dilation_w; + int w_pad = w * strides[2] - paddings[2] + w_offset * dilations[2]; if (h_pad >= 0 && h_pad < input_height && w_pad >= 0 && w_pad < input_width && d_pad >= 0 && d_pad < input_depth) { diff --git a/paddle/operators/math/vol2col.cu b/paddle/operators/math/vol2col.cu index 080d3e5466704f3676e4eae4ac79a3025c6bb3ec..addae3caf89eea2fee784ecb87d3a496ae773ad6 100644 --- a/paddle/operators/math/vol2col.cu +++ b/paddle/operators/math/vol2col.cu @@ -71,42 +71,42 @@ template class Vol2ColFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& vol, framework::Tensor& col, - int dilation_d, int dilation_h, int dilation_w, - int stride_depth, int stride_height, int stride_width, - int padding_depth, int padding_height, - int padding_width) const { + const framework::Tensor& vol, + const std::vector& dilations, + const std::vector& strides, + const std::vector& paddings, + framework::Tensor* col) const { PADDLE_ENFORCE(vol.dims().size() == 4); - PADDLE_ENFORCE(col.dims().size() == 7); + PADDLE_ENFORCE(col->dims().size() == 7); int input_channels = vol.dims()[0]; int input_depth = vol.dims()[1]; int input_height = vol.dims()[2]; int input_width = vol.dims()[3]; - int filter_depth = col.dims()[1]; - int filter_height = col.dims()[2]; - int filter_width = col.dims()[3]; - int output_depth = col.dims()[4]; - int output_height = col.dims()[5]; - int output_width = col.dims()[6]; + int filter_depth = col->dims()[1]; + int filter_height = col->dims()[2]; + int filter_width = col->dims()[3]; + int output_depth = col->dims()[4]; + int output_height = col->dims()[5]; + int output_width = col->dims()[6]; - PADDLE_ENFORCE_EQ((input_depth + 2 * padding_depth - - ((dilation_d * (filter_depth - 1) + 1))) / - stride_depth + + PADDLE_ENFORCE_EQ((input_depth + 2 * paddings[0] - + ((dilations[0] * (filter_depth - 1) + 1))) / + strides[0] + 1, output_depth, "input_depth and output_depth are " "Mismatching."); - PADDLE_ENFORCE_EQ((input_height + 2 * padding_height - - ((dilation_h * (filter_height - 1) + 1))) / - stride_height + + PADDLE_ENFORCE_EQ((input_height + 2 * paddings[1] - + ((dilations[1] * (filter_height - 1) + 1))) / + strides[1] + 1, output_height, "input_height and output_height are " "Mismatching."); - PADDLE_ENFORCE_EQ((input_width + 2 * padding_width - - ((dilation_w * (filter_width - 1) + 1))) / - stride_width + + PADDLE_ENFORCE_EQ((input_width + 2 * paddings[2] - + ((dilations[2] * (filter_width - 1) + 1))) / + strides[2] + 1, output_width, "input_width and output_width are " @@ -121,10 +121,10 @@ class Vol2ColFunctor { reinterpret_cast(context) .stream()>>>( num_outputs, vol.data(), input_depth, input_height, input_width, - dilation_d, dilation_h, dilation_w, filter_depth, filter_height, - filter_width, stride_depth, stride_height, stride_width, padding_depth, - padding_height, padding_width, output_depth, output_height, - output_width, col.data()); + dilations[0], dilations[1], dilations[2], filter_depth, filter_height, + filter_width, strides[0], strides[1], strides[2], paddings[0], + paddings[1], paddings[2], output_depth, output_height, output_width, + col->data()); } }; @@ -200,18 +200,18 @@ template class Col2VolFunctor { public: void operator()(const platform::DeviceContext& context, - framework::Tensor& vol, const framework::Tensor& col, - int dilation_d, int dilation_h, int dilation_w, - int stride_depth, int stride_height, int stride_width, - int padding_depth, int padding_height, - int padding_width) const { - PADDLE_ENFORCE(vol.dims().size() == 4); + const framework::Tensor& col, + const std::vector& dilations, + const std::vector& strides, + const std::vector& paddings, + framework::Tensor* vol) const { + PADDLE_ENFORCE(vol->dims().size() == 4); PADDLE_ENFORCE(col.dims().size() == 7); - int input_channels = vol.dims()[0]; - int input_depth = vol.dims()[1]; - int input_height = vol.dims()[2]; - int input_width = vol.dims()[3]; + int input_channels = vol->dims()[0]; + int input_depth = vol->dims()[1]; + int input_height = vol->dims()[2]; + int input_width = vol->dims()[3]; int filter_depth = col.dims()[1]; int filter_height = col.dims()[2]; int filter_width = col.dims()[3]; @@ -219,23 +219,23 @@ class Col2VolFunctor { int output_height = col.dims()[5]; int output_width = col.dims()[6]; - PADDLE_ENFORCE_EQ((input_depth + 2 * padding_depth - - ((dilation_d * (filter_depth - 1) + 1))) / - stride_depth + + PADDLE_ENFORCE_EQ((input_depth + 2 * paddings[0] - + ((dilations[0] * (filter_depth - 1) + 1))) / + strides[0] + 1, output_depth, "input_depth and output_depth are " "Mismatching."); - PADDLE_ENFORCE_EQ((input_height + 2 * padding_height - - ((dilation_h * (filter_height - 1) + 1))) / - stride_height + + PADDLE_ENFORCE_EQ((input_height + 2 * paddings[1] - + ((dilations[1] * (filter_height - 1) + 1))) / + strides[1] + 1, output_height, "input_height and output_height are " "Mismatching."); - PADDLE_ENFORCE_EQ((input_width + 2 * padding_width - - ((dilation_w * (filter_width - 1) + 1))) / - stride_width + + PADDLE_ENFORCE_EQ((input_width + 2 * paddings[2] - + ((dilations[2] * (filter_width - 1) + 1))) / + strides[2] + 1, output_width, "input_width and output_width are " @@ -250,10 +250,10 @@ class Col2VolFunctor { reinterpret_cast(context) .stream()>>>( num_kernels, col.data(), input_depth, input_height, input_width, - dilation_d, dilation_h, dilation_w, filter_depth, filter_height, - filter_width, stride_depth, stride_height, stride_width, padding_depth, - padding_height, padding_width, output_depth, output_height, - output_width, vol.data()); + dilations[0], dilations[1], dilations[2], filter_depth, filter_height, + filter_width, strides[0], strides[1], strides[2], paddings[0], + paddings[1], paddings[2], output_depth, output_height, output_width, + vol->data()); } }; diff --git a/paddle/operators/math/vol2col.h b/paddle/operators/math/vol2col.h index c2d8257c0ba5b0c6105669a501b187456ba052a5..cbc30bd754608dd6e6def1a4097d69bdf0c942c3 100644 --- a/paddle/operators/math/vol2col.h +++ b/paddle/operators/math/vol2col.h @@ -31,6 +31,15 @@ namespace math { * \param colData Column data. * \param colShape The shape of colData. * + * \param dilations dilation data. + * \param 3-dimension [dilation_depth, dilation_height, dilation_width]. + * + * \param strides stride data. + * \param 3-dimension [stride_depth, stride_height, stride_width]. + * + * \param paddings padding data. + * \param 3-dimension [d_pad, h_pad, w_pad]. + * * The shape of colData is: * [input_channels, filter_depth, filter_height, filter_width, output_depth, * output_height, output_width] @@ -57,22 +66,22 @@ template class Vol2ColFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& vol, framework::Tensor& col, - int dilation_d, int dilation_h, int dilation_w, - int stride_depth, int stride_height, int stride_width, - int padding_depth, int padding_height, - int padding_width) const; + const framework::Tensor& vol, + const std::vector& dilations, + const std::vector& strides, + const std::vector& paddings, + framework::Tensor* col) const; }; template class Col2VolFunctor { public: void operator()(const platform::DeviceContext& context, - framework::Tensor& vol, const framework::Tensor& col, - int dilation_d, int dilation_h, int dilation_w, - int stride_depth, int stride_height, int stride_width, - int padding_depth, int padding_height, - int padding_width) const; + const framework::Tensor& col, + const std::vector& dilations, + const std::vector& strides, + const std::vector& paddings, + framework::Tensor* vol) const; }; } // namespace math diff --git a/paddle/operators/math/vol2col_test.cc b/paddle/operators/math/vol2col_test.cc index 9d673ad36cfed90962f2eb9b6c5859e71473bcf7..c31c716842f30de67c29b803866b8c82ddcf4a41 100644 --- a/paddle/operators/math/vol2col_test.cc +++ b/paddle/operators/math/vol2col_test.cc @@ -62,12 +62,15 @@ void testVol2col() { int input_height = 2; int input_width = 3; int filter_size = 2; - int stride = 1; - int padding = 0; - int dilation = 1; - int output_depth = (input_depth - filter_size + 2 * padding) / stride + 1; - int output_height = (input_height - filter_size + 2 * padding) / stride + 1; - int output_width = (input_width - filter_size + 2 * padding) / stride + 1; + std::vector strides({1, 1, 1}); + std::vector paddings({0, 0, 0}); + std::vector dilations({1, 1, 1}); + int output_depth = + (input_depth - filter_size + 2 * paddings[0]) / strides[0] + 1; + int output_height = + (input_height - filter_size + 2 * paddings[1]) / strides[1] + 1; + int output_width = + (input_width - filter_size + 2 * paddings[2]) / strides[2] + 1; // Vol2Col test float* input_ptr = @@ -86,8 +89,7 @@ void testVol2col() { *place); paddle::operators::math::Vol2ColFunctor vol2col; - vol2col(*context, input, output, dilation, dilation, dilation, stride, stride, - stride, padding, padding, padding); + vol2col(*context, input, dilations, strides, paddings, &output); float vol_2_col[] = {0, 1, 1, 2, 3, 4, 4, 5, 6, 7, 7, 8, 9, 10, 10, 11}; float* out_cfo_ptr; @@ -112,8 +114,7 @@ void testVol2col() { } paddle::operators::math::Col2VolFunctor col2vol; - col2vol(*context, input, output, dilation, dilation, dilation, stride, stride, - stride, padding, padding, padding); + col2vol(*context, output, dilations, strides, paddings, &input); float* in_ptr; if (paddle::platform::is_cpu_place(*place)) {