未验证 提交 753fa844 编写于 作者: M mayang002 提交者: GitHub

[xpu] fix bugs of split/embedding_with_wltwise_add/beam_search_decode kernel (#51052)

上级 a348a423
......@@ -101,6 +101,8 @@ class BeamSearchDecodeXPUKernel : public framework::OpKernel<T> {
xpu::Error_t::SUCCESS,
platform::errors::External(
"Execute function CopyTensorByXPU failed by [%d]", r));
sentenceIds_temp->set_lod(sentenceIds->lod());
sentenceScores_temp->set_lod(sentenceScores->lod());
}
}
};
......
......@@ -53,8 +53,19 @@ void EmbeddingWithEltwiseAddXpuKernel(
std::vector<int>(idx_len, 0));
std::vector<xpu::VectorParam<int>> arg_ids;
for (int i = 0; i < emb_layer_num; i++) {
PADDLE_ENFORCE_EQ(
ids[i]->dtype() == phi::DataType::INT64 ||
ids[i]->dtype() == phi::DataType::INT32,
true,
errors::InvalidArgument(
"The data type of ids should be int64 or int32, but got %s.",
ids[i]->dtype()));
for (int j = 0; j < idx_len; j++) {
int_idx[i][j] = static_cast<int>(ids[i]->data<int64_t>()[j]);
if (ids[i]->dtype() == phi::DataType::INT64) {
int_idx[i][j] = static_cast<int>(ids[i]->data<int64_t>()[j]);
} else if (ids[i]->dtype() == phi::DataType::INT32) {
int_idx[i][j] = ids[i]->data<int>()[j];
}
}
arg_ids.push_back(
xpu::VectorParam<int>{int_idx[i].data(), idx_len, nullptr});
......
......@@ -36,6 +36,9 @@ void SplitKernel(const Context& dev_ctx,
out_ptrs.push_back(reinterpret_cast<XPUType*>(outs[j]->data<T>()));
split_lists.push_back(outs[j]->dims()[axis]);
}
if (x.numel() == 0) {
return;
}
int r = xpu::split<XPUType>(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
out_ptrs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册