From 61dccb06f9d35e92c1fb7af5a8fbc80f0e25bf85 Mon Sep 17 00:00:00 2001 From: phlrain Date: Sat, 12 Mar 2022 13:33:25 +0000 Subject: [PATCH] fix cpu bf16 bug; test=develop --- paddle/phi/kernels/cpu/embedding_grad_kernel.cc | 4 ++-- paddle/phi/kernels/cpu/embedding_kernel.cc | 2 +- paddle/phi/kernels/cpu/sparse_weight_embedding_grad_kernel.cc | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/paddle/phi/kernels/cpu/embedding_grad_kernel.cc b/paddle/phi/kernels/cpu/embedding_grad_kernel.cc index 67c28eefc87..4a6b1014277 100644 --- a/paddle/phi/kernels/cpu/embedding_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/embedding_grad_kernel.cc @@ -197,7 +197,7 @@ PD_REGISTER_KERNEL(embedding_grad, phi::EmbeddingGradKernel, float, double, - phi::dtype::float16) {} + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(embedding_sparse_grad, CPU, @@ -205,4 +205,4 @@ PD_REGISTER_KERNEL(embedding_sparse_grad, phi::EmbeddingSparseGradKernel, float, double, - phi::dtype::float16) {} + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/cpu/embedding_kernel.cc b/paddle/phi/kernels/cpu/embedding_kernel.cc index 63ea7004d42..6c92e9a660a 100644 --- a/paddle/phi/kernels/cpu/embedding_kernel.cc +++ b/paddle/phi/kernels/cpu/embedding_kernel.cc @@ -105,4 +105,4 @@ PD_REGISTER_KERNEL(embedding, phi::EmbeddingKernel, float, double, - phi::dtype::float16) {} + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/cpu/sparse_weight_embedding_grad_kernel.cc b/paddle/phi/kernels/cpu/sparse_weight_embedding_grad_kernel.cc index 743faa3e43e..89237d3f6e8 100644 --- a/paddle/phi/kernels/cpu/sparse_weight_embedding_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/sparse_weight_embedding_grad_kernel.cc @@ -197,7 +197,7 @@ PD_REGISTER_KERNEL(sparse_weight_embedding_grad, phi::SparseWeightEmbeddingGradKernel, float, double, - phi::dtype::float16) {} + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(sparse_weight_embedding_sparse_grad, CPU, @@ -205,4 +205,4 @@ PD_REGISTER_KERNEL(sparse_weight_embedding_sparse_grad, phi::SparseWeightEmbeddingSparseGradKernel, float, double, - phi::dtype::float16) {} + phi::dtype::bfloat16) {} -- GitLab