diff --git a/paddle/fluid/operators/conv_cudnn_op.cu b/paddle/fluid/operators/conv_cudnn_op.cu index dff60afd74c02f458b5b3c7428c2703197b61af0..2055bf560e69ca0ed354aadd00cdca331c22c76e 100644 --- a/paddle/fluid/operators/conv_cudnn_op.cu +++ b/paddle/fluid/operators/conv_cudnn_op.cu @@ -25,10 +25,10 @@ limitations under the License. */ #include "paddle/fluid/operators/conv_cudnn_helper.h" #endif #include "paddle/fluid/operators/conv_op.h" -#include "paddle/fluid/operators/math/padding.h" #include "paddle/fluid/platform/cudnn_workspace_helper.h" #include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/profiler/event_tracing.h" +#include "paddle/phi/kernels/funcs/padding.h" DECLARE_bool(cudnn_deterministic); DECLARE_uint64(conv_workspace_size_limit); @@ -148,7 +148,7 @@ class CUDNNConvOpKernel : public framework::OpKernel { in_data_dims, strides, ksize); int data_dim = strides.size(); // 2d or 3d - bool is_sys_pad = math::IsSymmetricPadding(paddings, data_dim); + bool is_sys_pad = phi::funcs::IsSymmetricPadding(paddings, data_dim); Tensor transformed_input; std::vector padding_common(data_dim, 0); @@ -196,13 +196,13 @@ class CUDNNConvOpKernel : public framework::OpKernel { T pad_value(0.0); switch (rank) { case 4: { - math::PadFunction( - ctx, input_pad, transformed_input_channel, pad_value, + phi::funcs::PadFunction( + dev_ctx, input_pad, transformed_input_channel, pad_value, &transformed_input); } break; case 5: { - math::PadFunction( - ctx, input_pad, transformed_input_channel, pad_value, + phi::funcs::PadFunction( + dev_ctx, input_pad, transformed_input_channel, pad_value, &transformed_input); } break; default: @@ -488,7 +488,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { // cuDNN only supports padding the same amount on every dimension. // So we create a new padded input tensor. int data_dim = strides.size(); // 2d or 3d - bool is_sys_pad = math::IsSymmetricPadding(paddings, data_dim); + bool is_sys_pad = phi::funcs::IsSymmetricPadding(paddings, data_dim); Tensor transformed_input(input->type()); Tensor transformed_input_grad(input->type()); std::vector padding_common(data_dim, 0); @@ -544,13 +544,13 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { T pad_value(0.0); switch (rank) { case 4: { - math::PadFunction( - ctx, input_pad, transformed_input_channel, pad_value, + phi::funcs::PadFunction( + dev_ctx, input_pad, transformed_input_channel, pad_value, &transformed_input); } break; case 5: { - math::PadFunction( - ctx, input_pad, transformed_input_channel, pad_value, + phi::funcs::PadFunction( + dev_ctx, input_pad, transformed_input_channel, pad_value, &transformed_input); } break; default: @@ -956,7 +956,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { in_data_dims, strides, ksize); int data_dim = strides.size(); // 2d or 3d - bool is_sys_pad = math::IsSymmetricPadding(paddings, data_dim); + bool is_sys_pad = phi::funcs::IsSymmetricPadding(paddings, data_dim); Tensor transformed_X(X->type()); Tensor transformed_ddX(X->type()); @@ -1004,20 +1004,22 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel { T pad_value(0.0); switch (rank) { case 4: { - math::PadFunction( - ctx, input_pad, transformed_X_channel, pad_value, &transformed_X); + phi::funcs::PadFunction( + dev_ctx, input_pad, transformed_X_channel, pad_value, + &transformed_X); if (ddX) { - math::PadFunction( - ctx, input_pad, transformed_ddX_channel, pad_value, + phi::funcs::PadFunction( + dev_ctx, input_pad, transformed_ddX_channel, pad_value, &transformed_ddX); } } break; case 5: { - math::PadFunction( - ctx, input_pad, transformed_X_channel, pad_value, &transformed_X); + phi::funcs::PadFunction( + dev_ctx, input_pad, transformed_X_channel, pad_value, + &transformed_X); if (ddX) { - math::PadFunction( - ctx, input_pad, transformed_ddX_channel, pad_value, + phi::funcs::PadFunction( + dev_ctx, input_pad, transformed_ddX_channel, pad_value, &transformed_ddX); } } break; diff --git a/paddle/fluid/operators/conv_transpose_cudnn_op.cu b/paddle/fluid/operators/conv_transpose_cudnn_op.cu index 4b8f9d7e6ca8d2f1dae99f1d034c53daf948f922..141a99f60f104c3bf32e16a1254d0f5eec623645 100644 --- a/paddle/fluid/operators/conv_transpose_cudnn_op.cu +++ b/paddle/fluid/operators/conv_transpose_cudnn_op.cu @@ -21,8 +21,8 @@ limitations under the License. */ #include "paddle/fluid/operators/conv_cudnn_helper.h" #endif #include "paddle/fluid/operators/conv_transpose_op.h" -#include "paddle/fluid/operators/math/padding.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/padding.h" namespace paddle { namespace operators { @@ -108,7 +108,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel { in_data_dims, strides, ksize); int data_dim = strides.size(); // 2d or 3d - bool is_sys_pad = math::IsSymmetricPadding(paddings, data_dim); + bool is_sys_pad = phi::funcs::IsSymmetricPadding(paddings, data_dim); std::vector input_pad(input_transpose.dims().size() * 2, 0); Tensor transformed_input; @@ -139,12 +139,14 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel { T pad_value(0.0); switch (rank) { case 4: { - math::PadFunction( - ctx, input_pad, input_transpose, pad_value, &transformed_input); + phi::funcs::PadFunction( + dev_ctx, input_pad, input_transpose, pad_value, + &transformed_input); } break; case 5: { - math::PadFunction( - ctx, input_pad, input_transpose, pad_value, &transformed_input); + phi::funcs::PadFunction( + dev_ctx, input_pad, input_transpose, pad_value, + &transformed_input); } break; default: PADDLE_THROW(platform::errors::InvalidArgument( @@ -375,7 +377,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel { in_data_dims, strides, ksize); int data_dim = strides.size(); // 2d or 3d - bool is_sys_pad = math::IsSymmetricPadding(paddings, data_dim); + bool is_sys_pad = phi::funcs::IsSymmetricPadding(paddings, data_dim); std::vector input_pad(input_transpose.dims().size() * 2, 0); Tensor transformed_output_grad; @@ -407,13 +409,13 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel { T pad_value(0.0); switch (rank) { case 4: { - math::PadFunction( - ctx, input_pad, output_grad_transpose, pad_value, + phi::funcs::PadFunction( + dev_ctx, input_pad, output_grad_transpose, pad_value, &transformed_output_grad); } break; case 5: { - math::PadFunction( - ctx, input_pad, output_grad_transpose, pad_value, + phi::funcs::PadFunction( + dev_ctx, input_pad, output_grad_transpose, pad_value, &transformed_output_grad); } break; default: @@ -735,7 +737,7 @@ class CUDNNConvTransposeDoubleGradOpKernel : public framework::OpKernel { in_data_dims, strides, ksize); int data_dim = strides.size(); // 2d or 3d - bool is_sys_pad = math::IsSymmetricPadding(paddings, data_dim); + bool is_sys_pad = phi::funcs::IsSymmetricPadding(paddings, data_dim); Tensor transformed_X(X->type()); Tensor transformed_ddX(X->type()); @@ -794,26 +796,28 @@ class CUDNNConvTransposeDoubleGradOpKernel : public framework::OpKernel { T pad_value(0.0); switch (rank) { case 4: { - math::PadFunction( - ctx, input_pad, transformed_X_channel, pad_value, &transformed_X); + phi::funcs::PadFunction( + dev_ctx, input_pad, transformed_X_channel, pad_value, + &transformed_X); if (dO) { - math::PadFunction( - ctx, input_pad, transformed_dO_channel, pad_value, + phi::funcs::PadFunction( + dev_ctx, input_pad, transformed_dO_channel, pad_value, &transformed_dO); } if (ddX) { - math::PadFunction( - ctx, input_pad, transformed_ddX_channel, pad_value, + phi::funcs::PadFunction( + dev_ctx, input_pad, transformed_ddX_channel, pad_value, &transformed_ddX); } } break; case 5: { - math::PadFunction( - ctx, input_pad, transformed_X_channel, pad_value, &transformed_X); + phi::funcs::PadFunction( + dev_ctx, input_pad, transformed_X_channel, pad_value, + &transformed_X); if (ddX) { - math::PadFunction( - ctx, input_pad, transformed_ddX_channel, pad_value, + phi::funcs::PadFunction( + dev_ctx, input_pad, transformed_ddX_channel, pad_value, &transformed_ddX); } } break; diff --git a/paddle/fluid/operators/fused/conv_fusion_op.cu b/paddle/fluid/operators/fused/conv_fusion_op.cu index bb5b363fe83995faf69f61b0a1a1693ff758fa37..5dbf4fb88b2a78838ce0fe95be653f68f4805416 100644 --- a/paddle/fluid/operators/fused/conv_fusion_op.cu +++ b/paddle/fluid/operators/fused/conv_fusion_op.cu @@ -17,8 +17,8 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/conv_cudnn_op_cache.h" #include "paddle/fluid/operators/conv_op.h" -#include "paddle/fluid/operators/math/padding.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h" +#include "paddle/phi/kernels/funcs/padding.h" DECLARE_int64(cudnn_exhaustive_search_times); @@ -86,7 +86,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel { in_data_dims, strides, ksize); int data_dim = strides.size(); // 2d or 3d - bool is_sys_pad = math::IsSymmetricPadding(paddings, data_dim); + bool is_sys_pad = phi::funcs::IsSymmetricPadding(paddings, data_dim); Tensor transformed_input; std::vector padding_common(data_dim, 0); @@ -118,13 +118,13 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel { T pad_value(0.0); switch (rank) { case 4: { - math::PadFunction( - ctx, input_pad, transformed_input_channel, pad_value, + phi::funcs::PadFunction( + dev_ctx, input_pad, transformed_input_channel, pad_value, &transformed_input); } break; case 5: { - math::PadFunction( - ctx, input_pad, transformed_input_channel, pad_value, + phi::funcs::PadFunction( + dev_ctx, input_pad, transformed_input_channel, pad_value, &transformed_input); } break; default: diff --git a/paddle/fluid/operators/pad_constant_like_op.h b/paddle/fluid/operators/pad_constant_like_op.h index 5df167fdf726345074cdc40afd0c5b394467578f..0aedd800e1a237d4baf0092eef9bac9f7dbe862d 100644 --- a/paddle/fluid/operators/pad_constant_like_op.h +++ b/paddle/fluid/operators/pad_constant_like_op.h @@ -20,7 +20,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor_util.h" -#include "paddle/fluid/operators/math/padding.h" +#include "paddle/phi/kernels/funcs/padding.h" namespace paddle { namespace operators { @@ -50,8 +50,9 @@ class PadConstantLikeKernel : public framework::OpKernel { pads[j * 2 + 1] = static_cast(in_x->dims()[j] - in_y->dims()[j]); } - math::PaddingFunctor(rank, context, pads, pad_value, - *in_y, out); + phi::funcs::PaddingFunctor( + rank, context.template device_context(), pads, pad_value, + *in_y, out); } }; @@ -82,8 +83,9 @@ class PadConstantLikeGradKernel : public framework::OpKernel { pads[j * 2 + 1] = static_cast(in_dout->dims()[j] - in_y->dims()[j]); } - math::PaddingGradFunctor(rank, context, pads, *in_dout, - d_y); + phi::funcs::PaddingGradFunctor( + rank, context.template device_context(), pads, *in_dout, + d_y); } }; diff --git a/paddle/fluid/operators/pad_op.cc b/paddle/fluid/operators/pad_op.cc index 39acba7e58aba51942d7d8de2d89e2783fd591f9..229e61ac9fe79d3c171d1f0612f22f3590587231 100644 --- a/paddle/fluid/operators/pad_op.cc +++ b/paddle/fluid/operators/pad_op.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/pad_op.h" #include +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/complex.h" namespace paddle { @@ -167,40 +167,3 @@ REGISTER_OPERATOR(pad, ops::PadOp, ops::PadOpMaker, REGISTER_OPERATOR(pad_grad, ops::PadOpGrad, ops::PadOpDoubleGradMaker, ops::PadOpDoubleGradMaker); -REGISTER_OP_CPU_KERNEL( - pad, ops::PadKernel, - ops::PadKernel, - ops::PadKernel, - ops::PadKernel, - ops::PadKernel>, - ops::PadKernel>); -REGISTER_OP_CPU_KERNEL( - pad_grad, ops::PadGradKernel, - ops::PadGradKernel, - ops::PadGradKernel>, - ops::PadGradKernel>); - -REGISTER_OP_CUDA_KERNEL( - pad, ops::PadKernel, - ops::PadKernel, - ops::PadKernel, - ops::PadKernel, - ops::PadKernel, - ops::PadKernel>, - ops::PadKernel>); -REGISTER_OP_CUDA_KERNEL( - pad_grad, ops::PadGradKernel, - ops::PadGradKernel, - ops::PadGradKernel, - ops::PadGradKernel>, - ops::PadGradKernel>); diff --git a/paddle/fluid/operators/pad_op.h b/paddle/fluid/operators/pad_op.h deleted file mode 100644 index d494c954e1ef73b585761acf7490a5e35beccac4..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/pad_op.h +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/padding.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -class PadKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto pads = context.Attr>("paddings"); - float pad_value = context.Attr("pad_value"); - auto* x = context.Input("X"); - auto* out = context.Output("Out"); - out->mutable_data(context.GetPlace()); - - int rank = x->dims().size(); - math::PaddingFunctor(rank, context, pads, - static_cast(pad_value), *x, out); - } -}; - -template -class PadGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto pads = context.Attr>("paddings"); - auto* d_out = context.Input(framework::GradVarName("Out")); - auto* d_x = context.Output(framework::GradVarName("X")); - if (d_x == nullptr) { - return; - } - - d_x->mutable_data(context.GetPlace()); - int rank = d_out->dims().size(); - math::PaddingGradFunctor(rank, context, pads, *d_out, - d_x); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/spectral_op.h b/paddle/fluid/operators/spectral_op.h index 2bc5124843c38152d2f5d3ffcef5a5ca24534bfd..a60ec5a4df52b8275a17185a63c8a7d27dd8132b 100644 --- a/paddle/fluid/operators/spectral_op.h +++ b/paddle/fluid/operators/spectral_op.h @@ -23,9 +23,9 @@ #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/operators/conj_op.h" #include "paddle/fluid/operators/eigen/eigen_function.h" -#include "paddle/fluid/operators/math/padding.h" #include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/for_range.h" +#include "paddle/phi/kernels/funcs/padding.h" #if defined(__NVCC__) || defined(__HIPCC__) #include "thrust/device_vector.h" #endif @@ -389,8 +389,9 @@ class FFTR2CGradKernel : public framework::OpKernel { std::vector pads(rank * 2, 0); pads[axes.back() * 2 + 1] = zero_length; - paddle::operators::math::PaddingFunctor( - rank, ctx, pads, static_cast(0), *dy, &full_dy); + phi::funcs::PaddingFunctor( + rank, ctx.template device_context(), pads, + static_cast(0), *dy, &full_dy); fft_c2c_func(dev_ctx, &full_dy, &complex_dx, axes, normalization, !forward); } diff --git a/paddle/phi/kernels/cpu/pad_grad_kernel.cc b/paddle/phi/kernels/cpu/pad_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..67e6da7d0e06a572d6d1b5f5353f3fdecc122eaa --- /dev/null +++ b/paddle/phi/kernels/cpu/pad_grad_kernel.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/pad_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/pad_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(pad_grad, + CPU, + ALL_LAYOUT, + phi::PadGradKernel, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/pad_kernel.cc b/paddle/phi/kernels/cpu/pad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..f4a0acdcca267050b04f8063147c140e5631e27b --- /dev/null +++ b/paddle/phi/kernels/cpu/pad_kernel.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/pad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/pad_kernel_impl.h" + +PD_REGISTER_KERNEL(pad, + CPU, + ALL_LAYOUT, + phi::PadKernel, + float, + double, + int, + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/fluid/operators/math/padding.h b/paddle/phi/kernels/funcs/padding.h similarity index 67% rename from paddle/fluid/operators/math/padding.h rename to paddle/phi/kernels/funcs/padding.h index 529d39c9ba50f016434b0b14c4d85c84483bad7f..6d10ff2dfcf39c6b57084e99eb31fc1d888f5f75 100644 --- a/paddle/fluid/operators/math/padding.h +++ b/paddle/phi/kernels/funcs/padding.h @@ -15,21 +15,26 @@ limitations under the License. */ #pragma once #include #include -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/operators/eigen/eigen_function.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/enforce.h" -namespace paddle { -namespace operators { -namespace math { +namespace phi { +namespace funcs { -template -using EigenTensor = framework::EigenTensor; +using EigenTensor = EigenTensor; template -void PadFunction(const framework::ExecutionContext& context, - const std::vector& pads, const framework::Tensor& src, - T pad_value, framework::Tensor* out) { +void PadFunction(const DeviceContext& context, + const std::vector& pads, + const DenseTensor& src, + T pad_value, + DenseTensor* out) { std::array, D> paddings; for (size_t i = 0; i < paddings.size(); ++i) { @@ -40,16 +45,16 @@ void PadFunction(const framework::ExecutionContext& context, auto src_tensor = EigenTensor::From(src); auto out_tensor = EigenTensor::From(*out); - auto& place = - *context.template device_context().eigen_device(); + auto& place = *(context.eigen_device()); EigenPad, T, D>::Eval( place, out_tensor, src_tensor, paddings, pad_value); } template -void PadGradFunction(const framework::ExecutionContext& context, - const std::vector& pads, const framework::Tensor& src, - framework::Tensor* d_out) { +void PadGradFunction(const DeviceContext& context, + const std::vector& pads, + const DenseTensor& src, + DenseTensor* d_out) { std::array, D> paddings; for (size_t i = 0; i < paddings.size(); ++i) { paddings[i].first = -pads[i * 2]; @@ -58,16 +63,18 @@ void PadGradFunction(const framework::ExecutionContext& context, auto d_out_tensor = EigenTensor::From(*d_out); auto src_tensor = EigenTensor::From(src); - auto& place = - *context.template device_context().eigen_device(); + auto& place = *(context.eigen_device()); EigenPad, T, D>::Eval( place, d_out_tensor, src_tensor, paddings, static_cast(0)); } template -void PaddingFunctor(int rank, const framework::ExecutionContext& context, - const std::vector& pads, T pad_value, - const framework::Tensor& src, framework::Tensor* out) { +void PaddingFunctor(int rank, + const DeviceContext& context, + const std::vector& pads, + T pad_value, + const DenseTensor& src, + DenseTensor* out) { switch (rank) { case 1: PadFunction(context, pads, src, pad_value, out); @@ -88,16 +95,18 @@ void PaddingFunctor(int rank, const framework::ExecutionContext& context, PadFunction(context, pads, src, pad_value, out); break; default: - PADDLE_THROW(platform::errors::Unimplemented( - "PadOp only support tensors with no more" - " than 6 dimensions currently.")); + PADDLE_THROW( + phi::errors::Unimplemented("PadOp only support tensors with no more" + " than 6 dimensions currently.")); } } template -void PaddingGradFunctor(int rank, const framework::ExecutionContext& context, +void PaddingGradFunctor(int rank, + const DeviceContext& context, const std::vector& pads, - const framework::Tensor& src, framework::Tensor* out) { + const DenseTensor& src, + DenseTensor* out) { switch (rank) { case 1: PadGradFunction(context, pads, src, out); @@ -118,9 +127,9 @@ void PaddingGradFunctor(int rank, const framework::ExecutionContext& context, PadGradFunction(context, pads, src, out); break; default: - PADDLE_THROW(platform::errors::Unimplemented( - "PadOp only support tensors with no more" - " than 6 dimensions currently.")); + PADDLE_THROW( + phi::errors::Unimplemented("PadOp only support tensors with no more" + " than 6 dimensions currently.")); } } @@ -137,6 +146,5 @@ inline bool IsSymmetricPadding(const std::vector& pads, } return is_sys_pad; } -} // namespace math -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/gpu/pad_grad_kernel.cu b/paddle/phi/kernels/gpu/pad_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..a25472d122b837fcc3928af15e0678f0362abf0c --- /dev/null +++ b/paddle/phi/kernels/gpu/pad_grad_kernel.cu @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/pad_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/pad_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(pad_grad, + GPU, + ALL_LAYOUT, + phi::PadGradKernel, + float, + double, + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/pad_kernel.cu b/paddle/phi/kernels/gpu/pad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..2b77a5f1aeb6cb3f24f274beb6939c480022fe49 --- /dev/null +++ b/paddle/phi/kernels/gpu/pad_kernel.cu @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/platform/complex.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/pad_kernel_impl.h" +#include "paddle/phi/kernels/pad_kernel.h" + +PD_REGISTER_KERNEL(pad, + GPU, + ALL_LAYOUT, + phi::PadKernel, + float, + double, + int, + int64_t, + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/impl/pad_grad_kernel_impl.h b/paddle/phi/kernels/impl/pad_grad_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..91f198f9fb681e4fabf7029fcc22343bb81953fd --- /dev/null +++ b/paddle/phi/kernels/impl/pad_grad_kernel_impl.h @@ -0,0 +1,33 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/padding.h" +namespace phi { +template +void PadGradKernel(const Context& dev_ctx, + const DenseTensor& d_out, + const std::vector& paddings, + float pad_value, + DenseTensor* d_x) { + if (d_x == nullptr) { + return; + } + dev_ctx.template Alloc(d_x); + int rank = d_out.dims().size(); + phi::funcs::PaddingGradFunctor( + rank, dev_ctx, paddings, d_out, d_x); +} +} // namespace phi diff --git a/paddle/phi/kernels/impl/pad_kernel_impl.h b/paddle/phi/kernels/impl/pad_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..8e3ebb0dfe03b2f13e2a321bb813f7d10e306b7a --- /dev/null +++ b/paddle/phi/kernels/impl/pad_kernel_impl.h @@ -0,0 +1,32 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/padding.h" +namespace phi { +template +void PadKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& paddings, + float pad_value, + DenseTensor* out) { + dev_ctx.template Alloc(out); + int rank = x.dims().size(); + funcs::PaddingFunctor( + rank, dev_ctx, paddings, static_cast(pad_value), x, out); +} +} // namespace phi diff --git a/paddle/phi/kernels/pad_grad_kernel.h b/paddle/phi/kernels/pad_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..f39d87e5c0ef6503d772e5f9ee95e307a13eda13 --- /dev/null +++ b/paddle/phi/kernels/pad_grad_kernel.h @@ -0,0 +1,28 @@ + +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void PadGradKernel(const Context& dev_ctx, + const DenseTensor& d_out, + const std::vector& paddings, + float pad_value, + DenseTensor* d_x); +} // namespace phi diff --git a/paddle/phi/kernels/pad_kernel.h b/paddle/phi/kernels/pad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..511e8cf73df97ffb250b1106aa98155de33a97d1 --- /dev/null +++ b/paddle/phi/kernels/pad_kernel.h @@ -0,0 +1,28 @@ + +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void PadKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& paddings, + float pad_value, + DenseTensor* out); +} // namespace phi diff --git a/paddle/phi/ops/compat/pad_sig.cc b/paddle/phi/ops/compat/pad_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..4eadbfa98beded121c4e6738384487a9ec10be42 --- /dev/null +++ b/paddle/phi/ops/compat/pad_sig.cc @@ -0,0 +1,28 @@ + +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature PadGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("pad_grad", + {GradVarName("Out")}, + {"paddings", "pad_value"}, + {GradVarName("X")}); +} + +} // namespace phi +PD_REGISTER_ARG_MAPPING_FN(pad_grad, phi::PadGradOpArgumentMapping);