未验证 提交 3c121040 编写于 作者: shaojie_wang's avatar shaojie_wang 提交者: GitHub

[Bfloat16]register bfloat16 datatype for squared l2 norm (#50908)

* register bfloat16 datatype for squared l2 norm

* register bfloat16 datatype for softmax with upper triangular mask

* register bfloat16 for tril triu cuda kernel
上级 5d322ced
......@@ -67,6 +67,11 @@ __device__ __inline__ void load_data_upper_tri(plat::float16* dst,
*(reinterpret_cast<float2*>(dst)) = *(reinterpret_cast<const float2*>(src));
}
__device__ __inline__ void load_data_upper_tri(plat::bfloat16* dst,
const plat::bfloat16* src) {
*(reinterpret_cast<float2*>(dst)) = *(reinterpret_cast<const float2*>(src));
}
__device__ __inline__ void load_data_upper_tri(float* dst, const float* src) {
*(reinterpret_cast<float4*>(dst)) = *(reinterpret_cast<const float4*>(src));
}
......@@ -75,6 +80,10 @@ __device__ __inline__ void load_zero_vector_upper_tri(plat::float16* dst) {
*(reinterpret_cast<float2*>(dst)) = make_float2(0.0f, 0.0f);
}
__device__ __inline__ void load_zero_vector_upper_tri(plat::bfloat16* dst) {
*(reinterpret_cast<float2*>(dst)) = make_float2(0.0f, 0.0f);
}
__device__ __inline__ void load_zero_vector_upper_tri(float* dst) {
*(reinterpret_cast<float4*>(dst)) = make_float4(0.0f, 0.0f, 0.0f, 0.0f);
}
......@@ -596,8 +605,11 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
fused_softmax_mask_upper_triangle,
ops::SoftmaxMaskFuseUpperTriangleKernel<phi::GPUContext, plat::float16>,
ops::SoftmaxMaskFuseUpperTriangleKernel<phi::GPUContext, plat::bfloat16>,
ops::SoftmaxMaskFuseUpperTriangleKernel<phi::GPUContext, float>);
REGISTER_OP_CUDA_KERNEL(
fused_softmax_mask_upper_triangle_grad,
ops::SoftmaxMaskFuseUpperTriangleGradKernel<phi::GPUContext, plat::float16>,
ops::SoftmaxMaskFuseUpperTriangleGradKernel<phi::GPUContext,
plat::bfloat16>,
ops::SoftmaxMaskFuseUpperTriangleGradKernel<phi::GPUContext, float>);
......@@ -57,4 +57,5 @@ PD_REGISTER_KERNEL(squared_l2_norm_grad,
phi::SquaredL2NormGradKernel,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -41,4 +41,5 @@ PD_REGISTER_KERNEL(squared_l2_norm,
phi::SquaredL2NormKernel,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -25,7 +25,8 @@ PD_REGISTER_KERNEL(tril_grad,
double,
int,
int64_t,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(triu_grad,
GPU,
......@@ -36,7 +37,8 @@ PD_REGISTER_KERNEL(triu_grad,
double,
int,
int64_t,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(tril_triu_grad,
GPU,
......@@ -47,4 +49,5 @@ PD_REGISTER_KERNEL(tril_triu_grad,
double,
int,
int64_t,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -25,7 +25,8 @@ PD_REGISTER_KERNEL(tril_triu,
double,
int,
int64_t,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(triu,
GPU,
......@@ -36,7 +37,8 @@ PD_REGISTER_KERNEL(triu,
double,
int,
int64_t,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(tril,
GPU,
......@@ -47,4 +49,5 @@ PD_REGISTER_KERNEL(tril,
double,
int,
int64_t,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册