未验证 提交 de8962bd 编写于 作者: L Li Min 提交者: GitHub

add data_type support for phi embedding op. (#40964)

上级 1c6dcfd9
...@@ -130,9 +130,11 @@ void EmbeddingGradKernel(const Context& ctx, ...@@ -130,9 +130,11 @@ void EmbeddingGradKernel(const Context& ctx,
functor.template apply<int>(); functor.template apply<int>();
} else if (input.dtype() == phi::DataType::INT64) { } else if (input.dtype() == phi::DataType::INT64) {
functor.template apply<int64_t>(); functor.template apply<int64_t>();
} else if (input.dtype() == phi::DataType::INT16) {
functor.template apply<int16_t>();
} else { } else {
PADDLE_THROW(phi::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"emebdding input only support int32 and int64")); "emebdding input only support int16, int32 and int64"));
} }
} }
...@@ -233,9 +235,10 @@ void EmbeddingSparseGradKernel(const Context& ctx, ...@@ -233,9 +235,10 @@ void EmbeddingSparseGradKernel(const Context& ctx,
functor.template apply<int>(); functor.template apply<int>();
} else if (input.dtype() == phi::DataType::INT64) { } else if (input.dtype() == phi::DataType::INT64) {
functor.template apply<int64_t>(); functor.template apply<int64_t>();
} else { } else if (input.dtype() == phi::DataType::INT16) {
functor.template apply<int16_t>();
PADDLE_THROW(phi::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"emebdding input only support int32 and int64")); "emebdding input only support int16, int32 and int64"));
} }
} }
......
...@@ -109,9 +109,11 @@ void EmbeddingKernel(const Context &ctx, ...@@ -109,9 +109,11 @@ void EmbeddingKernel(const Context &ctx,
functor.template apply<int32_t>(); functor.template apply<int32_t>();
} else if (input.dtype() == phi::DataType::INT64) { } else if (input.dtype() == phi::DataType::INT64) {
functor.template apply<int64_t>(); functor.template apply<int64_t>();
} else if (input.dtype() == phi::DataType::INT16) {
functor.template apply<int16_t>();
} else { } else {
PADDLE_THROW(phi::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"emebdding input only support int32 and int64")); "emebdding input only support int16, int32 and int64"));
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册