未验证 提交 de8962bd 编写于 作者: L Li Min 提交者: GitHub

add data_type support for phi embedding op. (#40964)

上级 1c6dcfd9
......@@ -130,9 +130,11 @@ void EmbeddingGradKernel(const Context& ctx,
functor.template apply<int>();
} else if (input.dtype() == phi::DataType::INT64) {
functor.template apply<int64_t>();
} else if (input.dtype() == phi::DataType::INT16) {
functor.template apply<int16_t>();
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"emebdding input only support int32 and int64"));
"emebdding input only support int16, int32 and int64"));
}
}
......@@ -233,9 +235,10 @@ void EmbeddingSparseGradKernel(const Context& ctx,
functor.template apply<int>();
} else if (input.dtype() == phi::DataType::INT64) {
functor.template apply<int64_t>();
} else {
} else if (input.dtype() == phi::DataType::INT16) {
functor.template apply<int16_t>();
PADDLE_THROW(phi::errors::Unimplemented(
"emebdding input only support int32 and int64"));
"emebdding input only support int16, int32 and int64"));
}
}
......
......@@ -109,9 +109,11 @@ void EmbeddingKernel(const Context &ctx,
functor.template apply<int32_t>();
} else if (input.dtype() == phi::DataType::INT64) {
functor.template apply<int64_t>();
} else if (input.dtype() == phi::DataType::INT16) {
functor.template apply<int16_t>();
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"emebdding input only support int32 and int64"));
"emebdding input only support int16, int32 and int64"));
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册