diff --git a/paddle/operators/gemm_conv2d_op.h b/paddle/operators/gemm_conv2d_op.h index f9215b46fcd9f091f88865df721481698ac0c8e7..5c9e81732aa72211c2021382cf9a907880c53c17 100644 --- a/paddle/operators/gemm_conv2d_op.h +++ b/paddle/operators/gemm_conv2d_op.h @@ -84,8 +84,8 @@ class GemmConv2DKernel : public framework::OpKernel { for (int g = 0; g < groups; g++) { // im2col Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); - im2col(in_slice, col, strides[0], strides[1], paddings[0], paddings[1], - context.device_context()); + im2col(context.device_context(), in_slice, col, strides[0], strides[1], + paddings[0], paddings[1]); // gemm Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); @@ -185,8 +185,8 @@ class GemmConvGrad2DKernel : public framework::OpKernel { // col2im Tensor in_grad_slice = in_grad_batch.Slice(g * in_step, (g + 1) * in_step); - col2im(in_grad_slice, col, strides[0], strides[1], paddings[0], - paddings[1], context.device_context()); + col2im(context.device_context(), in_grad_slice, col, strides[0], + strides[1], paddings[0], paddings[1]); } } } @@ -207,8 +207,8 @@ class GemmConvGrad2DKernel : public framework::OpKernel { Tensor out_grad_slice = out_grad_batch.Slice(g * out_step, (g + 1) * out_step); Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); - im2col(in_slice, col, strides[0], strides[1], paddings[0], - paddings[1], context.device_context()); + im2col(context.device_context(), in_slice, col, strides[0], + strides[1], paddings[0], paddings[1]); // gemm Tensor filter_grad_slice = diff --git a/paddle/operators/math/im2col.cc b/paddle/operators/math/im2col.cc index 36a07f7a3169be2a7df7192cf7680fc6262cb4a1..c08a3380f042886cd400df0d840e61856274619c 100644 --- a/paddle/operators/math/im2col.cc +++ b/paddle/operators/math/im2col.cc @@ -27,9 +27,10 @@ template class Im2ColFunctor { public: - void operator()(const framework::Tensor& im, framework::Tensor& col, + void operator()(const platform::DeviceContext& context, + const framework::Tensor& im, framework::Tensor& col, int stride_height, int stride_width, int padding_height, - int padding_width, const platform::DeviceContext& context) { + int padding_width) { PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); @@ -79,9 +80,9 @@ template class Col2ImFunctor { public: - void operator()(framework::Tensor& im, const framework::Tensor& col, - int stride_height, int stride_width, int padding_height, - int padding_width, const platform::DeviceContext& context) { + void operator()(const platform::DeviceContext& context, framework::Tensor& im, + const framework::Tensor& col, int stride_height, + int stride_width, int padding_height, int padding_width) { PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); int input_channels = im.dims()[0]; @@ -137,9 +138,10 @@ template class Im2ColFunctor { public: - void operator()(const framework::Tensor& im, framework::Tensor& col, + void operator()(const platform::DeviceContext& context, + const framework::Tensor& im, framework::Tensor& col, int stride_height, int stride_width, int padding_height, - int padding_width, const platform::DeviceContext& context) { + int padding_width) { PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); int input_channels = im.dims()[0]; @@ -197,9 +199,9 @@ template class Col2ImFunctor { public: - void operator()(framework::Tensor& im, const framework::Tensor& col, - int stride_height, int stride_width, int padding_height, - int padding_width, const platform::DeviceContext& context) { + void operator()(const platform::DeviceContext& context, framework::Tensor& im, + const framework::Tensor& col, int stride_height, + int stride_width, int padding_height, int padding_width) { PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); int input_channels = im.dims()[0]; diff --git a/paddle/operators/math/im2col.cu b/paddle/operators/math/im2col.cu index b433c8f8e84565351822d3a1eab27d77fc255f40..01f60bfe70f844fdcfd5aa481c27d9f12ec51305 100644 --- a/paddle/operators/math/im2col.cu +++ b/paddle/operators/math/im2col.cu @@ -64,9 +64,10 @@ template class Im2ColFunctor { public: - void operator()(const framework::Tensor& im, framework::Tensor& col, + void operator()(const platform::DeviceContext& context, + const framework::Tensor& im, framework::Tensor& col, int stride_height, int stride_width, int padding_height, - int padding_width, const platform::DeviceContext& context) { + int padding_width) { PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); @@ -149,9 +150,9 @@ template class Col2ImFunctor { public: - void operator()(framework::Tensor& im, const framework::Tensor& col, - int stride_height, int stride_width, int padding_height, - int padding_width, const platform::DeviceContext& context) { + void operator()(const platform::DeviceContext& context, framework::Tensor& im, + const framework::Tensor& col, int stride_height, + int stride_width, int padding_height, int padding_width) { PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); @@ -235,9 +236,10 @@ template class Im2ColFunctor { public: - void operator()(const framework::Tensor& im, framework::Tensor& col, + void operator()(const platform::DeviceContext& context, + const framework::Tensor& im, framework::Tensor& col, int stride_height, int stride_width, int padding_height, - int padding_width, const platform::DeviceContext& context) { + int padding_width) { PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); int input_channels = im.dims()[0]; @@ -318,9 +320,9 @@ template class Col2ImFunctor { public: - void operator()(framework::Tensor& im, const framework::Tensor& col, - int stride_height, int stride_width, int padding_height, - int padding_width, const platform::DeviceContext& context) { + void operator()(const platform::DeviceContext& context, framework::Tensor& im, + const framework::Tensor& col, int stride_height, + int stride_width, int padding_height, int padding_width) { PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); int input_channels = im.dims()[0]; diff --git a/paddle/operators/math/im2col.h b/paddle/operators/math/im2col.h index 9a119c6894f2e7722866f5bd47586a10af2b23d8..7b717e1603c94cd77c74cb0d86f1d23e2692f9d8 100644 --- a/paddle/operators/math/im2col.h +++ b/paddle/operators/math/im2col.h @@ -72,17 +72,18 @@ enum class ColFormat { kCFO = 0, kOCF = 1 }; template class Im2ColFunctor { public: - void operator()(const framework::Tensor& im, framework::Tensor& col, + void operator()(const platform::DeviceContext& context, + const framework::Tensor& im, framework::Tensor& col, int stride_height, int stride_width, int padding_height, - int padding_width, const platform::DeviceContext& context); + int padding_width); }; template class Col2ImFunctor { public: - void operator()(framework::Tensor& im, const framework::Tensor& col, - int stride_height, int stride_width, int padding_height, - int padding_width, const platform::DeviceContext& context); + void operator()(const platform::DeviceContext& context, framework::Tensor& im, + const framework::Tensor& col, int stride_height, + int stride_width, int padding_height, int padding_width); }; } // namespace math diff --git a/paddle/operators/math/im2col_test.cc b/paddle/operators/math/im2col_test.cc index e0943c0379bd5ee88c17f4eb97cc2bc877083cc0..f0b8c885918afe7f80edc465c6d9be7c11ac066f 100644 --- a/paddle/operators/math/im2col_test.cc +++ b/paddle/operators/math/im2col_test.cc @@ -78,8 +78,8 @@ void testIm2col() { PADDLE_THROW("no GPU support"); #endif // PADDLE_ONLY_CPU } - im2col(input, output_cfo, stride, stride, padding, padding, *context); - im2col_ocf(input, output_ocf, stride, stride, padding, padding, *context); + im2col(*context, input, output_cfo, stride, stride, padding, padding); + im2col_ocf(*context, input, output_ocf, stride, stride, padding, padding); float* out_cfo_ptr; if (paddle::platform::is_cpu_place(*place)) {