From 3e170163a10a1742c7b8a076aeac9c347f4fa146 Mon Sep 17 00:00:00 2001 From: lyq <30404405+affectionlu@users.noreply.github.com> Date: Mon, 25 Jul 2022 10:39:11 +0800 Subject: [PATCH] [Phi] Migrate squared_l2_norm_op to phi (#44492) --- .../new_executor/standalone_executor_test.cc | 3 +- paddle/fluid/operators/inplace_abn_op.cu | 1 - paddle/fluid/operators/optimizers/lamb_op.h | 17 ++--- paddle/fluid/operators/squared_l2_norm_op.cc | 51 +++++-------- paddle/fluid/operators/squared_l2_norm_op.cu | 24 ------- paddle/fluid/operators/squared_l2_norm_op.h | 71 ------------------- .../fluid/operators/squared_l2_norm_op_mlu.cc | 1 - .../fluid/operators/squared_l2_norm_op_npu.cc | 2 +- paddle/phi/api/yaml/legacy_api.yaml | 9 +++ paddle/phi/api/yaml/legacy_backward.yaml | 10 +++ paddle/phi/infermeta/unary.cc | 4 ++ paddle/phi/infermeta/unary.h | 2 + .../cpu/squared_l2_norm_grad_kernel.cc | 26 +++++++ .../phi/kernels/cpu/squared_l2_norm_kernel.cc | 23 ++++++ .../kernels/funcs}/squared_l2_norm.h | 33 ++++----- .../gpu/squared_l2_norm_grad_kernel.cu | 26 +++++++ .../phi/kernels/gpu/squared_l2_norm_kernel.cu | 23 ++++++ .../impl/squared_l2_norm_grad_kernel_impl.h | 41 +++++++++++ .../impl/squared_l2_norm_kernel_impl.h | 32 +++++++++ .../phi/kernels/squared_l2_norm_grad_kernel.h | 26 +++++++ paddle/phi/kernels/squared_l2_norm_kernel.h | 25 +++++++ paddle/phi/ops/compat/squared_l2_norm_sig.cc | 35 +++++++++ python/paddle/fluid/clip.py | 4 +- .../unittests/test_squared_l2_norm_op.py | 11 ++- 24 files changed, 340 insertions(+), 160 deletions(-) delete mode 100644 paddle/fluid/operators/squared_l2_norm_op.cu delete mode 100644 paddle/fluid/operators/squared_l2_norm_op.h create mode 100644 paddle/phi/kernels/cpu/squared_l2_norm_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/squared_l2_norm_kernel.cc rename paddle/{fluid/operators/math => phi/kernels/funcs}/squared_l2_norm.h (75%) create mode 100644 paddle/phi/kernels/gpu/squared_l2_norm_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/squared_l2_norm_kernel.cu create mode 100644 paddle/phi/kernels/impl/squared_l2_norm_grad_kernel_impl.h create mode 100644 paddle/phi/kernels/impl/squared_l2_norm_kernel_impl.h create mode 100644 paddle/phi/kernels/squared_l2_norm_grad_kernel.h create mode 100644 paddle/phi/kernels/squared_l2_norm_kernel.h create mode 100644 paddle/phi/ops/compat/squared_l2_norm_sig.cc diff --git a/paddle/fluid/framework/new_executor/standalone_executor_test.cc b/paddle/fluid/framework/new_executor/standalone_executor_test.cc index 1816b0942f..701c1edcaf 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor_test.cc +++ b/paddle/fluid/framework/new_executor/standalone_executor_test.cc @@ -57,7 +57,7 @@ USE_OP_ITSELF(sqrt); USE_OP_ITSELF(elementwise_max); USE_OP_ITSELF(elementwise_div); USE_OP_ITSELF(sgd); -USE_OP(squared_l2_norm); +USE_OP_ITSELF(squared_l2_norm); USE_OP_ITSELF(memcpy_h2d); USE_OP_ITSELF(memcpy_d2h); USE_OP_ITSELF(fetch_v2); @@ -87,6 +87,7 @@ PD_DECLARE_KERNEL(mean, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(mean_grad, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(sigmoid, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(sigmoid_grad, GPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(squared_l2_norm, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(reshape_grad, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(add_grad, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(matmul_grad, GPU, ALL_LAYOUT); diff --git a/paddle/fluid/operators/inplace_abn_op.cu b/paddle/fluid/operators/inplace_abn_op.cu index a74150a330..a63cd8b007 100644 --- a/paddle/fluid/operators/inplace_abn_op.cu +++ b/paddle/fluid/operators/inplace_abn_op.cu @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/inplace_abn_op.h" -#include #include "paddle/fluid/operators/batch_norm_op.h" #include "paddle/phi/kernels/batch_norm_grad_kernel.h" #include "paddle/phi/kernels/batch_norm_kernel.h" diff --git a/paddle/fluid/operators/optimizers/lamb_op.h b/paddle/fluid/operators/optimizers/lamb_op.h index 3beb78b656..0415bb7df0 100644 --- a/paddle/fluid/operators/optimizers/lamb_op.h +++ b/paddle/fluid/operators/optimizers/lamb_op.h @@ -22,11 +22,11 @@ limitations under the License. */ #include "paddle/fluid/memory/buffer.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" -#include "paddle/fluid/operators/math/squared_l2_norm.h" #include "paddle/fluid/operators/tensor_to_string.h" #include "paddle/fluid/platform/for_range.h" #include "paddle/phi/kernels/funcs/algorithm.h" #include "paddle/phi/kernels/funcs/eigen/extensions.h" +#include "paddle/phi/kernels/funcs/squared_l2_norm.h" namespace paddle { namespace operators { @@ -756,13 +756,14 @@ class LambOpKernel : public framework::OpKernel { // TODO(zengjinle): remove the following Eigen operations when // *skip_update == true. memory::Buffer buffer(dev_ctx.GetPlace()); - math::SquaredL2Norm(dev_ctx, - reinterpret_cast( - IsMultiPrecision ? master_param_ptr : param_ptr), - p_norm_ptr, - numel, - &buffer); - math::SquaredL2Norm( + phi::funcs::SquaredL2Norm( + dev_ctx, + reinterpret_cast(IsMultiPrecision ? master_param_ptr + : param_ptr), + p_norm_ptr, + numel, + &buffer); + phi::funcs::SquaredL2Norm( dev_ctx, trust_ratio_div_ptr, trust_ratio_div_norm_ptr, numel, &buffer); if (VLOG_IS_ON(1)) { diff --git a/paddle/fluid/operators/squared_l2_norm_op.cc b/paddle/fluid/operators/squared_l2_norm_op.cc index f6792baa1f..4653cc0cc2 100644 --- a/paddle/fluid/operators/squared_l2_norm_op.cc +++ b/paddle/fluid/operators/squared_l2_norm_op.cc @@ -12,9 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/squared_l2_norm_op.h" - -#include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -24,13 +25,6 @@ using framework::Tensor; class SquaredL2NormOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SquaredL2NormOp"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SquaredL2NormOp"); - - ctx->SetOutputDim("Out", {1}); - } }; template @@ -54,20 +48,6 @@ class SquaredL2NormGradOpMaker : public framework::SingleGradOpMaker { class SquaredL2NormGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SquaredL2NormGradOp"); - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), - "Input", - "Out@GRAD", - "SquaredL2NormGradOp"); - OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), - "Output", - "X@GRAD", - "SquaredL2NormGradOp"); - - ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); - } }; class SquaredL2NormOpMaker : public framework::OpProtoAndCheckerMaker { @@ -90,15 +70,22 @@ $$Out = \sum_{i} X_{i}^2$$ } // namespace paddle namespace ops = paddle::operators; + +DECLARE_INFER_SHAPE_FUNCTOR(squared_l2_norm, + SquaredL2NormInferShapeFunctor, + PD_INFER_META(phi::SquaredL2NormInferMeta)); + +DECLARE_INFER_SHAPE_FUNCTOR(squared_l2_norm_grad, + SquaredL2NormGradInferShapeFunctor, + PD_INFER_META(phi::UnchangedInferMeta)); + REGISTER_OPERATOR(squared_l2_norm, ops::SquaredL2NormOp, ops::SquaredL2NormOpMaker, ops::SquaredL2NormGradOpMaker, - ops::SquaredL2NormGradOpMaker); -REGISTER_OPERATOR(squared_l2_norm_grad, ops::SquaredL2NormGradOp); -REGISTER_OP_CPU_KERNEL(squared_l2_norm, - ops::SquaredL2NormKernel, - ops::SquaredL2NormKernel); -REGISTER_OP_CPU_KERNEL(squared_l2_norm_grad, - ops::SquaredL2NormGradKernel, - ops::SquaredL2NormGradKernel); + ops::SquaredL2NormGradOpMaker, + SquaredL2NormInferShapeFunctor); + +REGISTER_OPERATOR(squared_l2_norm_grad, + ops::SquaredL2NormGradOp, + SquaredL2NormGradInferShapeFunctor); diff --git a/paddle/fluid/operators/squared_l2_norm_op.cu b/paddle/fluid/operators/squared_l2_norm_op.cu deleted file mode 100644 index b51e56af8e..0000000000 --- a/paddle/fluid/operators/squared_l2_norm_op.cu +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -#include "paddle/fluid/operators/squared_l2_norm_op.h" - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - squared_l2_norm, - ops::SquaredL2NormKernel, - ops::SquaredL2NormKernel); -REGISTER_OP_CUDA_KERNEL( - squared_l2_norm_grad, - ops::SquaredL2NormGradKernel, - ops::SquaredL2NormGradKernel); diff --git a/paddle/fluid/operators/squared_l2_norm_op.h b/paddle/fluid/operators/squared_l2_norm_op.h deleted file mode 100644 index 147f0cc530..0000000000 --- a/paddle/fluid/operators/squared_l2_norm_op.h +++ /dev/null @@ -1,71 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/squared_l2_norm.h" - -namespace paddle { -namespace operators { - -// Out = sum(square(X)) -template -class SquaredL2NormKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &context) const override { - const framework::Tensor *x = context.Input("X"); - const auto *x_ptr = x->data(); - auto numel = x->numel(); - - framework::Tensor *out = context.Output("Out"); - auto *out_ptr = out->mutable_data(context.GetPlace()); - - math::SquaredL2Norm(context.template device_context(), - x_ptr, - out_ptr, - numel); - } -}; - -// dX = X -template -class SquaredL2NormGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &context) const override { - const framework::Tensor *X = context.Input("X"); - const framework::Tensor *dOut = - context.Input(framework::GradVarName("Out")); - PADDLE_ENFORCE_EQ( - dOut->numel(), - 1, - platform::errors::InvalidArgument( - "Input(GRAD@Out) of SquaredL2NormGradOP should be a scalar.")); - framework::Tensor *dX = - context.Output(framework::GradVarName("X")); - dX->mutable_data(context.GetPlace()); - - auto x = framework::EigenVector::Flatten(*X); - auto dout = framework::EigenVector::Flatten(*dOut); - auto dx = framework::EigenVector::Flatten(*dX); - auto *place = - context.template device_context().eigen_device(); - - Eigen::DSizes x_dsize(X->numel()); - dx.device(*place) = (dout.broadcast(x_dsize) * x) * static_cast(2.0); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/squared_l2_norm_op_mlu.cc b/paddle/fluid/operators/squared_l2_norm_op_mlu.cc index 34f699d68a..741d23540b 100644 --- a/paddle/fluid/operators/squared_l2_norm_op_mlu.cc +++ b/paddle/fluid/operators/squared_l2_norm_op_mlu.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/squared_l2_norm_op.h" // #include "paddle/fluid/platform/device/npu/npu_op_runner.h" #include "paddle/fluid/operators/mlu/mlu_baseop.h" diff --git a/paddle/fluid/operators/squared_l2_norm_op_npu.cc b/paddle/fluid/operators/squared_l2_norm_op_npu.cc index 3104d0cd2a..56fae36570 100644 --- a/paddle/fluid/operators/squared_l2_norm_op_npu.cc +++ b/paddle/fluid/operators/squared_l2_norm_op_npu.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/squared_l2_norm_op.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" namespace paddle { diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index ad93a7c607..464ea1f00b 100644 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -2062,6 +2062,15 @@ func : square backward : square_grad +- api : squared_l2_norm + args : (Tensor x) + output : Tensor + infer_meta : + func : SquaredL2NormInferMeta + kernel : + func : squared_l2_norm + backward : squared_l2_norm_grad + - api : squeeze args : (Tensor x, int[] axes) output : Tensor(out), Tensor(xshape) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 61eeec6c84..423958399a 100644 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -2009,6 +2009,16 @@ backward : square_double_grad inplace : (out_grad -> x_grad) +- backward_api : squared_l2_norm_grad + forward : squared_l2_norm(Tensor x) -> Tensor(out) + args : (Tensor x, Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param: [x] + kernel : + func : squared_l2_norm_grad + - backward_api : squeeze_double_grad forward : squeeze_grad(Tensor xshape, Tensor grad_out, int[] axes) -> Tensor(grad_x) args : (Tensor grad_x_grad, int[] axes) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 9b7dd1f45f..307cd0ad29 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -2489,6 +2489,10 @@ void SplitInferMeta(const MetaTensor& x, } } +void SquaredL2NormInferMeta(const MetaTensor& x, MetaTensor* out) { + out->set_dims({1}); +} + void SqueezeInferMeta(const MetaTensor& x, const std::vector& axes, MetaTensor* out) { diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 805fa3a56d..03f7b09fc8 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -345,6 +345,8 @@ void SplitInferMeta(const MetaTensor& x_meta, std::vector out, MetaConfig config = MetaConfig()); +void SquaredL2NormInferMeta(const MetaTensor& x, MetaTensor* out); + void SqueezeInferMeta(const MetaTensor& x, const std::vector& axes, MetaTensor* out); diff --git a/paddle/phi/kernels/cpu/squared_l2_norm_grad_kernel.cc b/paddle/phi/kernels/cpu/squared_l2_norm_grad_kernel.cc new file mode 100644 index 0000000000..2e2725fa9d --- /dev/null +++ b/paddle/phi/kernels/cpu/squared_l2_norm_grad_kernel.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/squared_l2_norm_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/squared_l2_norm_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(squared_l2_norm_grad, + CPU, + ALL_LAYOUT, + phi::SquaredL2NormGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/cpu/squared_l2_norm_kernel.cc b/paddle/phi/kernels/cpu/squared_l2_norm_kernel.cc new file mode 100644 index 0000000000..a1e00851cf --- /dev/null +++ b/paddle/phi/kernels/cpu/squared_l2_norm_kernel.cc @@ -0,0 +1,23 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/squared_l2_norm_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/squared_l2_norm_kernel_impl.h" + +PD_REGISTER_KERNEL( + squared_l2_norm, CPU, ALL_LAYOUT, phi::SquaredL2NormKernel, float, double) { +} diff --git a/paddle/fluid/operators/math/squared_l2_norm.h b/paddle/phi/kernels/funcs/squared_l2_norm.h similarity index 75% rename from paddle/fluid/operators/math/squared_l2_norm.h rename to paddle/phi/kernels/funcs/squared_l2_norm.h index 3054d5f8f0..21deecd5a7 100644 --- a/paddle/fluid/operators/math/squared_l2_norm.h +++ b/paddle/phi/kernels/funcs/squared_l2_norm.h @@ -14,13 +14,12 @@ #pragma once -#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/memory/buffer.h" -#include "paddle/fluid/platform/device_context.h" -#include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/core/device_context.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" #if defined(__NVCC__) || defined(__HIPCC__) -#include "paddle/fluid/operators/kernel_primitives/functor_primitives.h" +#include "paddle/phi/kernels/primitive/functor_primitives.h" #ifdef __NVCC__ #include "cub/cub.cuh" #else @@ -29,20 +28,19 @@ namespace cub = hipcub; #endif #endif -namespace paddle { -namespace operators { -namespace math { +namespace phi { +namespace funcs { template void SquaredL2Norm(const phi::CPUContext& ctx, const T1* x, T2* y, size_t numel, - memory::Buffer* buffer = nullptr) { + paddle::memory::Buffer* buffer = nullptr) { if (std::is_same::value) { - using EigenT = typename framework::EigenTensor::Type; - using ConstEigenT = typename framework::EigenTensor::ConstType; - using EigenDim = typename framework::EigenDim<1>::Type; + using EigenT = typename phi::EigenTensor::Type; + using ConstEigenT = typename phi::EigenTensor::ConstType; + using EigenDim = typename phi::EigenDim<1>::Type; ConstEigenT input(x, EigenDim(numel)); EigenT output(reinterpret_cast(y), EigenDim(1)); output.device(*ctx.eigen_device()) = input.square().sum(); @@ -58,17 +56,17 @@ void SquaredL2Norm(const phi::CPUContext& ctx, #if defined(__NVCC__) || defined(__HIPCC__) template -void SquaredL2Norm(const platform::CUDADeviceContext& ctx, +void SquaredL2Norm(const phi::GPUContext& ctx, const T1* x, T2* y, size_t numel, - memory::Buffer* buffer = nullptr) { + paddle::memory::Buffer* buffer = nullptr) { if (UNLIKELY(buffer == nullptr)) { - memory::Buffer tmp_buffer(ctx.GetPlace()); + paddle::memory::Buffer tmp_buffer(ctx.GetPlace()); return SquaredL2Norm(ctx, x, y, numel, &tmp_buffer); } - using FunctorT = kernel_primitives::SquareFunctor; + using FunctorT = phi::kps::SquareFunctor; cub::TransformInputIterator iter(x, FunctorT()); size_t temp_storage_bytes = 0; void* d_temp_storage = nullptr; @@ -89,6 +87,5 @@ void SquaredL2Norm(const platform::CUDADeviceContext& ctx, } #endif -} // namespace math -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/gpu/squared_l2_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/squared_l2_norm_grad_kernel.cu new file mode 100644 index 0000000000..908a7557d1 --- /dev/null +++ b/paddle/phi/kernels/gpu/squared_l2_norm_grad_kernel.cu @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/squared_l2_norm_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/squared_l2_norm_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(squared_l2_norm_grad, + GPU, + ALL_LAYOUT, + phi::SquaredL2NormGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/gpu/squared_l2_norm_kernel.cu b/paddle/phi/kernels/gpu/squared_l2_norm_kernel.cu new file mode 100644 index 0000000000..d585d209b4 --- /dev/null +++ b/paddle/phi/kernels/gpu/squared_l2_norm_kernel.cu @@ -0,0 +1,23 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/squared_l2_norm_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/squared_l2_norm_kernel_impl.h" + +PD_REGISTER_KERNEL( + squared_l2_norm, GPU, ALL_LAYOUT, phi::SquaredL2NormKernel, float, double) { +} diff --git a/paddle/phi/kernels/impl/squared_l2_norm_grad_kernel_impl.h b/paddle/phi/kernels/impl/squared_l2_norm_grad_kernel_impl.h new file mode 100644 index 0000000000..3a367fc000 --- /dev/null +++ b/paddle/phi/kernels/impl/squared_l2_norm_grad_kernel_impl.h @@ -0,0 +1,41 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" + +namespace phi { +template +void SquaredL2NormGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& dout, + DenseTensor* dx) { + dev_ctx.template Alloc(dx); + + PADDLE_ENFORCE_EQ( + dout.numel(), + 1, + phi::errors::InvalidArgument( + "Input(GRAD@Out) of SquaredL2NormGradOP should be a scalar.")); + + auto input = phi::EigenVector::Flatten(x); + auto d_out = phi::EigenVector::Flatten(dout); + auto d_x = phi::EigenVector::Flatten(*dx); + auto* place = dev_ctx.eigen_device(); + Eigen::DSizes x_dsize(x.numel()); + d_x.device(*place) = (d_out.broadcast(x_dsize) * input) * static_cast(2.0); +} +} // namespace phi diff --git a/paddle/phi/kernels/impl/squared_l2_norm_kernel_impl.h b/paddle/phi/kernels/impl/squared_l2_norm_kernel_impl.h new file mode 100644 index 0000000000..30805ef9ac --- /dev/null +++ b/paddle/phi/kernels/impl/squared_l2_norm_kernel_impl.h @@ -0,0 +1,32 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/squared_l2_norm.h" + +namespace phi { + +template +void SquaredL2NormKernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) { + dev_ctx.template Alloc(out); + auto x_ptr = x.template data(); + auto numel = x.numel(); + return phi::funcs::SquaredL2Norm(dev_ctx, x_ptr, out->data(), numel); +} + +} // namespace phi diff --git a/paddle/phi/kernels/squared_l2_norm_grad_kernel.h b/paddle/phi/kernels/squared_l2_norm_grad_kernel.h new file mode 100644 index 0000000000..78928ae3bf --- /dev/null +++ b/paddle/phi/kernels/squared_l2_norm_grad_kernel.h @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void SquaredL2NormGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& dout, + DenseTensor* dx); +} // namespace phi diff --git a/paddle/phi/kernels/squared_l2_norm_kernel.h b/paddle/phi/kernels/squared_l2_norm_kernel.h new file mode 100644 index 0000000000..e1dbbaa3cf --- /dev/null +++ b/paddle/phi/kernels/squared_l2_norm_kernel.h @@ -0,0 +1,25 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void SquaredL2NormKernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out); +} // namespace phi diff --git a/paddle/phi/ops/compat/squared_l2_norm_sig.cc b/paddle/phi/ops/compat/squared_l2_norm_sig.cc new file mode 100644 index 0000000000..7b228008f2 --- /dev/null +++ b/paddle/phi/ops/compat/squared_l2_norm_sig.cc @@ -0,0 +1,35 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature SquaredL2NormOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("squared_l2_norm", {"X"}, {}, {"Out"}); +} + +KernelSignature SquaredL2NormGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "squared_l2_norm_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(squared_l2_norm, + phi::SquaredL2NormOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(squared_l2_norm_grad, + phi::SquaredL2NormGradOpArgumentMapping); diff --git a/python/paddle/fluid/clip.py b/python/paddle/fluid/clip.py index df48de8ea2..dbe31bdf6f 100644 --- a/python/paddle/fluid/clip.py +++ b/python/paddle/fluid/clip.py @@ -73,8 +73,8 @@ def _squared_l2_norm(x): if in_dygraph_mode(): if x.is_selected_rows(): new_x = paddle.to_tensor(x.numpy()) - return _C_ops.squared_l2_norm(new_x) - return _C_ops.squared_l2_norm(x) + return _C_ops.final_state_squared_l2_norm(new_x) + return _C_ops.final_state_squared_l2_norm(x) else: if _in_legacy_dygraph(): return _C_ops.squared_l2_norm(x) diff --git a/python/paddle/fluid/tests/unittests/test_squared_l2_norm_op.py b/python/paddle/fluid/tests/unittests/test_squared_l2_norm_op.py index ee8f724563..452b0ac542 100644 --- a/python/paddle/fluid/tests/unittests/test_squared_l2_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_squared_l2_norm_op.py @@ -20,6 +20,14 @@ from numpy import linalg as LA from op_test import OpTest import paddle from paddle import _C_ops +from paddle.framework import in_dygraph_mode + + +def test_squared_l2_norm(x): + if in_dygraph_mode(): + return _C_ops.final_state_squared_l2_norm(x) + else: + return _C_ops.squared_l2_norm(x) class TestL2LossOp(OpTest): @@ -27,6 +35,7 @@ class TestL2LossOp(OpTest): """ def setUp(self): + self.python_api = test_squared_l2_norm self.op_type = "squared_l2_norm" self.max_relative_error = 0.05 @@ -36,7 +45,7 @@ class TestL2LossOp(OpTest): self.outputs = {'Out': np.square(LA.norm(X))} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): self.check_grad(['X'], -- GitLab