From f5ac335046feb81529e85cd0c386379746771157 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Wed, 11 Oct 2017 11:02:26 +0800 Subject: [PATCH] follow comments --- paddle/operators/math/CMakeLists.txt | 5 ++- paddle/operators/math/vol2col_test.cc | 47 +++++++++++++-------------- 2 files changed, 25 insertions(+), 27 deletions(-) diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index d32924db85e..2fd559e90a2 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -3,14 +3,13 @@ if(WITH_GPU) nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) nv_library(softmax SRCS softmax.cc softmax.cu DEPS operator) nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator) - nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context operator) + nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context) else() cc_library(math_function SRCS math_function.cc im2col.cc pooling.cc DEPS cblas device_context operator) cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) cc_library(softmax SRCS softmax.cc DEPS operator) cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator) - cc_library(vol2col SRCS vol2col.cc DEPS device_context operator) - + cc_library(vol2col SRCS vol2col.cc DEPS device_context) endif() cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor) diff --git a/paddle/operators/math/vol2col_test.cc b/paddle/operators/math/vol2col_test.cc index e3c599da87c..81225e9a980 100644 --- a/paddle/operators/math/vol2col_test.cc +++ b/paddle/operators/math/vol2col_test.cc @@ -18,10 +18,9 @@ limitations under the License. */ template void testVol2col() { - paddle::framework::Tensor input_tmp; paddle::framework::Tensor input; - paddle::framework::Tensor output_cfo; - paddle::framework::Tensor output_ocf; + paddle::framework::Tensor input_tmp; + paddle::framework::Tensor output; paddle::framework::Tensor output_tmp; auto* place = new Place(); @@ -44,14 +43,14 @@ void testVol2col() { * [6, 7, 8, * 9, 10, 11]] * - * output_cfo = [0, 1 - * 1, 2 - * 3, 4 - * 4, 5 - * 6, 7 - * 7, 8 - * 9, 10 - * 10, 11] + * output = [0, 1 + * 1, 2 + * 3, 4 + * 4, 5 + * 6, 7 + * 7, 8 + * 9, 10 + * 10, 11] * * col2vol = [[0, 2, 2, * 3, 8, 5] @@ -81,20 +80,20 @@ void testVol2col() { } else { input.CopyFrom(input_tmp, *place); } - output_cfo.mutable_data({1, filter_size, filter_size, filter_size, - output_depth, output_height, output_width}, - *place); + output.mutable_data({1, filter_size, filter_size, filter_size, + output_depth, output_height, output_width}, + *place); paddle::operators::math::Vol2ColFunctor vol2col; - vol2col(*context, input, output_cfo, stride, stride, stride, padding, padding, + vol2col(*context, input, output, stride, stride, stride, padding, padding, padding); float vol_2_col[] = {0, 1, 1, 2, 3, 4, 4, 5, 6, 7, 7, 8, 9, 10, 10, 11}; float* out_cfo_ptr; if (paddle::platform::is_cpu_place(*place)) { - out_cfo_ptr = output_cfo.data(); + out_cfo_ptr = output.data(); } else { - output_tmp.CopyFrom(output_cfo, paddle::platform::CPUPlace()); + output_tmp.CopyFrom(output, paddle::platform::CPUPlace()); out_cfo_ptr = output_tmp.data(); } @@ -112,25 +111,25 @@ void testVol2col() { } paddle::operators::math::Col2VolFunctor col2vol; - col2vol(*context, input, output_cfo, stride, stride, stride, padding, padding, + col2vol(*context, input, output, stride, stride, stride, padding, padding, padding); - float* in_cfo_ptr; + float* in_ptr; if (paddle::platform::is_cpu_place(*place)) { - in_cfo_ptr = input.data(); + in_ptr = input.data(); } else { input_tmp.CopyFrom(input, paddle::platform::CPUPlace()); - in_cfo_ptr = input_tmp.data(); + in_ptr = input_tmp.data(); } for (int i = 0; i < 12; ++i) { - EXPECT_EQ(in_cfo_ptr[i], col_2_vol[i]); + EXPECT_EQ(in_ptr[i], col_2_vol[i]); } } TEST(math, vol2col) { testVol2col(); -#ifndef PADDLE_ONLY_CPU +#ifdef PADDLE_WITH_CUDA testVol2col(); -#endif +#endif // PADDLE_WITH_CUDA } -- GitLab