// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/phi/kernels/p_norm_kernel.h" #include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/reduce_function.h" #include "paddle/phi/kernels/gpu/reduce.h" namespace phi { template __device__ __forceinline__ int sgn(T val) { return (T(0) < val) - (val < T(0)); } __device__ __forceinline__ dtype::float16 inline_abs(dtype::float16 x) { return static_cast(abs(static_cast(x))); } __device__ __forceinline__ dtype::bfloat16 inline_abs(dtype::bfloat16 x) { return static_cast(abs(static_cast(x))); } __device__ __forceinline__ float inline_abs(float x) { return abs(x); } __device__ __forceinline__ double inline_abs(double x) { return abs(x); } __device__ __forceinline__ int inline_sign(dtype::float16 x) { return sgn(x); } __device__ __forceinline__ int inline_sign(float x) { return sgn(x); } __device__ __forceinline__ int inline_sign(double x) { return sgn(x); } __device__ __forceinline__ dtype::float16 inline_pow(dtype::float16 base, dtype::float16 exponent) { return static_cast( pow(static_cast(base), static_cast(exponent))); } __device__ __forceinline__ dtype::bfloat16 inline_pow( dtype::bfloat16 base, dtype::bfloat16 exponent) { return static_cast( pow(static_cast(base), static_cast(exponent))); } __device__ __forceinline__ float inline_pow(float base, float exponent) { return pow(base, exponent); } __device__ __forceinline__ double inline_pow(double base, double exponent) { return pow(base, exponent); } template struct NonzeroFunctor { HOSTDEVICE explicit inline NonzeroFunctor() {} HOSTDEVICE inline T operator()(const T x) const { return static_cast(static_cast(x) != 0); } }; template struct AbsFunctor { HOSTDEVICE explicit inline AbsFunctor() {} HOSTDEVICE inline T operator()(const T x) const { return static_cast(inline_abs(x)); } }; template struct UnsignedPowFunctor { HOSTDEVICE explicit inline UnsignedPowFunctor(float porder) { this->porder = porder; } HOSTDEVICE inline T operator()(const T x) const { return static_cast(inline_pow(inline_abs(x), static_cast(porder))); } float porder; }; template void PNormKernel(const Context& dev_ctx, const DenseTensor& x, float porder, int axis, float epsilon, bool keepdim, bool asvector, DenseTensor* out) { auto* in_x = &x; auto* out_norm = out; T* norm = dev_ctx.template Alloc(out); auto xdim = in_x->dims(); std::vector axis_dims = {static_cast(axis)}; std::vector reduce_axis = funcs::details::GetReduceDim(axis_dims, xdim.size(), asvector); using MT = typename dtype::MPTypeTrait::Type; if (porder == 0) { phi::funcs::ReduceKernel>( dev_ctx, *in_x, out_norm, NonzeroFunctor(), reduce_axis); } else if (porder == INFINITY) { phi::funcs::ReduceKernel>( dev_ctx, *in_x, out_norm, AbsFunctor(), reduce_axis); } else if (porder == -INFINITY) { phi::funcs::ReduceKernel>( dev_ctx, *in_x, out_norm, AbsFunctor(), reduce_axis); } else { phi::funcs::ReduceKernel>( dev_ctx, *in_x, out_norm, UnsignedPowFunctor(porder), reduce_axis); const DenseTensor* tmp_norm = out_norm; std::vector ins = {tmp_norm}; std::vector outs = {out_norm}; phi::funcs::ElementwiseKernel( dev_ctx, ins, &outs, UnsignedPowFunctor(1. / porder)); } } } // namespace phi PD_REGISTER_KERNEL(p_norm, GPU, ALL_LAYOUT, phi::PNormKernel, float, double, phi::dtype::float16, phi::dtype::bfloat16) {}