未验证 提交 087c23a9 编写于 作者: G Guoxia Wang 提交者: GitHub

support fp16 (#35888)

上级 799f3861
......@@ -41,12 +41,16 @@ namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
elementwise_max,
ops::ElementwiseMaxKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::ElementwiseMaxKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseMaxKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseMaxKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseMaxKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
elementwise_max_grad,
ops::ElementwiseMaxGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::ElementwiseMaxGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseMaxGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseMaxGradKernel<paddle::platform::CUDADeviceContext, int>,
......
......@@ -39,14 +39,14 @@ class ElementwiseMaxKernel : public framework::OpKernel<T> {
template <typename T>
struct MaxGradDx {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
return dout * (x > y);
return dout * static_cast<T>(x > y);
}
};
template <typename T>
struct MaxGradDy {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
return dout * (x <= y);
return dout * static_cast<T>(x <= y);
}
};
......
......@@ -20,7 +20,9 @@ limitations under the License. */
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/p_norm_op.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
......@@ -30,12 +32,23 @@ __device__ __forceinline__ int sgn(T val) {
return (T(0) < val) - (val < T(0));
}
__device__ __forceinline__ platform::float16 inline_abs(platform::float16 x) {
return static_cast<platform::float16>(abs(static_cast<float>(x)));
}
__device__ __forceinline__ float inline_abs(float x) { return abs(x); }
__device__ __forceinline__ double inline_abs(double x) { return abs(x); }
__device__ __forceinline__ int inline_sign(platform::float16 x) {
return sgn<platform::float16>(x);
}
__device__ __forceinline__ int inline_sign(float x) { return sgn<float>(x); }
__device__ __forceinline__ int inline_sign(double x) { return sgn<double>(x); }
__device__ __forceinline__ platform::float16 inline_pow(
platform::float16 base, platform::float16 exponent) {
return static_cast<platform::float16>(
pow(static_cast<float>(base), static_cast<float>(exponent)));
}
__device__ __forceinline__ float inline_pow(float base, float exponent) {
return pow(base, exponent);
}
......@@ -47,21 +60,23 @@ template <typename T, int BlockDim>
__global__ void Pnorm(const T* x, const int pre,
const int axis_n, // dim in axis
const int post, float porder, T* out_norm) {
typedef cub::BlockReduce<T, BlockDim> BlockReduce;
using MT = typename details::MPTypeTrait<T>::Type;
typedef cub::BlockReduce<MT, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
int num = pre * post;
auto porder_t = static_cast<T>(porder);
auto porder_inv = static_cast<T>(1.0 / porder);
auto porder_t = static_cast<MT>(porder);
auto porder_inv = static_cast<MT>(1.0 / porder);
for (int i = blockIdx.x; i < num; i += gridDim.x) {
int base = (i / post) * post * axis_n + (i % post);
T sum = 0.0;
MT sum = static_cast<MT>(0.0);
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
const T x_ij = x[base + j * post];
const MT x_ij = static_cast<MT>(x[base + j * post]);
sum += inline_pow(inline_abs(x_ij), porder_t);
}
T reduce_result = BlockReduce(temp_storage).Sum(sum);
if (threadIdx.x == 0) out_norm[i] = inline_pow(reduce_result, porder_inv);
MT reduce_result = BlockReduce(temp_storage).Sum(sum);
if (threadIdx.x == 0)
out_norm[i] = static_cast<T>(inline_pow(reduce_result, porder_inv));
}
}
......@@ -69,18 +84,19 @@ template <typename T, int BlockDim>
__global__ void ZeorNorm(const T* x, const int pre,
const int axis_n, // dim in axis
const int post, T* out_norm) {
typedef cub::BlockReduce<T, BlockDim> BlockReduce;
using MT = typename details::MPTypeTrait<T>::Type;
typedef cub::BlockReduce<MT, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
int num = pre * post;
for (int i = blockIdx.x; i < num; i += gridDim.x) {
int base = (i / post) * post * axis_n + (i % post);
T sum = 0.0;
MT sum = static_cast<MT>(0.0);
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
const T x_ij = x[base + j * post];
sum += static_cast<T>(x_ij != 0);
const MT x_ij = static_cast<MT>(x[base + j * post]);
sum += static_cast<MT>(static_cast<double>(x_ij) != 0);
}
T reduce_result = BlockReduce(temp_storage).Sum(sum);
if (threadIdx.x == 0) out_norm[i] = reduce_result;
MT reduce_result = BlockReduce(temp_storage).Sum(sum);
if (threadIdx.x == 0) out_norm[i] = static_cast<T>(reduce_result);
}
}
......@@ -172,27 +188,29 @@ __global__ void PnormGradient(const T* x, const T* x_norm, const T* y_grad,
const float porder, const int pre,
const int axis_n, const int post, const T eps,
T* x_grad) {
using MT = typename details::MPTypeTrait<T>::Type;
// dx = (x/pnorm_broadcast).pow(p-1) * norm_dy.broadcast * sign(x)
int num = pre * post;
auto porder_grad = static_cast<T>(porder - 1.0f);
auto porder_grad = static_cast<MT>(porder - 1.0f);
for (int i = blockIdx.x; i < num; i += gridDim.x) {
__shared__ T pnorm_i;
__shared__ T yout_i;
__shared__ MT pnorm_i;
__shared__ MT yout_i;
auto base = (i / post) * post * axis_n + (i % post);
if (threadIdx.x == 0) {
pnorm_i = x_norm[i];
yout_i = y_grad[i];
pnorm_i = static_cast<MT>(x_norm[i]);
yout_i = static_cast<MT>(y_grad[i]);
}
__syncthreads();
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
int index = base + j * post;
const T x_ij = inline_abs(x[index]);
x_grad[index] = inline_pow(x_ij, porder_grad) /
(inline_pow(pnorm_i, porder_grad) + eps) * yout_i *
inline_sign(x[index]);
const MT x_ij = static_cast<MT>(inline_abs(x[index]));
x_grad[index] = static_cast<T>(
inline_pow(x_ij, porder_grad) /
(inline_pow(pnorm_i, porder_grad) + static_cast<MT>(eps)) * yout_i *
static_cast<MT>(inline_sign(x[index])));
}
}
}
......@@ -216,7 +234,7 @@ __global__ void InfNormGradient(const T* x, const T* x_norm, const T* y_grad,
int index = base + j * post;
const T x_ij = inline_abs(x[index]);
if (x_ij == pnorm_i) {
x_grad[index] = inline_sign(x[index]) * yout_i;
x_grad[index] = static_cast<T>(inline_sign(x[index])) * yout_i;
} else {
x_grad[index] = static_cast<T>(0);
}
......@@ -278,7 +296,11 @@ class PnormGradCUDAKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(p_norm, ops::PnormCUDAKernel<CUDA, float>,
REGISTER_OP_CUDA_KERNEL(p_norm,
ops::PnormCUDAKernel<CUDA, paddle::platform::float16>,
ops::PnormCUDAKernel<CUDA, float>,
ops::PnormCUDAKernel<CUDA, double>);
REGISTER_OP_CUDA_KERNEL(p_norm_grad, ops::PnormGradCUDAKernel<CUDA, float>,
ops::PnormGradCUDAKernel<CUDA, double>);
REGISTER_OP_CUDA_KERNEL(
p_norm_grad, ops::PnormGradCUDAKernel<CUDA, paddle::platform::float16>,
ops::PnormGradCUDAKernel<CUDA, float>,
ops::PnormGradCUDAKernel<CUDA, double>);
......@@ -86,7 +86,8 @@ def normalize(x, p=2, axis=1, epsilon=1e-12, name=None):
check_type(p, 'p', (float, int), 'normalize')
check_type(axis, 'axis', (int), 'normalize')
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'normalize')
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'normalize')
if len(x.shape) == 1 and axis != 0 and axis != -1:
raise ValueError(
"Axis must be 0 or -1 when x is a 1-D tensor, but received axis = {}".
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册