From 7f0bdf0787909929e44dbf9a1fc88ade3ca22522 Mon Sep 17 00:00:00 2001 From: zbt78 <1095497213@qq.com> Date: Mon, 14 Aug 2023 14:30:28 +0800 Subject: [PATCH] write the common functions p_norm_kernel.cu and p_norm_grad_kernel.cu to p_norm_utils.h (#56191) --- paddle/phi/kernels/funcs/p_norm_utils.h | 63 +++++++++++++++++++++++++ paddle/phi/kernels/gpu/p_norm_kernel.cu | 41 +--------------- 2 files changed, 64 insertions(+), 40 deletions(-) create mode 100644 paddle/phi/kernels/funcs/p_norm_utils.h diff --git a/paddle/phi/kernels/funcs/p_norm_utils.h b/paddle/phi/kernels/funcs/p_norm_utils.h new file mode 100644 index 00000000000..b614afa3630 --- /dev/null +++ b/paddle/phi/kernels/funcs/p_norm_utils.h @@ -0,0 +1,63 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace phi { +__device__ __forceinline__ dtype::float16 inline_abs(dtype::float16 x) { + return static_cast(abs(static_cast(x))); +} + +__device__ __forceinline__ dtype::bfloat16 inline_abs(dtype::bfloat16 x) { + return static_cast(abs(static_cast(x))); +} + +__device__ __forceinline__ float inline_abs(float x) { return abs(x); } + +__device__ __forceinline__ double inline_abs(double x) { return abs(x); } + +template +__device__ __forceinline__ int sgn(T val) { + return (T(0) < val) - (val < T(0)); +} + +__device__ __forceinline__ int inline_sign(dtype::float16 x) { + return sgn(x); +} + +__device__ __forceinline__ int inline_sign(dtype::bfloat16 x) { + return sgn(x); +} + +__device__ __forceinline__ int inline_sign(float x) { return sgn(x); } + +__device__ __forceinline__ int inline_sign(double x) { return sgn(x); } + +__device__ __forceinline__ dtype::float16 inline_pow(dtype::float16 base, + dtype::float16 exponent) { + return static_cast( + pow(static_cast(base), static_cast(exponent))); +} +__device__ __forceinline__ dtype::bfloat16 inline_pow( + dtype::bfloat16 base, dtype::bfloat16 exponent) { + return static_cast( + pow(static_cast(base), static_cast(exponent))); +} +__device__ __forceinline__ float inline_pow(float base, float exponent) { + return pow(base, exponent); +} +__device__ __forceinline__ double inline_pow(double base, double exponent) { + return pow(base, exponent); +} +} // namespace phi diff --git a/paddle/phi/kernels/gpu/p_norm_kernel.cu b/paddle/phi/kernels/gpu/p_norm_kernel.cu index 556a6308ff4..cb93772dea1 100644 --- a/paddle/phi/kernels/gpu/p_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/p_norm_kernel.cu @@ -17,50 +17,11 @@ #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/elementwise_base.h" +#include "paddle/phi/kernels/funcs/p_norm_utils.h" #include "paddle/phi/kernels/funcs/reduce_function.h" #include "paddle/phi/kernels/gpu/reduce.h" namespace phi { - -template -__device__ __forceinline__ int sgn(T val) { - return (T(0) < val) - (val < T(0)); -} - -__device__ __forceinline__ dtype::float16 inline_abs(dtype::float16 x) { - return static_cast(abs(static_cast(x))); -} - -__device__ __forceinline__ dtype::bfloat16 inline_abs(dtype::bfloat16 x) { - return static_cast(abs(static_cast(x))); -} - -__device__ __forceinline__ float inline_abs(float x) { return abs(x); } -__device__ __forceinline__ double inline_abs(double x) { return abs(x); } - -__device__ __forceinline__ int inline_sign(dtype::float16 x) { - return sgn(x); -} -__device__ __forceinline__ int inline_sign(float x) { return sgn(x); } -__device__ __forceinline__ int inline_sign(double x) { return sgn(x); } - -__device__ __forceinline__ dtype::float16 inline_pow(dtype::float16 base, - dtype::float16 exponent) { - return static_cast( - pow(static_cast(base), static_cast(exponent))); -} -__device__ __forceinline__ dtype::bfloat16 inline_pow( - dtype::bfloat16 base, dtype::bfloat16 exponent) { - return static_cast( - pow(static_cast(base), static_cast(exponent))); -} -__device__ __forceinline__ float inline_pow(float base, float exponent) { - return pow(base, exponent); -} -__device__ __forceinline__ double inline_pow(double base, double exponent) { - return pow(base, exponent); -} - template struct NonzeroFunctor { HOSTDEVICE explicit inline NonzeroFunctor() = default; -- GitLab