提交 97e9dd72 编写于 作者: C chengduoZH

add dilation for im2col

上级 91b72482
...@@ -22,8 +22,6 @@ class CudnnConvOpMaker : public Conv2DOpMaker { ...@@ -22,8 +22,6 @@ class CudnnConvOpMaker : public Conv2DOpMaker {
CudnnConvOpMaker(framework::OpProto* proto, CudnnConvOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker) framework::OpAttrChecker* op_checker)
: Conv2DOpMaker(proto, op_checker) { : Conv2DOpMaker(proto, op_checker) {
AddAttr<std::vector<int>>("dilations", "dilations of convolution operator.")
.SetDefault(std::vector<int>{1, 1});
AddAttr<int>("workspace_size_MB", AddAttr<int>("workspace_size_MB",
"workspace size for cudnn, in MB, " "workspace size for cudnn, in MB, "
"workspace is a section of GPU memory which will be " "workspace is a section of GPU memory which will be "
......
...@@ -30,6 +30,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -30,6 +30,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides"); std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings"); std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
int groups = ctx->Attrs().Get<int>("groups"); int groups = ctx->Attrs().Get<int>("groups");
std::vector<int> dilations = ctx->Attrs().Get<std::vector<int>>("dilations");
int input_channels = in_dims[1]; int input_channels = in_dims[1];
int output_channels = filter_dims[0]; int output_channels = filter_dims[0];
...@@ -54,7 +55,8 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -54,7 +55,8 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]}); std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
for (size_t i = 0; i < paddings.size(); ++i) { for (size_t i = 0; i < paddings.size(); ++i) {
output_shape.push_back(OutputSize(in_dims[i + 2], filter_dims[i + 2], output_shape.push_back(OutputSize(in_dims[i + 2], filter_dims[i + 2],
paddings[i], strides[i])); dilations[i], paddings[i], paddings[i],
strides[i]));
} }
ctx->SetOutputDim("Output", framework::make_ddim(output_shape)); ctx->SetOutputDim("Output", framework::make_ddim(output_shape));
} }
...@@ -90,6 +92,10 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto, ...@@ -90,6 +92,10 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto,
"first half of the input channels, while the second half of the filters " "first half of the input channels, while the second half of the filters "
"is only connected to the second half of the input channels.") "is only connected to the second half of the input channels.")
.SetDefault(1); .SetDefault(1);
AddAttr<std::vector<int>>("dilations",
"(vector default:{1, 1}), the dilations of "
"convolution operator.")
.SetDefault(std::vector<int>{1, 1});
AddComment(R"DOC( AddComment(R"DOC(
Convolution Operator. Convolution Operator.
...@@ -151,6 +157,11 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto, ...@@ -151,6 +157,11 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto,
"first half of the input channels, while the second half of the filters " "first half of the input channels, while the second half of the filters "
"is only connected to the second half of the input channels.") "is only connected to the second half of the input channels.")
.SetDefault(1); .SetDefault(1);
AddAttr<std::vector<int>>("dilations",
"(vector default:{1, 1, 1}), the dilations of "
"convolution operator. Currently, conv3d doesn't "
"support dilation.")
.SetDefault(std::vector<int>{1, 1, 1});
AddComment(R"DOC( AddComment(R"DOC(
Convolution3D Operator. Convolution3D Operator.
......
...@@ -27,9 +27,12 @@ using Tensor = framework::Tensor; ...@@ -27,9 +27,12 @@ using Tensor = framework::Tensor;
// Base convolution operator definations for other conv // Base convolution operator definations for other conv
// like operators to reuse the implementation. // like operators to reuse the implementation.
inline int OutputSize(int input_size, int filter_size, int padding, inline int OutputSize(int input_size, int filter_size, int dilation,
int stride) { int padding_up, int padding_down, int stride) {
int output_size = (input_size - filter_size + 2 * padding) / stride + 1; int output_size = (input_size + padding_up + padding_down -
(dilation * (filter_size - 1) + 1)) /
stride +
1;
return output_size; return output_size;
} }
...@@ -76,6 +79,7 @@ class GemmConvKernel : public framework::OpKernel<T> { ...@@ -76,6 +79,7 @@ class GemmConvKernel : public framework::OpKernel<T> {
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
int groups = context.Attr<int>("groups"); int groups = context.Attr<int>("groups");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
...@@ -139,9 +143,9 @@ class GemmConvKernel : public framework::OpKernel<T> { ...@@ -139,9 +143,9 @@ class GemmConvKernel : public framework::OpKernel<T> {
if (filter_shape_vec.size() == 2) { if (filter_shape_vec.size() == 2) {
// im2col // im2col
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col; math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
im2col(context.device_context(), in_slice, col, strides[0], im2col(context.device_context(), in_slice, col, dilations[0],
strides[1], paddings[0], paddings[0], paddings[1], dilations[1], strides[0], strides[1], paddings[0], paddings[0],
paddings[1]); paddings[1], paddings[1]);
} else if (filter_shape_vec.size() == 3) { } else if (filter_shape_vec.size() == 3) {
// vol2col // vol2col
math::Vol2ColFunctor<Place, T> vol2col; math::Vol2ColFunctor<Place, T> vol2col;
...@@ -181,6 +185,7 @@ class GemmConvGradKernel : public framework::OpKernel<T> { ...@@ -181,6 +185,7 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
int groups = context.Attr<int>("groups"); int groups = context.Attr<int>("groups");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
...@@ -263,9 +268,9 @@ class GemmConvGradKernel : public framework::OpKernel<T> { ...@@ -263,9 +268,9 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
if (filter_shape_vec.size() == 2) { if (filter_shape_vec.size() == 2) {
math::Col2ImFunctor<math::ColFormat::kCFO, Place, T> col2im; math::Col2ImFunctor<math::ColFormat::kCFO, Place, T> col2im;
col2im(context.device_context(), in_grad_slice, col, strides[0], col2im(context.device_context(), in_grad_slice, col, dilations[0],
strides[1], paddings[0], paddings[0], paddings[1], dilations[1], strides[0], strides[1], paddings[0],
paddings[1]); paddings[0], paddings[1], paddings[1]);
} else if (filter_shape_vec.size() == 3) { } else if (filter_shape_vec.size() == 3) {
math::Col2VolFunctor<Place, T> col2vol; math::Col2VolFunctor<Place, T> col2vol;
...@@ -295,9 +300,9 @@ class GemmConvGradKernel : public framework::OpKernel<T> { ...@@ -295,9 +300,9 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
if (filter_shape_vec.size() == 2) { if (filter_shape_vec.size() == 2) {
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col; math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
im2col(context.device_context(), in_slice, col, strides[0], im2col(context.device_context(), in_slice, col, dilations[0],
strides[1], paddings[0], paddings[0], paddings[1], dilations[1], strides[0], strides[1], paddings[0],
paddings[1]); paddings[0], paddings[1], paddings[1]);
} else if (filter_shape_vec.size() == 3) { } else if (filter_shape_vec.size() == 3) {
math::Vol2ColFunctor<Place, T> vol2col; math::Vol2ColFunctor<Place, T> vol2col;
vol2col(context.device_context(), in_slice, col, strides[0], vol2col(context.device_context(), in_slice, col, strides[0],
......
...@@ -69,6 +69,9 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> { ...@@ -69,6 +69,9 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
// TODO(Zhuoyuan): Paddings can be added in future. // TODO(Zhuoyuan): Paddings can be added in future.
// groups will alway be disabled in conv2dtranspose. // groups will alway be disabled in conv2dtranspose.
int dilation_h = 1;
int dilation_w = 1;
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
// input_shape_vec: {h, w} or {d, h, w} // input_shape_vec: {h, w} or {d, h, w}
...@@ -140,8 +143,8 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> { ...@@ -140,8 +143,8 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
// from (c * k_h * k_w, h * w) to (c, o_h, o_w) // from (c * k_h * k_w, h * w) to (c, o_h, o_w)
math::Col2ImFunctor<math::ColFormat::kCFO, Place, T> col2im; math::Col2ImFunctor<math::ColFormat::kCFO, Place, T> col2im;
col2im(context.device_context(), output_batch, col, strides[0], col2im(context.device_context(), output_batch, col, dilation_h,
strides[1], 0, 0, 0, 0); dilation_w, strides[0], strides[1], 0, 0, 0, 0);
} else if (filter_shape_vec.size() == 3) { } else if (filter_shape_vec.size() == 3) {
// col2vol: col_matrix -> dy // col2vol: col_matrix -> dy
// from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w) // from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w)
...@@ -174,6 +177,9 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> { ...@@ -174,6 +177,9 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// Actually, no paddings and groups allowed in conv transpose. // Actually, no paddings and groups allowed in conv transpose.
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
int dilation_h = 1;
int dilation_w = 1;
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
// input_shape_vec: {h, w} or {d, h, w} // input_shape_vec: {h, w} or {d, h, w}
...@@ -248,9 +254,9 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> { ...@@ -248,9 +254,9 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// im2col: dy -> col matrix // im2col: dy -> col matrix
// from (c, o_h, o_w) to (c * k_h * k_w, h * w) // from (c, o_h, o_w) to (c * k_h * k_w, h * w)
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col; math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
im2col(context.device_context(), output_grad_batch, col, strides[0], im2col(context.device_context(), output_grad_batch, col, dilation_h,
strides[1], paddings[0], paddings[0], paddings[1], dilation_w, strides[0], strides[1], paddings[0], paddings[0],
paddings[1]); paddings[1], paddings[1]);
} else if (filter_shape_vec.size() == 3) { } else if (filter_shape_vec.size() == 3) {
// vol2col: dy -> col_matrix // vol2col: dy -> col_matrix
// from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w) // from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w)
......
...@@ -95,6 +95,9 @@ class ContextProjectFunctor { ...@@ -95,6 +95,9 @@ class ContextProjectFunctor {
math::Im2ColFunctor<math::ColFormat::kOCF, Place, float> im2col_ocf; math::Im2ColFunctor<math::ColFormat::kOCF, Place, float> im2col_ocf;
int dilation_h = 1;
int dilation_w = 1;
int input_row_begin, input_row_end; int input_row_begin, input_row_end;
int sequence_height, sequence_width; int sequence_height, sequence_width;
sequence_width = in.dims()[1]; sequence_width = in.dims()[1];
...@@ -124,7 +127,7 @@ class ContextProjectFunctor { ...@@ -124,7 +127,7 @@ class ContextProjectFunctor {
sequence_width}); // input_channels, input_height, input_width sequence_width}); // input_channels, input_height, input_width
in_t.Resize(framework::make_ddim(input_shape)); in_t.Resize(framework::make_ddim(input_shape));
im2col_ocf(context, in_t, out_t, im2col_ocf(context, in_t, out_t, dilation_h, dilation_w,
/*stride_height*/ context_stride, /*stride_width*/ 1, up_pad, /*stride_height*/ context_stride, /*stride_width*/ 1, up_pad,
down_pad, 0, 0); down_pad, 0, 0);
out_t.Resize({sequence_height, context_length * sequence_width}); out_t.Resize({sequence_height, context_length * sequence_width});
...@@ -204,6 +207,9 @@ class ContextProjectGradFunctor { ...@@ -204,6 +207,9 @@ class ContextProjectGradFunctor {
math::Col2ImFunctor<math::ColFormat::kOCF, Place, float> col2im_ocf; math::Col2ImFunctor<math::ColFormat::kOCF, Place, float> col2im_ocf;
int dilation_h = 1;
int dilation_w = 1;
int input_row_begin, input_row_end; int input_row_begin, input_row_end;
int sequence_height, sequence_width; int sequence_height, sequence_width;
sequence_width = in.dims()[1]; sequence_width = in.dims()[1];
...@@ -234,7 +240,7 @@ class ContextProjectGradFunctor { ...@@ -234,7 +240,7 @@ class ContextProjectGradFunctor {
sequence_width}); // input_channels, input_height, input_width sequence_width}); // input_channels, input_height, input_width
in_t.Resize(framework::make_ddim(input_shape)); in_t.Resize(framework::make_ddim(input_shape));
col2im_ocf(context, in_t, out_t, col2im_ocf(context, in_t, out_t, dilation_h, dilation_w,
/*stride_height*/ context_stride, /*stride_width*/ 1, /*stride_height*/ context_stride, /*stride_width*/ 1,
up_pad, down_pad, 0, 0); up_pad, down_pad, 0, 0);
out_t.Resize({sequence_height, context_length * sequence_width}); out_t.Resize({sequence_height, context_length * sequence_width});
......
...@@ -29,35 +29,36 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -29,35 +29,36 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col, const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_up, int dilation_h, int dilation_w, int stride_height,
int padding_down, int padding_left, int padding_right) { int stride_width, int padding_up, int padding_down,
int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0]; int im_channels = im.dims()[0];
int input_height = im.dims()[1]; int im_height = im.dims()[1];
int input_width = im.dims()[2]; int im_width = im.dims()[2];
int filter_height = col.dims()[1]; int filter_height = col.dims()[1];
int filter_width = col.dims()[2]; int filter_width = col.dims()[2];
int output_height = col.dims()[3]; int col_height = col.dims()[3];
int output_width = col.dims()[4]; int col_width = col.dims()[4];
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down -
(input_height + padding_up + padding_down - filter_height) / ((dilation_h * (filter_height - 1) + 1))) /
stride_height + stride_height +
1, 1,
output_height, col_height,
"Output_height and padding(padding_up, padding_down) are " "Output_height and padding(padding_up, padding_down) are "
"inconsistent."); "inconsistent.");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right -
(input_width + padding_left + padding_right - filter_width) / ((dilation_w * (filter_width - 1) + 1))) /
stride_width + stride_width +
1, 1,
output_width, col_width,
"output_width and padding(padding_left, padding_right) are " "col_width and padding(padding_left, padding_right) are "
"inconsistent."); "inconsistent.");
int channels_col = input_channels * filter_height * filter_width; int channels_col = im_channels * filter_height * filter_width;
const T* im_data = im.data<T>(); const T* im_data = im.data<T>();
T* col_data = col.data<T>(); T* col_data = col.data<T>();
...@@ -66,19 +67,19 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -66,19 +67,19 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
int w_offset = c % filter_width; int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height; int h_offset = (c / filter_width) % filter_height;
int c_im = c / filter_width / filter_height; int c_im = c / filter_width / filter_height;
for (int h = 0; h < output_height; ++h) { for (int h = 0; h < col_height; ++h) {
for (int w = 0; w < output_width; ++w) { for (int w = 0; w < col_width; ++w) {
int im_row_idx = h * stride_height + h_offset - padding_up; int im_row_idx =
int im_col_idx = w * stride_width + w_offset - padding_left; h * stride_height - padding_up + h_offset * dilation_h;
int im_col_idx =
w * stride_width - padding_left + w_offset * dilation_w;
if (im_row_idx < 0 || im_row_idx >= input_height || im_col_idx < 0 || col_data[(c * col_height + h) * col_width + w] =
im_col_idx >= input_width) { (im_row_idx < 0 || im_row_idx >= im_height || im_col_idx < 0 ||
col_data[(c * output_height + h) * output_width + w] = T(0); im_col_idx >= im_width)
} else { ? static_cast<T>(0)
im_row_idx += c_im * input_height; : im_data[(im_row_idx + c_im * im_height) * im_width +
col_data[(c * output_height + h) * output_width + w] = im_col_idx];
im_data[im_row_idx * input_width + im_col_idx];
}
} }
} }
} }
...@@ -95,35 +96,35 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -95,35 +96,35 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
platform::CPUPlace, T> { platform::CPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im, void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int stride_height, const framework::Tensor& col, int dilation_h, int dilation_w,
int stride_width, int padding_up, int padding_down, int stride_height, int stride_width, int padding_up,
int padding_left, int padding_right) { int padding_down, int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0]; int im_channels = im.dims()[0];
int input_height = im.dims()[1]; int im_height = im.dims()[1];
int input_width = im.dims()[2]; int im_width = im.dims()[2];
int filter_height = col.dims()[1]; int filter_height = col.dims()[1];
int filter_width = col.dims()[2]; int filter_width = col.dims()[2];
int output_height = col.dims()[3]; int col_height = col.dims()[3];
int output_width = col.dims()[4]; int col_width = col.dims()[4];
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down -
(input_height + padding_up + padding_down - filter_height) / ((dilation_h * (filter_height - 1) + 1))) /
stride_height + stride_height +
1, 1,
output_height, col_height,
"Output_height and padding(padding_up, padding_down) are " "Output_height and padding(padding_up, padding_down) are "
"inconsistent."); "inconsistent.");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right -
(input_width + padding_left + padding_right - filter_width) / ((dilation_w * (filter_width - 1) + 1))) /
stride_width + stride_width +
1, 1,
output_width, col_width,
"output_width and padding(padding_left, padding_right) are " "col_width and padding(padding_left, padding_right) are "
"inconsistent."); "inconsistent.");
int channels_col = input_channels * filter_height * filter_width; int channels_col = im_channels * filter_height * filter_width;
T* im_data = im.data<T>(); T* im_data = im.data<T>();
const T* col_data = col.data<T>(); const T* col_data = col.data<T>();
...@@ -132,16 +133,18 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -132,16 +133,18 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
int w_offset = c % filter_width; int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height; int h_offset = (c / filter_width) % filter_height;
int c_im = c / filter_width / filter_height; int c_im = c / filter_width / filter_height;
for (int h = 0; h < output_height; ++h) { for (int h = 0; h < col_height; ++h) {
for (int w = 0; w < output_width; ++w) { for (int w = 0; w < col_width; ++w) {
int im_row_idx = h * stride_height + h_offset - padding_up; int im_row_idx =
int im_col_idx = w * stride_width + w_offset - padding_left; h * stride_height - padding_up + h_offset * dilation_h;
int im_col_idx =
w * stride_width - padding_left + w_offset * dilation_w;
if ((im_row_idx) >= 0 && (im_row_idx) < input_height && if ((im_row_idx) >= 0 && (im_row_idx) < im_height &&
(im_col_idx) >= 0 && (im_col_idx) < input_width) { (im_col_idx) >= 0 && (im_col_idx) < im_width) {
im_row_idx += c_im * input_height; im_row_idx += c_im * im_height;
im_data[im_row_idx * input_width + im_col_idx] += im_data[im_row_idx * im_width + im_col_idx] +=
col_data[(c * output_height + h) * output_width + w]; col_data[(c * col_height + h) * col_width + w];
} }
} }
} }
...@@ -169,39 +172,38 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -169,39 +172,38 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col, const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_up, int dilation_h, int dilation_w, int stride_height,
int padding_down, int padding_left, int padding_right) { int stride_width, int padding_up, int padding_down,
int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0]; int im_channels = im.dims()[0];
int input_height = im.dims()[1]; int im_height = im.dims()[1];
int input_width = im.dims()[2]; int im_width = im.dims()[2];
int filter_height = col.dims()[3]; int filter_height = col.dims()[3];
int filter_width = col.dims()[4]; int filter_width = col.dims()[4];
int output_height = col.dims()[0]; int col_height = col.dims()[0];
int output_width = col.dims()[1]; int col_width = col.dims()[1];
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - filter_height) /
(input_height + padding_up + padding_down - filter_height) / stride_height +
stride_height + 1,
1, col_height,
output_height, "Output_height and padding(padding_up, padding_down) are "
"Output_height and padding(padding_up, padding_down) are " "inconsistent.");
"inconsistent."); PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - filter_width) /
PADDLE_ENFORCE_EQ( stride_width +
(input_width + padding_left + padding_right - filter_width) / 1,
stride_width + col_width,
1, "col_width and padding(padding_left, padding_right) are "
output_width, "inconsistent.");
"output_width and padding(padding_left, padding_right) are "
"inconsistent.");
const T* im_data = im.data<T>(); const T* im_data = im.data<T>();
T* col_data = col.data<T>(); T* col_data = col.data<T>();
for (int col_row_idx = 0; col_row_idx < output_height; ++col_row_idx) { for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) {
for (int col_col_idx = 0; col_col_idx < output_width; ++col_col_idx) { for (int col_col_idx = 0; col_col_idx < col_width; ++col_col_idx) {
for (int channel = 0; channel < input_channels; ++channel) { for (int channel = 0; channel < im_channels; ++channel) {
for (int filter_row_idx = 0; filter_row_idx < filter_height; for (int filter_row_idx = 0; filter_row_idx < filter_height;
++filter_row_idx) { ++filter_row_idx) {
for (int filter_col_idx = 0; filter_col_idx < filter_width; for (int filter_col_idx = 0; filter_col_idx < filter_width;
...@@ -210,22 +212,21 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -210,22 +212,21 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
col_row_idx * stride_height + filter_row_idx - padding_up; col_row_idx * stride_height + filter_row_idx - padding_up;
int im_col_offset = int im_col_offset =
col_col_idx * stride_width + filter_col_idx - padding_left; col_col_idx * stride_width + filter_col_idx - padding_left;
int col_offset = ((((col_row_idx)*output_width + col_col_idx) * int col_offset =
input_channels + ((((col_row_idx)*col_width + col_col_idx) * im_channels +
channel) * channel) *
filter_height + filter_height +
filter_row_idx) * filter_row_idx) *
filter_width + filter_width +
filter_col_idx; filter_col_idx;
if (im_row_offset < 0 || im_row_offset >= input_height ||
im_col_offset < 0 || im_col_offset >= input_width) { int im_offset = (channel * im_height + im_row_offset) * im_width +
col_data[col_offset] = T(0); im_col_offset;
} else { col_data[col_offset] =
int im_offset = (im_row_offset < 0 || im_row_offset >= im_height ||
(channel * input_height + im_row_offset) * input_width + im_col_offset < 0 || im_col_offset >= im_width)
im_col_offset; ? static_cast<T>(0)
col_data[col_offset] = im_data[im_offset]; : im_data[im_offset];
}
} }
} }
} }
...@@ -244,40 +245,38 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -244,40 +245,38 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
platform::CPUPlace, T> { platform::CPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im, void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int stride_height, const framework::Tensor& col, int dilation_h, int dilation_w,
int stride_width, int padding_up, int padding_down, int stride_height, int stride_width, int padding_up,
int padding_left, int padding_right) { int padding_down, int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0]; int im_channels = im.dims()[0];
int input_height = im.dims()[1]; int im_height = im.dims()[1];
int input_width = im.dims()[2]; int im_width = im.dims()[2];
int filter_height = col.dims()[3]; int filter_height = col.dims()[3];
int filter_width = col.dims()[4]; int filter_width = col.dims()[4];
int output_height = col.dims()[0]; int col_height = col.dims()[0];
int output_width = col.dims()[1]; int col_width = col.dims()[1];
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - filter_height) /
(input_height + padding_up + padding_down - filter_height) / stride_height +
stride_height + 1,
1, col_height,
output_height, "Output_height and padding(padding_up, padding_down) are "
"Output_height and padding(padding_up, padding_down) are " "inconsistent.");
"inconsistent."); PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - filter_width) /
PADDLE_ENFORCE_EQ( stride_width +
(input_width + padding_left + padding_right - filter_width) / 1,
stride_width + col_width,
1, "col_width and padding(padding_left, padding_right) are "
output_width, "inconsistent.");
"output_width and padding(padding_left, padding_right) are "
"inconsistent.");
T* im_data = im.data<T>(); T* im_data = im.data<T>();
const T* col_data = col.data<T>(); const T* col_data = col.data<T>();
for (int col_row_idx = 0; col_row_idx < output_height; ++col_row_idx) { for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) {
for (int col_col_idx = 0; col_col_idx < output_width; ++col_col_idx) { for (int col_col_idx = 0; col_col_idx < col_width; ++col_col_idx) {
for (int channel = 0; channel < input_channels; ++channel) { for (int channel = 0; channel < im_channels; ++channel) {
for (int filter_row_idx = 0; filter_row_idx < filter_height; for (int filter_row_idx = 0; filter_row_idx < filter_height;
++filter_row_idx) { ++filter_row_idx) {
for (int filter_col_idx = 0; filter_col_idx < filter_width; for (int filter_col_idx = 0; filter_col_idx < filter_width;
...@@ -286,17 +285,17 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -286,17 +285,17 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
col_row_idx * stride_height + filter_row_idx - padding_up; col_row_idx * stride_height + filter_row_idx - padding_up;
int im_col_offset = int im_col_offset =
col_col_idx * stride_width + filter_col_idx - padding_left; col_col_idx * stride_width + filter_col_idx - padding_left;
int col_offset = (((col_row_idx * output_width + col_col_idx) * int col_offset =
input_channels + (((col_row_idx * col_width + col_col_idx) * im_channels +
channel) * channel) *
filter_height + filter_height +
filter_row_idx) * filter_row_idx) *
filter_width + filter_width +
filter_col_idx; filter_col_idx;
if (im_row_offset >= 0 && im_row_offset < input_height && if (im_row_offset >= 0 && im_row_offset < im_height &&
im_col_offset >= 0 && im_col_offset < input_width) { im_col_offset >= 0 && im_col_offset < im_width) {
int im_offset = int im_offset =
(channel * input_height + im_row_offset) * input_width + (channel * im_height + im_row_offset) * im_width +
im_col_offset; im_col_offset;
im_data[im_offset] += col_data[col_offset]; im_data[im_offset] += col_data[col_offset];
} }
......
...@@ -20,36 +20,32 @@ namespace operators { ...@@ -20,36 +20,32 @@ namespace operators {
namespace math { namespace math {
template <class T> template <class T>
__global__ void im2col(const T* data_im, int num_outs, int height, int width, __global__ void im2col(const T* data_im, int num_outs, int im_height,
int im_width, int dilation_h, int dilation_w,
int filter_height, int filter_width, int stride_height, int filter_height, int filter_width, int stride_height,
int stride_width, int padding_height, int padding_width, int stride_width, int padding_height, int padding_width,
int output_height, int output_width, T* data_col) { int col_height, int col_width, T* data_col) {
int index = (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; const int index =
(blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
if (index < num_outs) { if (index < num_outs) {
int w_out = index % output_width; int w_out = index % col_width;
index /= output_width; int h_out = (index / col_width) % col_height;
int h_out = index % output_height; int channel_in = index / col_width / col_height;
int channel_in = index / output_height;
int channel_out = channel_in * filter_height * filter_width; int channel_out = channel_in * filter_height * filter_width;
int h_in = h_out * stride_height; int h_in = h_out * stride_height - padding_height;
int w_in = w_out * stride_width; int w_in = w_out * stride_width - padding_width;
data_col += (channel_out * output_height + h_out) * output_width + w_out; data_col += (channel_out * col_height + h_out) * col_width + w_out;
data_im += (channel_in * im_height + h_in) * im_width + w_in;
for (int i = 0; i < filter_height; ++i) { for (int i = 0; i < filter_height; ++i) {
for (int j = 0; j < filter_width; ++j) { for (int j = 0; j < filter_width; ++j) {
int rIdx = int(h_in + i); int rIdx = h_in + i * dilation_h;
int cIdx = int(w_in + j); int cIdx = w_in + j * dilation_w;
if ((rIdx - (int)padding_height) >= (int)height || *data_col =
(rIdx - (int)padding_height) < 0 || (rIdx >= im_height || rIdx < 0 || cIdx >= im_width || cIdx < 0)
(cIdx - (int)padding_width) >= (int)width || ? 0
(cIdx - (int)padding_width) < 0) { : data_im[i * dilation_h * im_width + j * dilation_w];
*data_col = 0; data_col += col_height * col_width;
} else {
rIdx = rIdx + channel_in * height - padding_height;
cIdx = cIdx - padding_width;
*data_col = data_im[rIdx * width + cIdx];
}
data_col += output_height * output_width;
} }
} }
} }
...@@ -66,29 +62,36 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -66,29 +62,36 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col, const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_up, int dilation_h, int dilation_w, int stride_height,
int padding_down, int padding_left, int padding_right) { int stride_width, int padding_up, int padding_down,
int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0]; int im_channels = im.dims()[0];
int input_height = im.dims()[1]; int im_height = im.dims()[1];
int input_width = im.dims()[2]; int im_width = im.dims()[2];
int filter_height = col.dims()[1]; int filter_height = col.dims()[1];
int filter_width = col.dims()[2]; int filter_width = col.dims()[2];
int output_height = col.dims()[3]; int col_height = col.dims()[3];
int output_width = col.dims()[4]; int col_width = col.dims()[4];
PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) / PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down -
stride_height + (dilation_h * (filter_height - 1) + 1)) /
1 == stride_height +
output_height); 1,
PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) / col_height,
stride_width + "Output_height and padding(padding_up, padding_down) are "
1 == "inconsistent.");
output_width); PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right -
(dilation_w * (filter_width - 1) + 1)) /
int num_outputs = input_channels * output_height * output_width; stride_width +
1,
col_width,
"col_width and padding(padding_left, padding_right) are "
"inconsistent.");
int num_outputs = im_channels * col_height * col_width;
int blocks = (num_outputs + 1024 - 1) / 1024; int blocks = (num_outputs + 1024 - 1) / 1024;
int block_x = 512; int block_x = 512;
int block_y = (blocks + 512 - 1) / 512; int block_y = (blocks + 512 - 1) / 512;
...@@ -97,56 +100,56 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -97,56 +100,56 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
im2col<T><<<grid, threads, 0, im2col<T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context) reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>( .stream()>>>(
im.data<T>(), num_outputs, input_height, input_width, filter_height, im.data<T>(), num_outputs, im_height, im_width, dilation_h, dilation_w,
filter_width, stride_height, stride_width, padding_up, padding_left, filter_height, filter_width, stride_height, stride_width, padding_up,
output_height, output_width, col.data<T>()); padding_left, col_height, col_width, col.data<T>());
} }
}; };
template <class T> template <class T>
__global__ void col2im(size_t n, const T* data_col, size_t height, size_t width, __global__ void col2im(int n, const T* data_col, int im_height, int im_width,
size_t channels, size_t filter_height, int dilation_h, int dilation_w, int filter_height,
size_t filter_width, size_t stride_height, int filter_width, int stride_height, int stride_width,
size_t stride_width, size_t padding_height, int padding_height, int padding_width, int col_height,
size_t padding_width, size_t output_height, int col_width, T* data_im) {
size_t output_width, T* data_im) { const int index =
size_t index =
(blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
const int d_filter_height = dilation_h * (filter_height - 1) + 1;
const int d_filter_width = dilation_w * (filter_width - 1) + 1;
if (index < n) { if (index < n) {
T val = 0; T val = 0;
int w = int(index % width); int w = index % im_width;
int h = int((index / width) % height); int h = (index / im_width) % im_height;
int c = int(index / (width * height)); int c = index / (im_width * im_height);
if ((w - (int)padding_width) >= 0 &&
(w - (int)padding_width) < (width - 2 * padding_width) && // compute the start and end of the output
(h - (int)padding_height) >= 0 && int w_col_start =
(h - padding_height) < (height - 2 * padding_height)) { (w < d_filter_width) ? 0 : (w - d_filter_width) / stride_width + 1;
// compute the start and end of the output int w_col_end = min(w / stride_width + 1, col_width);
int w_col_start = (w < (int)filter_width) int h_col_start =
? 0 (h < d_filter_height) ? 0 : (h - d_filter_height) / stride_height + 1;
: (w - int(filter_width)) / (int)stride_width + 1; int h_col_end = min(h / stride_height + 1, col_height);
int w_col_end =
min((int)(w / (int)stride_width + 1), (int)(output_width)); for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
int h_col_start = (h < (int)filter_height) for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
? 0 int h_off = (h - h_col * stride_height);
: (h - (int)filter_height) / (int)stride_height + 1; int w_off = (w - w_col * stride_width);
int h_col_end = min(int(h / stride_height + 1), int(output_height)); if (h_off % dilation_h == 0 && w_off % dilation_w == 0) {
for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { h_off /= dilation_h;
for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { w_off /= dilation_w;
// the col location: [c * width * height + h_out, w_out] int data_col_index =
int c_col = int(c * filter_height * filter_width) + (((c * filter_height + h_off) * filter_width + w_off) *
(h - h_col * (int)stride_height) * (int)filter_width + col_height +
(w - w_col * (int)stride_width); h_col) *
val += col_width +
data_col[(c_col * output_height + h_col) * output_width + w_col]; w_col;
val += data_col[data_col_index];
} }
} }
h -= padding_height;
w -= padding_width;
data_im[c * ((width - 2 * padding_width) *
(height - 2 * padding_height)) +
h * (width - 2 * padding_width) + w] += val;
} }
data_im[index] = val;
} }
} }
...@@ -160,32 +163,36 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -160,32 +163,36 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
platform::GPUPlace, T> { platform::GPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im, void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int stride_height, const framework::Tensor& col, int dilation_h, int dilation_w,
int stride_width, int padding_up, int padding_down, int stride_height, int stride_width, int padding_up,
int padding_left, int padding_right) { int padding_down, int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0]; int im_channels = im.dims()[0];
int input_height = im.dims()[1]; int im_height = im.dims()[1];
int input_width = im.dims()[2]; int im_width = im.dims()[2];
int filter_height = col.dims()[1]; int filter_height = col.dims()[1];
int filter_width = col.dims()[2]; int filter_width = col.dims()[2];
int output_height = col.dims()[3]; int col_height = col.dims()[3];
int output_width = col.dims()[4]; int col_width = col.dims()[4];
PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) / PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down -
stride_height + (dilation_h * (filter_height - 1) + 1)) /
1 == stride_height +
output_height); 1,
PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) / col_height,
stride_width + "Output_height and padding(padding_up, padding_down) are "
1 == "inconsistent.");
output_width); PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right -
(dilation_w * (filter_width - 1) + 1)) /
size_t num_kernels = input_channels * stride_width +
(input_height + padding_up + padding_down) * 1,
(input_width + padding_left + padding_right); col_width,
"col_width and padding(padding_left, padding_right) are "
"inconsistent.");
size_t num_kernels = im_channels * im_height * im_width;
size_t blocks = (num_kernels + 1024 - 1) / 1024; size_t blocks = (num_kernels + 1024 - 1) / 1024;
size_t block_x = 512; size_t block_x = 512;
...@@ -198,10 +205,9 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -198,10 +205,9 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
col2im<T><<<grid, threads, 0, col2im<T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context) reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>( .stream()>>>(
num_kernels, col.data<T>(), input_height + padding_up + padding_down, num_kernels, col.data<T>(), im_height, im_width, dilation_h, dilation_w,
input_width + padding_left + padding_left, input_channels,
filter_height, filter_width, stride_height, stride_width, padding_up, filter_height, filter_width, stride_height, stride_width, padding_up,
padding_left, output_height, output_width, im.data<T>()); padding_left, col_height, col_width, im.data<T>());
} }
}; };
...@@ -215,33 +221,32 @@ template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -215,33 +221,32 @@ template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
platform::GPUPlace, double>; platform::GPUPlace, double>;
template <class T> template <class T>
__global__ void im2colOCF(const T* im_data, T* col_data, int input_channels, __global__ void im2colOCF(const T* im_data, T* col_data, int im_channels,
int input_height, int input_width, int filter_height, int im_height, int im_width, int filter_height,
int filter_width, int stride_height, int stride_width, int filter_width, int stride_height, int stride_width,
int padding_height, int padding_width, int padding_height, int padding_width, int col_height,
int output_height, int output_width) { int col_width) {
int swid = blockIdx.x; int swid = blockIdx.x;
int shid = blockIdx.y; int shid = blockIdx.y;
for (int channelid = threadIdx.z; channelid < input_channels; for (int channelid = threadIdx.z; channelid < im_channels;
channelid += blockDim.z) { channelid += blockDim.z) {
for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) { for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) {
for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) { for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) {
int width_offset = idx + swid * stride_width - padding_width; int width_offset = idx + swid * stride_width - padding_width;
int height_offset = idy + shid * stride_height - padding_height; int height_offset = idy + shid * stride_height - padding_height;
int im_offset = width_offset + height_offset * input_width + int im_offset = width_offset + height_offset * im_width +
channelid * input_height * input_width; channelid * im_height * im_width;
int col_offset = idx + idy * filter_width + int col_offset = idx + idy * filter_width +
channelid * filter_height * filter_width + channelid * filter_height * filter_width +
(shid * output_width + swid) * (shid * col_width + swid) *
(input_channels * filter_height * filter_width); (im_channels * filter_height * filter_width);
if (height_offset >= input_height || height_offset < 0 || col_data[col_offset] =
width_offset >= input_width || width_offset < 0) { (height_offset >= im_height || height_offset < 0 ||
col_data[col_offset] = T(0); width_offset >= im_width || width_offset < 0)
} else { ? T(0)
col_data[col_offset] = im_data[im_offset]; : im_data[im_offset];
}
} }
} }
} }
...@@ -258,26 +263,33 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -258,26 +263,33 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col, const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_up, int dilation_h, int dilation_w, int stride_height,
int padding_down, int padding_left, int padding_right) { int stride_width, int padding_up, int padding_down,
int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0]; int im_channels = im.dims()[0];
int input_height = im.dims()[1]; int im_height = im.dims()[1];
int input_width = im.dims()[2]; int im_width = im.dims()[2];
int filter_height = col.dims()[3]; int filter_height = col.dims()[3];
int filter_width = col.dims()[4]; int filter_width = col.dims()[4];
int output_height = col.dims()[0]; int col_height = col.dims()[0];
int output_width = col.dims()[1]; int col_width = col.dims()[1];
PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) / PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down -
stride_height + (dilation_h * (filter_height - 1) + 1)) /
1 == stride_height +
output_height); 1,
PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) / col_height,
stride_width + "Output_height and padding(padding_up, padding_down) are "
1 == "inconsistent.");
output_width); PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right -
(dilation_w * (filter_width - 1) + 1)) /
stride_width +
1,
col_width,
"col_width and padding(padding_left, padding_right) are "
"inconsistent.");
int block_dim_x = 0; int block_dim_x = 0;
int block_dim_y = 0; int block_dim_y = 0;
...@@ -296,42 +308,41 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -296,42 +308,41 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
} }
int block_dim_z = 1024 / block_dim_x / block_dim_y; int block_dim_z = 1024 / block_dim_x / block_dim_y;
dim3 threads(block_dim_x, block_dim_y, dim3 threads(block_dim_x, block_dim_y, std::min(block_dim_z, im_channels));
std::min(block_dim_z, input_channels)); dim3 grid(col_width, col_height);
dim3 grid(output_width, output_height);
im2colOCF<T><<<grid, threads, 0, im2colOCF<T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context) reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>( .stream()>>>(
im.data<T>(), col.data<T>(), input_channels, input_height, input_width, im.data<T>(), col.data<T>(), im_channels, im_height, im_width,
filter_height, filter_width, stride_height, stride_width, padding_up, filter_height, filter_width, stride_height, stride_width, padding_up,
padding_left, output_height, output_width); padding_left, col_height, col_width);
} }
}; };
template <class T> template <class T>
__global__ void col2imOCF(T* im_data, const T* col_data, int input_channels, __global__ void col2imOCF(T* im_data, const T* col_data, int im_channels,
int input_height, int input_width, int filter_height, int im_height, int im_width, int filter_height,
int filter_width, int stride_height, int stride_width, int filter_width, int stride_height, int stride_width,
int padding_height, int padding_width, int padding_height, int padding_width, int col_height,
int output_height, int output_width) { int col_width) {
int swid = blockIdx.x; int swid = blockIdx.x;
int shid = blockIdx.y; int shid = blockIdx.y;
for (int channelid = threadIdx.z; channelid < input_channels; for (int channelid = threadIdx.z; channelid < im_channels;
channelid += blockDim.z) { channelid += blockDim.z) {
for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) { for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) {
for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) { for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) {
int width_offset = idx + swid * stride_width - padding_width; int width_offset = idx + swid * stride_width - padding_width;
int height_offset = idy + shid * stride_height - padding_height; int height_offset = idy + shid * stride_height - padding_height;
int im_offset = width_offset + height_offset * input_width + int im_offset = width_offset + height_offset * im_width +
channelid * input_height * input_width; channelid * im_height * im_width;
int col_offset = idx + idy * filter_width + int col_offset = idx + idy * filter_width +
channelid * filter_height * filter_width + channelid * filter_height * filter_width +
(shid * output_width + swid) * (shid * col_width + swid) *
(input_channels * filter_height * filter_width); (im_channels * filter_height * filter_width);
if (height_offset >= 0 && height_offset < input_height && if (height_offset >= 0 && height_offset < im_height &&
width_offset >= 0 && width_offset < input_width) { width_offset >= 0 && width_offset < im_width) {
paddle::platform::CudaAtomicAdd(im_data + im_offset, paddle::platform::CudaAtomicAdd(im_data + im_offset,
col_data[col_offset]); col_data[col_offset]);
} }
...@@ -350,27 +361,33 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -350,27 +361,33 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
platform::GPUPlace, T> { platform::GPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im, void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int stride_height, const framework::Tensor& col, int dilation_h, int dilation_w,
int stride_width, int padding_up, int padding_down, int stride_height, int stride_width, int padding_up,
int padding_left, int padding_right) { int padding_down, int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0]; int im_channels = im.dims()[0];
int input_height = im.dims()[1]; int im_height = im.dims()[1];
int input_width = im.dims()[2]; int im_width = im.dims()[2];
int filter_height = col.dims()[3]; int filter_height = col.dims()[3];
int filter_width = col.dims()[4]; int filter_width = col.dims()[4];
int output_height = col.dims()[0]; int col_height = col.dims()[0];
int output_width = col.dims()[1]; int col_width = col.dims()[1];
PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) / PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down -
stride_height + (dilation_h * (filter_height - 1) + 1)) /
1 == stride_height +
output_height); 1,
PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) / col_height,
stride_width + "Output_height and padding(padding_up, padding_down) are "
1 == "inconsistent.");
output_width); PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right -
(dilation_w * (filter_width - 1) + 1)) /
stride_width +
1,
col_width,
"col_width and padding(padding_left, padding_right) are "
"inconsistent.");
int block_dim_x = 0; int block_dim_x = 0;
int block_dim_y = 0; int block_dim_y = 0;
...@@ -389,15 +406,14 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -389,15 +406,14 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
} }
int block_dim_z = 1024 / block_dim_x / block_dim_y; int block_dim_z = 1024 / block_dim_x / block_dim_y;
dim3 threads(block_dim_x, block_dim_y, dim3 threads(block_dim_x, block_dim_y, std::min(block_dim_z, im_channels));
std::min(block_dim_z, input_channels)); dim3 grid(col_width, col_height);
dim3 grid(output_width, output_height);
col2imOCF<T><<<grid, threads, 0, col2imOCF<T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context) reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>( .stream()>>>(
im.data<T>(), col.data<T>(), input_channels, input_height, input_width, im.data<T>(), col.data<T>(), im_channels, im_height, im_width,
filter_height, filter_width, stride_height, stride_width, padding_up, filter_height, filter_width, stride_height, stride_width, padding_up,
padding_left, output_height, output_width); padding_left, col_height, col_width);
} }
}; };
......
...@@ -74,17 +74,18 @@ class Im2ColFunctor { ...@@ -74,17 +74,18 @@ class Im2ColFunctor {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col, const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_up, int dilation_h, int dilation_w, int stride_height,
int padding_down, int padding_left, int padding_right); int stride_width, int padding_up, int padding_down,
int padding_left, int padding_right);
}; };
template <ColFormat Format, typename Place, typename T> template <ColFormat Format, typename Place, typename T>
class Col2ImFunctor { class Col2ImFunctor {
public: public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im, void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int stride_height, const framework::Tensor& col, int dilation_h, int dilation_w,
int stride_width, int padding_up, int padding_down, int stride_height, int stride_width, int padding_up,
int padding_left, int padding_right); int padding_down, int padding_left, int padding_right);
}; };
} // namespace math } // namespace math
......
...@@ -47,6 +47,8 @@ void testIm2col() { ...@@ -47,6 +47,8 @@ void testIm2col() {
int filter_size = 2; int filter_size = 2;
int stride = 1; int stride = 1;
int padding = 0; int padding = 0;
int dilation_h = 1;
int dilation_w = 1;
int output_height = (input_height - filter_size + 2 * padding) / stride + 1; int output_height = (input_height - filter_size + 2 * padding) / stride + 1;
int output_width = (input_width - filter_size + 2 * padding) / stride + 1; int output_width = (input_width - filter_size + 2 * padding) / stride + 1;
float* input_ptr = input_tmp.mutable_data<float>( float* input_ptr = input_tmp.mutable_data<float>(
...@@ -85,10 +87,10 @@ void testIm2col() { ...@@ -85,10 +87,10 @@ void testIm2col() {
paddle::operators::math::ColFormat::kOCF, Place, float> paddle::operators::math::ColFormat::kOCF, Place, float>
im2col_ocf; im2col_ocf;
im2col(*context, input, output_cfo, stride, stride, padding, padding, padding, im2col(*context, input, output_cfo, dilation_h, dilation_w, stride, stride,
padding); padding, padding, padding, padding);
im2col_ocf(*context, input, output_ocf, stride, stride, padding, padding, im2col_ocf(*context, input, output_ocf, dilation_h, dilation_w, stride,
padding, padding); stride, padding, padding, padding, padding);
float out_cfo_data[] = {0, 1, 1, 2, 3, 4, 4, 5}; float out_cfo_data[] = {0, 1, 1, 2, 3, 4, 4, 5};
float out_ocf_data[] = {0, 1, 3, 4, 1, 2, 4, 5}; float out_ocf_data[] = {0, 1, 3, 4, 1, 2, 4, 5};
...@@ -131,8 +133,8 @@ void testIm2col() { ...@@ -131,8 +133,8 @@ void testIm2col() {
input.CopyFrom(input_tmp, *place, *context); input.CopyFrom(input_tmp, *place, *context);
} }
col2im(*context, input, output_cfo, stride, stride, padding, padding, padding, col2im(*context, input, output_cfo, dilation_h, dilation_w, stride, stride,
padding); padding, padding, padding, padding);
float* in_ptr; float* in_ptr;
if (paddle::platform::is_cpu_place(*place)) { if (paddle::platform::is_cpu_place(*place)) {
...@@ -153,8 +155,8 @@ void testIm2col() { ...@@ -153,8 +155,8 @@ void testIm2col() {
input.CopyFrom(input_tmp, *place, *context); input.CopyFrom(input_tmp, *place, *context);
} }
col2im_ocf(*context, input, output_ocf, stride, stride, padding, padding, col2im_ocf(*context, input, output_ocf, dilation_h, dilation_w, stride,
padding, padding); stride, padding, padding, padding, padding);
if (paddle::platform::is_cpu_place(*place)) { if (paddle::platform::is_cpu_place(*place)) {
in_ptr = input.data<float>(); in_ptr = input.data<float>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册