From 92afe14658db2008aab640db09e4f733427ff267 Mon Sep 17 00:00:00 2001 From: zhiboniu <31800336+zhiboniu@users.noreply.github.com> Date: Thu, 24 Mar 2022 20:39:37 +0800 Subject: [PATCH] p_norm transfer to phi kernels (#40819) --- paddle/fluid/operators/p_norm_op.cc | 79 ++----- paddle/fluid/operators/p_norm_op.cu | 222 ------------------- paddle/fluid/operators/p_norm_op.h | 138 ------------ paddle/fluid/operators/p_norm_op_npu.cc | 2 +- paddle/phi/infermeta/unary.cc | 57 +++++ paddle/phi/infermeta/unary.h | 8 + paddle/phi/kernels/cpu/p_norm_grad_kernel.cc | 101 +++++++++ paddle/phi/kernels/cpu/p_norm_kernel.cc | 90 ++++++++ paddle/phi/kernels/gpu/p_norm_grad_kernel.cu | 112 ++++++++++ paddle/phi/kernels/gpu/p_norm_kernel.cu | 138 ++++++++++++ paddle/phi/kernels/p_norm_grad_kernel.h | 32 +++ paddle/phi/kernels/p_norm_kernel.h | 31 +++ paddle/phi/ops/compat/p_norm_sig.cc | 26 +++ 13 files changed, 610 insertions(+), 426 deletions(-) delete mode 100644 paddle/fluid/operators/p_norm_op.cu delete mode 100644 paddle/fluid/operators/p_norm_op.h create mode 100644 paddle/phi/kernels/cpu/p_norm_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/p_norm_kernel.cc create mode 100644 paddle/phi/kernels/gpu/p_norm_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/p_norm_kernel.cu create mode 100644 paddle/phi/kernels/p_norm_grad_kernel.h create mode 100644 paddle/phi/kernels/p_norm_kernel.h create mode 100644 paddle/phi/ops/compat/p_norm_sig.cc diff --git a/paddle/fluid/operators/p_norm_op.cc b/paddle/fluid/operators/p_norm_op.cc index f287755040..c7c8ebf562 100644 --- a/paddle/fluid/operators/p_norm_op.cc +++ b/paddle/fluid/operators/p_norm_op.cc @@ -11,12 +11,15 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#include "paddle/fluid/operators/p_norm_op.h" #include #include #include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -81,68 +84,11 @@ where, $\sum_i $ is calculated along the `axis` dimension. class PnormOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "p_norm"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "p_norm"); - auto x_dim = ctx->GetInputDim("X"); - auto x_rank = x_dim.size(); - int axis = ctx->Attrs().Get("axis"); - bool keepdim = ctx->Attrs().Get("keepdim"); - - PADDLE_ENFORCE_GE(axis, -x_rank, - platform::errors::InvalidArgument( - "Attr(axis) value should be in range [-R, R-1], R is " - "the rank of Input(X). But received axis: %d, R: %d. " - "Current Input(X)'s shape is=[%s].", - axis, x_rank, x_dim)); - PADDLE_ENFORCE_LT(axis, x_rank, - platform::errors::InvalidArgument( - "Attr(axis) value should be in range [-R, R-1], R is " - "the rank of Input(X). But received axis: %d, R: %d. " - "Current Input(X)'s shape is=[%s].", - axis, x_rank, x_dim)); - - std::vector reduce_dims; - bool asvector = ctx->Attrs().Get("asvector"); - if (asvector) { - reduce_dims.emplace_back(1); - if (keepdim) { - for (int i = 1; i < x_dim.size(); ++i) { - reduce_dims.emplace_back(1); - } - x_dim = phi::make_ddim(reduce_dims); - } - } else { - if (axis < 0) axis = x_dim.size() + axis; - for (int i = 0; i < x_dim.size(); ++i) { - if (i != axis) reduce_dims.emplace_back(x_dim[i]); - } - if (reduce_dims.size() == 0) { - reduce_dims.emplace_back(1); - } - } - x_dim[axis] = 1; - - if (keepdim) { - ctx->SetOutputDim("Out", x_dim); - } else { - ctx->SetOutputDim("Out", phi::make_ddim(reduce_dims)); - } - } }; class PnormOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "p_norm"); - OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "p_norm"); - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", - "Out@GRAD", "p_norm"); - OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", - "X@GRAD", "p_norm"); - ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); - } }; template @@ -167,14 +113,17 @@ class PnormOpGradOpMaker : public framework::SingleGradOpMaker { namespace ops = paddle::operators; using CPU = paddle::platform::CPUDeviceContext; +DECLARE_INFER_SHAPE_FUNCTOR(p_norm, PNormInferShapeFunctor, + PD_INFER_META(phi::PNormInferMeta)); +DECLARE_INFER_SHAPE_FUNCTOR(p_norm_grad, PNormGradInferShapeFunctor, + PD_INFER_META(phi::GeneralUnaryGradInferMeta)); + REGISTER_OPERATOR(p_norm, ops::PnormOp, ops::PnormOpMaker, ops::PnormOpGradOpMaker, - ops::PnormOpGradOpMaker); -REGISTER_OPERATOR(p_norm_grad, ops::PnormOpGrad); -REGISTER_OP_CPU_KERNEL(p_norm, ops::PnormKernel, - ops::PnormKernel); -REGISTER_OP_CPU_KERNEL(p_norm_grad, ops::PnormGradKernel, - ops::PnormGradKernel); + ops::PnormOpGradOpMaker, + PNormInferShapeFunctor); +REGISTER_OPERATOR(p_norm_grad, ops::PnormOpGrad, PNormGradInferShapeFunctor); + REGISTER_OP_VERSION(p_norm) .AddCheckpoint( R"ROC( diff --git a/paddle/fluid/operators/p_norm_op.cu b/paddle/fluid/operators/p_norm_op.cu deleted file mode 100644 index d0b78b9b06..0000000000 --- a/paddle/fluid/operators/p_norm_op.cu +++ /dev/null @@ -1,222 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -Indicesou may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#ifdef __NVCC__ -#include "cub/cub.cuh" -#endif -#ifdef __HIPCC__ -#include -namespace cub = hipcub; -#endif -#include "paddle/fluid/operators/amp/fp16_type_traits.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" -#include "paddle/fluid/operators/fc_op.h" -#include "paddle/fluid/operators/p_norm_op.h" -#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" -#include "paddle/fluid/operators/reduce_ops/reduce_op.h" -#include "paddle/fluid/platform/float16.h" - -namespace paddle { -namespace operators { - -template -__device__ __forceinline__ int sgn(T val) { - return (T(0) < val) - (val < T(0)); -} - -__device__ __forceinline__ platform::float16 inline_abs(platform::float16 x) { - return static_cast(abs(static_cast(x))); -} - -__device__ __forceinline__ platform::bfloat16 inline_abs(platform::bfloat16 x) { - return static_cast(abs(static_cast(x))); -} - -__device__ __forceinline__ float inline_abs(float x) { return abs(x); } -__device__ __forceinline__ double inline_abs(double x) { return abs(x); } - -__device__ __forceinline__ int inline_sign(platform::float16 x) { - return sgn(x); -} -__device__ __forceinline__ int inline_sign(float x) { return sgn(x); } -__device__ __forceinline__ int inline_sign(double x) { return sgn(x); } - -__device__ __forceinline__ platform::float16 inline_pow( - platform::float16 base, platform::float16 exponent) { - return static_cast( - pow(static_cast(base), static_cast(exponent))); -} -__device__ __forceinline__ platform::bfloat16 inline_pow( - platform::bfloat16 base, platform::bfloat16 exponent) { - return static_cast( - pow(static_cast(base), static_cast(exponent))); -} -__device__ __forceinline__ float inline_pow(float base, float exponent) { - return pow(base, exponent); -} -__device__ __forceinline__ double inline_pow(double base, double exponent) { - return pow(base, exponent); -} - -template -struct NonzeroFunctor { - HOSTDEVICE explicit inline NonzeroFunctor() {} - HOSTDEVICE inline T operator()(const T x) const { - return static_cast(static_cast(x) != 0); - } -}; - -template -struct AbsFunctor { - HOSTDEVICE explicit inline AbsFunctor() {} - HOSTDEVICE inline T operator()(const T x) const { - return static_cast(inline_abs(x)); - } -}; - -template -struct UnsignedPowFunctor { - HOSTDEVICE explicit inline UnsignedPowFunctor(float porder) { - this->porder = porder; - } - HOSTDEVICE inline T operator()(const T x) const { - return static_cast(inline_pow(inline_abs(x), static_cast(porder))); - } - float porder; -}; - -template -class PnormCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* in_x = ctx.Input("X"); - auto* out_norm = ctx.Output("Out"); - const T* x = in_x->data(); - T* norm = out_norm->mutable_data(ctx.GetPlace()); - auto xdim = in_x->dims(); - float porder = ctx.Attr("porder"); - bool asvector = ctx.Attr("asvector"); - int axis = ctx.Attr("axis"); - std::vector reduce_axis = {axis}; - reduce_axis = GetReduceDim(reduce_axis, xdim.size(), asvector); - auto stream = ctx.cuda_device_context().stream(); - - using MT = typename details::MPTypeTrait::Type; - if (porder == 0) { - TensorReduceImpl>( - ctx.cuda_device_context(), *in_x, out_norm, NonzeroFunctor(), - reduce_axis, stream); - } else if (porder == INFINITY) { - TensorReduceImpl>( - ctx.cuda_device_context(), *in_x, out_norm, AbsFunctor(), - reduce_axis, stream); - } else if (porder == -INFINITY) { - TensorReduceImpl>( - ctx.cuda_device_context(), *in_x, out_norm, AbsFunctor(), - reduce_axis, stream); - } else { - TensorReduceImpl>( - ctx.cuda_device_context(), *in_x, out_norm, - UnsignedPowFunctor(porder), reduce_axis, stream); - - const framework::Tensor* tmp_norm = out_norm; - std::vector ins = {tmp_norm}; - std::vector outs = {out_norm}; - const auto& cuda_ctx = - ctx.template device_context(); - paddle::operators::LaunchSameDimsElementwiseCudaKernel( - cuda_ctx, ins, &outs, UnsignedPowFunctor(1. / porder)); - } - } -}; - -template -struct AbsMaxAndMinGradFunctor { - template - void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy, - const Dim& dim, int size) { - dx->device(place) = dy->broadcast(dim) * (*x).sign() * - ((*x).abs() == y->broadcast(dim)).template cast(); - } -}; - -template -struct PNormGradFunctor { - HOSTDEVICE explicit inline PNormGradFunctor(float porder) { - this->porder = static_cast(porder - 1.); - } - template - void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy, - const Dim& dim, int size) { - dx->device(place) = (*x).abs().pow(this->porder) * (*x).sign() * - dy->broadcast(dim) * - (*y).pow(-this->porder).broadcast(dim); - } - T porder; -}; - -template -class PnormGradCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* in_x = ctx.Input("X"); - auto* in_norm = ctx.Input("Out"); - auto* in_norm_dy = - ctx.Input(framework::GradVarName("Out")); - auto* out_dx = ctx.Output(framework::GradVarName("X")); - T* dx = out_dx->mutable_data(ctx.GetPlace()); - - auto xdim = in_x->dims(); - float porder = ctx.Attr("porder"); - int axis = ctx.Attr("axis"); - bool reduce_all = (in_norm->numel() == 1); - if (axis < 0) axis = xdim.size() + axis; - const std::vector dims = {axis}; - - auto& cuda_ctx = ctx.template device_context(); - - if (porder == 0) { - phi::funcs::SetConstant set_zero; - set_zero(cuda_ctx, out_dx, static_cast(0)); - } else if (porder == INFINITY || porder == -INFINITY) { - AbsMaxAndMinGradFunctor functor; - LaunchReduceGradKernel>( - ctx, in_x, in_norm, in_norm_dy, out_dx, functor, dims, reduce_all); - } else { - auto functor = PNormGradFunctor(porder); - LaunchReduceGradKernel>( - ctx, in_x, in_norm, in_norm_dy, out_dx, functor, dims, reduce_all); - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -using CUDA = paddle::platform::CUDADeviceContext; - -REGISTER_OP_CUDA_KERNEL(p_norm, - ops::PnormCUDAKernel, - ops::PnormCUDAKernel, - ops::PnormCUDAKernel, - ops::PnormCUDAKernel); -REGISTER_OP_CUDA_KERNEL( - p_norm_grad, ops::PnormGradCUDAKernel, - ops::PnormGradCUDAKernel, - ops::PnormGradCUDAKernel, - ops::PnormGradCUDAKernel); diff --git a/paddle/fluid/operators/p_norm_op.h b/paddle/fluid/operators/p_norm_op.h deleted file mode 100644 index f2da1af8cc..0000000000 --- a/paddle/fluid/operators/p_norm_op.h +++ /dev/null @@ -1,138 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -Indicesou may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { - -inline void GetDims(const framework::DDim& dim, int axis, int* pre, int* n, - int* post, bool asvector) { - *pre = 1; - *post = 1; - *n = dim[axis]; - if (asvector) { - *n = product(dim); - } else { - for (int i = 0; i < axis; ++i) { - (*pre) *= dim[i]; - } - for (int i = axis + 1; i < dim.size(); ++i) { - (*post) *= dim[i]; - } - } -} - -template -class PnormKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* in_x = ctx.Input("X"); - auto* out_norm = ctx.Output("Out"); - out_norm->mutable_data(ctx.GetPlace()); - - auto xdim = in_x->dims(); - float porder = ctx.Attr("porder"); - int axis = ctx.Attr("axis"); - bool asvector = ctx.Attr("asvector"); - if (axis < 0) axis = xdim.size() + axis; - int pre, n, post; - GetDims(xdim, axis, &pre, &n, &post, asvector); - - auto* place = ctx.template device_context().eigen_device(); - - Eigen::DSizes shape(pre, n, post); - Eigen::DSizes norm_shape(pre, post); - - auto x_e = framework::EigenVector::Flatten(*in_x); - auto norm_e = framework::EigenVector::Flatten(*out_norm); - - auto x = x_e.reshape(shape); - auto norm = norm_e.reshape(norm_shape); - - // p=0 means number of non-zero elements of (x) - // p=inf means the maximum of |x| - // p=-inf means the minimum of |x| - // otherwise, Lp-norm = pow(sum(pow(|x|, p)), 1/p) - Eigen::DSizes rdim(1); - if (porder == 0) { - norm.device(*place) = (x != x.constant(0)).template cast().sum(rdim); - } else if (porder == INFINITY) { - norm.device(*place) = x.abs().maximum(rdim); - } else if (porder == -INFINITY) { - norm.device(*place) = x.abs().minimum(rdim); - } else { - norm.device(*place) = x.abs().pow(porder).sum(rdim).pow(1.0f / porder); - } - } -}; - -template -class PnormGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* in_x = ctx.Input("X"); - auto* in_norm = ctx.Input("Out"); - auto* in_norm_dy = - ctx.Input(framework::GradVarName("Out")); - auto* out_dx = ctx.Output(framework::GradVarName("X")); - out_dx->mutable_data(ctx.GetPlace()); - - T eps = static_cast(ctx.Attr("epsilon")); - auto xdim = in_x->dims(); - float porder = ctx.Attr("porder"); - - int axis = ctx.Attr("axis"); - bool asvector = ctx.Attr("asvector"); - if (axis < 0) axis = xdim.size() + axis; - int pre, n, post; - GetDims(xdim, axis, &pre, &n, &post, asvector); - Eigen::DSizes shape(pre, n, post); - Eigen::DSizes rshape(pre, 1, post); - - auto* place = ctx.template device_context().eigen_device(); - - auto x_e = framework::EigenVector::Flatten(*in_x); - auto dx_e = framework::EigenVector::Flatten(*out_dx); - auto norm_e = framework::EigenVector::Flatten(*in_norm); - auto norm_dy_e = framework::EigenVector::Flatten(*in_norm_dy); - - auto x = x_e.reshape(shape); - auto dx = dx_e.reshape(shape); - auto norm = norm_e.reshape(rshape); - auto norm_dy = norm_dy_e.reshape(rshape); - - Eigen::DSizes rdim(1); - Eigen::DSizes bcast(1, n, 1); - - if (porder == 0) { - phi::funcs::SetConstant set_zero; - auto& dev_ctx = ctx.template device_context(); - set_zero(dev_ctx, out_dx, static_cast(0)); - } else if (porder == INFINITY || porder == -INFINITY) { - dx.device(*place) = - (x.abs() == norm.broadcast(bcast)).template cast() * x.sign() * - norm_dy.broadcast(bcast); - } else { - dx.device(*place) = - (x.abs()).pow(porder - 1.0f) / - ((norm.broadcast(bcast)).pow(porder - 1.0f) + x.constant(eps)); - dx.device(*place) = dx * norm_dy.broadcast(bcast) * x.sign(); - } - } -}; -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/p_norm_op_npu.cc b/paddle/fluid/operators/p_norm_op_npu.cc index f842114daa..dfc927740f 100644 --- a/paddle/fluid/operators/p_norm_op_npu.cc +++ b/paddle/fluid/operators/p_norm_op_npu.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/p_norm_op.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" namespace paddle { diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index b76661d49b..a67cc270c2 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1012,6 +1012,63 @@ void PixelShuffleInferMeta(const MetaTensor& x, out->set_dims(output_dims); } +void PNormInferMeta(const MetaTensor& x, + float porder, + int axis, + float epsilon, + bool keepdim, + bool asvector, + MetaTensor* out) { + auto x_dim = x.dims(); + auto x_rank = x_dim.size(); + + PADDLE_ENFORCE_GE(axis, + -x_rank, + errors::InvalidArgument( + "Attr(axis) value should be in range [-R, R-1], R is " + "the rank of Input(X). But received axis: %d, R: %d. " + "Current Input(X)'s shape is=[%s].", + axis, + x_rank, + x_dim)); + PADDLE_ENFORCE_LT(axis, + x_rank, + errors::InvalidArgument( + "Attr(axis) value should be in range [-R, R-1], R is " + "the rank of Input(X). But received axis: %d, R: %d. " + "Current Input(X)'s shape is=[%s].", + axis, + x_rank, + x_dim)); + + std::vector reduce_dims; + if (asvector) { + reduce_dims.emplace_back(1); + if (keepdim) { + for (int i = 1; i < x_dim.size(); ++i) { + reduce_dims.emplace_back(1); + } + x_dim = phi::make_ddim(reduce_dims); + } + } else { + if (axis < 0) axis = x_dim.size() + axis; + for (int i = 0; i < x_dim.size(); ++i) { + if (i != axis) reduce_dims.emplace_back(x_dim[i]); + } + if (reduce_dims.size() == 0) { + reduce_dims.emplace_back(1); + } + } + x_dim[axis] = 1; + + if (keepdim) { + out->set_dims(x_dim); + } else { + out->set_dims(phi::make_ddim(reduce_dims)); + } + out->set_dtype(x.dtype()); +} + void PoolInferMeta(const MetaTensor& x, const std::vector& kernel_size, const std::vector& strides, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 8e254965ab..697926b76a 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -166,6 +166,14 @@ void PixelShuffleInferMeta(const MetaTensor& x, const std::string& data_format, MetaTensor* out); +void PNormInferMeta(const MetaTensor& x, + float porder, + int axis, + float epsilon, + bool keepdim, + bool asvector, + MetaTensor* out); + void PoolInferMeta(const MetaTensor& x, const std::vector& kernel_size, const std::vector& strides, diff --git a/paddle/phi/kernels/cpu/p_norm_grad_kernel.cc b/paddle/phi/kernels/cpu/p_norm_grad_kernel.cc new file mode 100644 index 0000000000..44ab050408 --- /dev/null +++ b/paddle/phi/kernels/cpu/p_norm_grad_kernel.cc @@ -0,0 +1,101 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/p_norm_grad_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +inline void GetDims(const phi::DDim& dim, + int axis, + int* pre, + int* n, + int* post, + bool asvector) { + *pre = 1; + *post = 1; + *n = dim[axis]; + if (asvector) { + *n = product(dim); + } else { + for (int i = 0; i < axis; ++i) { + (*pre) *= dim[i]; + } + for (int i = axis + 1; i < dim.size(); ++i) { + (*post) *= dim[i]; + } + } +} + +template +void PNormGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out, + const DenseTensor& out_grad, + float porder, + int axis, + float epsilon, + bool keepdim, + bool asvector, + DenseTensor* x_grad) { + auto* in_x = &x; + auto* in_norm = &out; + auto* in_norm_dy = &out_grad; + auto* out_dx = x_grad; + dev_ctx.template Alloc(out_dx); + + T eps = static_cast(epsilon); + auto xdim = in_x->dims(); + + if (axis < 0) axis = xdim.size() + axis; + int pre, n, post; + GetDims(xdim, axis, &pre, &n, &post, asvector); + Eigen::DSizes shape(pre, n, post); + Eigen::DSizes rshape(pre, 1, post); + + auto* place = dev_ctx.eigen_device(); + + auto x_e = phi::EigenVector::Flatten(*in_x); + auto dx_e = phi::EigenVector::Flatten(*out_dx); + auto norm_e = phi::EigenVector::Flatten(*in_norm); + auto norm_dy_e = phi::EigenVector::Flatten(*in_norm_dy); + + auto xr = x_e.reshape(shape); + auto dx = dx_e.reshape(shape); + auto norm = norm_e.reshape(rshape); + auto norm_dy = norm_dy_e.reshape(rshape); + + Eigen::DSizes rdim(1); + Eigen::DSizes bcast(1, n, 1); + + if (porder == 0) { + phi::funcs::SetConstant set_zero; + set_zero(dev_ctx, out_dx, static_cast(0)); + } else if (porder == INFINITY || porder == -INFINITY) { + dx.device(*place) = (xr.abs() == norm.broadcast(bcast)).template cast() * + xr.sign() * norm_dy.broadcast(bcast); + } else { + dx.device(*place) = + (xr.abs()).pow(porder - 1.0f) / + ((norm.broadcast(bcast)).pow(porder - 1.0f) + xr.constant(eps)); + dx.device(*place) = dx * norm_dy.broadcast(bcast) * xr.sign(); + } +} +} // namespace phi +PD_REGISTER_KERNEL( + p_norm_grad, CPU, ALL_LAYOUT, phi::PNormGradKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/p_norm_kernel.cc b/paddle/phi/kernels/cpu/p_norm_kernel.cc new file mode 100644 index 0000000000..9da7fdbb29 --- /dev/null +++ b/paddle/phi/kernels/cpu/p_norm_kernel.cc @@ -0,0 +1,90 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/p_norm_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/common_shape.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +inline void GetDims(const phi::DDim& dim, + int axis, + int* pre, + int* n, + int* post, + bool asvector) { + *pre = 1; + *post = 1; + *n = dim[axis]; + if (asvector) { + *n = product(dim); + } else { + for (int i = 0; i < axis; ++i) { + (*pre) *= dim[i]; + } + for (int i = axis + 1; i < dim.size(); ++i) { + (*post) *= dim[i]; + } + } +} + +template +void PNormKernel(const Context& dev_ctx, + const DenseTensor& x, + float porder, + int axis, + float epsilon, + bool keepdim, + bool asvector, + DenseTensor* out) { + auto* in_x = &x; + dev_ctx.template Alloc(out); + + auto xdim = in_x->dims(); + if (axis < 0) axis = xdim.size() + axis; + int pre, n, post; + GetDims(xdim, axis, &pre, &n, &post, asvector); + + auto* place = dev_ctx.eigen_device(); + + Eigen::DSizes shape(pre, n, post); + Eigen::DSizes norm_shape(pre, post); + + auto x_e = phi::EigenVector::Flatten(*in_x); + auto norm_e = phi::EigenVector::Flatten(*out); + + auto xr = x_e.reshape(shape); + auto norm = norm_e.reshape(norm_shape); + + // p=0 means number of non-zero elements of (xr) + // p=inf means the maximum of |xr| + // p=-inf means the minimum of |xr| + // otherwise, Lp-norm = pow(sum(pow(|xr|, p)), 1/p) + Eigen::DSizes rdim(1); + if (porder == 0) { + norm.device(*place) = (xr != xr.constant(0)).template cast().sum(rdim); + } else if (porder == INFINITY) { + norm.device(*place) = xr.abs().maximum(rdim); + } else if (porder == -INFINITY) { + norm.device(*place) = xr.abs().minimum(rdim); + } else { + norm.device(*place) = xr.abs().pow(porder).sum(rdim).pow(1.0f / porder); + } +} +} // namespace phi +PD_REGISTER_KERNEL(p_norm, CPU, ALL_LAYOUT, phi::PNormKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/p_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/p_norm_grad_kernel.cu new file mode 100644 index 0000000000..9b0e43d25a --- /dev/null +++ b/paddle/phi/kernels/gpu/p_norm_grad_kernel.cu @@ -0,0 +1,112 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/p_norm_grad_kernel.h" + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/reduce_grad_functions.h" + +namespace phi { + +template +struct AbsMaxAndMinGradFunctor { + template + void operator()(const Context& place, + X* x, + Y* y, + DX* dx, + DY* dy, + const Dim& dim, + int size) { + dx->device(place) = dy->broadcast(dim) * (*x).sign() * + ((*x).abs() == y->broadcast(dim)).template cast(); + } +}; + +template +struct PNormGradFunctor { + HOSTDEVICE explicit inline PNormGradFunctor(float porder) { + this->porder = static_cast(porder - 1.); + } + template + void operator()(const Context& place, + X* x, + Y* y, + DX* dx, + DY* dy, + const Dim& dim, + int size) { + dx->device(place) = (*x).abs().pow(this->porder) * (*x).sign() * + dy->broadcast(dim) * + (*y).pow(-this->porder).broadcast(dim); + } + T porder; +}; + +template +void PNormGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out, + const DenseTensor& out_grad, + float porder, + int axis, + float epsilon, + bool keepdim, + bool asvector, + DenseTensor* x_grad) { + auto* in_x = &x; + auto* in_norm = &out; + auto* in_norm_dy = &out_grad; + auto* out_dx = x_grad; + dev_ctx.template Alloc(out_dx); + + auto xdim = in_x->dims(); + bool reduce_all = (in_norm->numel() == 1); + if (axis < 0) axis = xdim.size() + axis; + const std::vector dims = {axis}; + + if (porder == 0) { + phi::funcs::SetConstant set_zero; + set_zero(dev_ctx, out_dx, static_cast(0)); + } else if (porder == INFINITY || porder == -INFINITY) { + AbsMaxAndMinGradFunctor functor; + funcs::LaunchReduceGradKernel>( + dev_ctx, in_x, in_norm, in_norm_dy, out_dx, functor, dims, reduce_all); + + } else { + auto functor = PNormGradFunctor(porder); + funcs::LaunchReduceGradKernel>( + dev_ctx, in_x, in_norm, in_norm_dy, out_dx, functor, dims, reduce_all); + } +} +} // namespace phi +PD_REGISTER_KERNEL(p_norm_grad, + GPU, + ALL_LAYOUT, + phi::PNormGradKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/p_norm_kernel.cu b/paddle/phi/kernels/gpu/p_norm_kernel.cu new file mode 100644 index 0000000000..80ef97d9cf --- /dev/null +++ b/paddle/phi/kernels/gpu/p_norm_kernel.cu @@ -0,0 +1,138 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/p_norm_kernel.h" + +#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/reduce_function.h" +#include "paddle/phi/kernels/gpu/reduce.h" + +namespace phi { + +template +__device__ __forceinline__ int sgn(T val) { + return (T(0) < val) - (val < T(0)); +} + +__device__ __forceinline__ dtype::float16 inline_abs(dtype::float16 x) { + return static_cast(abs(static_cast(x))); +} + +__device__ __forceinline__ dtype::bfloat16 inline_abs(dtype::bfloat16 x) { + return static_cast(abs(static_cast(x))); +} + +__device__ __forceinline__ float inline_abs(float x) { return abs(x); } +__device__ __forceinline__ double inline_abs(double x) { return abs(x); } + +__device__ __forceinline__ int inline_sign(dtype::float16 x) { + return sgn(x); +} +__device__ __forceinline__ int inline_sign(float x) { return sgn(x); } +__device__ __forceinline__ int inline_sign(double x) { return sgn(x); } + +__device__ __forceinline__ dtype::float16 inline_pow(dtype::float16 base, + dtype::float16 exponent) { + return static_cast( + pow(static_cast(base), static_cast(exponent))); +} +__device__ __forceinline__ dtype::bfloat16 inline_pow( + dtype::bfloat16 base, dtype::bfloat16 exponent) { + return static_cast( + pow(static_cast(base), static_cast(exponent))); +} +__device__ __forceinline__ float inline_pow(float base, float exponent) { + return pow(base, exponent); +} +__device__ __forceinline__ double inline_pow(double base, double exponent) { + return pow(base, exponent); +} + +template +struct NonzeroFunctor { + HOSTDEVICE explicit inline NonzeroFunctor() {} + HOSTDEVICE inline T operator()(const T x) const { + return static_cast(static_cast(x) != 0); + } +}; + +template +struct AbsFunctor { + HOSTDEVICE explicit inline AbsFunctor() {} + HOSTDEVICE inline T operator()(const T x) const { + return static_cast(inline_abs(x)); + } +}; + +template +struct UnsignedPowFunctor { + HOSTDEVICE explicit inline UnsignedPowFunctor(float porder) { + this->porder = porder; + } + HOSTDEVICE inline T operator()(const T x) const { + return static_cast(inline_pow(inline_abs(x), static_cast(porder))); + } + float porder; +}; + +template +void PNormKernel(const Context& dev_ctx, + const DenseTensor& x, + float porder, + int axis, + float epsilon, + bool keepdim, + bool asvector, + DenseTensor* out) { + auto* in_x = &x; + auto* out_norm = out; + T* norm = dev_ctx.template Alloc(out); + auto xdim = in_x->dims(); + std::vector axis_dims = {static_cast(axis)}; + std::vector reduce_axis = + funcs::details::GetReduceDim(axis_dims, xdim.size(), asvector); + + using MT = typename dtype::MPTypeTrait::Type; + if (porder == 0) { + phi::funcs::ReduceKernel>( + dev_ctx, *in_x, out_norm, NonzeroFunctor(), reduce_axis); + } else if (porder == INFINITY) { + phi::funcs::ReduceKernel>( + dev_ctx, *in_x, out_norm, AbsFunctor(), reduce_axis); + } else if (porder == -INFINITY) { + phi::funcs::ReduceKernel>( + dev_ctx, *in_x, out_norm, AbsFunctor(), reduce_axis); + } else { + phi::funcs::ReduceKernel>( + dev_ctx, *in_x, out_norm, UnsignedPowFunctor(porder), reduce_axis); + + const DenseTensor* tmp_norm = out_norm; + std::vector ins = {tmp_norm}; + std::vector outs = {out_norm}; + phi::funcs::ElementwiseKernel( + dev_ctx, ins, &outs, UnsignedPowFunctor(1. / porder)); + } +} +} // namespace phi + +PD_REGISTER_KERNEL(p_norm, + GPU, + ALL_LAYOUT, + phi::PNormKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/p_norm_grad_kernel.h b/paddle/phi/kernels/p_norm_grad_kernel.h new file mode 100644 index 0000000000..a64c8ceee4 --- /dev/null +++ b/paddle/phi/kernels/p_norm_grad_kernel.h @@ -0,0 +1,32 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void PNormGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out, + const DenseTensor& out_grad, + float porder, + int axis, + float epsilon, + bool keepdim, + bool asvector, + DenseTensor* x_grad); +} // namespace phi diff --git a/paddle/phi/kernels/p_norm_kernel.h b/paddle/phi/kernels/p_norm_kernel.h new file mode 100644 index 0000000000..8e9af01ba3 --- /dev/null +++ b/paddle/phi/kernels/p_norm_kernel.h @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void PNormKernel(const Context& dev_ctx, + const DenseTensor& x, + float porder, + int axis, + float epsilon, + bool keepdim, + bool asvector, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/p_norm_sig.cc b/paddle/phi/ops/compat/p_norm_sig.cc new file mode 100644 index 0000000000..d3bff55346 --- /dev/null +++ b/paddle/phi/ops/compat/p_norm_sig.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { +KernelSignature PNormGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("p_norm_grad", + {"X", "Out", GradVarName("Out")}, + {"porder", "axis", "epsilon", "keepdim", "asvector"}, + {GradVarName("X")}); +} +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(p_norm_grad, phi::PNormGradOpArgumentMapping); -- GitLab