From c85d777f879e128a3a9b00ddfc243879a747f5da Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Tue, 10 Oct 2017 22:35:55 +0800 Subject: [PATCH] follow comments --- paddle/operators/math/CMakeLists.txt | 8 ++++-- paddle/operators/math/vol2col.cc | 2 +- paddle/operators/math/vol2col_test.cc | 40 +++++++-------------------- 3 files changed, 16 insertions(+), 34 deletions(-) diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index d6e8373210a..575e89eed80 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -1,15 +1,17 @@ if(WITH_GPU) - nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc im2col.cu vol2col.cc vol2col.cu pooling.cc pooling.cu DEPS cblas device_context operator) + nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc im2col.cu pooling.cc pooling.cu DEPS cblas device_context operator) nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) nv_library(softmax SRCS softmax.cc softmax.cu DEPS operator) nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator) + nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS cblas device_context operator) else() - cc_library(math_function SRCS math_function.cc im2col.cc vol2col.cc pooling.cc DEPS cblas device_context operator) + cc_library(math_function SRCS math_function.cc im2col.cc pooling.cc DEPS cblas device_context operator) cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) cc_library(softmax SRCS softmax.cc DEPS operator) cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator) + cc_library(vol2col SRCS vol2col.cc DEPS cblas device_context operator) endif() cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor) -cc_test(vol2col_test SRCS vol2col_test.cc DEPS math_function tensor) +cc_test(vol2col_test SRCS vol2col_test.cc DEPS vol2col tensor) diff --git a/paddle/operators/math/vol2col.cc b/paddle/operators/math/vol2col.cc index 5bad2e8073f..e9718a04738 100644 --- a/paddle/operators/math/vol2col.cc +++ b/paddle/operators/math/vol2col.cc @@ -67,7 +67,7 @@ class Vol2ColFunctor { ((c * output_depth + d) * output_height + h) * output_width + w; if (h_pad < 0 || h_pad >= input_height || w_pad < 0 || w_pad >= input_width || d_pad < 0 || d_pad >= input_depth) { - col_data[col_idx] = T(0); + col_data[col_idx] = static_cast(0); } else { int vol_idx = ((c_in * input_depth + d_pad) * input_height + h_pad) * diff --git a/paddle/operators/math/vol2col_test.cc b/paddle/operators/math/vol2col_test.cc index 107a94511f2..e3c599da87c 100644 --- a/paddle/operators/math/vol2col_test.cc +++ b/paddle/operators/math/vol2col_test.cc @@ -30,12 +30,12 @@ void testVol2col() { context = new paddle::platform::CPUDeviceContext(paddle::platform::CPUPlace()); } else { -#ifndef PADDLE_ONLY_CPU +#ifdef PADDLE_WITH_CUDA context = new paddle::platform::CUDADeviceContext(paddle::platform::GPUPlace()); #else PADDLE_THROW("no GPU support"); -#endif // PADDLE_ONLY_CPU +#endif // PADDLE_WITH_CUDA } /** @@ -89,6 +89,7 @@ void testVol2col() { vol2col(*context, input, output_cfo, stride, stride, stride, padding, padding, padding); + float vol_2_col[] = {0, 1, 1, 2, 3, 4, 4, 5, 6, 7, 7, 8, 9, 10, 10, 11}; float* out_cfo_ptr; if (paddle::platform::is_cpu_place(*place)) { out_cfo_ptr = output_cfo.data(); @@ -97,24 +98,12 @@ void testVol2col() { out_cfo_ptr = output_tmp.data(); } - EXPECT_EQ(out_cfo_ptr[0], 0); - EXPECT_EQ(out_cfo_ptr[1], 1); - EXPECT_EQ(out_cfo_ptr[2], 1); - EXPECT_EQ(out_cfo_ptr[3], 2); - EXPECT_EQ(out_cfo_ptr[4], 3); - EXPECT_EQ(out_cfo_ptr[5], 4); - EXPECT_EQ(out_cfo_ptr[6], 4); - EXPECT_EQ(out_cfo_ptr[7], 5); - EXPECT_EQ(out_cfo_ptr[8], 6); - EXPECT_EQ(out_cfo_ptr[9], 7); - EXPECT_EQ(out_cfo_ptr[10], 7); - EXPECT_EQ(out_cfo_ptr[11], 8); - EXPECT_EQ(out_cfo_ptr[12], 9); - EXPECT_EQ(out_cfo_ptr[13], 10); - EXPECT_EQ(out_cfo_ptr[14], 10); - EXPECT_EQ(out_cfo_ptr[15], 11); + for (int i = 0; i < 16; ++i) { + EXPECT_EQ(out_cfo_ptr[i], vol_2_col[i]); + } // Col2Vol test + float col_2_vol[] = {0, 2, 2, 3, 8, 5, 6, 14, 8, 9, 20, 11}; memset(input_ptr, 0, 12 * sizeof(float)); if (paddle::platform::is_cpu_place(*place)) { input = input_tmp; @@ -134,18 +123,9 @@ void testVol2col() { in_cfo_ptr = input_tmp.data(); } - EXPECT_EQ(in_cfo_ptr[0], 0); - EXPECT_EQ(in_cfo_ptr[1], 2); - EXPECT_EQ(in_cfo_ptr[2], 2); - EXPECT_EQ(in_cfo_ptr[3], 3); - EXPECT_EQ(in_cfo_ptr[4], 8); - EXPECT_EQ(in_cfo_ptr[5], 5); - EXPECT_EQ(in_cfo_ptr[6], 6); - EXPECT_EQ(in_cfo_ptr[7], 14); - EXPECT_EQ(in_cfo_ptr[8], 8); - EXPECT_EQ(in_cfo_ptr[9], 9); - EXPECT_EQ(in_cfo_ptr[10], 20); - EXPECT_EQ(in_cfo_ptr[11], 11); + for (int i = 0; i < 12; ++i) { + EXPECT_EQ(in_cfo_ptr[i], col_2_vol[i]); + } } TEST(math, vol2col) { -- GitLab