提交 dc7d0735 编写于 作者: C chengduoZH

add padding up, down, left, right

上级 d2c1408f
...@@ -116,7 +116,7 @@ class GemmConv2DKernel : public framework::OpKernel<T> { ...@@ -116,7 +116,7 @@ class GemmConv2DKernel : public framework::OpKernel<T> {
// im2col // im2col
Tensor in_slice = in_batch.Slice<T>(g * in_step, (g + 1) * in_step); Tensor in_slice = in_batch.Slice<T>(g * in_step, (g + 1) * in_step);
im2col(context.device_context(), in_slice, col, strides[0], strides[1], im2col(context.device_context(), in_slice, col, strides[0], strides[1],
paddings[0], paddings[1]); paddings[0], paddings[0], paddings[1], paddings[1]);
// gemm // gemm
Tensor out_slice = out_batch.Slice<T>(g * out_step, (g + 1) * out_step); Tensor out_slice = out_batch.Slice<T>(g * out_step, (g + 1) * out_step);
...@@ -217,7 +217,8 @@ class GemmConvGrad2DKernel : public framework::OpKernel<T> { ...@@ -217,7 +217,8 @@ class GemmConvGrad2DKernel : public framework::OpKernel<T> {
Tensor in_grad_slice = Tensor in_grad_slice =
in_grad_batch.Slice<T>(g * in_step, (g + 1) * in_step); in_grad_batch.Slice<T>(g * in_step, (g + 1) * in_step);
col2im(context.device_context(), in_grad_slice, col, strides[0], col2im(context.device_context(), in_grad_slice, col, strides[0],
strides[1], paddings[0], paddings[1]); strides[1], paddings[0], paddings[0], paddings[1],
paddings[1]);
} }
} }
} }
...@@ -239,7 +240,8 @@ class GemmConvGrad2DKernel : public framework::OpKernel<T> { ...@@ -239,7 +240,8 @@ class GemmConvGrad2DKernel : public framework::OpKernel<T> {
out_grad_batch.Slice<T>(g * out_step, (g + 1) * out_step); out_grad_batch.Slice<T>(g * out_step, (g + 1) * out_step);
Tensor in_slice = in_batch.Slice<T>(g * in_step, (g + 1) * in_step); Tensor in_slice = in_batch.Slice<T>(g * in_step, (g + 1) * in_step);
im2col(context.device_context(), in_slice, col, strides[0], im2col(context.device_context(), in_slice, col, strides[0],
strides[1], paddings[0], paddings[1]); strides[1], paddings[0], paddings[0], paddings[1],
paddings[1]);
// gemm // gemm
Tensor filter_grad_slice = Tensor filter_grad_slice =
......
...@@ -29,8 +29,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -29,8 +29,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col, const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height, int stride_height, int stride_width, int padding_up,
int padding_width) { int padding_down, int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
...@@ -41,6 +41,16 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -41,6 +41,16 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
int filter_width = col.dims()[2]; int filter_width = col.dims()[2];
int output_height = col.dims()[3]; int output_height = col.dims()[3];
int output_width = col.dims()[4]; int output_width = col.dims()[4];
PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) /
stride_height +
1 ==
output_height);
PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) /
stride_width +
1 ==
output_width);
int channels_col = input_channels * filter_height * filter_width; int channels_col = input_channels * filter_height * filter_width;
const T* im_data = im.data<T>(); const T* im_data = im.data<T>();
...@@ -54,14 +64,14 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -54,14 +64,14 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
for (int w = 0; w < output_width; ++w) { for (int w = 0; w < output_width; ++w) {
int im_row_idx = h * stride_height + h_offset; int im_row_idx = h * stride_height + h_offset;
int im_col_idx = w * stride_width + w_offset; int im_col_idx = w * stride_width + w_offset;
if ((im_row_idx - padding_height) < 0 || if ((im_row_idx - padding_up) < 0 ||
(im_row_idx - padding_height) >= input_height || (im_row_idx - padding_up) >= input_height ||
(im_col_idx - padding_width) < 0 || (im_col_idx - padding_left) < 0 ||
(im_col_idx - padding_width) >= input_width) { (im_col_idx - padding_left) >= input_width) {
col_data[(c * output_height + h) * output_width + w] = T(0); col_data[(c * output_height + h) * output_width + w] = T(0);
} else { } else {
im_row_idx += c_im * input_height - padding_height; im_row_idx += c_im * input_height - padding_up;
im_col_idx -= padding_width; im_col_idx -= padding_left;
col_data[(c * output_height + h) * output_width + w] = col_data[(c * output_height + h) * output_width + w] =
im_data[im_row_idx * input_width + im_col_idx]; im_data[im_row_idx * input_width + im_col_idx];
} }
...@@ -82,7 +92,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -82,7 +92,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
public: public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im, void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int stride_height, const framework::Tensor& col, int stride_height,
int stride_width, int padding_height, int padding_width) { int stride_width, int padding_up, int padding_down,
int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0]; int input_channels = im.dims()[0];
...@@ -92,6 +103,16 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -92,6 +103,16 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
int filter_width = col.dims()[2]; int filter_width = col.dims()[2];
int output_height = col.dims()[3]; int output_height = col.dims()[3];
int output_width = col.dims()[4]; int output_width = col.dims()[4];
PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) /
stride_height +
1 ==
output_height);
PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) /
stride_width +
1 ==
output_width);
int channels_col = input_channels * filter_height * filter_width; int channels_col = input_channels * filter_height * filter_width;
T* im_data = im.data<T>(); T* im_data = im.data<T>();
...@@ -105,12 +126,12 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -105,12 +126,12 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
for (int w = 0; w < output_width; ++w) { for (int w = 0; w < output_width; ++w) {
int im_row_idx = h * stride_height + h_offset; int im_row_idx = h * stride_height + h_offset;
int im_col_idx = w * stride_width + w_offset; int im_col_idx = w * stride_width + w_offset;
if ((im_row_idx - padding_height) >= 0 && if ((im_row_idx - padding_up) >= 0 &&
(im_row_idx - padding_height) < input_height && (im_row_idx - padding_up) < input_height &&
(im_col_idx - padding_width) >= 0 && (im_col_idx - padding_left) >= 0 &&
(im_col_idx - padding_width) < input_width) { (im_col_idx - padding_left) < input_width) {
im_row_idx += c_im * input_height - padding_height; im_row_idx += c_im * input_height - padding_up;
im_col_idx -= padding_width; im_col_idx -= padding_left;
im_data[im_row_idx * input_width + im_col_idx] += im_data[im_row_idx * input_width + im_col_idx] +=
col_data[(c * output_height + h) * output_width + w]; col_data[(c * output_height + h) * output_width + w];
} }
...@@ -140,8 +161,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -140,8 +161,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col, const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int up_pad, int stride_height, int stride_width, int padding_up,
int down_pad) { int padding_down, int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0]; int input_channels = im.dims()[0];
...@@ -149,25 +170,22 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -149,25 +170,22 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
int input_width = im.dims()[2]; int input_width = im.dims()[2];
int filter_height = col.dims()[3]; int filter_height = col.dims()[3];
int filter_width = col.dims()[4]; int filter_width = col.dims()[4];
// int output_height = col.dims()[0]; int output_height = col.dims()[0];
int output_width = col.dims()[1]; int output_width = col.dims()[1];
int row_begin, row_end; PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) /
int padding_height = std::max(up_pad, down_pad); stride_height +
int padding_width = 0; 1 ==
if (up_pad >= down_pad) { output_height);
row_begin = 0; PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) /
} else { stride_width +
row_begin = down_pad - up_pad; 1 ==
} output_width);
row_end = row_begin + ((input_height + up_pad + down_pad - filter_height) /
stride_height +
1);
const T* im_data = im.data<T>(); const T* im_data = im.data<T>();
T* col_data = col.data<T>(); T* col_data = col.data<T>();
for (int col_row_idx = row_begin; col_row_idx < row_end; ++col_row_idx) { for (int col_row_idx = 0; col_row_idx < output_height; ++col_row_idx) {
for (int col_col_idx = 0; col_col_idx < output_width; ++col_col_idx) { for (int col_col_idx = 0; col_col_idx < output_width; ++col_col_idx) {
for (int channel = 0; channel < input_channels; ++channel) { for (int channel = 0; channel < input_channels; ++channel) {
for (int filter_row_idx = 0; filter_row_idx < filter_height; for (int filter_row_idx = 0; filter_row_idx < filter_height;
...@@ -175,17 +193,16 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -175,17 +193,16 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
for (int filter_col_idx = 0; filter_col_idx < filter_width; for (int filter_col_idx = 0; filter_col_idx < filter_width;
++filter_col_idx) { ++filter_col_idx) {
int im_row_offset = int im_row_offset =
col_row_idx * stride_height + filter_row_idx - padding_height; col_row_idx * stride_height + filter_row_idx - padding_up;
int im_col_offset = int im_col_offset =
col_col_idx * stride_width + filter_col_idx - padding_width; col_col_idx * stride_width + filter_col_idx - padding_left;
int col_offset = int col_offset = ((((col_row_idx)*output_width + col_col_idx) *
((((col_row_idx - row_begin) * output_width + col_col_idx) * input_channels +
input_channels + channel) *
channel) * filter_height +
filter_height + filter_row_idx) *
filter_row_idx) * filter_width +
filter_width + filter_col_idx;
filter_col_idx;
if (im_row_offset < 0 || im_row_offset >= input_height || if (im_row_offset < 0 || im_row_offset >= input_height ||
im_col_offset < 0 || im_col_offset >= input_width) { im_col_offset < 0 || im_col_offset >= input_width) {
col_data[col_offset] = T(0); col_data[col_offset] = T(0);
...@@ -214,7 +231,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -214,7 +231,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
public: public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im, void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int stride_height, const framework::Tensor& col, int stride_height,
int stride_width, int up_pad, int down_pad) { int stride_width, int padding_up, int padding_down,
int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0]; int input_channels = im.dims()[0];
...@@ -222,25 +240,22 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -222,25 +240,22 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
int input_width = im.dims()[2]; int input_width = im.dims()[2];
int filter_height = col.dims()[3]; int filter_height = col.dims()[3];
int filter_width = col.dims()[4]; int filter_width = col.dims()[4];
// int output_height = col.dims()[0]; int output_height = col.dims()[0];
int output_width = col.dims()[1]; int output_width = col.dims()[1];
int row_begin, row_end; PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) /
int padding_height = std::max(up_pad, down_pad); stride_height +
int padding_width = 0; 1 ==
if (up_pad >= down_pad) { output_height);
row_begin = 0; PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) /
} else { stride_width +
row_begin = down_pad - up_pad; 1 ==
} output_width);
row_end = row_begin + ((input_height + up_pad + down_pad - filter_height) /
stride_height +
1);
T* im_data = im.data<T>(); T* im_data = im.data<T>();
const T* col_data = col.data<T>(); const T* col_data = col.data<T>();
for (int col_row_idx = row_begin; col_row_idx < row_end; ++col_row_idx) { for (int col_row_idx = 0; col_row_idx < output_height; ++col_row_idx) {
for (int col_col_idx = 0; col_col_idx < output_width; ++col_col_idx) { for (int col_col_idx = 0; col_col_idx < output_width; ++col_col_idx) {
for (int channel = 0; channel < input_channels; ++channel) { for (int channel = 0; channel < input_channels; ++channel) {
for (int filter_row_idx = 0; filter_row_idx < filter_height; for (int filter_row_idx = 0; filter_row_idx < filter_height;
...@@ -248,17 +263,16 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -248,17 +263,16 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
for (int filter_col_idx = 0; filter_col_idx < filter_width; for (int filter_col_idx = 0; filter_col_idx < filter_width;
++filter_col_idx) { ++filter_col_idx) {
int im_row_offset = // change or not ??? int im_row_offset = // change or not ???
col_row_idx * stride_height + filter_row_idx - padding_height; col_row_idx * stride_height + filter_row_idx - padding_up;
int im_col_offset = int im_col_offset =
col_col_idx * stride_width + filter_col_idx - padding_width; col_col_idx * stride_width + filter_col_idx - padding_left;
int col_offset = int col_offset = (((col_row_idx * output_width + col_col_idx) *
((((col_row_idx - row_begin) * output_width + col_col_idx) * input_channels +
input_channels + channel) *
channel) * filter_height +
filter_height + filter_row_idx) *
filter_row_idx) * filter_width +
filter_width + filter_col_idx;
filter_col_idx;
if (im_row_offset >= 0 && im_row_offset < input_height && if (im_row_offset >= 0 && im_row_offset < input_height &&
im_col_offset >= 0 && im_col_offset < input_width) { im_col_offset >= 0 && im_col_offset < input_width) {
int im_offset = int im_offset =
......
...@@ -66,8 +66,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -66,8 +66,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col, const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height, int stride_height, int stride_width, int padding_up,
int padding_width) { int padding_down, int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
...@@ -79,6 +79,15 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -79,6 +79,15 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
int output_height = col.dims()[3]; int output_height = col.dims()[3];
int output_width = col.dims()[4]; int output_width = col.dims()[4];
PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) /
stride_height +
1 ==
output_height);
PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) /
stride_width +
1 ==
output_width);
int num_outputs = input_channels * output_height * output_width; int num_outputs = input_channels * output_height * output_width;
int blocks = (num_outputs + 1024 - 1) / 1024; int blocks = (num_outputs + 1024 - 1) / 1024;
int block_x = 512; int block_x = 512;
...@@ -89,8 +98,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -89,8 +98,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
reinterpret_cast<const platform::CUDADeviceContext&>(context) reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>( .stream()>>>(
im.data<T>(), num_outputs, input_height, input_width, filter_height, im.data<T>(), num_outputs, input_height, input_width, filter_height,
filter_width, stride_height, stride_width, padding_height, filter_width, stride_height, stride_width, padding_up, padding_left,
padding_width, output_height, output_width, col.data<T>()); output_height, output_width, col.data<T>());
} }
}; };
...@@ -152,7 +161,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -152,7 +161,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
public: public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im, void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int stride_height, const framework::Tensor& col, int stride_height,
int stride_width, int padding_height, int padding_width) { int stride_width, int padding_up, int padding_down,
int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
...@@ -164,8 +174,18 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -164,8 +174,18 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
int output_height = col.dims()[3]; int output_height = col.dims()[3];
int output_width = col.dims()[4]; int output_width = col.dims()[4];
size_t num_kernels = input_channels * (input_height + 2 * padding_height) * PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) /
(input_width + 2 * padding_width); stride_height +
1 ==
output_height);
PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) /
stride_width +
1 ==
output_width);
size_t num_kernels = input_channels *
(input_height + padding_up + padding_down) *
(input_width + padding_left + padding_right);
size_t blocks = (num_kernels + 1024 - 1) / 1024; size_t blocks = (num_kernels + 1024 - 1) / 1024;
size_t block_x = 512; size_t block_x = 512;
...@@ -178,10 +198,10 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -178,10 +198,10 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
col2im<T><<<grid, threads, 0, col2im<T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context) reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>( .stream()>>>(
num_kernels, col.data<T>(), input_height + 2 * padding_height, num_kernels, col.data<T>(), input_height + padding_up + padding_down,
input_width + 2 * padding_width, input_channels, filter_height, input_width + padding_left + padding_left, input_channels,
filter_width, stride_height, stride_width, padding_height, filter_height, filter_width, stride_height, stride_width, padding_up,
padding_width, output_height, output_width, im.data<T>()); padding_left, output_height, output_width, im.data<T>());
} }
}; };
...@@ -199,8 +219,7 @@ __global__ void im2colOCF(const T* im_data, T* col_data, int input_channels, ...@@ -199,8 +219,7 @@ __global__ void im2colOCF(const T* im_data, T* col_data, int input_channels,
int input_height, int input_width, int filter_height, int input_height, int input_width, int filter_height,
int filter_width, int stride_height, int stride_width, int filter_width, int stride_height, int stride_width,
int padding_height, int padding_width, int padding_height, int padding_width,
int output_height, int output_width, int row_begin, int output_height, int output_width) {
int row_end) {
int swid = blockIdx.x; int swid = blockIdx.x;
int shid = blockIdx.y; int shid = blockIdx.y;
for (int channelid = threadIdx.z; channelid < input_channels; for (int channelid = threadIdx.z; channelid < input_channels;
...@@ -208,8 +227,7 @@ __global__ void im2colOCF(const T* im_data, T* col_data, int input_channels, ...@@ -208,8 +227,7 @@ __global__ void im2colOCF(const T* im_data, T* col_data, int input_channels,
for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) { for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) {
for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) { for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) {
int width_offset = idx + swid * stride_width - padding_width; int width_offset = idx + swid * stride_width - padding_width;
int height_offset = int height_offset = idy + shid * stride_height - padding_height;
idy + (shid + row_begin) * stride_height - padding_height;
int im_offset = width_offset + height_offset * input_width + int im_offset = width_offset + height_offset * input_width +
channelid * input_height * input_width; channelid * input_height * input_width;
...@@ -240,8 +258,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -240,8 +258,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col, const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int up_pad, int stride_height, int stride_width, int padding_up,
int down_pad) { int padding_down, int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0]; int input_channels = im.dims()[0];
...@@ -249,22 +267,18 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -249,22 +267,18 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
int input_width = im.dims()[2]; int input_width = im.dims()[2];
int filter_height = col.dims()[3]; int filter_height = col.dims()[3];
int filter_width = col.dims()[4]; int filter_width = col.dims()[4];
int output_height = col.dims()[0];
int row_begin, row_end;
int padding_height = std::max(up_pad, down_pad);
int padding_width = 0;
if (up_pad >= down_pad) {
row_begin = 0;
} else {
row_begin = down_pad - up_pad;
}
row_end = row_begin + ((input_height + up_pad + down_pad - filter_height) /
stride_height +
1);
int output_height = row_end - row_begin; // col.dims()[0];
int output_width = col.dims()[1]; int output_width = col.dims()[1];
PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) /
stride_height +
1 ==
output_height);
PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) /
stride_width +
1 ==
output_width);
int block_dim_x = 0; int block_dim_x = 0;
int block_dim_y = 0; int block_dim_y = 0;
if (filter_height <= 4 && filter_width <= 4) { if (filter_height <= 4 && filter_width <= 4) {
...@@ -289,9 +303,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -289,9 +303,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
reinterpret_cast<const platform::CUDADeviceContext&>(context) reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>( .stream()>>>(
im.data<T>(), col.data<T>(), input_channels, input_height, input_width, im.data<T>(), col.data<T>(), input_channels, input_height, input_width,
filter_height, filter_width, stride_height, stride_width, filter_height, filter_width, stride_height, stride_width, padding_up,
padding_height, padding_width, output_height, output_width, row_begin, padding_left, output_height, output_width);
row_end);
} }
}; };
...@@ -300,8 +313,7 @@ __global__ void col2imOCF(T* im_data, const T* col_data, int input_channels, ...@@ -300,8 +313,7 @@ __global__ void col2imOCF(T* im_data, const T* col_data, int input_channels,
int input_height, int input_width, int filter_height, int input_height, int input_width, int filter_height,
int filter_width, int stride_height, int stride_width, int filter_width, int stride_height, int stride_width,
int padding_height, int padding_width, int padding_height, int padding_width,
int output_height, int output_width, int row_begin, int output_height, int output_width) {
int row_end) {
int swid = blockIdx.x; int swid = blockIdx.x;
int shid = blockIdx.y; int shid = blockIdx.y;
for (int channelid = threadIdx.z; channelid < input_channels; for (int channelid = threadIdx.z; channelid < input_channels;
...@@ -309,8 +321,7 @@ __global__ void col2imOCF(T* im_data, const T* col_data, int input_channels, ...@@ -309,8 +321,7 @@ __global__ void col2imOCF(T* im_data, const T* col_data, int input_channels,
for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) { for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) {
for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) { for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) {
int width_offset = idx + swid * stride_width - padding_width; int width_offset = idx + swid * stride_width - padding_width;
int height_offset = int height_offset = idy + shid * stride_height - padding_height;
idy + (shid + row_begin) * stride_height - padding_height;
int im_offset = width_offset + height_offset * input_width + int im_offset = width_offset + height_offset * input_width +
channelid * input_height * input_width; channelid * input_height * input_width;
...@@ -340,7 +351,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -340,7 +351,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
public: public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im, void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int stride_height, const framework::Tensor& col, int stride_height,
int stride_width, int up_pad, int down_pad) { int stride_width, int padding_up, int padding_down,
int padding_left, int padding_right) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0]; int input_channels = im.dims()[0];
...@@ -348,22 +360,18 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -348,22 +360,18 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
int input_width = im.dims()[2]; int input_width = im.dims()[2];
int filter_height = col.dims()[3]; int filter_height = col.dims()[3];
int filter_width = col.dims()[4]; int filter_width = col.dims()[4];
int output_height = col.dims()[0];
int row_begin, row_end;
int padding_height = std::max(up_pad, down_pad);
int padding_width = 0;
if (up_pad >= down_pad) {
row_begin = 0;
} else {
row_begin = down_pad - up_pad;
}
row_end = row_begin + ((input_height + up_pad + down_pad - filter_height) /
stride_height +
1);
int output_height = row_end - row_begin; // col.dims()[0];
int output_width = col.dims()[1]; int output_width = col.dims()[1];
PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) /
stride_height +
1 ==
output_height);
PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) /
stride_width +
1 ==
output_width);
int block_dim_x = 0; int block_dim_x = 0;
int block_dim_y = 0; int block_dim_y = 0;
if (filter_height <= 4 && filter_width <= 4) { if (filter_height <= 4 && filter_width <= 4) {
...@@ -388,9 +396,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -388,9 +396,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
reinterpret_cast<const platform::CUDADeviceContext&>(context) reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>( .stream()>>>(
im.data<T>(), col.data<T>(), input_channels, input_height, input_width, im.data<T>(), col.data<T>(), input_channels, input_height, input_width,
filter_height, filter_width, stride_height, stride_width, filter_height, filter_width, stride_height, stride_width, padding_up,
padding_height, padding_width, output_height, output_width, row_begin, padding_left, output_height, output_width);
row_end);
} }
}; };
......
...@@ -74,8 +74,8 @@ class Im2ColFunctor { ...@@ -74,8 +74,8 @@ class Im2ColFunctor {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col, const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height, int stride_height, int stride_width, int padding_up,
int padding_width); int padding_down, int padding_left, int padding_right);
}; };
template <ColFormat Format, typename Place, typename T> template <ColFormat Format, typename Place, typename T>
...@@ -83,7 +83,8 @@ class Col2ImFunctor { ...@@ -83,7 +83,8 @@ class Col2ImFunctor {
public: public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im, void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int stride_height, const framework::Tensor& col, int stride_height,
int stride_width, int padding_height, int padding_width); int stride_width, int padding_up, int padding_down,
int padding_left, int padding_right);
}; };
} // namespace math } // namespace math
......
...@@ -85,10 +85,10 @@ void testIm2col() { ...@@ -85,10 +85,10 @@ void testIm2col() {
paddle::operators::math::ColFormat::kOCF, Place, float> paddle::operators::math::ColFormat::kOCF, Place, float>
im2col_ocf; im2col_ocf;
im2col(*context, input, output_cfo, stride, stride, padding, padding); im2col(*context, input, output_cfo, stride, stride, padding, padding, padding,
im2col_ocf(*context, input, output_ocf, /*stride_height*/ stride, padding);
/*stride_width*/ stride, /*up_pad*/ padding, im2col_ocf(*context, input, output_ocf, stride, stride, padding, padding,
/*down_pad*/ padding); padding, padding);
float out_cfo_data[] = {0, 1, 1, 2, 3, 4, 4, 5}; float out_cfo_data[] = {0, 1, 1, 2, 3, 4, 4, 5};
float out_ocf_data[] = {0, 1, 3, 4, 1, 2, 4, 5}; float out_ocf_data[] = {0, 1, 3, 4, 1, 2, 4, 5};
...@@ -133,7 +133,8 @@ void testIm2col() { ...@@ -133,7 +133,8 @@ void testIm2col() {
input.CopyFrom<float>(input_tmp, *place, *context); input.CopyFrom<float>(input_tmp, *place, *context);
} }
col2im(*context, input, output_cfo, stride, stride, padding, padding); col2im(*context, input, output_cfo, stride, stride, padding, padding, padding,
padding);
float* in_ptr; float* in_ptr;
if (paddle::platform::is_cpu_place(*place)) { if (paddle::platform::is_cpu_place(*place)) {
...@@ -154,9 +155,8 @@ void testIm2col() { ...@@ -154,9 +155,8 @@ void testIm2col() {
input.CopyFrom<float>(input_tmp, *place, *context); input.CopyFrom<float>(input_tmp, *place, *context);
} }
col2im_ocf(*context, input, output_ocf, /*stride_height*/ stride, col2im_ocf(*context, input, output_ocf, stride, stride, padding, padding,
/*stride_width*/ stride, /*up_pad*/ padding, padding, padding);
/*down_pad*/ padding);
if (paddle::platform::is_cpu_place(*place)) { if (paddle::platform::is_cpu_place(*place)) {
in_ptr = input.data<float>(); in_ptr = input.data<float>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册