提交 c85d777f 编写于 作者: C chengduoZH

follow comments

上级 3db3a106
if(WITH_GPU) if(WITH_GPU)
nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc im2col.cu vol2col.cc vol2col.cu pooling.cc pooling.cu DEPS cblas device_context operator) nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc im2col.cu pooling.cc pooling.cu DEPS cblas device_context operator)
nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
nv_library(softmax SRCS softmax.cc softmax.cu DEPS operator) nv_library(softmax SRCS softmax.cc softmax.cu DEPS operator)
nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator) nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator)
nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS cblas device_context operator)
else() else()
cc_library(math_function SRCS math_function.cc im2col.cc vol2col.cc pooling.cc DEPS cblas device_context operator) cc_library(math_function SRCS math_function.cc im2col.cc pooling.cc DEPS cblas device_context operator)
cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
cc_library(softmax SRCS softmax.cc DEPS operator) cc_library(softmax SRCS softmax.cc DEPS operator)
cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator) cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator)
cc_library(vol2col SRCS vol2col.cc DEPS cblas device_context operator)
endif() endif()
cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor) cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor)
cc_test(vol2col_test SRCS vol2col_test.cc DEPS math_function tensor) cc_test(vol2col_test SRCS vol2col_test.cc DEPS vol2col tensor)
...@@ -67,7 +67,7 @@ class Vol2ColFunctor<platform::CPUPlace, T> { ...@@ -67,7 +67,7 @@ class Vol2ColFunctor<platform::CPUPlace, T> {
((c * output_depth + d) * output_height + h) * output_width + w; ((c * output_depth + d) * output_height + h) * output_width + w;
if (h_pad < 0 || h_pad >= input_height || w_pad < 0 || if (h_pad < 0 || h_pad >= input_height || w_pad < 0 ||
w_pad >= input_width || d_pad < 0 || d_pad >= input_depth) { w_pad >= input_width || d_pad < 0 || d_pad >= input_depth) {
col_data[col_idx] = T(0); col_data[col_idx] = static_cast<T>(0);
} else { } else {
int vol_idx = int vol_idx =
((c_in * input_depth + d_pad) * input_height + h_pad) * ((c_in * input_depth + d_pad) * input_height + h_pad) *
......
...@@ -30,12 +30,12 @@ void testVol2col() { ...@@ -30,12 +30,12 @@ void testVol2col() {
context = context =
new paddle::platform::CPUDeviceContext(paddle::platform::CPUPlace()); new paddle::platform::CPUDeviceContext(paddle::platform::CPUPlace());
} else { } else {
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
context = context =
new paddle::platform::CUDADeviceContext(paddle::platform::GPUPlace()); new paddle::platform::CUDADeviceContext(paddle::platform::GPUPlace());
#else #else
PADDLE_THROW("no GPU support"); PADDLE_THROW("no GPU support");
#endif // PADDLE_ONLY_CPU #endif // PADDLE_WITH_CUDA
} }
/** /**
...@@ -89,6 +89,7 @@ void testVol2col() { ...@@ -89,6 +89,7 @@ void testVol2col() {
vol2col(*context, input, output_cfo, stride, stride, stride, padding, padding, vol2col(*context, input, output_cfo, stride, stride, stride, padding, padding,
padding); padding);
float vol_2_col[] = {0, 1, 1, 2, 3, 4, 4, 5, 6, 7, 7, 8, 9, 10, 10, 11};
float* out_cfo_ptr; float* out_cfo_ptr;
if (paddle::platform::is_cpu_place(*place)) { if (paddle::platform::is_cpu_place(*place)) {
out_cfo_ptr = output_cfo.data<float>(); out_cfo_ptr = output_cfo.data<float>();
...@@ -97,24 +98,12 @@ void testVol2col() { ...@@ -97,24 +98,12 @@ void testVol2col() {
out_cfo_ptr = output_tmp.data<float>(); out_cfo_ptr = output_tmp.data<float>();
} }
EXPECT_EQ(out_cfo_ptr[0], 0); for (int i = 0; i < 16; ++i) {
EXPECT_EQ(out_cfo_ptr[1], 1); EXPECT_EQ(out_cfo_ptr[i], vol_2_col[i]);
EXPECT_EQ(out_cfo_ptr[2], 1); }
EXPECT_EQ(out_cfo_ptr[3], 2);
EXPECT_EQ(out_cfo_ptr[4], 3);
EXPECT_EQ(out_cfo_ptr[5], 4);
EXPECT_EQ(out_cfo_ptr[6], 4);
EXPECT_EQ(out_cfo_ptr[7], 5);
EXPECT_EQ(out_cfo_ptr[8], 6);
EXPECT_EQ(out_cfo_ptr[9], 7);
EXPECT_EQ(out_cfo_ptr[10], 7);
EXPECT_EQ(out_cfo_ptr[11], 8);
EXPECT_EQ(out_cfo_ptr[12], 9);
EXPECT_EQ(out_cfo_ptr[13], 10);
EXPECT_EQ(out_cfo_ptr[14], 10);
EXPECT_EQ(out_cfo_ptr[15], 11);
// Col2Vol test // Col2Vol test
float col_2_vol[] = {0, 2, 2, 3, 8, 5, 6, 14, 8, 9, 20, 11};
memset(input_ptr, 0, 12 * sizeof(float)); memset(input_ptr, 0, 12 * sizeof(float));
if (paddle::platform::is_cpu_place(*place)) { if (paddle::platform::is_cpu_place(*place)) {
input = input_tmp; input = input_tmp;
...@@ -134,18 +123,9 @@ void testVol2col() { ...@@ -134,18 +123,9 @@ void testVol2col() {
in_cfo_ptr = input_tmp.data<float>(); in_cfo_ptr = input_tmp.data<float>();
} }
EXPECT_EQ(in_cfo_ptr[0], 0); for (int i = 0; i < 12; ++i) {
EXPECT_EQ(in_cfo_ptr[1], 2); EXPECT_EQ(in_cfo_ptr[i], col_2_vol[i]);
EXPECT_EQ(in_cfo_ptr[2], 2); }
EXPECT_EQ(in_cfo_ptr[3], 3);
EXPECT_EQ(in_cfo_ptr[4], 8);
EXPECT_EQ(in_cfo_ptr[5], 5);
EXPECT_EQ(in_cfo_ptr[6], 6);
EXPECT_EQ(in_cfo_ptr[7], 14);
EXPECT_EQ(in_cfo_ptr[8], 8);
EXPECT_EQ(in_cfo_ptr[9], 9);
EXPECT_EQ(in_cfo_ptr[10], 20);
EXPECT_EQ(in_cfo_ptr[11], 11);
} }
TEST(math, vol2col) { TEST(math, vol2col) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册