未验证 提交 0d17c047 编写于 作者: Z zhaoying9105 提交者: GitHub

[MLU](bugfix): fix MLUCnnl::ScatterFunctor function declare bug (#43778)

上级 03972d5a
......@@ -33,26 +33,43 @@ class ScatterMLUKernel : public framework::OpKernel<T> {
cnnlScatterRefMode_t mode;
if (overwrite) {
mode = CNNL_SCATTERREF_UPDATE;
MLUCnnl::ScatterFunctor(ctx, x_desc.get(), GetBasePtr(x),
updates_desc.get(), GetBasePtr(updates),
indices_desc.get(), GetBasePtr(indices), mode);
MLUCnnl::ScatterRefFunctor(ctx,
x_desc.get(),
GetBasePtr(x),
updates_desc.get(),
GetBasePtr(updates),
indices_desc.get(),
GetBasePtr(indices),
mode);
} else {
Tensor tensor_zeros(updates->type());
tensor_zeros.mutable_data<T>(updates->dims(), ctx.GetPlace());
MLUCnnlTensorDesc tensor_zeros_desc(tensor_zeros);
float value = 0.0;
auto value_t = static_cast<T>(value);
MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &value_t,
tensor_zeros_desc.get(), GetBasePtr(&tensor_zeros));
MLUCnnl::Fill(ctx,
CNNL_POINTER_MODE_HOST,
&value_t,
tensor_zeros_desc.get(),
GetBasePtr(&tensor_zeros));
mode = CNNL_SCATTERREF_UPDATE;
MLUCnnl::ScatterFunctor(ctx, x_desc.get(), GetBasePtr(x),
tensor_zeros_desc.get(),
GetBasePtr(&tensor_zeros), indices_desc.get(),
GetBasePtr(indices), mode);
MLUCnnl::ScatterRefFunctor(ctx,
x_desc.get(),
GetBasePtr(x),
tensor_zeros_desc.get(),
GetBasePtr(&tensor_zeros),
indices_desc.get(),
GetBasePtr(indices),
mode);
mode = CNNL_SCATTERREF_ADD;
MLUCnnl::ScatterFunctor(ctx, x_desc.get(), GetBasePtr(x),
updates_desc.get(), GetBasePtr(updates),
indices_desc.get(), GetBasePtr(indices), mode);
MLUCnnl::ScatterRefFunctor(ctx,
x_desc.get(),
GetBasePtr(x),
updates_desc.get(),
GetBasePtr(updates),
indices_desc.get(),
GetBasePtr(indices),
mode);
}
paddle::framework::TensorCopy(*x, place, out);
}
......@@ -62,5 +79,6 @@ class ScatterMLUKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_MLU_KERNEL(scatter, ops::ScatterMLUKernel<float>,
REGISTER_OP_MLU_KERNEL(scatter,
ops::ScatterMLUKernel<float>,
ops::ScatterMLUKernel<paddle::platform::float16>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册