提交 f5ac3350 编写于 作者: C chengduoZH

follow comments

上级 1d41a6d4
...@@ -3,14 +3,13 @@ if(WITH_GPU) ...@@ -3,14 +3,13 @@ if(WITH_GPU)
nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
nv_library(softmax SRCS softmax.cc softmax.cu DEPS operator) nv_library(softmax SRCS softmax.cc softmax.cu DEPS operator)
nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator) nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator)
nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context operator) nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context)
else() else()
cc_library(math_function SRCS math_function.cc im2col.cc pooling.cc DEPS cblas device_context operator) cc_library(math_function SRCS math_function.cc im2col.cc pooling.cc DEPS cblas device_context operator)
cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
cc_library(softmax SRCS softmax.cc DEPS operator) cc_library(softmax SRCS softmax.cc DEPS operator)
cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator) cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator)
cc_library(vol2col SRCS vol2col.cc DEPS device_context operator) cc_library(vol2col SRCS vol2col.cc DEPS device_context)
endif() endif()
cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor) cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor)
......
...@@ -18,10 +18,9 @@ limitations under the License. */ ...@@ -18,10 +18,9 @@ limitations under the License. */
template <typename Place> template <typename Place>
void testVol2col() { void testVol2col() {
paddle::framework::Tensor input_tmp;
paddle::framework::Tensor input; paddle::framework::Tensor input;
paddle::framework::Tensor output_cfo; paddle::framework::Tensor input_tmp;
paddle::framework::Tensor output_ocf; paddle::framework::Tensor output;
paddle::framework::Tensor output_tmp; paddle::framework::Tensor output_tmp;
auto* place = new Place(); auto* place = new Place();
...@@ -44,14 +43,14 @@ void testVol2col() { ...@@ -44,14 +43,14 @@ void testVol2col() {
* [6, 7, 8, * [6, 7, 8,
* 9, 10, 11]] * 9, 10, 11]]
* *
* output_cfo = [0, 1 * output = [0, 1
* 1, 2 * 1, 2
* 3, 4 * 3, 4
* 4, 5 * 4, 5
* 6, 7 * 6, 7
* 7, 8 * 7, 8
* 9, 10 * 9, 10
* 10, 11] * 10, 11]
* *
* col2vol = [[0, 2, 2, * col2vol = [[0, 2, 2,
* 3, 8, 5] * 3, 8, 5]
...@@ -81,20 +80,20 @@ void testVol2col() { ...@@ -81,20 +80,20 @@ void testVol2col() {
} else { } else {
input.CopyFrom<float>(input_tmp, *place); input.CopyFrom<float>(input_tmp, *place);
} }
output_cfo.mutable_data<float>({1, filter_size, filter_size, filter_size, output.mutable_data<float>({1, filter_size, filter_size, filter_size,
output_depth, output_height, output_width}, output_depth, output_height, output_width},
*place); *place);
paddle::operators::math::Vol2ColFunctor<Place, float> vol2col; paddle::operators::math::Vol2ColFunctor<Place, float> vol2col;
vol2col(*context, input, output_cfo, stride, stride, stride, padding, padding, vol2col(*context, input, output, stride, stride, stride, padding, padding,
padding); padding);
float vol_2_col[] = {0, 1, 1, 2, 3, 4, 4, 5, 6, 7, 7, 8, 9, 10, 10, 11}; float vol_2_col[] = {0, 1, 1, 2, 3, 4, 4, 5, 6, 7, 7, 8, 9, 10, 10, 11};
float* out_cfo_ptr; float* out_cfo_ptr;
if (paddle::platform::is_cpu_place(*place)) { if (paddle::platform::is_cpu_place(*place)) {
out_cfo_ptr = output_cfo.data<float>(); out_cfo_ptr = output.data<float>();
} else { } else {
output_tmp.CopyFrom<float>(output_cfo, paddle::platform::CPUPlace()); output_tmp.CopyFrom<float>(output, paddle::platform::CPUPlace());
out_cfo_ptr = output_tmp.data<float>(); out_cfo_ptr = output_tmp.data<float>();
} }
...@@ -112,25 +111,25 @@ void testVol2col() { ...@@ -112,25 +111,25 @@ void testVol2col() {
} }
paddle::operators::math::Col2VolFunctor<Place, float> col2vol; paddle::operators::math::Col2VolFunctor<Place, float> col2vol;
col2vol(*context, input, output_cfo, stride, stride, stride, padding, padding, col2vol(*context, input, output, stride, stride, stride, padding, padding,
padding); padding);
float* in_cfo_ptr; float* in_ptr;
if (paddle::platform::is_cpu_place(*place)) { if (paddle::platform::is_cpu_place(*place)) {
in_cfo_ptr = input.data<float>(); in_ptr = input.data<float>();
} else { } else {
input_tmp.CopyFrom<float>(input, paddle::platform::CPUPlace()); input_tmp.CopyFrom<float>(input, paddle::platform::CPUPlace());
in_cfo_ptr = input_tmp.data<float>(); in_ptr = input_tmp.data<float>();
} }
for (int i = 0; i < 12; ++i) { for (int i = 0; i < 12; ++i) {
EXPECT_EQ(in_cfo_ptr[i], col_2_vol[i]); EXPECT_EQ(in_ptr[i], col_2_vol[i]);
} }
} }
TEST(math, vol2col) { TEST(math, vol2col) {
testVol2col<paddle::platform::CPUPlace>(); testVol2col<paddle::platform::CPUPlace>();
#ifndef PADDLE_ONLY_CPU #ifdef PADDLE_WITH_CUDA
testVol2col<paddle::platform::GPUPlace>(); testVol2col<paddle::platform::GPUPlace>();
#endif #endif // PADDLE_WITH_CUDA
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册