未验证 提交 3c121040 编写于 作者: shaojie_wang's avatar shaojie_wang 提交者: GitHub

[Bfloat16]register bfloat16 datatype for squared l2 norm (#50908)

* register bfloat16 datatype for squared l2 norm

* register bfloat16 datatype for softmax with upper triangular mask

* register bfloat16 for tril triu cuda kernel
上级 5d322ced
...@@ -67,6 +67,11 @@ __device__ __inline__ void load_data_upper_tri(plat::float16* dst, ...@@ -67,6 +67,11 @@ __device__ __inline__ void load_data_upper_tri(plat::float16* dst,
*(reinterpret_cast<float2*>(dst)) = *(reinterpret_cast<const float2*>(src)); *(reinterpret_cast<float2*>(dst)) = *(reinterpret_cast<const float2*>(src));
} }
__device__ __inline__ void load_data_upper_tri(plat::bfloat16* dst,
const plat::bfloat16* src) {
*(reinterpret_cast<float2*>(dst)) = *(reinterpret_cast<const float2*>(src));
}
__device__ __inline__ void load_data_upper_tri(float* dst, const float* src) { __device__ __inline__ void load_data_upper_tri(float* dst, const float* src) {
*(reinterpret_cast<float4*>(dst)) = *(reinterpret_cast<const float4*>(src)); *(reinterpret_cast<float4*>(dst)) = *(reinterpret_cast<const float4*>(src));
} }
...@@ -75,6 +80,10 @@ __device__ __inline__ void load_zero_vector_upper_tri(plat::float16* dst) { ...@@ -75,6 +80,10 @@ __device__ __inline__ void load_zero_vector_upper_tri(plat::float16* dst) {
*(reinterpret_cast<float2*>(dst)) = make_float2(0.0f, 0.0f); *(reinterpret_cast<float2*>(dst)) = make_float2(0.0f, 0.0f);
} }
__device__ __inline__ void load_zero_vector_upper_tri(plat::bfloat16* dst) {
*(reinterpret_cast<float2*>(dst)) = make_float2(0.0f, 0.0f);
}
__device__ __inline__ void load_zero_vector_upper_tri(float* dst) { __device__ __inline__ void load_zero_vector_upper_tri(float* dst) {
*(reinterpret_cast<float4*>(dst)) = make_float4(0.0f, 0.0f, 0.0f, 0.0f); *(reinterpret_cast<float4*>(dst)) = make_float4(0.0f, 0.0f, 0.0f, 0.0f);
} }
...@@ -596,8 +605,11 @@ namespace plat = paddle::platform; ...@@ -596,8 +605,11 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
fused_softmax_mask_upper_triangle, fused_softmax_mask_upper_triangle,
ops::SoftmaxMaskFuseUpperTriangleKernel<phi::GPUContext, plat::float16>, ops::SoftmaxMaskFuseUpperTriangleKernel<phi::GPUContext, plat::float16>,
ops::SoftmaxMaskFuseUpperTriangleKernel<phi::GPUContext, plat::bfloat16>,
ops::SoftmaxMaskFuseUpperTriangleKernel<phi::GPUContext, float>); ops::SoftmaxMaskFuseUpperTriangleKernel<phi::GPUContext, float>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
fused_softmax_mask_upper_triangle_grad, fused_softmax_mask_upper_triangle_grad,
ops::SoftmaxMaskFuseUpperTriangleGradKernel<phi::GPUContext, plat::float16>, ops::SoftmaxMaskFuseUpperTriangleGradKernel<phi::GPUContext, plat::float16>,
ops::SoftmaxMaskFuseUpperTriangleGradKernel<phi::GPUContext,
plat::bfloat16>,
ops::SoftmaxMaskFuseUpperTriangleGradKernel<phi::GPUContext, float>); ops::SoftmaxMaskFuseUpperTriangleGradKernel<phi::GPUContext, float>);
...@@ -57,4 +57,5 @@ PD_REGISTER_KERNEL(squared_l2_norm_grad, ...@@ -57,4 +57,5 @@ PD_REGISTER_KERNEL(squared_l2_norm_grad,
phi::SquaredL2NormGradKernel, phi::SquaredL2NormGradKernel,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -41,4 +41,5 @@ PD_REGISTER_KERNEL(squared_l2_norm, ...@@ -41,4 +41,5 @@ PD_REGISTER_KERNEL(squared_l2_norm,
phi::SquaredL2NormKernel, phi::SquaredL2NormKernel,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -25,7 +25,8 @@ PD_REGISTER_KERNEL(tril_grad, ...@@ -25,7 +25,8 @@ PD_REGISTER_KERNEL(tril_grad,
double, double,
int, int,
int64_t, int64_t,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(triu_grad, PD_REGISTER_KERNEL(triu_grad,
GPU, GPU,
...@@ -36,7 +37,8 @@ PD_REGISTER_KERNEL(triu_grad, ...@@ -36,7 +37,8 @@ PD_REGISTER_KERNEL(triu_grad,
double, double,
int, int,
int64_t, int64_t,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(tril_triu_grad, PD_REGISTER_KERNEL(tril_triu_grad,
GPU, GPU,
...@@ -47,4 +49,5 @@ PD_REGISTER_KERNEL(tril_triu_grad, ...@@ -47,4 +49,5 @@ PD_REGISTER_KERNEL(tril_triu_grad,
double, double,
int, int,
int64_t, int64_t,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -25,7 +25,8 @@ PD_REGISTER_KERNEL(tril_triu, ...@@ -25,7 +25,8 @@ PD_REGISTER_KERNEL(tril_triu,
double, double,
int, int,
int64_t, int64_t,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(triu, PD_REGISTER_KERNEL(triu,
GPU, GPU,
...@@ -36,7 +37,8 @@ PD_REGISTER_KERNEL(triu, ...@@ -36,7 +37,8 @@ PD_REGISTER_KERNEL(triu,
double, double,
int, int,
int64_t, int64_t,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(tril, PD_REGISTER_KERNEL(tril,
GPU, GPU,
...@@ -47,4 +49,5 @@ PD_REGISTER_KERNEL(tril, ...@@ -47,4 +49,5 @@ PD_REGISTER_KERNEL(tril,
double, double,
int, int,
int64_t, int64_t,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册