提交 97e9dd72 编写于 作者: C chengduoZH

add dilation for im2col

上级 91b72482
......@@ -22,8 +22,6 @@ class CudnnConvOpMaker : public Conv2DOpMaker {
CudnnConvOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: Conv2DOpMaker(proto, op_checker) {
AddAttr<std::vector<int>>("dilations", "dilations of convolution operator.")
.SetDefault(std::vector<int>{1, 1});
AddAttr<int>("workspace_size_MB",
"workspace size for cudnn, in MB, "
"workspace is a section of GPU memory which will be "
......
......@@ -30,6 +30,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
int groups = ctx->Attrs().Get<int>("groups");
std::vector<int> dilations = ctx->Attrs().Get<std::vector<int>>("dilations");
int input_channels = in_dims[1];
int output_channels = filter_dims[0];
......@@ -54,7 +55,8 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
for (size_t i = 0; i < paddings.size(); ++i) {
output_shape.push_back(OutputSize(in_dims[i + 2], filter_dims[i + 2],
paddings[i], strides[i]));
dilations[i], paddings[i], paddings[i],
strides[i]));
}
ctx->SetOutputDim("Output", framework::make_ddim(output_shape));
}
......@@ -90,6 +92,10 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto,
"first half of the input channels, while the second half of the filters "
"is only connected to the second half of the input channels.")
.SetDefault(1);
AddAttr<std::vector<int>>("dilations",
"(vector default:{1, 1}), the dilations of "
"convolution operator.")
.SetDefault(std::vector<int>{1, 1});
AddComment(R"DOC(
Convolution Operator.
......@@ -151,6 +157,11 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto,
"first half of the input channels, while the second half of the filters "
"is only connected to the second half of the input channels.")
.SetDefault(1);
AddAttr<std::vector<int>>("dilations",
"(vector default:{1, 1, 1}), the dilations of "
"convolution operator. Currently, conv3d doesn't "
"support dilation.")
.SetDefault(std::vector<int>{1, 1, 1});
AddComment(R"DOC(
Convolution3D Operator.
......
......@@ -27,9 +27,12 @@ using Tensor = framework::Tensor;
// Base convolution operator definations for other conv
// like operators to reuse the implementation.
inline int OutputSize(int input_size, int filter_size, int padding,
int stride) {
int output_size = (input_size - filter_size + 2 * padding) / stride + 1;
inline int OutputSize(int input_size, int filter_size, int dilation,
int padding_up, int padding_down, int stride) {
int output_size = (input_size + padding_up + padding_down -
(dilation * (filter_size - 1) + 1)) /
stride +
1;
return output_size;
}
......@@ -76,6 +79,7 @@ class GemmConvKernel : public framework::OpKernel<T> {
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
int groups = context.Attr<int>("groups");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
const int batch_size = static_cast<int>(input->dims()[0]);
......@@ -139,9 +143,9 @@ class GemmConvKernel : public framework::OpKernel<T> {
if (filter_shape_vec.size() == 2) {
// im2col
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
im2col(context.device_context(), in_slice, col, strides[0],
strides[1], paddings[0], paddings[0], paddings[1],
paddings[1]);
im2col(context.device_context(), in_slice, col, dilations[0],
dilations[1], strides[0], strides[1], paddings[0], paddings[0],
paddings[1], paddings[1]);
} else if (filter_shape_vec.size() == 3) {
// vol2col
math::Vol2ColFunctor<Place, T> vol2col;
......@@ -181,6 +185,7 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
int groups = context.Attr<int>("groups");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
const int batch_size = static_cast<int>(input->dims()[0]);
......@@ -263,9 +268,9 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
if (filter_shape_vec.size() == 2) {
math::Col2ImFunctor<math::ColFormat::kCFO, Place, T> col2im;
col2im(context.device_context(), in_grad_slice, col, strides[0],
strides[1], paddings[0], paddings[0], paddings[1],
paddings[1]);
col2im(context.device_context(), in_grad_slice, col, dilations[0],
dilations[1], strides[0], strides[1], paddings[0],
paddings[0], paddings[1], paddings[1]);
} else if (filter_shape_vec.size() == 3) {
math::Col2VolFunctor<Place, T> col2vol;
......@@ -295,9 +300,9 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
if (filter_shape_vec.size() == 2) {
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
im2col(context.device_context(), in_slice, col, strides[0],
strides[1], paddings[0], paddings[0], paddings[1],
paddings[1]);
im2col(context.device_context(), in_slice, col, dilations[0],
dilations[1], strides[0], strides[1], paddings[0],
paddings[0], paddings[1], paddings[1]);
} else if (filter_shape_vec.size() == 3) {
math::Vol2ColFunctor<Place, T> vol2col;
vol2col(context.device_context(), in_slice, col, strides[0],
......
......@@ -69,6 +69,9 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
// TODO(Zhuoyuan): Paddings can be added in future.
// groups will alway be disabled in conv2dtranspose.
int dilation_h = 1;
int dilation_w = 1;
const int batch_size = static_cast<int>(input->dims()[0]);
// input_shape_vec: {h, w} or {d, h, w}
......@@ -140,8 +143,8 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
// from (c * k_h * k_w, h * w) to (c, o_h, o_w)
math::Col2ImFunctor<math::ColFormat::kCFO, Place, T> col2im;
col2im(context.device_context(), output_batch, col, strides[0],
strides[1], 0, 0, 0, 0);
col2im(context.device_context(), output_batch, col, dilation_h,
dilation_w, strides[0], strides[1], 0, 0, 0, 0);
} else if (filter_shape_vec.size() == 3) {
// col2vol: col_matrix -> dy
// from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w)
......@@ -174,6 +177,9 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// Actually, no paddings and groups allowed in conv transpose.
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
int dilation_h = 1;
int dilation_w = 1;
const int batch_size = static_cast<int>(input->dims()[0]);
// input_shape_vec: {h, w} or {d, h, w}
......@@ -248,9 +254,9 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// im2col: dy -> col matrix
// from (c, o_h, o_w) to (c * k_h * k_w, h * w)
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
im2col(context.device_context(), output_grad_batch, col, strides[0],
strides[1], paddings[0], paddings[0], paddings[1],
paddings[1]);
im2col(context.device_context(), output_grad_batch, col, dilation_h,
dilation_w, strides[0], strides[1], paddings[0], paddings[0],
paddings[1], paddings[1]);
} else if (filter_shape_vec.size() == 3) {
// vol2col: dy -> col_matrix
// from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w)
......
......@@ -95,6 +95,9 @@ class ContextProjectFunctor {
math::Im2ColFunctor<math::ColFormat::kOCF, Place, float> im2col_ocf;
int dilation_h = 1;
int dilation_w = 1;
int input_row_begin, input_row_end;
int sequence_height, sequence_width;
sequence_width = in.dims()[1];
......@@ -124,7 +127,7 @@ class ContextProjectFunctor {
sequence_width}); // input_channels, input_height, input_width
in_t.Resize(framework::make_ddim(input_shape));
im2col_ocf(context, in_t, out_t,
im2col_ocf(context, in_t, out_t, dilation_h, dilation_w,
/*stride_height*/ context_stride, /*stride_width*/ 1, up_pad,
down_pad, 0, 0);
out_t.Resize({sequence_height, context_length * sequence_width});
......@@ -204,6 +207,9 @@ class ContextProjectGradFunctor {
math::Col2ImFunctor<math::ColFormat::kOCF, Place, float> col2im_ocf;
int dilation_h = 1;
int dilation_w = 1;
int input_row_begin, input_row_end;
int sequence_height, sequence_width;
sequence_width = in.dims()[1];
......@@ -234,7 +240,7 @@ class ContextProjectGradFunctor {
sequence_width}); // input_channels, input_height, input_width
in_t.Resize(framework::make_ddim(input_shape));
col2im_ocf(context, in_t, out_t,
col2im_ocf(context, in_t, out_t, dilation_h, dilation_w,
/*stride_height*/ context_stride, /*stride_width*/ 1,
up_pad, down_pad, 0, 0);
out_t.Resize({sequence_height, context_length * sequence_width});
......
......@@ -29,35 +29,36 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_up,
int padding_down, int padding_left, int padding_right) {
int dilation_h, int dilation_w, int stride_height,
int stride_width, int padding_up, int padding_down,
int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0];
int input_height = im.dims()[1];
int input_width = im.dims()[2];
int im_channels = im.dims()[0];
int im_height = im.dims()[1];
int im_width = im.dims()[2];
int filter_height = col.dims()[1];
int filter_width = col.dims()[2];
int output_height = col.dims()[3];
int output_width = col.dims()[4];
int col_height = col.dims()[3];
int col_width = col.dims()[4];
PADDLE_ENFORCE_EQ(
(input_height + padding_up + padding_down - filter_height) /
stride_height +
1,
output_height,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent.");
PADDLE_ENFORCE_EQ(
(input_width + padding_left + padding_right - filter_width) /
stride_width +
1,
output_width,
"output_width and padding(padding_left, padding_right) are "
"inconsistent.");
PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down -
((dilation_h * (filter_height - 1) + 1))) /
stride_height +
1,
col_height,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent.");
PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right -
((dilation_w * (filter_width - 1) + 1))) /
stride_width +
1,
col_width,
"col_width and padding(padding_left, padding_right) are "
"inconsistent.");
int channels_col = input_channels * filter_height * filter_width;
int channels_col = im_channels * filter_height * filter_width;
const T* im_data = im.data<T>();
T* col_data = col.data<T>();
......@@ -66,19 +67,19 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height;
int c_im = c / filter_width / filter_height;
for (int h = 0; h < output_height; ++h) {
for (int w = 0; w < output_width; ++w) {
int im_row_idx = h * stride_height + h_offset - padding_up;
int im_col_idx = w * stride_width + w_offset - padding_left;
for (int h = 0; h < col_height; ++h) {
for (int w = 0; w < col_width; ++w) {
int im_row_idx =
h * stride_height - padding_up + h_offset * dilation_h;
int im_col_idx =
w * stride_width - padding_left + w_offset * dilation_w;
if (im_row_idx < 0 || im_row_idx >= input_height || im_col_idx < 0 ||
im_col_idx >= input_width) {
col_data[(c * output_height + h) * output_width + w] = T(0);
} else {
im_row_idx += c_im * input_height;
col_data[(c * output_height + h) * output_width + w] =
im_data[im_row_idx * input_width + im_col_idx];
}
col_data[(c * col_height + h) * col_width + w] =
(im_row_idx < 0 || im_row_idx >= im_height || im_col_idx < 0 ||
im_col_idx >= im_width)
? static_cast<T>(0)
: im_data[(im_row_idx + c_im * im_height) * im_width +
im_col_idx];
}
}
}
......@@ -95,35 +96,35 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int stride_height,
int stride_width, int padding_up, int padding_down,
int padding_left, int padding_right) {
const framework::Tensor& col, int dilation_h, int dilation_w,
int stride_height, int stride_width, int padding_up,
int padding_down, int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0];
int input_height = im.dims()[1];
int input_width = im.dims()[2];
int im_channels = im.dims()[0];
int im_height = im.dims()[1];
int im_width = im.dims()[2];
int filter_height = col.dims()[1];
int filter_width = col.dims()[2];
int output_height = col.dims()[3];
int output_width = col.dims()[4];
int col_height = col.dims()[3];
int col_width = col.dims()[4];
PADDLE_ENFORCE_EQ(
(input_height + padding_up + padding_down - filter_height) /
stride_height +
1,
output_height,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent.");
PADDLE_ENFORCE_EQ(
(input_width + padding_left + padding_right - filter_width) /
stride_width +
1,
output_width,
"output_width and padding(padding_left, padding_right) are "
"inconsistent.");
PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down -
((dilation_h * (filter_height - 1) + 1))) /
stride_height +
1,
col_height,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent.");
PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right -
((dilation_w * (filter_width - 1) + 1))) /
stride_width +
1,
col_width,
"col_width and padding(padding_left, padding_right) are "
"inconsistent.");
int channels_col = input_channels * filter_height * filter_width;
int channels_col = im_channels * filter_height * filter_width;
T* im_data = im.data<T>();
const T* col_data = col.data<T>();
......@@ -132,16 +133,18 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height;
int c_im = c / filter_width / filter_height;
for (int h = 0; h < output_height; ++h) {
for (int w = 0; w < output_width; ++w) {
int im_row_idx = h * stride_height + h_offset - padding_up;
int im_col_idx = w * stride_width + w_offset - padding_left;
for (int h = 0; h < col_height; ++h) {
for (int w = 0; w < col_width; ++w) {
int im_row_idx =
h * stride_height - padding_up + h_offset * dilation_h;
int im_col_idx =
w * stride_width - padding_left + w_offset * dilation_w;
if ((im_row_idx) >= 0 && (im_row_idx) < input_height &&
(im_col_idx) >= 0 && (im_col_idx) < input_width) {
im_row_idx += c_im * input_height;
im_data[im_row_idx * input_width + im_col_idx] +=
col_data[(c * output_height + h) * output_width + w];
if ((im_row_idx) >= 0 && (im_row_idx) < im_height &&
(im_col_idx) >= 0 && (im_col_idx) < im_width) {
im_row_idx += c_im * im_height;
im_data[im_row_idx * im_width + im_col_idx] +=
col_data[(c * col_height + h) * col_width + w];
}
}
}
......@@ -169,39 +172,38 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_up,
int padding_down, int padding_left, int padding_right) {
int dilation_h, int dilation_w, int stride_height,
int stride_width, int padding_up, int padding_down,
int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0];
int input_height = im.dims()[1];
int input_width = im.dims()[2];
int im_channels = im.dims()[0];
int im_height = im.dims()[1];
int im_width = im.dims()[2];
int filter_height = col.dims()[3];
int filter_width = col.dims()[4];
int output_height = col.dims()[0];
int output_width = col.dims()[1];
int col_height = col.dims()[0];
int col_width = col.dims()[1];
PADDLE_ENFORCE_EQ(
(input_height + padding_up + padding_down - filter_height) /
stride_height +
1,
output_height,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent.");
PADDLE_ENFORCE_EQ(
(input_width + padding_left + padding_right - filter_width) /
stride_width +
1,
output_width,
"output_width and padding(padding_left, padding_right) are "
"inconsistent.");
PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - filter_height) /
stride_height +
1,
col_height,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent.");
PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - filter_width) /
stride_width +
1,
col_width,
"col_width and padding(padding_left, padding_right) are "
"inconsistent.");
const T* im_data = im.data<T>();
T* col_data = col.data<T>();
for (int col_row_idx = 0; col_row_idx < output_height; ++col_row_idx) {
for (int col_col_idx = 0; col_col_idx < output_width; ++col_col_idx) {
for (int channel = 0; channel < input_channels; ++channel) {
for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) {
for (int col_col_idx = 0; col_col_idx < col_width; ++col_col_idx) {
for (int channel = 0; channel < im_channels; ++channel) {
for (int filter_row_idx = 0; filter_row_idx < filter_height;
++filter_row_idx) {
for (int filter_col_idx = 0; filter_col_idx < filter_width;
......@@ -210,22 +212,21 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
col_row_idx * stride_height + filter_row_idx - padding_up;
int im_col_offset =
col_col_idx * stride_width + filter_col_idx - padding_left;
int col_offset = ((((col_row_idx)*output_width + col_col_idx) *
input_channels +
channel) *
filter_height +
filter_row_idx) *
filter_width +
filter_col_idx;
if (im_row_offset < 0 || im_row_offset >= input_height ||
im_col_offset < 0 || im_col_offset >= input_width) {
col_data[col_offset] = T(0);
} else {
int im_offset =
(channel * input_height + im_row_offset) * input_width +
im_col_offset;
col_data[col_offset] = im_data[im_offset];
}
int col_offset =
((((col_row_idx)*col_width + col_col_idx) * im_channels +
channel) *
filter_height +
filter_row_idx) *
filter_width +
filter_col_idx;
int im_offset = (channel * im_height + im_row_offset) * im_width +
im_col_offset;
col_data[col_offset] =
(im_row_offset < 0 || im_row_offset >= im_height ||
im_col_offset < 0 || im_col_offset >= im_width)
? static_cast<T>(0)
: im_data[im_offset];
}
}
}
......@@ -244,40 +245,38 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int stride_height,
int stride_width, int padding_up, int padding_down,
int padding_left, int padding_right) {
const framework::Tensor& col, int dilation_h, int dilation_w,
int stride_height, int stride_width, int padding_up,
int padding_down, int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0];
int input_height = im.dims()[1];
int input_width = im.dims()[2];
int im_channels = im.dims()[0];
int im_height = im.dims()[1];
int im_width = im.dims()[2];
int filter_height = col.dims()[3];
int filter_width = col.dims()[4];
int output_height = col.dims()[0];
int output_width = col.dims()[1];
int col_height = col.dims()[0];
int col_width = col.dims()[1];
PADDLE_ENFORCE_EQ(
(input_height + padding_up + padding_down - filter_height) /
stride_height +
1,
output_height,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent.");
PADDLE_ENFORCE_EQ(
(input_width + padding_left + padding_right - filter_width) /
stride_width +
1,
output_width,
"output_width and padding(padding_left, padding_right) are "
"inconsistent.");
PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - filter_height) /
stride_height +
1,
col_height,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent.");
PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - filter_width) /
stride_width +
1,
col_width,
"col_width and padding(padding_left, padding_right) are "
"inconsistent.");
T* im_data = im.data<T>();
const T* col_data = col.data<T>();
for (int col_row_idx = 0; col_row_idx < output_height; ++col_row_idx) {
for (int col_col_idx = 0; col_col_idx < output_width; ++col_col_idx) {
for (int channel = 0; channel < input_channels; ++channel) {
for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) {
for (int col_col_idx = 0; col_col_idx < col_width; ++col_col_idx) {
for (int channel = 0; channel < im_channels; ++channel) {
for (int filter_row_idx = 0; filter_row_idx < filter_height;
++filter_row_idx) {
for (int filter_col_idx = 0; filter_col_idx < filter_width;
......@@ -286,17 +285,17 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
col_row_idx * stride_height + filter_row_idx - padding_up;
int im_col_offset =
col_col_idx * stride_width + filter_col_idx - padding_left;
int col_offset = (((col_row_idx * output_width + col_col_idx) *
input_channels +
channel) *
filter_height +
filter_row_idx) *
filter_width +
filter_col_idx;
if (im_row_offset >= 0 && im_row_offset < input_height &&
im_col_offset >= 0 && im_col_offset < input_width) {
int col_offset =
(((col_row_idx * col_width + col_col_idx) * im_channels +
channel) *
filter_height +
filter_row_idx) *
filter_width +
filter_col_idx;
if (im_row_offset >= 0 && im_row_offset < im_height &&
im_col_offset >= 0 && im_col_offset < im_width) {
int im_offset =
(channel * input_height + im_row_offset) * input_width +
(channel * im_height + im_row_offset) * im_width +
im_col_offset;
im_data[im_offset] += col_data[col_offset];
}
......
......@@ -20,36 +20,32 @@ namespace operators {
namespace math {
template <class T>
__global__ void im2col(const T* data_im, int num_outs, int height, int width,
__global__ void im2col(const T* data_im, int num_outs, int im_height,
int im_width, int dilation_h, int dilation_w,
int filter_height, int filter_width, int stride_height,
int stride_width, int padding_height, int padding_width,
int output_height, int output_width, T* data_col) {
int index = (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
int col_height, int col_width, T* data_col) {
const int index =
(blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
if (index < num_outs) {
int w_out = index % output_width;
index /= output_width;
int h_out = index % output_height;
int channel_in = index / output_height;
int w_out = index % col_width;
int h_out = (index / col_width) % col_height;
int channel_in = index / col_width / col_height;
int channel_out = channel_in * filter_height * filter_width;
int h_in = h_out * stride_height;
int w_in = w_out * stride_width;
int h_in = h_out * stride_height - padding_height;
int w_in = w_out * stride_width - padding_width;
data_col += (channel_out * output_height + h_out) * output_width + w_out;
data_col += (channel_out * col_height + h_out) * col_width + w_out;
data_im += (channel_in * im_height + h_in) * im_width + w_in;
for (int i = 0; i < filter_height; ++i) {
for (int j = 0; j < filter_width; ++j) {
int rIdx = int(h_in + i);
int cIdx = int(w_in + j);
if ((rIdx - (int)padding_height) >= (int)height ||
(rIdx - (int)padding_height) < 0 ||
(cIdx - (int)padding_width) >= (int)width ||
(cIdx - (int)padding_width) < 0) {
*data_col = 0;
} else {
rIdx = rIdx + channel_in * height - padding_height;
cIdx = cIdx - padding_width;
*data_col = data_im[rIdx * width + cIdx];
}
data_col += output_height * output_width;
int rIdx = h_in + i * dilation_h;
int cIdx = w_in + j * dilation_w;
*data_col =
(rIdx >= im_height || rIdx < 0 || cIdx >= im_width || cIdx < 0)
? 0
: data_im[i * dilation_h * im_width + j * dilation_w];
data_col += col_height * col_width;
}
}
}
......@@ -66,29 +62,36 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_up,
int padding_down, int padding_left, int padding_right) {
int dilation_h, int dilation_w, int stride_height,
int stride_width, int padding_up, int padding_down,
int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0];
int input_height = im.dims()[1];
int input_width = im.dims()[2];
int im_channels = im.dims()[0];
int im_height = im.dims()[1];
int im_width = im.dims()[2];
int filter_height = col.dims()[1];
int filter_width = col.dims()[2];
int output_height = col.dims()[3];
int output_width = col.dims()[4];
PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) /
stride_height +
1 ==
output_height);
PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) /
stride_width +
1 ==
output_width);
int num_outputs = input_channels * output_height * output_width;
int col_height = col.dims()[3];
int col_width = col.dims()[4];
PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down -
(dilation_h * (filter_height - 1) + 1)) /
stride_height +
1,
col_height,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent.");
PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right -
(dilation_w * (filter_width - 1) + 1)) /
stride_width +
1,
col_width,
"col_width and padding(padding_left, padding_right) are "
"inconsistent.");
int num_outputs = im_channels * col_height * col_width;
int blocks = (num_outputs + 1024 - 1) / 1024;
int block_x = 512;
int block_y = (blocks + 512 - 1) / 512;
......@@ -97,56 +100,56 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
im2col<T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
im.data<T>(), num_outputs, input_height, input_width, filter_height,
filter_width, stride_height, stride_width, padding_up, padding_left,
output_height, output_width, col.data<T>());
im.data<T>(), num_outputs, im_height, im_width, dilation_h, dilation_w,
filter_height, filter_width, stride_height, stride_width, padding_up,
padding_left, col_height, col_width, col.data<T>());
}
};
template <class T>
__global__ void col2im(size_t n, const T* data_col, size_t height, size_t width,
size_t channels, size_t filter_height,
size_t filter_width, size_t stride_height,
size_t stride_width, size_t padding_height,
size_t padding_width, size_t output_height,
size_t output_width, T* data_im) {
size_t index =
__global__ void col2im(int n, const T* data_col, int im_height, int im_width,
int dilation_h, int dilation_w, int filter_height,
int filter_width, int stride_height, int stride_width,
int padding_height, int padding_width, int col_height,
int col_width, T* data_im) {
const int index =
(blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
const int d_filter_height = dilation_h * (filter_height - 1) + 1;
const int d_filter_width = dilation_w * (filter_width - 1) + 1;
if (index < n) {
T val = 0;
int w = int(index % width);
int h = int((index / width) % height);
int c = int(index / (width * height));
if ((w - (int)padding_width) >= 0 &&
(w - (int)padding_width) < (width - 2 * padding_width) &&
(h - (int)padding_height) >= 0 &&
(h - padding_height) < (height - 2 * padding_height)) {
// compute the start and end of the output
int w_col_start = (w < (int)filter_width)
? 0
: (w - int(filter_width)) / (int)stride_width + 1;
int w_col_end =
min((int)(w / (int)stride_width + 1), (int)(output_width));
int h_col_start = (h < (int)filter_height)
? 0
: (h - (int)filter_height) / (int)stride_height + 1;
int h_col_end = min(int(h / stride_height + 1), int(output_height));
for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
// the col location: [c * width * height + h_out, w_out]
int c_col = int(c * filter_height * filter_width) +
(h - h_col * (int)stride_height) * (int)filter_width +
(w - w_col * (int)stride_width);
val +=
data_col[(c_col * output_height + h_col) * output_width + w_col];
int w = index % im_width;
int h = (index / im_width) % im_height;
int c = index / (im_width * im_height);
// compute the start and end of the output
int w_col_start =
(w < d_filter_width) ? 0 : (w - d_filter_width) / stride_width + 1;
int w_col_end = min(w / stride_width + 1, col_width);
int h_col_start =
(h < d_filter_height) ? 0 : (h - d_filter_height) / stride_height + 1;
int h_col_end = min(h / stride_height + 1, col_height);
for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
int h_off = (h - h_col * stride_height);
int w_off = (w - w_col * stride_width);
if (h_off % dilation_h == 0 && w_off % dilation_w == 0) {
h_off /= dilation_h;
w_off /= dilation_w;
int data_col_index =
(((c * filter_height + h_off) * filter_width + w_off) *
col_height +
h_col) *
col_width +
w_col;
val += data_col[data_col_index];
}
}
h -= padding_height;
w -= padding_width;
data_im[c * ((width - 2 * padding_width) *
(height - 2 * padding_height)) +
h * (width - 2 * padding_width) + w] += val;
}
data_im[index] = val;
}
}
......@@ -160,32 +163,36 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
platform::GPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int stride_height,
int stride_width, int padding_up, int padding_down,
int padding_left, int padding_right) {
const framework::Tensor& col, int dilation_h, int dilation_w,
int stride_height, int stride_width, int padding_up,
int padding_down, int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0];
int input_height = im.dims()[1];
int input_width = im.dims()[2];
int im_channels = im.dims()[0];
int im_height = im.dims()[1];
int im_width = im.dims()[2];
int filter_height = col.dims()[1];
int filter_width = col.dims()[2];
int output_height = col.dims()[3];
int output_width = col.dims()[4];
PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) /
stride_height +
1 ==
output_height);
PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) /
stride_width +
1 ==
output_width);
size_t num_kernels = input_channels *
(input_height + padding_up + padding_down) *
(input_width + padding_left + padding_right);
int col_height = col.dims()[3];
int col_width = col.dims()[4];
PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down -
(dilation_h * (filter_height - 1) + 1)) /
stride_height +
1,
col_height,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent.");
PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right -
(dilation_w * (filter_width - 1) + 1)) /
stride_width +
1,
col_width,
"col_width and padding(padding_left, padding_right) are "
"inconsistent.");
size_t num_kernels = im_channels * im_height * im_width;
size_t blocks = (num_kernels + 1024 - 1) / 1024;
size_t block_x = 512;
......@@ -198,10 +205,9 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
col2im<T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
num_kernels, col.data<T>(), input_height + padding_up + padding_down,
input_width + padding_left + padding_left, input_channels,
num_kernels, col.data<T>(), im_height, im_width, dilation_h, dilation_w,
filter_height, filter_width, stride_height, stride_width, padding_up,
padding_left, output_height, output_width, im.data<T>());
padding_left, col_height, col_width, im.data<T>());
}
};
......@@ -215,33 +221,32 @@ template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
platform::GPUPlace, double>;
template <class T>
__global__ void im2colOCF(const T* im_data, T* col_data, int input_channels,
int input_height, int input_width, int filter_height,
__global__ void im2colOCF(const T* im_data, T* col_data, int im_channels,
int im_height, int im_width, int filter_height,
int filter_width, int stride_height, int stride_width,
int padding_height, int padding_width,
int output_height, int output_width) {
int padding_height, int padding_width, int col_height,
int col_width) {
int swid = blockIdx.x;
int shid = blockIdx.y;
for (int channelid = threadIdx.z; channelid < input_channels;
for (int channelid = threadIdx.z; channelid < im_channels;
channelid += blockDim.z) {
for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) {
for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) {
int width_offset = idx + swid * stride_width - padding_width;
int height_offset = idy + shid * stride_height - padding_height;
int im_offset = width_offset + height_offset * input_width +
channelid * input_height * input_width;
int im_offset = width_offset + height_offset * im_width +
channelid * im_height * im_width;
int col_offset = idx + idy * filter_width +
channelid * filter_height * filter_width +
(shid * output_width + swid) *
(input_channels * filter_height * filter_width);
if (height_offset >= input_height || height_offset < 0 ||
width_offset >= input_width || width_offset < 0) {
col_data[col_offset] = T(0);
} else {
col_data[col_offset] = im_data[im_offset];
}
(shid * col_width + swid) *
(im_channels * filter_height * filter_width);
col_data[col_offset] =
(height_offset >= im_height || height_offset < 0 ||
width_offset >= im_width || width_offset < 0)
? T(0)
: im_data[im_offset];
}
}
}
......@@ -258,26 +263,33 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_up,
int padding_down, int padding_left, int padding_right) {
int dilation_h, int dilation_w, int stride_height,
int stride_width, int padding_up, int padding_down,
int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0];
int input_height = im.dims()[1];
int input_width = im.dims()[2];
int im_channels = im.dims()[0];
int im_height = im.dims()[1];
int im_width = im.dims()[2];
int filter_height = col.dims()[3];
int filter_width = col.dims()[4];
int output_height = col.dims()[0];
int output_width = col.dims()[1];
PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) /
stride_height +
1 ==
output_height);
PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) /
stride_width +
1 ==
output_width);
int col_height = col.dims()[0];
int col_width = col.dims()[1];
PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down -
(dilation_h * (filter_height - 1) + 1)) /
stride_height +
1,
col_height,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent.");
PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right -
(dilation_w * (filter_width - 1) + 1)) /
stride_width +
1,
col_width,
"col_width and padding(padding_left, padding_right) are "
"inconsistent.");
int block_dim_x = 0;
int block_dim_y = 0;
......@@ -296,42 +308,41 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
}
int block_dim_z = 1024 / block_dim_x / block_dim_y;
dim3 threads(block_dim_x, block_dim_y,
std::min(block_dim_z, input_channels));
dim3 grid(output_width, output_height);
dim3 threads(block_dim_x, block_dim_y, std::min(block_dim_z, im_channels));
dim3 grid(col_width, col_height);
im2colOCF<T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
im.data<T>(), col.data<T>(), input_channels, input_height, input_width,
im.data<T>(), col.data<T>(), im_channels, im_height, im_width,
filter_height, filter_width, stride_height, stride_width, padding_up,
padding_left, output_height, output_width);
padding_left, col_height, col_width);
}
};
template <class T>
__global__ void col2imOCF(T* im_data, const T* col_data, int input_channels,
int input_height, int input_width, int filter_height,
__global__ void col2imOCF(T* im_data, const T* col_data, int im_channels,
int im_height, int im_width, int filter_height,
int filter_width, int stride_height, int stride_width,
int padding_height, int padding_width,
int output_height, int output_width) {
int padding_height, int padding_width, int col_height,
int col_width) {
int swid = blockIdx.x;
int shid = blockIdx.y;
for (int channelid = threadIdx.z; channelid < input_channels;
for (int channelid = threadIdx.z; channelid < im_channels;
channelid += blockDim.z) {
for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) {
for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) {
int width_offset = idx + swid * stride_width - padding_width;
int height_offset = idy + shid * stride_height - padding_height;
int im_offset = width_offset + height_offset * input_width +
channelid * input_height * input_width;
int im_offset = width_offset + height_offset * im_width +
channelid * im_height * im_width;
int col_offset = idx + idy * filter_width +
channelid * filter_height * filter_width +
(shid * output_width + swid) *
(input_channels * filter_height * filter_width);
(shid * col_width + swid) *
(im_channels * filter_height * filter_width);
if (height_offset >= 0 && height_offset < input_height &&
width_offset >= 0 && width_offset < input_width) {
if (height_offset >= 0 && height_offset < im_height &&
width_offset >= 0 && width_offset < im_width) {
paddle::platform::CudaAtomicAdd(im_data + im_offset,
col_data[col_offset]);
}
......@@ -350,27 +361,33 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
platform::GPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int stride_height,
int stride_width, int padding_up, int padding_down,
int padding_left, int padding_right) {
const framework::Tensor& col, int dilation_h, int dilation_w,
int stride_height, int stride_width, int padding_up,
int padding_down, int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0];
int input_height = im.dims()[1];
int input_width = im.dims()[2];
int im_channels = im.dims()[0];
int im_height = im.dims()[1];
int im_width = im.dims()[2];
int filter_height = col.dims()[3];
int filter_width = col.dims()[4];
int output_height = col.dims()[0];
int output_width = col.dims()[1];
PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) /
stride_height +
1 ==
output_height);
PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) /
stride_width +
1 ==
output_width);
int col_height = col.dims()[0];
int col_width = col.dims()[1];
PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down -
(dilation_h * (filter_height - 1) + 1)) /
stride_height +
1,
col_height,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent.");
PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right -
(dilation_w * (filter_width - 1) + 1)) /
stride_width +
1,
col_width,
"col_width and padding(padding_left, padding_right) are "
"inconsistent.");
int block_dim_x = 0;
int block_dim_y = 0;
......@@ -389,15 +406,14 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
}
int block_dim_z = 1024 / block_dim_x / block_dim_y;
dim3 threads(block_dim_x, block_dim_y,
std::min(block_dim_z, input_channels));
dim3 grid(output_width, output_height);
dim3 threads(block_dim_x, block_dim_y, std::min(block_dim_z, im_channels));
dim3 grid(col_width, col_height);
col2imOCF<T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
im.data<T>(), col.data<T>(), input_channels, input_height, input_width,
im.data<T>(), col.data<T>(), im_channels, im_height, im_width,
filter_height, filter_width, stride_height, stride_width, padding_up,
padding_left, output_height, output_width);
padding_left, col_height, col_width);
}
};
......
......@@ -74,17 +74,18 @@ class Im2ColFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_up,
int padding_down, int padding_left, int padding_right);
int dilation_h, int dilation_w, int stride_height,
int stride_width, int padding_up, int padding_down,
int padding_left, int padding_right);
};
template <ColFormat Format, typename Place, typename T>
class Col2ImFunctor {
public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int stride_height,
int stride_width, int padding_up, int padding_down,
int padding_left, int padding_right);
const framework::Tensor& col, int dilation_h, int dilation_w,
int stride_height, int stride_width, int padding_up,
int padding_down, int padding_left, int padding_right);
};
} // namespace math
......
......@@ -47,6 +47,8 @@ void testIm2col() {
int filter_size = 2;
int stride = 1;
int padding = 0;
int dilation_h = 1;
int dilation_w = 1;
int output_height = (input_height - filter_size + 2 * padding) / stride + 1;
int output_width = (input_width - filter_size + 2 * padding) / stride + 1;
float* input_ptr = input_tmp.mutable_data<float>(
......@@ -85,10 +87,10 @@ void testIm2col() {
paddle::operators::math::ColFormat::kOCF, Place, float>
im2col_ocf;
im2col(*context, input, output_cfo, stride, stride, padding, padding, padding,
padding);
im2col_ocf(*context, input, output_ocf, stride, stride, padding, padding,
padding, padding);
im2col(*context, input, output_cfo, dilation_h, dilation_w, stride, stride,
padding, padding, padding, padding);
im2col_ocf(*context, input, output_ocf, dilation_h, dilation_w, stride,
stride, padding, padding, padding, padding);
float out_cfo_data[] = {0, 1, 1, 2, 3, 4, 4, 5};
float out_ocf_data[] = {0, 1, 3, 4, 1, 2, 4, 5};
......@@ -131,8 +133,8 @@ void testIm2col() {
input.CopyFrom(input_tmp, *place, *context);
}
col2im(*context, input, output_cfo, stride, stride, padding, padding, padding,
padding);
col2im(*context, input, output_cfo, dilation_h, dilation_w, stride, stride,
padding, padding, padding, padding);
float* in_ptr;
if (paddle::platform::is_cpu_place(*place)) {
......@@ -153,8 +155,8 @@ void testIm2col() {
input.CopyFrom(input_tmp, *place, *context);
}
col2im_ocf(*context, input, output_ocf, stride, stride, padding, padding,
padding, padding);
col2im_ocf(*context, input, output_ocf, dilation_h, dilation_w, stride,
stride, padding, padding, padding, padding);
if (paddle::platform::is_cpu_place(*place)) {
in_ptr = input.data<float>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册