未验证 提交 1a8cc15e 编写于 作者: S shentanyue 提交者: GitHub

[XPU] EmbeddingWithEltwiseAddXpuKernel support FP16 (#51426)

上级 cb7fd370
...@@ -81,7 +81,7 @@ XPUOpMap& get_kl1_ops() { ...@@ -81,7 +81,7 @@ XPUOpMap& get_kl1_ops() {
{"elementwise_sub_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"elementwise_sub_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"elementwise_sub", XPUKernelSet({phi::DataType::FLOAT32})}, {"elementwise_sub", XPUKernelSet({phi::DataType::FLOAT32})},
{"embedding_with_eltwise_add_xpu", {"embedding_with_eltwise_add_xpu",
XPUKernelSet({phi::DataType::FLOAT32})}, XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"equal", XPUKernelSet({phi::DataType::INT64})}, {"equal", XPUKernelSet({phi::DataType::INT64})},
{"expand_as_v2", {"expand_as_v2",
XPUKernelSet({phi::DataType::INT32, XPUKernelSet({phi::DataType::INT32,
......
...@@ -225,7 +225,7 @@ XPUOpMap& get_kl2_ops() { ...@@ -225,7 +225,7 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT64, phi::DataType::INT64,
phi::DataType::INT32})}, phi::DataType::INT32})},
{"embedding_with_eltwise_add_xpu", {"embedding_with_eltwise_add_xpu",
XPUKernelSet({phi::DataType::FLOAT32})}, XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"empty", {"empty",
XPUKernelSet({phi::DataType::INT64, XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32, phi::DataType::INT32,
......
...@@ -25,6 +25,7 @@ void EmbeddingWithEltwiseAddXpuKernel( ...@@ -25,6 +25,7 @@ void EmbeddingWithEltwiseAddXpuKernel(
const std::vector<const DenseTensor*>& tables, const std::vector<const DenseTensor*>& tables,
int64_t padding_idx, int64_t padding_idx,
DenseTensor* out) { DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
auto& id_dims = ids[0]->dims(); auto& id_dims = ids[0]->dims();
int idx_len = id_dims[0] * id_dims[1]; int idx_len = id_dims[0] * id_dims[1];
int emb_layer_num = ids.size(); int emb_layer_num = ids.size();
...@@ -47,7 +48,20 @@ void EmbeddingWithEltwiseAddXpuKernel( ...@@ -47,7 +48,20 @@ void EmbeddingWithEltwiseAddXpuKernel(
table_dims[1], table_dims[1],
embed_dim)); embed_dim));
table_lens_cpu.push_back(table_dims[0]); table_lens_cpu.push_back(table_dims[0]);
arg_tables.push_back(table->data<float>()); if (std::is_same<T, phi::dtype::float16>::value) {
DenseTensor table_data_fp32_t;
ctx.template Alloc<float>(&table_data_fp32_t,
table->numel() * sizeof(float));
int r = xpu::cast<XPUType, float>(
ctx.x_context(),
reinterpret_cast<const XPUType*>(table->data<T>()),
table_data_fp32_t.data<float>(),
table->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
arg_tables.push_back(table_data_fp32_t.data<float>());
} else {
arg_tables.push_back(table->data<float>());
}
} }
std::vector<std::vector<int>> int_idx(emb_layer_num, std::vector<std::vector<int>> int_idx(emb_layer_num,
std::vector<int>(idx_len, 0)); std::vector<int>(idx_len, 0));
...@@ -70,17 +84,39 @@ void EmbeddingWithEltwiseAddXpuKernel( ...@@ -70,17 +84,39 @@ void EmbeddingWithEltwiseAddXpuKernel(
arg_ids.push_back( arg_ids.push_back(
xpu::VectorParam<int>{int_idx[i].data(), idx_len, nullptr}); xpu::VectorParam<int>{int_idx[i].data(), idx_len, nullptr});
} }
ctx.template Alloc<T>(out); ctx.template Alloc<T>(out);
int r = xpu::multi_embedding_fusion<float, float, int>( if (std::is_same<T, phi::dtype::float16>::value) {
ctx.x_context(), DenseTensor out_fp32_t;
arg_tables, /* tables */ ctx.template Alloc<float>(&out_fp32_t, out->numel() * sizeof(float));
out->data<T>(), int r = xpu::multi_embedding_fusion<float, float, int>(
arg_ids, ctx.x_context(),
table_lens_cpu, arg_tables, /* tables */
embed_dim, out_fp32_t.data<float>(),
std::vector<float>(table_lens_cpu.size(), 1.0f), arg_ids,
std::vector<int>(table_lens_cpu.size(), padding_idx)); table_lens_cpu,
PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding_with_eltwise_add_xpu"); embed_dim,
std::vector<float>(table_lens_cpu.size(), 1.0f),
std::vector<int>(table_lens_cpu.size(), padding_idx));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding_with_eltwise_add_xpu");
r = xpu::cast(ctx.x_context(),
out_fp32_t.data<float>(),
reinterpret_cast<float16*>(out->data<T>()),
out->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
} else {
int r = xpu::multi_embedding_fusion<float, float, int>(
ctx.x_context(),
arg_tables, /* tables */
out->data<float>(),
arg_ids,
table_lens_cpu,
embed_dim,
std::vector<float>(table_lens_cpu.size(), 1.0f),
std::vector<int>(table_lens_cpu.size(), padding_idx));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding_with_eltwise_add_xpu");
}
} }
} // namespace fusion } // namespace fusion
...@@ -90,6 +126,7 @@ PD_REGISTER_KERNEL(embedding_with_eltwise_add_xpu, ...@@ -90,6 +126,7 @@ PD_REGISTER_KERNEL(embedding_with_eltwise_add_xpu,
XPU, XPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::fusion::EmbeddingWithEltwiseAddXpuKernel, phi::fusion::EmbeddingWithEltwiseAddXpuKernel,
float) { float,
phi::dtype::float16) {
kernel->InputAt(0).SetBackend(phi::Backend::CPU); kernel->InputAt(0).SetBackend(phi::Backend::CPU);
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册