提交 d2c1408f 编写于 作者: C chengduoZH

fix im2col kocf for sequence projection

上级 8fe7bf38
...@@ -140,8 +140,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -140,8 +140,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col, const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height, int stride_height, int stride_width, int up_pad,
int padding_width) { int down_pad) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0]; int input_channels = im.dims()[0];
...@@ -149,13 +149,25 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -149,13 +149,25 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
int input_width = im.dims()[2]; int input_width = im.dims()[2];
int filter_height = col.dims()[3]; int filter_height = col.dims()[3];
int filter_width = col.dims()[4]; int filter_width = col.dims()[4];
int output_height = col.dims()[0]; // int output_height = col.dims()[0];
int output_width = col.dims()[1]; int output_width = col.dims()[1];
int row_begin, row_end;
int padding_height = std::max(up_pad, down_pad);
int padding_width = 0;
if (up_pad >= down_pad) {
row_begin = 0;
} else {
row_begin = down_pad - up_pad;
}
row_end = row_begin + ((input_height + up_pad + down_pad - filter_height) /
stride_height +
1);
const T* im_data = im.data<T>(); const T* im_data = im.data<T>();
T* col_data = col.data<T>(); T* col_data = col.data<T>();
for (int col_row_idx = 0; col_row_idx < output_height; ++col_row_idx) { for (int col_row_idx = row_begin; col_row_idx < row_end; ++col_row_idx) {
for (int col_col_idx = 0; col_col_idx < output_width; ++col_col_idx) { for (int col_col_idx = 0; col_col_idx < output_width; ++col_col_idx) {
for (int channel = 0; channel < input_channels; ++channel) { for (int channel = 0; channel < input_channels; ++channel) {
for (int filter_row_idx = 0; filter_row_idx < filter_height; for (int filter_row_idx = 0; filter_row_idx < filter_height;
...@@ -166,7 +178,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -166,7 +178,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
col_row_idx * stride_height + filter_row_idx - padding_height; col_row_idx * stride_height + filter_row_idx - padding_height;
int im_col_offset = int im_col_offset =
col_col_idx * stride_width + filter_col_idx - padding_width; col_col_idx * stride_width + filter_col_idx - padding_width;
int col_offset = (((col_row_idx * output_width + col_col_idx) * int col_offset =
((((col_row_idx - row_begin) * output_width + col_col_idx) *
input_channels + input_channels +
channel) * channel) *
filter_height + filter_height +
...@@ -201,7 +214,7 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -201,7 +214,7 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
public: public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im, void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int stride_height, const framework::Tensor& col, int stride_height,
int stride_width, int padding_height, int padding_width) { int stride_width, int up_pad, int down_pad) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0]; int input_channels = im.dims()[0];
...@@ -209,24 +222,37 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -209,24 +222,37 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
int input_width = im.dims()[2]; int input_width = im.dims()[2];
int filter_height = col.dims()[3]; int filter_height = col.dims()[3];
int filter_width = col.dims()[4]; int filter_width = col.dims()[4];
int output_height = col.dims()[0]; // int output_height = col.dims()[0];
int output_width = col.dims()[1]; int output_width = col.dims()[1];
int row_begin, row_end;
int padding_height = std::max(up_pad, down_pad);
int padding_width = 0;
if (up_pad >= down_pad) {
row_begin = 0;
} else {
row_begin = down_pad - up_pad;
}
row_end = row_begin + ((input_height + up_pad + down_pad - filter_height) /
stride_height +
1);
T* im_data = im.data<T>(); T* im_data = im.data<T>();
const T* col_data = col.data<T>(); const T* col_data = col.data<T>();
for (int col_row_idx = 0; col_row_idx < output_height; ++col_row_idx) { for (int col_row_idx = row_begin; col_row_idx < row_end; ++col_row_idx) {
for (int col_col_idx = 0; col_col_idx < output_width; ++col_col_idx) { for (int col_col_idx = 0; col_col_idx < output_width; ++col_col_idx) {
for (int channel = 0; channel < input_channels; ++channel) { for (int channel = 0; channel < input_channels; ++channel) {
for (int filter_row_idx = 0; filter_row_idx < filter_height; for (int filter_row_idx = 0; filter_row_idx < filter_height;
++filter_row_idx) { ++filter_row_idx) {
for (int filter_col_idx = 0; filter_col_idx < filter_width; for (int filter_col_idx = 0; filter_col_idx < filter_width;
++filter_col_idx) { ++filter_col_idx) {
int im_row_offset = int im_row_offset = // change or not ???
col_row_idx * stride_height + filter_row_idx - padding_height; col_row_idx * stride_height + filter_row_idx - padding_height;
int im_col_offset = int im_col_offset =
col_col_idx * stride_width + filter_col_idx - padding_width; col_col_idx * stride_width + filter_col_idx - padding_width;
int col_offset = (((col_row_idx * output_width + col_col_idx) * int col_offset =
((((col_row_idx - row_begin) * output_width + col_col_idx) *
input_channels + input_channels +
channel) * channel) *
filter_height + filter_height +
......
...@@ -199,7 +199,8 @@ __global__ void im2colOCF(const T* im_data, T* col_data, int input_channels, ...@@ -199,7 +199,8 @@ __global__ void im2colOCF(const T* im_data, T* col_data, int input_channels,
int input_height, int input_width, int filter_height, int input_height, int input_width, int filter_height,
int filter_width, int stride_height, int stride_width, int filter_width, int stride_height, int stride_width,
int padding_height, int padding_width, int padding_height, int padding_width,
int output_height, int output_width) { int output_height, int output_width, int row_begin,
int row_end) {
int swid = blockIdx.x; int swid = blockIdx.x;
int shid = blockIdx.y; int shid = blockIdx.y;
for (int channelid = threadIdx.z; channelid < input_channels; for (int channelid = threadIdx.z; channelid < input_channels;
...@@ -207,7 +208,8 @@ __global__ void im2colOCF(const T* im_data, T* col_data, int input_channels, ...@@ -207,7 +208,8 @@ __global__ void im2colOCF(const T* im_data, T* col_data, int input_channels,
for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) { for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) {
for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) { for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) {
int width_offset = idx + swid * stride_width - padding_width; int width_offset = idx + swid * stride_width - padding_width;
int height_offset = idy + shid * stride_height - padding_height; int height_offset =
idy + (shid + row_begin) * stride_height - padding_height;
int im_offset = width_offset + height_offset * input_width + int im_offset = width_offset + height_offset * input_width +
channelid * input_height * input_width; channelid * input_height * input_width;
...@@ -238,8 +240,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -238,8 +240,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col, const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height, int stride_height, int stride_width, int up_pad,
int padding_width) { int down_pad) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0]; int input_channels = im.dims()[0];
...@@ -247,7 +249,20 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -247,7 +249,20 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
int input_width = im.dims()[2]; int input_width = im.dims()[2];
int filter_height = col.dims()[3]; int filter_height = col.dims()[3];
int filter_width = col.dims()[4]; int filter_width = col.dims()[4];
int output_height = col.dims()[0];
int row_begin, row_end;
int padding_height = std::max(up_pad, down_pad);
int padding_width = 0;
if (up_pad >= down_pad) {
row_begin = 0;
} else {
row_begin = down_pad - up_pad;
}
row_end = row_begin + ((input_height + up_pad + down_pad - filter_height) /
stride_height +
1);
int output_height = row_end - row_begin; // col.dims()[0];
int output_width = col.dims()[1]; int output_width = col.dims()[1];
int block_dim_x = 0; int block_dim_x = 0;
...@@ -275,7 +290,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -275,7 +290,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
.stream()>>>( .stream()>>>(
im.data<T>(), col.data<T>(), input_channels, input_height, input_width, im.data<T>(), col.data<T>(), input_channels, input_height, input_width,
filter_height, filter_width, stride_height, stride_width, filter_height, filter_width, stride_height, stride_width,
padding_height, padding_width, output_height, output_width); padding_height, padding_width, output_height, output_width, row_begin,
row_end);
} }
}; };
...@@ -284,7 +300,8 @@ __global__ void col2imOCF(T* im_data, const T* col_data, int input_channels, ...@@ -284,7 +300,8 @@ __global__ void col2imOCF(T* im_data, const T* col_data, int input_channels,
int input_height, int input_width, int filter_height, int input_height, int input_width, int filter_height,
int filter_width, int stride_height, int stride_width, int filter_width, int stride_height, int stride_width,
int padding_height, int padding_width, int padding_height, int padding_width,
int output_height, int output_width) { int output_height, int output_width, int row_begin,
int row_end) {
int swid = blockIdx.x; int swid = blockIdx.x;
int shid = blockIdx.y; int shid = blockIdx.y;
for (int channelid = threadIdx.z; channelid < input_channels; for (int channelid = threadIdx.z; channelid < input_channels;
...@@ -292,7 +309,8 @@ __global__ void col2imOCF(T* im_data, const T* col_data, int input_channels, ...@@ -292,7 +309,8 @@ __global__ void col2imOCF(T* im_data, const T* col_data, int input_channels,
for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) { for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) {
for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) { for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) {
int width_offset = idx + swid * stride_width - padding_width; int width_offset = idx + swid * stride_width - padding_width;
int height_offset = idy + shid * stride_height - padding_height; int height_offset =
idy + (shid + row_begin) * stride_height - padding_height;
int im_offset = width_offset + height_offset * input_width + int im_offset = width_offset + height_offset * input_width +
channelid * input_height * input_width; channelid * input_height * input_width;
...@@ -322,7 +340,7 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -322,7 +340,7 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
public: public:
void operator()(const platform::DeviceContext& context, framework::Tensor& im, void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int stride_height, const framework::Tensor& col, int stride_height,
int stride_width, int padding_height, int padding_width) { int stride_width, int up_pad, int down_pad) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0]; int input_channels = im.dims()[0];
...@@ -330,7 +348,20 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -330,7 +348,20 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
int input_width = im.dims()[2]; int input_width = im.dims()[2];
int filter_height = col.dims()[3]; int filter_height = col.dims()[3];
int filter_width = col.dims()[4]; int filter_width = col.dims()[4];
int output_height = col.dims()[0];
int row_begin, row_end;
int padding_height = std::max(up_pad, down_pad);
int padding_width = 0;
if (up_pad >= down_pad) {
row_begin = 0;
} else {
row_begin = down_pad - up_pad;
}
row_end = row_begin + ((input_height + up_pad + down_pad - filter_height) /
stride_height +
1);
int output_height = row_end - row_begin; // col.dims()[0];
int output_width = col.dims()[1]; int output_width = col.dims()[1];
int block_dim_x = 0; int block_dim_x = 0;
...@@ -358,7 +389,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -358,7 +389,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
.stream()>>>( .stream()>>>(
im.data<T>(), col.data<T>(), input_channels, input_height, input_width, im.data<T>(), col.data<T>(), input_channels, input_height, input_width,
filter_height, filter_width, stride_height, stride_width, filter_height, filter_width, stride_height, stride_width,
padding_height, padding_width, output_height, output_width); padding_height, padding_width, output_height, output_width, row_begin,
row_end);
} }
}; };
......
...@@ -35,6 +35,12 @@ void testIm2col() { ...@@ -35,6 +35,12 @@ void testIm2col() {
* *
* output_ocf = [0, 1, 3, 4 * output_ocf = [0, 1, 3, 4
* 1, 2, 4, 5] * 1, 2, 4, 5]
*
* col2im_cfo = [0, 2, 2
* 3, 4, 5]
*
* col2im_ocf = [0, 2, 2
* 3, 4, 5]
*/ */
int input_height = 2; int input_height = 2;
int input_width = 3; int input_width = 3;
...@@ -59,7 +65,7 @@ void testIm2col() { ...@@ -59,7 +65,7 @@ void testIm2col() {
new paddle::platform::CUDADeviceContext(paddle::platform::GPUPlace()); new paddle::platform::CUDADeviceContext(paddle::platform::GPUPlace());
#else #else
PADDLE_THROW("no GPU support"); PADDLE_THROW("no GPU support");
#endif // PADDLE_ONLY_CPU #endif // PADDLE_WITH_CUDA
} }
if (paddle::platform::is_cpu_place(*place)) { if (paddle::platform::is_cpu_place(*place)) {
input = input_tmp; input = input_tmp;
...@@ -71,6 +77,7 @@ void testIm2col() { ...@@ -71,6 +77,7 @@ void testIm2col() {
output_ocf.mutable_data<float>( output_ocf.mutable_data<float>(
{output_height, output_width, 1, filter_size, filter_size}, *place); {output_height, output_width, 1, filter_size, filter_size}, *place);
// Im2Col
paddle::operators::math::Im2ColFunctor< paddle::operators::math::Im2ColFunctor<
paddle::operators::math::ColFormat::kCFO, Place, float> paddle::operators::math::ColFormat::kCFO, Place, float>
im2col; im2col;
...@@ -79,7 +86,12 @@ void testIm2col() { ...@@ -79,7 +86,12 @@ void testIm2col() {
im2col_ocf; im2col_ocf;
im2col(*context, input, output_cfo, stride, stride, padding, padding); im2col(*context, input, output_cfo, stride, stride, padding, padding);
im2col_ocf(*context, input, output_ocf, stride, stride, padding, padding); im2col_ocf(*context, input, output_ocf, /*stride_height*/ stride,
/*stride_width*/ stride, /*up_pad*/ padding,
/*down_pad*/ padding);
float out_cfo_data[] = {0, 1, 1, 2, 3, 4, 4, 5};
float out_ocf_data[] = {0, 1, 3, 4, 1, 2, 4, 5};
float* out_cfo_ptr; float* out_cfo_ptr;
if (paddle::platform::is_cpu_place(*place)) { if (paddle::platform::is_cpu_place(*place)) {
...@@ -89,14 +101,9 @@ void testIm2col() { ...@@ -89,14 +101,9 @@ void testIm2col() {
*context); *context);
out_cfo_ptr = output_tmp.data<float>(); out_cfo_ptr = output_tmp.data<float>();
} }
EXPECT_EQ(out_cfo_ptr[0], 0); for (int i = 0; i < 6; ++i) {
EXPECT_EQ(out_cfo_ptr[1], 1); EXPECT_EQ(out_cfo_ptr[i], out_cfo_data[i]);
EXPECT_EQ(out_cfo_ptr[2], 1); }
EXPECT_EQ(out_cfo_ptr[3], 2);
EXPECT_EQ(out_cfo_ptr[4], 3);
EXPECT_EQ(out_cfo_ptr[5], 4);
EXPECT_EQ(out_cfo_ptr[6], 4);
EXPECT_EQ(out_cfo_ptr[7], 5);
float* out_ocf_ptr; float* out_ocf_ptr;
if (paddle::platform::is_cpu_place(*place)) { if (paddle::platform::is_cpu_place(*place)) {
...@@ -106,14 +113,60 @@ void testIm2col() { ...@@ -106,14 +113,60 @@ void testIm2col() {
*context); *context);
out_ocf_ptr = output_tmp.data<float>(); out_ocf_ptr = output_tmp.data<float>();
} }
EXPECT_EQ(out_ocf_ptr[0], 0); for (int i = 0; i < 6; ++i) {
EXPECT_EQ(out_ocf_ptr[1], 1); EXPECT_EQ(out_ocf_ptr[i], out_ocf_data[i]);
EXPECT_EQ(out_ocf_ptr[2], 3); }
EXPECT_EQ(out_ocf_ptr[3], 4);
EXPECT_EQ(out_ocf_ptr[4], 1); // Col2Im: kCFO
EXPECT_EQ(out_ocf_ptr[5], 2); paddle::operators::math::Col2ImFunctor<
EXPECT_EQ(out_ocf_ptr[6], 4); paddle::operators::math::ColFormat::kCFO, Place, float>
EXPECT_EQ(out_ocf_ptr[7], 5); col2im;
paddle::operators::math::Col2ImFunctor<
paddle::operators::math::ColFormat::kOCF, Place, float>
col2im_ocf;
float col2im_data[] = {0, 2, 2, 3, 8, 5};
memset(input_ptr, 0, 6 * sizeof(float));
if (paddle::platform::is_cpu_place(*place)) {
input = input_tmp;
} else {
input.CopyFrom<float>(input_tmp, *place, *context);
}
col2im(*context, input, output_cfo, stride, stride, padding, padding);
float* in_ptr;
if (paddle::platform::is_cpu_place(*place)) {
in_ptr = input.data<float>();
} else {
input_tmp.CopyFrom<float>(input, paddle::platform::CPUPlace(), *context);
in_ptr = input_tmp.data<float>();
}
for (int i = 0; i < 6; ++i) {
EXPECT_EQ(in_ptr[i], col2im_data[i]);
}
// Col2Im: kOCF
memset(input_ptr, 0, 6 * sizeof(float));
if (paddle::platform::is_cpu_place(*place)) {
input = input_tmp;
} else {
input.CopyFrom<float>(input_tmp, *place, *context);
}
col2im_ocf(*context, input, output_ocf, /*stride_height*/ stride,
/*stride_width*/ stride, /*up_pad*/ padding,
/*down_pad*/ padding);
if (paddle::platform::is_cpu_place(*place)) {
in_ptr = input.data<float>();
} else {
input_tmp.CopyFrom<float>(input, paddle::platform::CPUPlace(), *context);
in_ptr = input_tmp.data<float>();
}
for (int i = 0; i < 6; ++i) {
EXPECT_EQ(in_ptr[i], col2im_data[i]);
}
} }
TEST(math, im2col) { TEST(math, im2col) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册