diff --git a/paddle/fluid/operators/conv_op.h b/paddle/fluid/operators/conv_op.h index 79d07887fb0e06e7c8411e0e72e5f11f7b06a05d..924ed1fcf7d35a81c108ee08f1471079ff28c98a 100644 --- a/paddle/fluid/operators/conv_op.h +++ b/paddle/fluid/operators/conv_op.h @@ -22,8 +22,8 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/layout_utils.h" -#include "paddle/fluid/operators/math/im2col.h" #include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/im2col.h" #include "paddle/phi/kernels/funcs/vol2col.h" namespace paddle { diff --git a/paddle/fluid/operators/im2sequence_op.h b/paddle/fluid/operators/im2sequence_op.h index 9fcf02e999d1743fbc626dc64a727c684ed3eb0f..afb4db0f3c633b033c0b66cea8f07a3308ecfe50 100644 --- a/paddle/fluid/operators/im2sequence_op.h +++ b/paddle/fluid/operators/im2sequence_op.h @@ -20,7 +20,7 @@ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/eigen/eigen_function.h" -#include "paddle/fluid/operators/math/im2col.h" +#include "paddle/phi/kernels/funcs/im2col.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { @@ -101,7 +101,8 @@ class Im2SequenceKernel : public framework::OpKernel { kernels[1]}); offset_out += output_height[i] * output_width[i]; - math::Im2ColFunctor f; + phi::funcs::Im2ColFunctor + f; auto& dev_ctx = ctx.template device_context(); f(dev_ctx, src, dilations, strides, paddings, &dst); } @@ -135,7 +136,8 @@ class Im2SequenceKernel : public framework::OpKernel { kernels[0], kernels[1]}); - math::Im2ColFunctor f; + phi::funcs::Im2ColFunctor + f; auto& dev_ctx = ctx.template device_context(); f(dev_ctx, src, dilations, strides, paddings, &dst); } @@ -190,7 +192,8 @@ class Im2SequenceGradKernel : public framework::OpKernel { d_x->Slice(i, i + 1).Resize({img_channels, img_height, img_width}); const Tensor src = d_out->Slice(i, i + 1).Resize( {output_height, output_width, img_channels, kernels[0], kernels[1]}); - math::Col2ImFunctor f; + phi::funcs::Col2ImFunctor + f; auto& dev_ctx = ctx.template device_context(); f(dev_ctx, src, dilations, strides, paddings, &dst); } diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index 1f5dd8a9b22842d54e1eed3d6db33bd9f525c2fc..e2a62273d03282809a4e795d98aaa2a0a9250536 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -23,7 +23,6 @@ endif() math_library(context_project DEPS im2col math_function) math_library(cos_sim_functor) math_library(depthwise_conv) -math_library(im2col) math_library(sample_prob) math_library(sampler DEPS generator) diff --git a/paddle/fluid/operators/math/context_project.h b/paddle/fluid/operators/math/context_project.h index f7f2cfb64aa347eabd308d76e70d291db3da83da..832be9b0efd2f6bd63c3a8d98bbcbba1705e1d15 100644 --- a/paddle/fluid/operators/math/context_project.h +++ b/paddle/fluid/operators/math/context_project.h @@ -18,8 +18,8 @@ limitations under the License. */ #include #include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/operators/math/im2col.h" #include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/im2col.h" namespace paddle { namespace operators { @@ -100,7 +100,8 @@ class ContextProjectFunctor { phi::DenseTensor* col) { auto lod_level_0 = in.lod()[0]; - math::Im2ColFunctor im2col_ocf; + phi::funcs::Im2ColFunctor + im2col_ocf; std::vector dilation({1, 1}); std::vector padding({up_pad, 0, down_pad, 0}); @@ -230,7 +231,8 @@ class ContextProjectGradFunctor { phi::DenseTensor* col) { auto lod_level_0 = in.lod()[0]; - math::Col2ImFunctor col2im_ocf; + phi::funcs::Col2ImFunctor + col2im_ocf; std::vector dilation({1, 1}); std::vector padding({up_pad, 0, down_pad, 0}); diff --git a/paddle/fluid/operators/math/im2col_test.cc b/paddle/fluid/operators/math/im2col_test.cc index 70ac7a225d6a37a98b8d31436101a644c2b00969..1fa0cb1aeb1616d59a27c4a308acab7c043c58e6 100644 --- a/paddle/fluid/operators/math/im2col_test.cc +++ b/paddle/fluid/operators/math/im2col_test.cc @@ -12,13 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/math/im2col.h" +#include "paddle/phi/kernels/funcs/im2col.h" #include -#include "paddle/fluid/operators/math/im2col_cfo_cpu.h" +#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/place.h" +#include "paddle/phi/kernels/funcs/im2col_cfo_cpu.h" template void testIm2col() { @@ -76,15 +77,9 @@ void testIm2col() { {output_height, output_width, 1, filter_size, filter_size}, *place); // Im2Col - paddle::operators::math::Im2ColFunctor< - paddle::operators::math::ColFormat::kCFO, - DeviceContext, - float> + phi::funcs::Im2ColFunctor im2col; - paddle::operators::math::Im2ColFunctor< - paddle::operators::math::ColFormat::kOCF, - DeviceContext, - float> + phi::funcs::Im2ColFunctor im2col_ocf; im2col(*context, input, dilation, stride, padding, &output_cfo); @@ -119,15 +114,9 @@ void testIm2col() { } // Col2Im: kCFO - paddle::operators::math::Col2ImFunctor< - paddle::operators::math::ColFormat::kCFO, - DeviceContext, - float> + phi::funcs::Col2ImFunctor col2im; - paddle::operators::math::Col2ImFunctor< - paddle::operators::math::ColFormat::kOCF, - DeviceContext, - float> + phi::funcs::Col2ImFunctor col2im_ocf; float col2im_data[] = {0, 2, 2, 3, 8, 5}; @@ -237,15 +226,9 @@ void testIm2col() { {output_height, output_width, 1, filter_size, filter_size}, *place); // Im2Col - paddle::operators::math::Im2ColFunctor< - paddle::operators::math::ColFormat::kCFO, - phi::GPUContext, - float> + phi::funcs::Im2ColFunctor im2col; - paddle::operators::math::Im2ColFunctor< - paddle::operators::math::ColFormat::kOCF, - phi::GPUContext, - float> + phi::funcs::Im2ColFunctor im2col_ocf; im2col(*context, input, dilation, stride, padding, &output_cfo); @@ -280,15 +263,9 @@ void testIm2col() { } // Col2Im: kCFO - paddle::operators::math::Col2ImFunctor< - paddle::operators::math::ColFormat::kCFO, - phi::GPUContext, - float> + phi::funcs::Col2ImFunctor col2im; - paddle::operators::math::Col2ImFunctor< - paddle::operators::math::ColFormat::kOCF, - phi::GPUContext, - float> + phi::funcs::Col2ImFunctor col2im_ocf; float col2im_data[] = {0, 2, 2, 3, 8, 5}; @@ -363,18 +340,15 @@ TEST(math, im2col) { int output_width = (iw - fw + padding[1] * 2) / stride[1] + 1; \ out.mutable_data({ic, fh, fw, output_height, output_width}, place); \ ref.mutable_data({ic, fh, fw, output_height, output_width}, place); \ - paddle::operators::math::Im2ColFunctor< \ - paddle::operators::math::ColFormat::kCFO, \ - phi::CPUContext, \ - float> \ - im2col + phi::funcs:: \ + Im2ColFunctor \ + im2col void testIm2colCPU(int ic, int ih, int iw, int fh, int fw, int ph, int pw) { PREPARE_IM2COL_CPU; im2col(context, input, dilation, stride, padding, &out); - paddle::operators::math::im2col_common( - input, dilation, stride, padding, &ref); + phi::funcs::im2col_common(input, dilation, stride, padding, &ref); float* ref_data = ref.data(); float* out_data = out.data(); @@ -398,8 +372,7 @@ void benchIm2col(int ic, int ih, int iw, int fh, int fw, int ph, int pw) { auto t2 = GetCurrentMs(); for (int i = 0; i < repeat; ++i) { - paddle::operators::math::im2col_common( - input, dilation, stride, padding, &ref); + phi::funcs::im2col_common(input, dilation, stride, padding, &ref); } auto t3 = GetCurrentMs(); diff --git a/paddle/phi/kernels/funcs/CMakeLists.txt b/paddle/phi/kernels/funcs/CMakeLists.txt index 41c6cf677717ded023a7f7cbf3fef1cb855eaf39..d429d4a8dad2c36a7faae7f637e609e6e75dca7e 100644 --- a/paddle/phi/kernels/funcs/CMakeLists.txt +++ b/paddle/phi/kernels/funcs/CMakeLists.txt @@ -17,6 +17,7 @@ math_library(segment_pooling) math_library(sequence2batch) math_library(matrix_solve DEPS dense_tensor eigen3 blas math_function) math_library(cross_entropy) +math_library(im2col) math_library(vol2col) cc_library( diff --git a/paddle/fluid/operators/math/im2col.cc b/paddle/phi/kernels/funcs/im2col.cc similarity index 79% rename from paddle/fluid/operators/math/im2col.cc rename to paddle/phi/kernels/funcs/im2col.cc index 39b0312e67766e7976b651ba2226b44d72b16976..71d9c49e347d0df6466796f6f4f8a4ac952fe602 100644 --- a/paddle/fluid/operators/math/im2col.cc +++ b/paddle/phi/kernels/funcs/im2col.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,17 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/math/im2col.h" - -#include "paddle/fluid/operators/math/im2col_cfo_cpu.h" +#include "paddle/phi/kernels/funcs/im2col.h" +#include "paddle/phi/kernels/funcs/im2col_cfo_cpu.h" namespace phi { class CPUContext; } // namespace phi -namespace paddle { -namespace operators { -namespace math { +namespace phi { +namespace funcs { /* * im = [input_channels, input_height, input_width] @@ -30,9 +28,7 @@ namespace math { * [input_channels, filter_height, filter_width, output_height, output_width] */ template -class Im2ColFunctor { +class Im2ColFunctor { public: void operator()(const DeviceContext& context, const phi::DenseTensor& im, @@ -43,13 +39,13 @@ class Im2ColFunctordims().size(), 5, - platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The dimension of tensor 'col' should be 5. But got " "the dims of tensor 'col' is [%s].", col->dims())); @@ -77,9 +73,7 @@ class Im2ColFunctor -class Col2ImFunctor { +class Col2ImFunctor { public: void operator()(const DeviceContext& context, const phi::DenseTensor& col, @@ -90,13 +84,13 @@ class Col2ImFunctordims().size(), 3, - platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The dimension of tensor 'im' should be 3. But got " "the dims of tensor 'im' is [%s].", im->dims())); PADDLE_ENFORCE_EQ(col.dims().size(), 5, - platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The dimension of tensor 'col' should be 5. But got " "the dims of tensor 'col' is [%s].", col.dims())); @@ -111,22 +105,22 @@ class Col2ImFunctor; -template class Im2ColFunctor; -template class Col2ImFunctor; -template class Col2ImFunctor; @@ -179,9 +173,7 @@ template class Col2ImFunctor -class Im2ColFunctor { +class Im2ColFunctor { public: void operator()(const DeviceContext& context, const phi::DenseTensor& im, @@ -192,13 +184,13 @@ class Im2ColFunctordims().size(), 5, - platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The dimension of tensor 'col' should be 5. But got " "the dims of tensor 'col' is [%s].", col->dims())); @@ -254,9 +246,7 @@ class Im2ColFunctor -class Col2ImFunctor { +class Col2ImFunctor { public: void operator()(const DeviceContext& context, const phi::DenseTensor& col, @@ -267,13 +257,13 @@ class Col2ImFunctordims().size(), 3, - platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The dimension of tensor 'im' should be 3. But got " "the dims of tensor 'im' is [%s].", im->dims())); PADDLE_ENFORCE_EQ(col.dims().size(), 5, - platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The dimension of tensor 'col' should be 5. But got " "the dims of tensor 'col' is [%s].", col.dims())); @@ -288,14 +278,14 @@ class Col2ImFunctordata(); const T* col_data = col.data(); @@ -335,18 +325,17 @@ class Col2ImFunctor; -template class Im2ColFunctor; -template class Col2ImFunctor; -template class Col2ImFunctor; -} // namespace math -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace phi diff --git a/paddle/fluid/operators/math/im2col.cu b/paddle/phi/kernels/funcs/im2col.cu similarity index 87% rename from paddle/fluid/operators/math/im2col.cu rename to paddle/phi/kernels/funcs/im2col.cu index 5c7038714e93c68b2f03d7ea96a159a5f3f4b28b..78a0b345e389fbb85e286c16a09b43a2748b7f16 100644 --- a/paddle/fluid/operators/math/im2col.cu +++ b/paddle/phi/kernels/funcs/im2col.cu @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,14 +15,13 @@ limitations under the License. */ #include #include -#include "paddle/fluid/operators/math/im2col.h" -#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" +#include "paddle/phi/kernels/funcs/im2col.h" -namespace paddle { -namespace operators { -namespace math { +namespace phi { +namespace funcs { template __global__ void im2col(const T* data_im, @@ -86,9 +85,7 @@ __global__ void im2col(const T* data_im, * [input_channels, filter_height, filter_width, output_height, output_width] */ template -class Im2ColFunctor { +class Im2ColFunctor { public: void operator()(const DeviceContext& context, const phi::DenseTensor& im, @@ -99,13 +96,13 @@ class Im2ColFunctordims().size(), 5, - platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The dimension of tensor 'col' should be 5. But got " "the dims of tensor 'col' is [%s].", col->dims())); @@ -124,7 +121,7 @@ class Im2ColFunctor -class Col2ImFunctor { +class Col2ImFunctor { public: void operator()(const DeviceContext& context, const phi::DenseTensor& col, @@ -236,13 +231,13 @@ class Col2ImFunctordims().size(), 3, - platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The dimension of tensor 'im' should be 3. But got " "the dims of tensor 'im' is [%s].", im->dims())); PADDLE_ENFORCE_EQ(col.dims().size(), 5, - platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The dimension of tensor 'col' should be 5. But got " "the dims of tensor 'col' is [%s].", col.dims())); @@ -258,28 +253,28 @@ class Col2ImFunctor; -template class Im2ColFunctor; -template class Col2ImFunctor; -template class Col2ImFunctor; @@ -367,9 +362,7 @@ __global__ void im2colOCF(const T* im_data, * [output_height, output_width, input_channels, filter_height, filter_width] */ template -class Im2ColFunctor { +class Im2ColFunctor { public: void operator()(const DeviceContext& context, const phi::DenseTensor& im, @@ -380,13 +373,13 @@ class Im2ColFunctordims().size(), 5, - platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The dimension of tensor 'col' should be 5. But got " "the dims of tensor 'col' is [%s].", col->dims())); @@ -479,9 +472,7 @@ __global__ void col2imOCF(const T* col_data, * [output_height, output_width, input_channels, filter_height, filter_width] */ template -class Col2ImFunctor { +class Col2ImFunctor { public: void operator()(const DeviceContext& context, const phi::DenseTensor& col, @@ -492,13 +483,13 @@ class Col2ImFunctordims().size(), 3, - platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The dimension of tensor 'im' should be 3. But got " "the dims of tensor 'im' is [%s].", im->dims())); PADDLE_ENFORCE_EQ(col.dims().size(), 5, - platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The dimension of tensor 'col' should be 5. But got " "the dims of tensor 'col' is [%s].", col.dims())); @@ -511,22 +502,22 @@ class Col2ImFunctor; -template class Im2ColFunctor; -template class Col2ImFunctor; -template class Col2ImFunctor; -} // namespace math -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace phi diff --git a/paddle/fluid/operators/math/im2col.h b/paddle/phi/kernels/funcs/im2col.h similarity index 91% rename from paddle/fluid/operators/math/im2col.h rename to paddle/phi/kernels/funcs/im2col.h index 3ce785a8901a6296494bc7a298f61d238d274deb..73b2866924d1e9264c474561834d8c0c1dd35aa4 100644 --- a/paddle/fluid/operators/math/im2col.h +++ b/paddle/phi/kernels/funcs/im2col.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,13 +16,13 @@ limitations under the License. */ #include -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/framework/tensor_util.h" -#include "paddle/fluid/platform/device_context.h" +#include "paddle/phi/common/layout.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/errors.h" -namespace paddle { -namespace operators { -namespace math { +namespace phi { +namespace funcs { using DataLayout = phi::DataLayout; @@ -107,6 +107,5 @@ class Col2ImFunctor { const DataLayout data_layout = DataLayout::kNCHW); }; -} // namespace math -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace phi diff --git a/paddle/fluid/operators/math/im2col_cfo_cpu.h b/paddle/phi/kernels/funcs/im2col_cfo_cpu.h similarity index 98% rename from paddle/fluid/operators/math/im2col_cfo_cpu.h rename to paddle/phi/kernels/funcs/im2col_cfo_cpu.h index bef9e0a8449f64e3ed47328f9b3993f53bd0d45c..c901cc9f551440e068e50c2ff5add75b118ca87c 100644 --- a/paddle/fluid/operators/math/im2col_cfo_cpu.h +++ b/paddle/phi/kernels/funcs/im2col_cfo_cpu.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,11 +16,10 @@ limitations under the License. */ #include -#include "paddle/fluid/framework/tensor.h" +#include "paddle/phi/core/dense_tensor.h" -namespace paddle { -namespace operators { -namespace math { +namespace phi { +namespace funcs { /** * The most common im2col algorithm. @@ -317,6 +316,5 @@ inline void im2col_sh1sw1dh1dw1ph1pw1(const phi::DenseTensor& im, } } -} // namespace math -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/impl/conv_grad_kernel_impl.h b/paddle/phi/kernels/impl/conv_grad_kernel_impl.h index e66a870c3aa25f20c5a0f0807ec012ef50bd84d4..ec75952aaae8e24a39c7562f77c1be9452d3c5d2 100644 --- a/paddle/phi/kernels/impl/conv_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/conv_grad_kernel_impl.h @@ -14,10 +14,10 @@ #pragma once -#include "paddle/fluid/operators/math/im2col.h" #include "paddle/phi/kernels/cpu/conv_util.h" #include "paddle/phi/kernels/funcs/batch_norm_utils.h" #include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/im2col.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/vol2col.h" @@ -147,10 +147,8 @@ void ConvGradKernel(const Context& dev_ctx, if (is_expand) { set_zero(dev_ctx, &transformed_input_grad, static_cast(0)); } + phi::funcs::Col2ImFunctor col2im; phi::funcs::Col2VolFunctor col2vol; - paddle::operators::math:: - Col2ImFunctor - col2im; for (int i = 0; i < batch_size; i++) { DenseTensor out_grad_batch = @@ -203,9 +201,7 @@ void ConvGradKernel(const Context& dev_ctx, Tensor filter_grad_ = *filter_grad; filter_grad_.Resize(filter_matrix_shape); set_zero(dev_ctx, filter_grad, static_cast(0)); - paddle::operators::math:: - Im2ColFunctor - im2col; + phi::funcs::Im2ColFunctor im2col; phi::funcs::Vol2ColFunctor vol2col; for (int i = 0; i < batch_size; i++) { DenseTensor out_grad_batch = @@ -381,10 +377,8 @@ void ConvGradGradKernel(const Context& dev_ctx, if (is_expand) { set_zero(dev_ctx, &transformed_dX, static_cast(0)); } + phi::funcs::Col2ImFunctor col2im; phi::funcs::Col2VolFunctor col2vol; - paddle::operators::math:: - Col2ImFunctor - col2im; for (int i = 0; i < batch_size; i++) { DenseTensor dy_batch = @@ -428,9 +422,7 @@ void ConvGradGradKernel(const Context& dev_ctx, set_zero(dev_ctx, dW, static_cast(0)); DenseTensor dW_arr = *dW; dW_arr.Resize(filter_matrix_shape); - paddle::operators::math:: - Im2ColFunctor - im2col; + phi::funcs::Im2ColFunctor im2col; phi::funcs::Vol2ColFunctor vol2col; for (int i = 0; i < batch_size; ++i) { DenseTensor dy_batch = @@ -477,9 +469,7 @@ void ConvGradGradKernel(const Context& dev_ctx, } set_zero(dev_ctx, &transformed_ddY, static_cast(0)); - paddle::operators::math:: - Im2ColFunctor - im2col; + phi::funcs::Im2ColFunctor im2col; phi::funcs::Vol2ColFunctor vol2col; for (int i = 0; i < batch_size; ++i) { DenseTensor ddy_batch = diff --git a/paddle/phi/kernels/impl/conv_kernel_impl.h b/paddle/phi/kernels/impl/conv_kernel_impl.h index 59bea1d0564c601a5019688d7a2dd87f8c5f722c..06ba3104a81124c1dcfd4440ab1011fdc0f37618 100644 --- a/paddle/phi/kernels/impl/conv_kernel_impl.h +++ b/paddle/phi/kernels/impl/conv_kernel_impl.h @@ -14,11 +14,11 @@ #pragma once -#include "paddle/fluid/operators/math/im2col.h" #include "paddle/phi/kernels/conv_kernel.h" #include "paddle/phi/kernels/cpu/conv_util.h" #include "paddle/phi/kernels/funcs/batch_norm_utils.h" #include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/im2col.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/vol2col.h" @@ -133,10 +133,8 @@ void ConvKernelImpl(const Context& dev_ctx, int in_step = static_cast(transformed_input.dims()[1]) / groups; int out_step = static_cast(transformed_output.dims()[1]) / groups; + phi::funcs::Im2ColFunctor im2col; phi::funcs::Vol2ColFunctor vol2col; - paddle::operators::math:: - Im2ColFunctor - im2col; auto blas = phi::funcs::GetBlas(dev_ctx); for (int i = 0; i < batch_size; i++) { diff --git a/paddle/phi/kernels/impl/conv_transpose_grad_kernel_impl.h b/paddle/phi/kernels/impl/conv_transpose_grad_kernel_impl.h index e25a6fd56ee2aafaf9da0f5d4028247fb027de35..64810d82f003438c6be1de410f816654aeb8029a 100644 --- a/paddle/phi/kernels/impl/conv_transpose_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/conv_transpose_grad_kernel_impl.h @@ -14,13 +14,13 @@ #pragma once -#include "paddle/fluid/operators/math/im2col.h" #include "paddle/phi/common/layout.h" #include "paddle/phi/core/ddim.h" #include "paddle/phi/kernels/conv_transpose_grad_kernel.h" #include "paddle/phi/kernels/cpu/conv_util.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/concat_and_split_functor.h" +#include "paddle/phi/kernels/funcs/im2col.h" #include "paddle/phi/kernels/funcs/slice.h" #include "paddle/phi/kernels/funcs/vol2col.h" @@ -143,9 +143,7 @@ void ConvTransposeGradRawKernel(const Context& ctx, DenseTensor dfilter_; funcs::SetConstant set_zero; - paddle::operators::math:: - Im2ColFunctor - im2col; + phi::funcs::Im2ColFunctor im2col; phi::funcs::Vol2ColFunctor vol2col; funcs::ConcatFunctor concat_functor; diff --git a/paddle/phi/kernels/impl/conv_transpose_kernel_impl.h b/paddle/phi/kernels/impl/conv_transpose_kernel_impl.h index a854bf3ee70dea2c3102069920c56c02792dff13..819b1afcf6bb6d0c28071cefe3475f2977f1bcff 100644 --- a/paddle/phi/kernels/impl/conv_transpose_kernel_impl.h +++ b/paddle/phi/kernels/impl/conv_transpose_kernel_impl.h @@ -14,13 +14,13 @@ #pragma once -#include "paddle/fluid/operators/math/im2col.h" #include "paddle/phi/common/layout.h" #include "paddle/phi/core/ddim.h" #include "paddle/phi/kernels/conv_transpose_kernel.h" #include "paddle/phi/kernels/cpu/conv_util.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/concat_and_split_functor.h" +#include "paddle/phi/kernels/funcs/im2col.h" #include "paddle/phi/kernels/funcs/slice.h" #include "paddle/phi/kernels/funcs/vol2col.h" @@ -136,9 +136,7 @@ void ConvTransposeRawKernel(const Context& ctx, (data_layout != DataLayout::kNHWC ? static_cast(out_dims[1]) / groups : static_cast(out_dims[out_dims.size() - 1]) / groups); - paddle::operators::math:: - Col2ImFunctor - col2im; + phi::funcs::Col2ImFunctor col2im; phi::funcs::Col2VolFunctor col2vol; funcs::ConcatFunctor concat_functor; diff --git a/paddle/phi/kernels/impl/fold_grad_kernel_impl.h b/paddle/phi/kernels/impl/fold_grad_kernel_impl.h index b9320eab85046f0cf353ac464d89fab8be581a45..7118de3174f7db4d75ec3f80c73956eb06ccb7c2 100644 --- a/paddle/phi/kernels/impl/fold_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/fold_grad_kernel_impl.h @@ -16,8 +16,8 @@ #include -#include "paddle/fluid/operators/math/im2col.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/im2col.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/unfold_functor.h" @@ -60,9 +60,7 @@ void FoldGradKernel(const Context& ctx, output_height, output_width}); - paddle::operators::math:: - Im2ColFunctor - im2col; + phi::funcs::Im2ColFunctor im2col; for (int i = 0; i < batch_size; i++) { DenseTensor out_grad_batch = out_grad.Slice(i, i + 1).Resize(out_shape); diff --git a/paddle/phi/kernels/impl/fold_kernel_impl.h b/paddle/phi/kernels/impl/fold_kernel_impl.h index 415beca7bd928367736b6365f29ae7309d680abc..21864b00cae765279218fb547e1bc9c273f83ee2 100644 --- a/paddle/phi/kernels/impl/fold_kernel_impl.h +++ b/paddle/phi/kernels/impl/fold_kernel_impl.h @@ -16,9 +16,9 @@ #include -#include "paddle/fluid/operators/math/im2col.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/enforce.h" +#include "paddle/phi/kernels/funcs/im2col.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/unfold_functor.h" @@ -36,9 +36,7 @@ void FoldKernel(const Context& ctx, const int batch_size = static_cast(x.dims()[0]); ctx.template Alloc(out); - paddle::operators::math:: - Col2ImFunctor - col2im; + phi::funcs::Col2ImFunctor col2im; const auto& x_dims = x.dims(); int output_height = (output_sizes[0] + 2 * paddings[0] - diff --git a/paddle/phi/kernels/impl/unfold_grad_kernel_impl.h b/paddle/phi/kernels/impl/unfold_grad_kernel_impl.h index 66fa2a4dc04f556f8de9cae23e34ba12f539e25e..78bd068041dd5909c466e125e143d86a37ca8dc0 100644 --- a/paddle/phi/kernels/impl/unfold_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/unfold_grad_kernel_impl.h @@ -16,8 +16,8 @@ #include -#include "paddle/fluid/operators/math/im2col.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/im2col.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/unfold_functor.h" @@ -56,9 +56,7 @@ void UnfoldGradKernel(const Context& ctx, DDim out_matrix_shape = make_ddim( {x_dims[1], kernel_sizes[0], kernel_sizes[1], out_height, out_width}); - paddle::operators::math:: - Col2ImFunctor - col2im; + phi::funcs::Col2ImFunctor col2im; phi::funcs::SetConstant set_zero; set_zero(ctx, x_grad, static_cast(0)); diff --git a/paddle/phi/kernels/impl/unfold_kernel_impl.h b/paddle/phi/kernels/impl/unfold_kernel_impl.h index 3b75e149f48e28728810c38b9e51e23910753b07..7b7e9923d0004de60866257aa4a01538f0b0e614 100644 --- a/paddle/phi/kernels/impl/unfold_kernel_impl.h +++ b/paddle/phi/kernels/impl/unfold_kernel_impl.h @@ -16,8 +16,8 @@ #include -#include "paddle/fluid/operators/math/im2col.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/im2col.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/unfold_functor.h" @@ -34,9 +34,7 @@ void UnfoldKernel(const Context& ctx, const int batch_size = static_cast(x.dims()[0]); ctx.template Alloc(out); - paddle::operators::math:: - Im2ColFunctor - im2col; + phi::funcs::Im2ColFunctor im2col; const auto& x_dims = x.dims(); int out_height = phi::funcs::CalcOutputSize(x_dims[2],