未验证 提交 7f0bdf07 编写于 作者: 周波涛's avatar 周波涛 提交者: GitHub

write the common functions p_norm_kernel.cu and p_norm_grad_kernel.cu to p_norm_utils.h (#56191)

上级 08d726f5
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace phi {
__device__ __forceinline__ dtype::float16 inline_abs(dtype::float16 x) {
return static_cast<dtype::float16>(abs(static_cast<float>(x)));
}
__device__ __forceinline__ dtype::bfloat16 inline_abs(dtype::bfloat16 x) {
return static_cast<dtype::bfloat16>(abs(static_cast<float>(x)));
}
__device__ __forceinline__ float inline_abs(float x) { return abs(x); }
__device__ __forceinline__ double inline_abs(double x) { return abs(x); }
template <typename T>
__device__ __forceinline__ int sgn(T val) {
return (T(0) < val) - (val < T(0));
}
__device__ __forceinline__ int inline_sign(dtype::float16 x) {
return sgn<dtype::float16>(x);
}
__device__ __forceinline__ int inline_sign(dtype::bfloat16 x) {
return sgn<dtype::bfloat16>(x);
}
__device__ __forceinline__ int inline_sign(float x) { return sgn<float>(x); }
__device__ __forceinline__ int inline_sign(double x) { return sgn<double>(x); }
__device__ __forceinline__ dtype::float16 inline_pow(dtype::float16 base,
dtype::float16 exponent) {
return static_cast<dtype::float16>(
pow(static_cast<float>(base), static_cast<float>(exponent)));
}
__device__ __forceinline__ dtype::bfloat16 inline_pow(
dtype::bfloat16 base, dtype::bfloat16 exponent) {
return static_cast<dtype::bfloat16>(
pow(static_cast<float>(base), static_cast<float>(exponent)));
}
__device__ __forceinline__ float inline_pow(float base, float exponent) {
return pow(base, exponent);
}
__device__ __forceinline__ double inline_pow(double base, double exponent) {
return pow(base, exponent);
}
} // namespace phi
......@@ -17,50 +17,11 @@
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/p_norm_utils.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
#include "paddle/phi/kernels/gpu/reduce.h"
namespace phi {
template <typename T>
__device__ __forceinline__ int sgn(T val) {
return (T(0) < val) - (val < T(0));
}
__device__ __forceinline__ dtype::float16 inline_abs(dtype::float16 x) {
return static_cast<dtype::float16>(abs(static_cast<float>(x)));
}
__device__ __forceinline__ dtype::bfloat16 inline_abs(dtype::bfloat16 x) {
return static_cast<dtype::bfloat16>(abs(static_cast<float>(x)));
}
__device__ __forceinline__ float inline_abs(float x) { return abs(x); }
__device__ __forceinline__ double inline_abs(double x) { return abs(x); }
__device__ __forceinline__ int inline_sign(dtype::float16 x) {
return sgn<dtype::float16>(x);
}
__device__ __forceinline__ int inline_sign(float x) { return sgn<float>(x); }
__device__ __forceinline__ int inline_sign(double x) { return sgn<double>(x); }
__device__ __forceinline__ dtype::float16 inline_pow(dtype::float16 base,
dtype::float16 exponent) {
return static_cast<dtype::float16>(
pow(static_cast<float>(base), static_cast<float>(exponent)));
}
__device__ __forceinline__ dtype::bfloat16 inline_pow(
dtype::bfloat16 base, dtype::bfloat16 exponent) {
return static_cast<dtype::bfloat16>(
pow(static_cast<float>(base), static_cast<float>(exponent)));
}
__device__ __forceinline__ float inline_pow(float base, float exponent) {
return pow(base, exponent);
}
__device__ __forceinline__ double inline_pow(double base, double exponent) {
return pow(base, exponent);
}
template <typename T>
struct NonzeroFunctor {
HOSTDEVICE explicit inline NonzeroFunctor() = default;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册