提交 97e9dd72 编写于 作者: C chengduoZH

add dilation for im2col

上级 91b72482
...@@ -22,8 +22,6 @@ class CudnnConvOpMaker : public Conv2DOpMaker { ...@@ -22,8 +22,6 @@ class CudnnConvOpMaker : public Conv2DOpMaker {
CudnnConvOpMaker(framework::OpProto* proto, CudnnConvOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker) framework::OpAttrChecker* op_checker)
: Conv2DOpMaker(proto, op_checker) { : Conv2DOpMaker(proto, op_checker) {
AddAttr<std::vector<int>>("dilations", "dilations of convolution operator.")
.SetDefault(std::vector<int>{1, 1});
AddAttr<int>("workspace_size_MB", AddAttr<int>("workspace_size_MB",
"workspace size for cudnn, in MB, " "workspace size for cudnn, in MB, "
"workspace is a section of GPU memory which will be " "workspace is a section of GPU memory which will be "
......
...@@ -30,6 +30,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -30,6 +30,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides"); std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings"); std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
int groups = ctx->Attrs().Get<int>("groups"); int groups = ctx->Attrs().Get<int>("groups");
std::vector<int> dilations = ctx->Attrs().Get<std::vector<int>>("dilations");
int input_channels = in_dims[1]; int input_channels = in_dims[1];
int output_channels = filter_dims[0]; int output_channels = filter_dims[0];
...@@ -54,7 +55,8 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -54,7 +55,8 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]}); std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
for (size_t i = 0; i < paddings.size(); ++i) { for (size_t i = 0; i < paddings.size(); ++i) {
output_shape.push_back(OutputSize(in_dims[i + 2], filter_dims[i + 2], output_shape.push_back(OutputSize(in_dims[i + 2], filter_dims[i + 2],
paddings[i], strides[i])); dilations[i], paddings[i], paddings[i],
strides[i]));
} }
ctx->SetOutputDim("Output", framework::make_ddim(output_shape)); ctx->SetOutputDim("Output", framework::make_ddim(output_shape));
} }
...@@ -90,6 +92,10 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto, ...@@ -90,6 +92,10 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto,
"first half of the input channels, while the second half of the filters " "first half of the input channels, while the second half of the filters "
"is only connected to the second half of the input channels.") "is only connected to the second half of the input channels.")
.SetDefault(1); .SetDefault(1);
AddAttr<std::vector<int>>("dilations",
"(vector default:{1, 1}), the dilations of "
"convolution operator.")
.SetDefault(std::vector<int>{1, 1});
AddComment(R"DOC( AddComment(R"DOC(
Convolution Operator. Convolution Operator.
...@@ -151,6 +157,11 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto, ...@@ -151,6 +157,11 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto,
"first half of the input channels, while the second half of the filters " "first half of the input channels, while the second half of the filters "
"is only connected to the second half of the input channels.") "is only connected to the second half of the input channels.")
.SetDefault(1); .SetDefault(1);
AddAttr<std::vector<int>>("dilations",
"(vector default:{1, 1, 1}), the dilations of "
"convolution operator. Currently, conv3d doesn't "
"support dilation.")
.SetDefault(std::vector<int>{1, 1, 1});
AddComment(R"DOC( AddComment(R"DOC(
Convolution3D Operator. Convolution3D Operator.
......
...@@ -27,9 +27,12 @@ using Tensor = framework::Tensor; ...@@ -27,9 +27,12 @@ using Tensor = framework::Tensor;
// Base convolution operator definations for other conv // Base convolution operator definations for other conv
// like operators to reuse the implementation. // like operators to reuse the implementation.
inline int OutputSize(int input_size, int filter_size, int padding, inline int OutputSize(int input_size, int filter_size, int dilation,
int stride) { int padding_up, int padding_down, int stride) {
int output_size = (input_size - filter_size + 2 * padding) / stride + 1; int output_size = (input_size + padding_up + padding_down -
(dilation * (filter_size - 1) + 1)) /
stride +
1;
return output_size; return output_size;
} }
...@@ -76,6 +79,7 @@ class GemmConvKernel : public framework::OpKernel<T> { ...@@ -76,6 +79,7 @@ class GemmConvKernel : public framework::OpKernel<T> {
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
int groups = context.Attr<int>("groups"); int groups = context.Attr<int>("groups");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
...@@ -139,9 +143,9 @@ class GemmConvKernel : public framework::OpKernel<T> { ...@@ -139,9 +143,9 @@ class GemmConvKernel : public framework::OpKernel<T> {
if (filter_shape_vec.size() == 2) { if (filter_shape_vec.size() == 2) {
// im2col // im2col
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col; math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
im2col(context.device_context(), in_slice, col, strides[0], im2col(context.device_context(), in_slice, col, dilations[0],
strides[1], paddings[0], paddings[0], paddings[1], dilations[1], strides[0], strides[1], paddings[0], paddings[0],
paddings[1]); paddings[1], paddings[1]);
} else if (filter_shape_vec.size() == 3) { } else if (filter_shape_vec.size() == 3) {
// vol2col // vol2col
math::Vol2ColFunctor<Place, T> vol2col; math::Vol2ColFunctor<Place, T> vol2col;
...@@ -181,6 +185,7 @@ class GemmConvGradKernel : public framework::OpKernel<T> { ...@@ -181,6 +185,7 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
int groups = context.Attr<int>("groups"); int groups = context.Attr<int>("groups");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
...@@ -263,9 +268,9 @@ class GemmConvGradKernel : public framework::OpKernel<T> { ...@@ -263,9 +268,9 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
if (filter_shape_vec.size() == 2) { if (filter_shape_vec.size() == 2) {
math::Col2ImFunctor<math::ColFormat::kCFO, Place, T> col2im; math::Col2ImFunctor<math::ColFormat::kCFO, Place, T> col2im;
col2im(context.device_context(), in_grad_slice, col, strides[0], col2im(context.device_context(), in_grad_slice, col, dilations[0],
strides[1], paddings[0], paddings[0], paddings[1], dilations[1], strides[0], strides[1], paddings[0],
paddings[1]); paddings[0], paddings[1], paddings[1]);
} else if (filter_shape_vec.size() == 3) { } else if (filter_shape_vec.size() == 3) {
math::Col2VolFunctor<Place, T> col2vol; math::Col2VolFunctor<Place, T> col2vol;
...@@ -295,9 +300,9 @@ class GemmConvGradKernel : public framework::OpKernel<T> { ...@@ -295,9 +300,9 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
if (filter_shape_vec.size() == 2) { if (filter_shape_vec.size() == 2) {
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col; math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
im2col(context.device_context(), in_slice, col, strides[0], im2col(context.device_context(), in_slice, col, dilations[0],
strides[1], paddings[0], paddings[0], paddings[1], dilations[1], strides[0], strides[1], paddings[0],
paddings[1]); paddings[0], paddings[1], paddings[1]);
} else if (filter_shape_vec.size() == 3) { } else if (filter_shape_vec.size() == 3) {
math::Vol2ColFunctor<Place, T> vol2col; math::Vol2ColFunctor<Place, T> vol2col;
vol2col(context.device_context(), in_slice, col, strides[0], vol2col(context.device_context(), in_slice, col, strides[0],
......
...@@ -69,6 +69,9 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> { ...@@ -69,6 +69,9 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
// TODO(Zhuoyuan): Paddings can be added in future. // TODO(Zhuoyuan): Paddings can be added in future.
// groups will alway be disabled in conv2dtranspose. // groups will alway be disabled in conv2dtranspose.
int dilation_h = 1;
int dilation_w = 1;
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
// input_shape_vec: {h, w} or {d, h, w} // input_shape_vec: {h, w} or {d, h, w}
...@@ -140,8 +143,8 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> { ...@@ -140,8 +143,8 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
// from (c * k_h * k_w, h * w) to (c, o_h, o_w) // from (c * k_h * k_w, h * w) to (c, o_h, o_w)
math::Col2ImFunctor<math::ColFormat::kCFO, Place, T> col2im; math::Col2ImFunctor<math::ColFormat::kCFO, Place, T> col2im;
col2im(context.device_context(), output_batch, col, strides[0], col2im(context.device_context(), output_batch, col, dilation_h,
strides[1], 0, 0, 0, 0); dilation_w, strides[0], strides[1], 0, 0, 0, 0);
} else if (filter_shape_vec.size() == 3) { } else if (filter_shape_vec.size() == 3) {
// col2vol: col_matrix -> dy // col2vol: col_matrix -> dy
// from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w) // from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w)
...@@ -174,6 +177,9 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> { ...@@ -174,6 +177,9 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// Actually, no paddings and groups allowed in conv transpose. // Actually, no paddings and groups allowed in conv transpose.
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
int dilation_h = 1;
int dilation_w = 1;
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
// input_shape_vec: {h, w} or {d, h, w} // input_shape_vec: {h, w} or {d, h, w}
...@@ -248,9 +254,9 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> { ...@@ -248,9 +254,9 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// im2col: dy -> col matrix // im2col: dy -> col matrix
// from (c, o_h, o_w) to (c * k_h * k_w, h * w) // from (c, o_h, o_w) to (c * k_h * k_w, h * w)
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col; math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
im2col(context.device_context(), output_grad_batch, col, strides[0], im2col(context.device_context(), output_grad_batch, col, dilation_h,
strides[1], paddings[0], paddings[0], paddings[1], dilation_w, strides[0], strides[1], paddings[0], paddings[0],
paddings[1]); paddings[1], paddings[1]);
} else if (filter_shape_vec.size() == 3) { } else if (filter_shape_vec.size() == 3) {
// vol2col: dy -> col_matrix // vol2col: dy -> col_matrix
// from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w) // from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w)
......
...@@ -95,6 +95,9 @@ class ContextProjectFunctor { ...@@ -95,6 +95,9 @@ class ContextProjectFunctor {
math::Im2ColFunctor<math::ColFormat::kOCF, Place, float> im2col_ocf; math::Im2ColFunctor<math::ColFormat::kOCF, Place, float> im2col_ocf;
int dilation_h = 1;
int dilation_w = 1;
int input_row_begin, input_row_end; int input_row_begin, input_row_end;
int sequence_height, sequence_width; int sequence_height, sequence_width;
sequence_width = in.dims()[1]; sequence_width = in.dims()[1];
...@@ -124,7 +127,7 @@ class ContextProjectFunctor { ...@@ -124,7 +127,7 @@ class ContextProjectFunctor {
sequence_width}); // input_channels, input_height, input_width sequence_width}); // input_channels, input_height, input_width
in_t.Resize(framework::make_ddim(input_shape)); in_t.Resize(framework::make_ddim(input_shape));
im2col_ocf(context, in_t, out_t, im2col_ocf(context, in_t, out_t, dilation_h, dilation_w,
/*stride_height*/ context_stride, /*stride_width*/ 1, up_pad, /*stride_height*/ context_stride, /*stride_width*/ 1, up_pad,
down_pad, 0, 0); down_pad, 0, 0);
out_t.Resize({sequence_height, context_length * sequence_width}); out_t.Resize({sequence_height, context_length * sequence_width});
...@@ -204,6 +207,9 @@ class ContextProjectGradFunctor { ...@@ -204,6 +207,9 @@ class ContextProjectGradFunctor {
math::Col2ImFunctor<math::ColFormat::kOCF, Place, float> col2im_ocf; math::Col2ImFunctor<math::ColFormat::kOCF, Place, float> col2im_ocf;
int dilation_h = 1;
int dilation_w = 1;
int input_row_begin, input_row_end; int input_row_begin, input_row_end;
int sequence_height, sequence_width; int sequence_height, sequence_width;
sequence_width = in.dims()[1]; sequence_width = in.dims()[1];
...@@ -234,7 +240,7 @@ class ContextProjectGradFunctor { ...@@ -234,7 +240,7 @@ class ContextProjectGradFunctor {
sequence_width}); // input_channels, input_height, input_width sequence_width}); // input_channels, input_height, input_width
in_t.Resize(framework::make_ddim(input_shape)); in_t.Resize(framework::make_ddim(input_shape));
col2im_ocf(context, in_t, out_t, col2im_ocf(context, in_t, out_t, dilation_h, dilation_w,
/*stride_height*/ context_stride, /*stride_width*/ 1, /*stride_height*/ context_stride, /*stride_width*/ 1,
up_pad, down_pad, 0, 0); up_pad, down_pad, 0, 0);
out_t.Resize({sequence_height, context_length * sequence_width}); out_t.Resize({sequence_height, context_length * sequence_width});
......
...@@ -29,35 +29,36 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -29,35 +29,36 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col, const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_up, int dilation_h, int dilation_w, int stride_height,
int padding_down, int padding_left, int padding_right) { int stride_width, int padding_up, int padding_down,
int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0]; int im_channels = im.dims()[0];
int input_height = im.dims()[1]; int im_height = im.dims()[1];
int input_width = im.dims()[2]; int im_width = im.dims()[2];
int filter_height = col.dims()[1]; int filter_height = col.dims()[1];
int filter_width = col.dims()[2]; int filter_width = col.dims()[2];
int output_height = col.dims()[3]; int col_height = col.dims()[3];
int output_width = col.dims()[4]; int col_width = col.dims()[4];
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down -
(input_height + padding_up + padding_down - filter_height) / ((dilation_h * (filter_height - 1) + 1))) /
stride_height + stride_height +
1, 1,
output_height, col_height,
"Output_height and padding(padding_up, padding_down) are " "Output_height and padding(padding_up, padding_down) are "
"inconsistent."); "inconsistent.");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right -
(input_width + padding_left + padding_right - filter_width) / ((dilation_w * (filter_width - 1) + 1))) /
stride_width + stride_width +
1, 1,
output_width, col_width,
"output_width and padding(padding_left, padding_right) are " "col_width and padding(padding_left, padding_right) are "
"inconsistent."); "inconsistent.");
int channels_col = input_channels * filter_height * filter_width; int channels_col = im_channels * filter_height * filter_width;
const T* im_data = im.data<T>(); const T* im_data = im.data<T>();
T* col_data = col.data<T>(); T* col_data = col.data<T>();
...@@ -66,19 +67,19 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -66,19 +67,19 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
int w_offset = c % filter_width; int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height; int h_offset = (c / filter_width) % filter_height;
int c_im = c / filter_width / filter_height; int c_im = c / filter_width / filter_height;
for (int h = 0; h < output_height; ++h) { for (int h = 0; h < col_height; ++h) {
for (int w = 0; w < output_width; ++w) { for (int w = 0; w < col_width; ++w) {
int im_row_idx = h * stride_height + h_offset - padding_up; int im_row_idx =
int im_col_idx = w * stride_width + w_offset - padding_left; h * stride_height - padding_up + h_offset * dilation_h;
int im_col_idx =
w * stride_width - padding_left + w_offset * dilation_w;
if (im_row_idx < 0 || im_row_idx >= input_height || im_col_idx < 0 || col_data[(c * col_height + h) * col_width + w] =
im_col_idx >= input_width) { (im_row_idx < 0 || im_row_idx >= im_height || im_col_idx < 0 ||
col_data[(c * output_height + h) * output_width + w] = T(0); im_col_idx >= im_width)
} else { ? static_cast<T>(0)
im_row_idx += c_im * input_height; : im_data[(im_row_idx + c_im * im_height) * im_width +
col_data[(c * output_height + h) * output_width + w] = im_col_idx];
im_data[im_row_idx * input_width + im_col_idx];
}
} }
} }
} }
...@@ -95,35 +96,35 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -95,35 +96,35 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
platform::CPUPlace, T> { platform::CPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im, void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int stride_height, const framework::Tensor& col, int dilation_h, int dilation_w,
int stride_width, int padding_up, int padding_down, int stride_height, int stride_width, int padding_up,
int padding_left, int padding_right) { int padding_down, int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0]; int im_channels = im.dims()[0];
int input_height = im.dims()[1]; int im_height = im.dims()[1];
int input_width = im.dims()[2]; int im_width = im.dims()[2];
int filter_height = col.dims()[1]; int filter_height = col.dims()[1];
int filter_width = col.dims()[2]; int filter_width = col.dims()[2];
int output_height = col.dims()[3]; int col_height = col.dims()[3];
int output_width = col.dims()[4]; int col_width = col.dims()[4];
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down -
(input_height + padding_up + padding_down - filter_height) / ((dilation_h * (filter_height - 1) + 1))) /
stride_height + stride_height +
1, 1,
output_height, col_height,
"Output_height and padding(padding_up, padding_down) are " "Output_height and padding(padding_up, padding_down) are "
"inconsistent."); "inconsistent.");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right -
(input_width + padding_left + padding_right - filter_width) / ((dilation_w * (filter_width - 1) + 1))) /
stride_width + stride_width +
1, 1,
output_width, col_width,
"output_width and padding(padding_left, padding_right) are " "col_width and padding(padding_left, padding_right) are "
"inconsistent."); "inconsistent.");
int channels_col = input_channels * filter_height * filter_width; int channels_col = im_channels * filter_height * filter_width;
T* im_data = im.data<T>(); T* im_data = im.data<T>();
const T* col_data = col.data<T>(); const T* col_data = col.data<T>();
...@@ -132,16 +133,18 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -132,16 +133,18 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
int w_offset = c % filter_width; int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height; int h_offset = (c / filter_width) % filter_height;
int c_im = c / filter_width / filter_height; int c_im = c / filter_width / filter_height;
for (int h = 0; h < output_height; ++h) { for (int h = 0; h < col_height; ++h) {
for (int w = 0; w < output_width; ++w) { for (int w = 0; w < col_width; ++w) {
int im_row_idx = h * stride_height + h_offset - padding_up; int im_row_idx =
int im_col_idx = w * stride_width + w_offset - padding_left; h * stride_height - padding_up + h_offset * dilation_h;
int im_col_idx =
w * stride_width - padding_left + w_offset * dilation_w;
if ((im_row_idx) >= 0 && (im_row_idx) < input_height && if ((im_row_idx) >= 0 && (im_row_idx) < im_height &&
(im_col_idx) >= 0 && (im_col_idx) < input_width) { (im_col_idx) >= 0 && (im_col_idx) < im_width) {
im_row_idx += c_im * input_height; im_row_idx += c_im * im_height;
im_data[im_row_idx * input_width + im_col_idx] += im_data[im_row_idx * im_width + im_col_idx] +=
col_data[(c * output_height + h) * output_width + w]; col_data[(c * col_height + h) * col_width + w];
} }
} }
} }
...@@ -169,39 +172,38 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -169,39 +172,38 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col, const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_up, int dilation_h, int dilation_w, int stride_height,
int padding_down, int padding_left, int padding_right) { int stride_width, int padding_up, int padding_down,
int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0]; int im_channels = im.dims()[0];
int input_height = im.dims()[1]; int im_height = im.dims()[1];
int input_width = im.dims()[2]; int im_width = im.dims()[2];
int filter_height = col.dims()[3]; int filter_height = col.dims()[3];
int filter_width = col.dims()[4]; int filter_width = col.dims()[4];
int output_height = col.dims()[0]; int col_height = col.dims()[0];
int output_width = col.dims()[1]; int col_width = col.dims()[1];
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - filter_height) /
(input_height + padding_up + padding_down - filter_height) / stride_height +
stride_height + 1,
1, col_height,
output_height, "Output_height and padding(padding_up, padding_down) are "
"Output_height and padding(padding_up, padding_down) are " "inconsistent.");
"inconsistent."); PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - filter_width) /
PADDLE_ENFORCE_EQ( stride_width +
(input_width + padding_left + padding_right - filter_width) / 1,
stride_width + col_width,
1, "col_width and padding(padding_left, padding_right) are "
output_width, "inconsistent.");
"output_width and padding(padding_left, padding_right) are "
"inconsistent.");
const T* im_data = im.data<T>(); const T* im_data = im.data<T>();
T* col_data = col.data<T>(); T* col_data = col.data<T>();
for (int col_row_idx = 0; col_row_idx < output_height; ++col_row_idx) { for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) {
for (int col_col_idx = 0; col_col_idx < output_width; ++col_col_idx) { for (int col_col_idx = 0; col_col_idx < col_width; ++col_col_idx) {
for (int channel = 0; channel < input_channels; ++channel) { for (int channel = 0; channel < im_channels; ++channel) {
for (int filter_row_idx = 0; filter_row_idx < filter_height; for (int filter_row_idx = 0; filter_row_idx < filter_height;
++filter_row_idx) { ++filter_row_idx) {
for (int filter_col_idx = 0; filter_col_idx < filter_width; for (int filter_col_idx = 0; filter_col_idx < filter_width;
...@@ -210,22 +212,21 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -210,22 +212,21 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
col_row_idx * stride_height + filter_row_idx - padding_up; col_row_idx * stride_height + filter_row_idx - padding_up;
int im_col_offset = int im_col_offset =
col_col_idx * stride_width + filter_col_idx - padding_left; col_col_idx * stride_width + filter_col_idx - padding_left;
int col_offset = ((((col_row_idx)*output_width + col_col_idx) * int col_offset =
input_channels + ((((col_row_idx)*col_width + col_col_idx) * im_channels +
channel) * channel) *
filter_height + filter_height +
filter_row_idx) * filter_row_idx) *
filter_width + filter_width +
filter_col_idx; filter_col_idx;
if (im_row_offset < 0 || im_row_offset >= input_height ||
im_col_offset < 0 || im_col_offset >= input_width) { int im_offset = (channel * im_height + im_row_offset) * im_width +
col_data[col_offset] = T(0); im_col_offset;
} else { col_data[col_offset] =
int im_offset = (im_row_offset < 0 || im_row_offset >= im_height ||
(channel * input_height + im_row_offset) * input_width + im_col_offset < 0 || im_col_offset >= im_width)
im_col_offset; ? static_cast<T>(0)
col_data[col_offset] = im_data[im_offset]; : im_data[im_offset];
}
} }
} }
} }
...@@ -244,40 +245,38 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -244,40 +245,38 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
platform::CPUPlace, T> { platform::CPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im, void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int stride_height, const framework::Tensor& col, int dilation_h, int dilation_w,
int stride_width, int padding_up, int padding_down, int stride_height, int stride_width, int padding_up,
int padding_left, int padding_right) { int padding_down, int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0]; int im_channels = im.dims()[0];
int input_height = im.dims()[1]; int im_height = im.dims()[1];
int input_width = im.dims()[2]; int im_width = im.dims()[2];
int filter_height = col.dims()[3]; int filter_height = col.dims()[3];
int filter_width = col.dims()[4]; int filter_width = col.dims()[4];
int output_height = col.dims()[0]; int col_height = col.dims()[0];
int output_width = col.dims()[1]; int col_width = col.dims()[1];
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ((im_height + padding_up + padding_down - filter_height) /
(input_height + padding_up + padding_down - filter_height) / stride_height +
stride_height + 1,
1, col_height,
output_height, "Output_height and padding(padding_up, padding_down) are "
"Output_height and padding(padding_up, padding_down) are " "inconsistent.");
"inconsistent."); PADDLE_ENFORCE_EQ((im_width + padding_left + padding_right - filter_width) /
PADDLE_ENFORCE_EQ( stride_width +
(input_width + padding_left + padding_right - filter_width) / 1,
stride_width + col_width,
1, "col_width and padding(padding_left, padding_right) are "
output_width, "inconsistent.");
"output_width and padding(padding_left, padding_right) are "
"inconsistent.");
T* im_data = im.data<T>(); T* im_data = im.data<T>();
const T* col_data = col.data<T>(); const T* col_data = col.data<T>();
for (int col_row_idx = 0; col_row_idx < output_height; ++col_row_idx) { for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) {
for (int col_col_idx = 0; col_col_idx < output_width; ++col_col_idx) { for (int col_col_idx = 0; col_col_idx < col_width; ++col_col_idx) {
for (int channel = 0; channel < input_channels; ++channel) { for (int channel = 0; channel < im_channels; ++channel) {
for (int filter_row_idx = 0; filter_row_idx < filter_height; for (int filter_row_idx = 0; filter_row_idx < filter_height;
++filter_row_idx) { ++filter_row_idx) {
for (int filter_col_idx = 0; filter_col_idx < filter_width; for (int filter_col_idx = 0; filter_col_idx < filter_width;
...@@ -286,17 +285,17 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -286,17 +285,17 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
col_row_idx * stride_height + filter_row_idx - padding_up; col_row_idx * stride_height + filter_row_idx - padding_up;
int im_col_offset = int im_col_offset =
col_col_idx * stride_width + filter_col_idx - padding_left; col_col_idx * stride_width + filter_col_idx - padding_left;
int col_offset = (((col_row_idx * output_width + col_col_idx) * int col_offset =
input_channels + (((col_row_idx * col_width + col_col_idx) * im_channels +
channel) * channel) *
filter_height + filter_height +
filter_row_idx) * filter_row_idx) *
filter_width + filter_width +
filter_col_idx; filter_col_idx;
if (im_row_offset >= 0 && im_row_offset < input_height && if (im_row_offset >= 0 && im_row_offset < im_height &&
im_col_offset >= 0 && im_col_offset < input_width) { im_col_offset >= 0 && im_col_offset < im_width) {
int im_offset = int im_offset =
(channel * input_height + im_row_offset) * input_width + (channel * im_height + im_row_offset) * im_width +
im_col_offset; im_col_offset;
im_data[im_offset] += col_data[col_offset]; im_data[im_offset] += col_data[col_offset];
} }
......
此差异已折叠。
...@@ -74,17 +74,18 @@ class Im2ColFunctor { ...@@ -74,17 +74,18 @@ class Im2ColFunctor {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col, const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_up, int dilation_h, int dilation_w, int stride_height,
int padding_down, int padding_left, int padding_right); int stride_width, int padding_up, int padding_down,
int padding_left, int padding_right);
}; };
template <ColFormat Format, typename Place, typename T> template <ColFormat Format, typename Place, typename T>
class Col2ImFunctor { class Col2ImFunctor {
public: public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im, void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int stride_height, const framework::Tensor& col, int dilation_h, int dilation_w,
int stride_width, int padding_up, int padding_down, int stride_height, int stride_width, int padding_up,
int padding_left, int padding_right); int padding_down, int padding_left, int padding_right);
}; };
} // namespace math } // namespace math
......
...@@ -47,6 +47,8 @@ void testIm2col() { ...@@ -47,6 +47,8 @@ void testIm2col() {
int filter_size = 2; int filter_size = 2;
int stride = 1; int stride = 1;
int padding = 0; int padding = 0;
int dilation_h = 1;
int dilation_w = 1;
int output_height = (input_height - filter_size + 2 * padding) / stride + 1; int output_height = (input_height - filter_size + 2 * padding) / stride + 1;
int output_width = (input_width - filter_size + 2 * padding) / stride + 1; int output_width = (input_width - filter_size + 2 * padding) / stride + 1;
float* input_ptr = input_tmp.mutable_data<float>( float* input_ptr = input_tmp.mutable_data<float>(
...@@ -85,10 +87,10 @@ void testIm2col() { ...@@ -85,10 +87,10 @@ void testIm2col() {
paddle::operators::math::ColFormat::kOCF, Place, float> paddle::operators::math::ColFormat::kOCF, Place, float>
im2col_ocf; im2col_ocf;
im2col(*context, input, output_cfo, stride, stride, padding, padding, padding, im2col(*context, input, output_cfo, dilation_h, dilation_w, stride, stride,
padding); padding, padding, padding, padding);
im2col_ocf(*context, input, output_ocf, stride, stride, padding, padding, im2col_ocf(*context, input, output_ocf, dilation_h, dilation_w, stride,
padding, padding); stride, padding, padding, padding, padding);
float out_cfo_data[] = {0, 1, 1, 2, 3, 4, 4, 5}; float out_cfo_data[] = {0, 1, 1, 2, 3, 4, 4, 5};
float out_ocf_data[] = {0, 1, 3, 4, 1, 2, 4, 5}; float out_ocf_data[] = {0, 1, 3, 4, 1, 2, 4, 5};
...@@ -131,8 +133,8 @@ void testIm2col() { ...@@ -131,8 +133,8 @@ void testIm2col() {
input.CopyFrom(input_tmp, *place, *context); input.CopyFrom(input_tmp, *place, *context);
} }
col2im(*context, input, output_cfo, stride, stride, padding, padding, padding, col2im(*context, input, output_cfo, dilation_h, dilation_w, stride, stride,
padding); padding, padding, padding, padding);
float* in_ptr; float* in_ptr;
if (paddle::platform::is_cpu_place(*place)) { if (paddle::platform::is_cpu_place(*place)) {
...@@ -153,8 +155,8 @@ void testIm2col() { ...@@ -153,8 +155,8 @@ void testIm2col() {
input.CopyFrom(input_tmp, *place, *context); input.CopyFrom(input_tmp, *place, *context);
} }
col2im_ocf(*context, input, output_ocf, stride, stride, padding, padding, col2im_ocf(*context, input, output_ocf, dilation_h, dilation_w, stride,
padding, padding); stride, padding, padding, padding, padding);
if (paddle::platform::is_cpu_place(*place)) { if (paddle::platform::is_cpu_place(*place)) {
in_ptr = input.data<float>(); in_ptr = input.data<float>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册