diff --git a/paddle/operators/math/im2col.cu b/paddle/operators/math/im2col.cu index 23d3854610555dfff70a93759a85cf03c756795f..64ecd4e2157436a1671e3504d7ad0705fcbc6252 100644 --- a/paddle/operators/math/im2col.cu +++ b/paddle/operators/math/im2col.cu @@ -83,8 +83,9 @@ class Im2ColFunctor { int block_y = (blocks + 512 - 1) / 512; dim3 threads(1024, 1); dim3 grid(block_x, block_y); - // TODO(hedaoyuan): launch kernel on specified stream - im2col<<>>( + im2col<<< + grid, threads, 0, + reinterpret_cast(context)->stream()>>>( im.data(), num_outputs, input_height, input_width, filter_height, filter_width, stride_height, stride_width, padding_height, padding_width, output_height, output_width, col.data()); @@ -171,8 +172,9 @@ class Col2ImFunctor { // To avoid involving atomic operations, we will launch one kernel per // bottom dimension, and then in the kernel add up the top dimensions. - // TODO(hedaoyuan): launch kernel on specified stream - col2im<<>>( + col2im<<< + grid, threads, 0, + reinterpret_cast(context)->stream()>>>( num_kernels, col.data(), input_height + 2 * padding_height, input_width + 2 * padding_width, input_channels, filter_height, filter_width, stride_height, stride_width, padding_height, @@ -259,8 +261,9 @@ class Im2ColFunctor { dim3 threads(block_dim_x, block_dim_y, std::min(block_dim_z, input_channels)); dim3 grid(output_width, output_height); - // TODO(hedaoyuan): launch kernel on specified stream - im2colOCF<<>>( + im2colOCF<<< + grid, threads, 0, + reinterpret_cast(context)->stream()>>>( im.data(), col.data(), input_channels, input_height, input_width, filter_height, filter_width, stride_height, stride_width, padding_height, padding_width, output_height, output_width); @@ -340,8 +343,9 @@ class Col2ImFunctor { dim3 threads(block_dim_x, block_dim_y, std::min(block_dim_z, input_channels)); dim3 grid(output_width, output_height); - // TODO(hedaoyuan): launch kernel on specified stream - col2imOCF<<>>( + col2imOCF<<< + grid, threads, 0, + reinterpret_cast(context)->stream()>>>( im.data(), col.data(), input_channels, input_height, input_width, filter_height, filter_width, stride_height, stride_width, padding_height, padding_width, output_height, output_width); diff --git a/paddle/operators/math/im2col_test.cc b/paddle/operators/math/im2col_test.cc index 4a9deb721027423c0b538421bf124b73bcc1a660..ee5fb98acdfd109369b8953ee1a19599d914e290 100644 --- a/paddle/operators/math/im2col_test.cc +++ b/paddle/operators/math/im2col_test.cc @@ -16,19 +16,13 @@ limitations under the License. */ #include #include -TEST(math, im2col) { +template +void testIm2col() { + paddle::framework::Tensor input_tmp; paddle::framework::Tensor input; paddle::framework::Tensor output_cfo; paddle::framework::Tensor output_ocf; - paddle::framework::Tensor input_check; - - int input_height = 2; - int input_width = 3; - int filter_size = 2; - int stride = 1; - int padding = 0; - int output_height = (input_height - filter_size + 2 * padding) / stride + 1; - int output_width = (input_width - filter_size + 2 * padding) / stride + 1; + paddle::framework::Tensor output_tmp; /** * input = [0, 1, 2, @@ -42,31 +36,54 @@ TEST(math, im2col) { * output_ocf = [0, 1, 3, 4 * 1, 2, 4, 5] */ - auto* cpu_place = new paddle::platform::CPUPlace(); - float* input_ptr = - input.mutable_data({1, input_height, input_width}, *cpu_place); + int input_height = 2; + int input_width = 3; + int filter_size = 2; + int stride = 1; + int padding = 0; + int output_height = (input_height - filter_size + 2 * padding) / stride + 1; + int output_width = (input_width - filter_size + 2 * padding) / stride + 1; + float* input_ptr = input_tmp.mutable_data( + {1, input_height, input_width}, paddle::platform::CPUPlace()); float arr[6] = {0, 1, 2, 3, 4, 5}; memcpy(input_ptr, arr, 6 * sizeof(float)); + + auto* place = new Place(); + if (paddle::platform::is_cpu_place(*place)) { + input = input_tmp; + } else { + input.CopyFrom(input_tmp, *place); + } output_cfo.mutable_data( - {1, filter_size, filter_size, output_height, output_width}, *cpu_place); + {1, filter_size, filter_size, output_height, output_width}, *place); output_ocf.mutable_data( - {output_height, output_width, 1, filter_size, filter_size}, *cpu_place); + {output_height, output_width, 1, filter_size, filter_size}, *place); paddle::operators::math::Im2ColFunctor< - paddle::operators::math::ColFormat::kCFO, paddle::platform::CPUPlace, - float> + paddle::operators::math::ColFormat::kCFO, Place, float> im2col; paddle::operators::math::Im2ColFunctor< - paddle::operators::math::ColFormat::kOCF, paddle::platform::CPUPlace, - float> + paddle::operators::math::ColFormat::kOCF, Place, float> im2col_ocf; - paddle::platform::DeviceContext* context = - new paddle::platform::CPUDeviceContext(*cpu_place); + paddle::platform::DeviceContext* context; + if (paddle::platform::is_cpu_place(*place)) { + context = + new paddle::platform::CPUDeviceContext(paddle::platform::CPUPlace()); + } else { + context = + new paddle::platform::CUDADeviceContext(paddle::platform::GPUPlace()); + } im2col(input, output_cfo, stride, stride, padding, padding, context); im2col_ocf(input, output_ocf, stride, stride, padding, padding, context); - float* out_cfo_ptr = output_cfo.data(); + float* out_cfo_ptr; + if (paddle::platform::is_cpu_place(*place)) { + out_cfo_ptr = output_cfo.data(); + } else { + output_tmp.CopyFrom(output_cfo, paddle::platform::CPUPlace()); + out_cfo_ptr = output_tmp.data(); + } EXPECT_EQ(out_cfo_ptr[0], 0); EXPECT_EQ(out_cfo_ptr[1], 1); EXPECT_EQ(out_cfo_ptr[2], 1); @@ -76,7 +93,13 @@ TEST(math, im2col) { EXPECT_EQ(out_cfo_ptr[6], 4); EXPECT_EQ(out_cfo_ptr[7], 5); - float* out_ocf_ptr = output_ocf.data(); + float* out_ocf_ptr; + if (paddle::platform::is_cpu_place(*place)) { + out_ocf_ptr = output_ocf.data(); + } else { + output_tmp.CopyFrom(output_ocf, paddle::platform::CPUPlace()); + out_ocf_ptr = output_tmp.data(); + } EXPECT_EQ(out_ocf_ptr[0], 0); EXPECT_EQ(out_ocf_ptr[1], 1); EXPECT_EQ(out_ocf_ptr[2], 3); @@ -86,3 +109,10 @@ TEST(math, im2col) { EXPECT_EQ(out_ocf_ptr[6], 4); EXPECT_EQ(out_ocf_ptr[7], 5); } + +TEST(math, im2col) { + testIm2col(); +#ifndef PADDLE_ONLY_CPU + testIm2col(); +#endif +} \ No newline at end of file