提交 65d418f0 编写于 作者: T tensor-tang

complete im2col with padding==1 and speedup filter width==1

上级 52eb86e3
......@@ -40,10 +40,12 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
dilation[1] == 1) {
if (padding[0] == 0 && padding[1] == 0) {
im2col_sh1sw1dh1dw1ph0pw0<T>(im, col);
} else {
im2col_sh1sw1dh1dw1<T>(im, padding, col);
}
return;
} else if (padding[0] == 1 && padding[1] == 1) {
im2col_sh1sw1dh1dw1ph1pw1<T>(im, col);
return;
}
// TODO(TJ): complete padding >=2
}
im2col_common<T>(im, dilation, stride, padding, col);
}
......
......@@ -21,7 +21,7 @@ namespace paddle {
namespace operators {
namespace math {
/*
/**
* The most common im2col algorithm.
* Support dilation, stride and padding.
*/
......@@ -61,9 +61,9 @@ inline void im2col_common(const framework::Tensor& im,
}
}
/*
/**
* im2col algorithm with strides == 1, dilations == 1, paddings == 0
* */
*/
template <typename T>
inline void im2col_sh1sw1dh1dw1ph0pw0(const framework::Tensor& im,
framework::Tensor* col) {
......@@ -96,10 +96,12 @@ inline void im2col_sh1sw1dh1dw1ph0pw0(const framework::Tensor& im,
}
}
// further optimize: padding == 1 need special
/**
* im2col algorithm with strides == 1, dilations == 1, paddings == 1
* and filter_width == 1 have a special implementation
*/
template <typename T>
inline void im2col_sh1sw1dh1dw1(const framework::Tensor& im,
const std::vector<int>& padding,
inline void im2col_sh1sw1dh1dw1ph1pw1(const framework::Tensor& im,
framework::Tensor* col) {
int im_channels = im.dims()[0];
int im_height = im.dims()[1];
......@@ -108,122 +110,103 @@ inline void im2col_sh1sw1dh1dw1(const framework::Tensor& im,
int filter_width = col->dims()[2];
int output_height = col->dims()[3];
int output_width = col->dims()[4];
constexpr int sh = 1;
constexpr int sw = 1;
constexpr int plh = 1;
constexpr int prh = 1;
constexpr int plw = 1;
constexpr int prw = 1;
const T* im_data = im.data<T>();
T* col_data = col->data<T>();
int col_matrix_width = output_width * output_height;
int im_size = im_height * im_width;
int plh = padding[0];
int plw = padding[1];
int prh = (output_height - 1) * sh + filter_height - im_height - plh;
int prw = (output_width - 1) * sw + filter_width - im_width - plw;
// fill height padding : 0 ~ plh-1, (oh-prh) ~ (oh-1)
// TODO(TJ): refine ph*xxx
assert(plh == prh); // because stride_h == 1
int col_matrix_width = output_width * output_height;
int col_block_fh = filter_width * col_matrix_width; // fw*oh*ow
int col_block_ic = filter_height * col_block_fh; // fh*fw*oh*ow
for (int ph = 0; ph < plh; ++ph) {
int sz = output_width * (plh - ph);
size_t copy_sz = sizeof(T) * sz;
T* col_start_l = col_data + ph * col_block_fh;
T* col_start_r = col_data + (filter_height - ph - 1) * col_block_fh +
col_matrix_width - sz;
// fill height padding
{
size_t copy_size = sizeof(T) * output_width;
T* col_start_l = col_data;
T* col_start_r = col_data + (filter_height - 1) * col_block_fh +
col_matrix_width - output_width;
for (int ic = 0; ic < im_channels; ++ic) {
// TODO(TJ): move * outside
T* dst_data_l = col_start_l + ic * col_block_ic;
T* dst_data_r = col_start_r + ic * col_block_ic;
for (int kw = 0; kw < filter_width; ++kw) {
std::memset(dst_data_l, 0, copy_sz);
std::memset(dst_data_r, 0, copy_sz);
std::memset(dst_data_l, 0, copy_size);
std::memset(dst_data_r, 0, copy_size);
dst_data_l = dst_data_l + col_matrix_width;
dst_data_r = dst_data_r + col_matrix_width;
}
}
}
auto pad = static_cast<T>(0);
if (filter_width == 1) {
// fill width padding
assert(plw == prw); // because stride_w == 1
if (plw == 1) {
auto pad = static_cast<T>(0); // padding zero
for (int ic = 0; ic < im_channels; ++ic) {
// TODO(TJ): use add and resue stride
// TODO(TJ): move * outside
T* dst_data_ic = col_data + ic * col_block_ic;
for (int kh = 0; kh < filter_height; ++kh) {
T* dst_data_kh = dst_data_ic + kh * col_block_fh;
for (T* dst_data :
{dst_data_kh, dst_data_kh +
(filter_width - prw) * col_matrix_width +
output_width - 1}) {
// TODO(TJ): from plh, saving repeated assignment
// TODO(TJ): move * outside
T* dst_data = dst_data_ic + kh * col_block_fh;
for (int oh = 0; oh < output_height; ++oh) {
*dst_data = pad;
dst_data = dst_data + output_width;
dst_data = dst_data + output_width - 1;
*dst_data = pad;
++dst_data;
}
}
}
// fill core
size_t copy_size = sizeof(T) * (output_width - plw - prw);
for (int oh = 0; oh < output_height; ++oh) {
const T* im_data_start =
im_data + (oh - plh > 0 ? oh - plh : 0) * im_width;
T* dst_data = col_data + oh * output_width;
for (int ic = 0; ic < im_channels; ++ic) {
const T* src_data = im_data_start + ic * im_size;
for (int kh = 0; kh < filter_height; ++kh) {
if ((oh < plh && kh < plh) || (oh > (output_height - prh - 1) &&
kh > (filter_height - prh - 1))) {
dst_data = dst_data + col_matrix_width;
continue;
}
std::memcpy(dst_data + plw, src_data, copy_size);
dst_data = dst_data + col_matrix_width;
src_data = src_data + im_width;
}
}
} else {
// padding_size > 1
}
return;
}
// filter_width != 1
// fill width padding
for (int ic = 0; ic < im_channels; ++ic) {
// TODO(TJ): use add and resue stride
// TODO(TJ): move * outside
T* dst_data_ic = col_data + ic * col_block_ic;
for (int kh = 0; kh < filter_height; ++kh) {
// TODO(TJ): move * outside
T* dst_data_kh = dst_data_ic + kh * col_block_fh;
for (int kw = 0; kw < plw; ++kw) {
// TODO(TJ): reuse array outside this for
size_t sz = sizeof(T) * (plw - kw);
T* dst_data = dst_data_kh + kw * col_matrix_width;
// TODO(TJ): from plh, saving repeated assignment
for (int oh = 0; oh < output_height; ++oh) {
std::memset(dst_data, 0, sz);
dst_data = dst_data + output_width;
}
}
// TODO(TJ): use reverse to save cache
for (int kw = 0; kw < prw; ++kw) {
// TODO(TJ): reuse array outside this for
auto num = (prw - kw);
size_t sz = sizeof(T) * num;
T* dst_data = dst_data_kh +
(filter_width - 1 - kw) * col_matrix_width +
output_width - num;
for (T* dst_data :
{dst_data_kh, dst_data_kh + (filter_width - prw) * col_matrix_width +
output_width - 1}) {
// TODO(TJ): from plh, saving repeated assignment
for (int oh = 0; oh < output_height; ++oh) {
std::memset(dst_data, 0, sz);
*dst_data = pad;
dst_data = dst_data + output_width;
}
}
}
}
}
// fill im_data
// padding cover two cases:
// 1. kw > 2*pw: kw = 3, pw = 1
// 0 x x x x ... x x x x 0
// 1 1 1 1 1 1
// ==>
// 0 x ... x x
// x x ... x x
// x x ... x 0
// 2. kw < 2*pw: kw = 3, pw = 2
// 0 0 x x x ... x x x 0 0
// 1 1 1 1 1 1
// ==>
// 0 0 x ... x x x
// 0 x x ... x x 0
// x x x ... x 0 0
// TODO(TJ): use array like: size_t copy_size[kw]={sizeof(T) *
// (output_width-1)}
// length of copy_size is equal kw.
if (plw + prw < filter_width) {
for (int oh = 0; oh < output_height; ++oh) {
const T* im_data_start =
im_data + (oh - plh > 0 ? oh - plh : 0) * im_width;
const T* im_data_start = im_data + (oh - plh > 0 ? oh - plh : 0) * im_width;
T* dst_data = col_data + oh * output_width;
for (int ic = 0; ic < im_channels; ++ic) {
const T* src_data = im_data_start + ic * im_size;
......@@ -255,9 +238,6 @@ inline void im2col_sh1sw1dh1dw1(const framework::Tensor& im,
}
}
}
} else {
LOG(FATAL) << "Not implement yet";
}
}
} // namespace math
......
......@@ -227,7 +227,8 @@ void benchIm2col(int ic, int ih, int iw, int fh, int fw, int ph, int pw) {
auto t3 = GetCurrentMs();
LOG(INFO) << "before: " << (t3 - t2) / repeat
<< ",after: " << (t2 - t1) / repeat;
<< ",after: " << (t2 - t1) / repeat
<< ",boost: " << ((t3 - t2) / (t2 - t1) - 1) * 100 << "%";
}
TEST(math, im2col_cputest) {
......@@ -244,6 +245,10 @@ TEST(math, im2col_cputest) {
// height != width
testIm2colCPU(/*ic*/ 2, /*ih*/ 5, /*iw*/ 4, /*fh*/ 2, /*fw*/ 3, /*ph*/ p,
/*pw*/ p);
testIm2colCPU(/*ic*/ 2, /*ih*/ 5, /*iw*/ 4, /*fh*/ 1, /*fw*/ 3, /*ph*/ p,
/*pw*/ p);
testIm2colCPU(/*ic*/ 2, /*ih*/ 4, /*iw*/ 5, /*fh*/ 3, /*fw*/ 1, /*ph*/ p,
/*pw*/ p);
// filter == 1
testIm2colCPU(/*ic*/ 3, /*ih*/ 4, /*iw*/ 4, /*fh*/ 1, /*fw*/ 1, /*ph*/ p,
......@@ -251,13 +256,14 @@ TEST(math, im2col_cputest) {
testIm2colCPU(/*ic*/ 3, /*ih*/ 3, /*iw*/ 4, /*fh*/ 1, /*fw*/ 1, /*ph*/ p,
/*pw*/ p);
}
// padding_h != padding_w
testIm2colCPU(/*ic*/ 2, /*ih*/ 4, /*iw*/ 4, /*fh*/ 2, /*fw*/ 3, /*ph*/ 1,
/*pw*/ 2);
// benchmark
for (int p : {0, 1, 2}) {
for (int k : {3, 5}) {
for (int p : {0, 1}) {
for (int k : {1, 3, 5}) {
LOG(INFO) << "padding == " << p << ", filter == " << k;
benchIm2col(/*ic*/ 3, /*ih*/ 224, /*iw*/ 224, /*fh*/ k, /*fw*/ k,
/*ph*/ p, /*pw*/ p);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册