提交 3017f460 编写于 作者: T tensor-tang

add more test cases

上级 8d6be4fb
...@@ -108,8 +108,8 @@ inline void im2col_sh1sw1dh1dw1(const framework::Tensor& im, ...@@ -108,8 +108,8 @@ inline void im2col_sh1sw1dh1dw1(const framework::Tensor& im,
int filter_width = col->dims()[2]; int filter_width = col->dims()[2];
int output_height = col->dims()[3]; int output_height = col->dims()[3];
int output_width = col->dims()[4]; int output_width = col->dims()[4];
const int sh = 1; constexpr int sh = 1;
const int sw = 1; constexpr int sw = 1;
const T* im_data = im.data<T>(); const T* im_data = im.data<T>();
T* col_data = col->data<T>(); T* col_data = col->data<T>();
......
...@@ -208,7 +208,7 @@ void testIm2colCPU(int ic, int ih, int iw, int fh, int fw, int ph, int pw) { ...@@ -208,7 +208,7 @@ void testIm2colCPU(int ic, int ih, int iw, int fh, int fw, int ph, int pw) {
void benchIm2col(int ic, int ih, int iw, int fh, int fw, int ph, int pw) { void benchIm2col(int ic, int ih, int iw, int fh, int fw, int ph, int pw) {
PREPARE_IM2COL_CPU; PREPARE_IM2COL_CPU;
constexpr int repeat = 30; constexpr int repeat = 100;
auto GetCurrentMs = []() -> double { auto GetCurrentMs = []() -> double {
struct timeval time; struct timeval time;
gettimeofday(&time, NULL); gettimeofday(&time, NULL);
...@@ -231,17 +231,39 @@ void benchIm2col(int ic, int ih, int iw, int fh, int fw, int ph, int pw) { ...@@ -231,17 +231,39 @@ void benchIm2col(int ic, int ih, int iw, int fh, int fw, int ph, int pw) {
} }
TEST(math, im2col_cputest) { TEST(math, im2col_cputest) {
testIm2colCPU(/*ic*/ 2, /*ih*/ 5, /*iw*/ 4, /*fh*/ 3, /*fw*/ 3, /*ph*/ 0, // padding_h == padding_w
/*pw*/ 0); for (int p = 0; p < 4; ++p) {
testIm2colCPU(/*ic*/ 2, /*ih*/ 5, /*iw*/ 4, /*fh*/ 3, /*fw*/ 3, /*ph*/ 1, // width == height
/*pw*/ 1); testIm2colCPU(/*ic*/ 2, /*ih*/ 5, /*iw*/ 5, /*fh*/ 4, /*fw*/ 4, /*ph*/ p,
/*pw*/ p);
testIm2colCPU(/*ic*/ 2, /*ih*/ 4, /*iw*/ 4, /*fh*/ 3, /*fw*/ 3, /*ph*/ p,
/*pw*/ p);
testIm2colCPU(/*ic*/ 2, /*ih*/ 4, /*iw*/ 4, /*fh*/ 2, /*fw*/ 2, /*ph*/ p,
/*pw*/ p);
benchIm2col(/*ic*/ 3, /*ih*/ 224, /*iw*/ 224, /*fh*/ 3, /*fw*/ 3, /*ph*/ 1, // height != width
/*pw*/ 1); testIm2colCPU(/*ic*/ 2, /*ih*/ 5, /*iw*/ 4, /*fh*/ 2, /*fw*/ 3, /*ph*/ p,
/*pw*/ p);
// filter == 1
testIm2colCPU(/*ic*/ 3, /*ih*/ 4, /*iw*/ 4, /*fh*/ 1, /*fw*/ 1, /*ph*/ p,
/*pw*/ p);
testIm2colCPU(/*ic*/ 3, /*ih*/ 3, /*iw*/ 4, /*fh*/ 1, /*fw*/ 1, /*ph*/ p,
/*pw*/ p);
}
// padding_h != padding_w
testIm2colCPU(/*ic*/ 2, /*ih*/ 4, /*iw*/ 4, /*fh*/ 2, /*fw*/ 3, /*ph*/ 1,
/*pw*/ 2);
// benchmark
LOG(INFO) << "padding == 0";
benchIm2col(/*ic*/ 3, /*ih*/ 224, /*iw*/ 224, /*fh*/ 3, /*fw*/ 3, /*ph*/ 0, benchIm2col(/*ic*/ 3, /*ih*/ 224, /*iw*/ 224, /*fh*/ 3, /*fw*/ 3, /*ph*/ 0,
/*pw*/ 0); /*pw*/ 0);
benchIm2col(/*ic*/ 3, /*ih*/ 224, /*iw*/ 224, /*fh*/ 5, /*fw*/ 5, /*ph*/ 1,
/*pw*/ 1);
benchIm2col(/*ic*/ 3, /*ih*/ 224, /*iw*/ 224, /*fh*/ 5, /*fw*/ 5, /*ph*/ 0, benchIm2col(/*ic*/ 3, /*ih*/ 224, /*iw*/ 224, /*fh*/ 5, /*fw*/ 5, /*ph*/ 0,
/*pw*/ 0); /*pw*/ 0);
LOG(INFO) << "padding == 1";
benchIm2col(/*ic*/ 3, /*ih*/ 224, /*iw*/ 224, /*fh*/ 3, /*fw*/ 3, /*ph*/ 1,
/*pw*/ 1);
benchIm2col(/*ic*/ 3, /*ih*/ 224, /*iw*/ 224, /*fh*/ 5, /*fw*/ 5, /*ph*/ 1,
/*pw*/ 1);
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册