From de8962bd90f3437bcbb59378e19ce4dd99bef0e6 Mon Sep 17 00:00:00 2001 From: Li Min <11663212+limin2021@users.noreply.github.com> Date: Sun, 27 Mar 2022 15:05:38 +0800 Subject: [PATCH] add data_type support for phi embedding op. (#40964) --- paddle/phi/kernels/gpu/embedding_grad_kernel.cu | 9 ++++++--- paddle/phi/kernels/gpu/embedding_kernel.cu | 4 +++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/paddle/phi/kernels/gpu/embedding_grad_kernel.cu b/paddle/phi/kernels/gpu/embedding_grad_kernel.cu index a970348760c..47b1b304f5e 100644 --- a/paddle/phi/kernels/gpu/embedding_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/embedding_grad_kernel.cu @@ -130,9 +130,11 @@ void EmbeddingGradKernel(const Context& ctx, functor.template apply(); } else if (input.dtype() == phi::DataType::INT64) { functor.template apply(); + } else if (input.dtype() == phi::DataType::INT16) { + functor.template apply(); } else { PADDLE_THROW(phi::errors::Unimplemented( - "emebdding input only support int32 and int64")); + "emebdding input only support int16, int32 and int64")); } } @@ -233,9 +235,10 @@ void EmbeddingSparseGradKernel(const Context& ctx, functor.template apply(); } else if (input.dtype() == phi::DataType::INT64) { functor.template apply(); - } else { + } else if (input.dtype() == phi::DataType::INT16) { + functor.template apply(); PADDLE_THROW(phi::errors::Unimplemented( - "emebdding input only support int32 and int64")); + "emebdding input only support int16, int32 and int64")); } } diff --git a/paddle/phi/kernels/gpu/embedding_kernel.cu b/paddle/phi/kernels/gpu/embedding_kernel.cu index 7f3a31ba544..14a40abefff 100644 --- a/paddle/phi/kernels/gpu/embedding_kernel.cu +++ b/paddle/phi/kernels/gpu/embedding_kernel.cu @@ -109,9 +109,11 @@ void EmbeddingKernel(const Context &ctx, functor.template apply(); } else if (input.dtype() == phi::DataType::INT64) { functor.template apply(); + } else if (input.dtype() == phi::DataType::INT16) { + functor.template apply(); } else { PADDLE_THROW(phi::errors::Unimplemented( - "emebdding input only support int32 and int64")); + "emebdding input only support int16, int32 and int64")); } } -- GitLab