diff --git a/paddle/operators/math/im2col.cc b/paddle/operators/math/im2col.cc index 8124e322cbf305b772f15c683c6904c673fcdee0..bcc18af0364978369f84ac9f80721f218f02e89d 100644 --- a/paddle/operators/math/im2col.cc +++ b/paddle/operators/math/im2col.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include "paddle/operators/math/im2col.h" namespace paddle { +namespace math { /* * im = [input_channels, input_height, input_width] @@ -26,7 +27,7 @@ class Im2ColFunctor { public: void operator()(const framework::Tensor& im, framework::Tensor& col, int stride_height, int stride_width, int padding_height, - int padding_width) { + int padding_width, platform::DeviceContext* context) { PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); @@ -77,7 +78,7 @@ class Col2ImFunctor { public: void operator()(framework::Tensor& im, const framework::Tensor& col, int stride_height, int stride_width, int padding_height, - int padding_width) { + int padding_width, platform::DeviceContext* context) { PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); int input_channels = im.dims()[0]; @@ -130,7 +131,7 @@ class Im2ColFunctor { public: void operator()(const framework::Tensor& im, framework::Tensor& col, int stride_height, int stride_width, int padding_height, - int padding_width) { + int padding_width, platform::DeviceContext* context) { PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); int input_channels = im.dims()[0]; @@ -189,7 +190,7 @@ class Col2ImFunctor { public: void operator()(framework::Tensor& im, const framework::Tensor& col, int stride_height, int stride_width, int padding_height, - int padding_width) { + int padding_width, platform::DeviceContext* context) { PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); int input_channels = im.dims()[0]; @@ -241,4 +242,5 @@ template class Im2ColFunctor; template class Col2ImFunctor; template class Col2ImFunctor; +} // namespace math } // namespace paddle diff --git a/paddle/operators/math/im2col.cu b/paddle/operators/math/im2col.cu index 875989af58239e6c3199ed776aed778911552312..2caa7c5ec268013f11561f2aa62a8e2bfe021590 100644 --- a/paddle/operators/math/im2col.cu +++ b/paddle/operators/math/im2col.cu @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/platform/cuda_helper.h" namespace paddle { +namespace math { template __global__ void im2col(const T* data_im, int num_outs, int height, int width, @@ -63,7 +64,7 @@ class Im2ColFunctor { public: void operator()(const framework::Tensor& im, framework::Tensor& col, int stride_height, int stride_width, int padding_height, - int padding_width) { + int padding_width, platform::DeviceContext* context) { PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); @@ -81,6 +82,7 @@ class Im2ColFunctor { int block_y = (blocks + 512 - 1) / 512; dim3 threads(1024, 1); dim3 grid(block_x, block_y); + // TODO(hedaoyuan): launch kernel on specified stream im2col<<>>( im.data(), num_outputs, input_height, input_width, filter_height, filter_width, stride_height, stride_width, padding_height, @@ -145,7 +147,7 @@ class Col2ImFunctor { public: void operator()(framework::Tensor& im, const framework::Tensor& col, int stride_height, int stride_width, int padding_height, - int padding_width) { + int padding_width, platform::DeviceContext* context) { PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); @@ -168,6 +170,7 @@ class Col2ImFunctor { // To avoid involving atomic operations, we will launch one kernel per // bottom dimension, and then in the kernel add up the top dimensions. + // TODO(hedaoyuan): launch kernel on specified stream col2im<<>>( num_kernels, col.data(), input_height + 2 * padding_height, input_width + 2 * padding_width, input_channels, filter_height, @@ -224,7 +227,7 @@ class Im2ColFunctor { public: void operator()(const framework::Tensor& im, framework::Tensor& col, int stride_height, int stride_width, int padding_height, - int padding_width) { + int padding_width, platform::DeviceContext* context) { PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); int input_channels = im.dims()[0]; @@ -255,6 +258,7 @@ class Im2ColFunctor { dim3 threads(block_dim_x, block_dim_y, std::min(block_dim_z, input_channels)); dim3 grid(output_width, output_height); + // TODO(hedaoyuan): launch kernel on specified stream im2colOCF<<>>( im.data(), col.data(), input_channels, input_height, input_width, filter_height, filter_width, stride_height, stride_width, @@ -304,7 +308,7 @@ class Col2ImFunctor { public: void operator()(framework::Tensor& im, const framework::Tensor& col, int stride_height, int stride_width, int padding_height, - int padding_width) { + int padding_width, platform::DeviceContext* context) { PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); int input_channels = im.dims()[0]; @@ -335,7 +339,8 @@ class Col2ImFunctor { dim3 threads(block_dim_x, block_dim_y, std::min(block_dim_z, input_channels)); dim3 grid(output_width, output_height); - col2imOCF<<>>( + // TODO(hedaoyuan): launch kernel on specified stream + col2imOCF<<>>( im.data(), col.data(), input_channels, input_height, input_width, filter_height, filter_width, stride_height, stride_width, padding_height, padding_width, output_height, output_width); @@ -347,4 +352,5 @@ template class Im2ColFunctor; template class Col2ImFunctor; template class Col2ImFunctor; +} // namespace math } // namespace paddle diff --git a/paddle/operators/math/im2col.h b/paddle/operators/math/im2col.h index da51bc69a2c3ecc932abfa175009e241879f6737..f6b428e2894bc83e0619b237a9cf7368f4588e09 100644 --- a/paddle/operators/math/im2col.h +++ b/paddle/operators/math/im2col.h @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/platform/device_context.h" namespace paddle { +namespace math { /* The storage format of the coldata in the Im2ColFunctor and Col2ImFunctor. */ enum ColFormat { kCFO = 0, kOCF = 1 }; @@ -72,7 +73,7 @@ class Im2ColFunctor { public: void operator()(const framework::Tensor& im, framework::Tensor& col, int stride_height, int stride_width, int padding_height, - int padding_width); + int padding_width, platform::DeviceContext* context); }; template @@ -80,7 +81,8 @@ class Col2ImFunctor { public: void operator()(framework::Tensor& im, const framework::Tensor& col, int stride_height, int stride_width, int padding_height, - int padding_width); + int padding_width, platform::DeviceContext* context); }; +} // namespace math } // namespace paddle