diff --git a/paddle/phi/kernels/gpu/embedding_grad_kernel.cu b/paddle/phi/kernels/gpu/embedding_grad_kernel.cu index a970348760c18e2c67e9c7b366cdc2f5e18e3abd..47b1b304f5ec99206253b8bc5f3ab6c99b9c429f 100644 --- a/paddle/phi/kernels/gpu/embedding_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/embedding_grad_kernel.cu @@ -130,9 +130,11 @@ void EmbeddingGradKernel(const Context& ctx, functor.template apply(); } else if (input.dtype() == phi::DataType::INT64) { functor.template apply(); + } else if (input.dtype() == phi::DataType::INT16) { + functor.template apply(); } else { PADDLE_THROW(phi::errors::Unimplemented( - "emebdding input only support int32 and int64")); + "emebdding input only support int16, int32 and int64")); } } @@ -233,9 +235,10 @@ void EmbeddingSparseGradKernel(const Context& ctx, functor.template apply(); } else if (input.dtype() == phi::DataType::INT64) { functor.template apply(); - } else { + } else if (input.dtype() == phi::DataType::INT16) { + functor.template apply(); PADDLE_THROW(phi::errors::Unimplemented( - "emebdding input only support int32 and int64")); + "emebdding input only support int16, int32 and int64")); } } diff --git a/paddle/phi/kernels/gpu/embedding_kernel.cu b/paddle/phi/kernels/gpu/embedding_kernel.cu index 7f3a31ba544d88534d8a606fba53e017a155023c..14a40abefffd23db5ef8dc38ded321757768b063 100644 --- a/paddle/phi/kernels/gpu/embedding_kernel.cu +++ b/paddle/phi/kernels/gpu/embedding_kernel.cu @@ -109,9 +109,11 @@ void EmbeddingKernel(const Context &ctx, functor.template apply(); } else if (input.dtype() == phi::DataType::INT64) { functor.template apply(); + } else if (input.dtype() == phi::DataType::INT16) { + functor.template apply(); } else { PADDLE_THROW(phi::errors::Unimplemented( - "emebdding input only support int32 and int64")); + "emebdding input only support int16, int32 and int64")); } }