未验证 提交 190cf44f 编写于 作者: F fwenguang 提交者: GitHub

[MLU] fix cnnl error when index is 2D (#42669)

上级 9ac736c2
......@@ -27,11 +27,28 @@ class GatherOpMLUKernel : public framework::OpKernel<T> {
auto *index = ctx.Input<Tensor>("Index");
auto axis = ctx.Attr<int>("axis");
const auto index_dims = index->dims();
if (index_dims.size() == 2) {
PADDLE_ENFORCE_EQ(
index_dims[1], 1,
platform::errors::InvalidArgument(
"The last dim of index should be 1 when it is 2D, but we get %d",
index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
index_dims.size(), 1,
platform::errors::InvalidArgument(
"The index should be 1D, when it is not 2D, but we get %d",
index_dims.size()));
}
auto *out = ctx.Output<Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
MLUCnnlTensorDesc x_desc(*x);
MLUCnnlTensorDesc index_desc(*index);
int index_shape_1d[1] = {static_cast<int>(index_dims[0])};
MLUCnnlTensorDesc index_desc(1, index_shape_1d,
ToCnnlDataType(index->dtype()));
MLUCnnlTensorDesc out_desc(*out);
MLUCnnl::GatherFunctor(ctx, axis, 0 /*batch_dims*/, x_desc.get(),
GetBasePtr(x), index_desc.get(), GetBasePtr(index),
......@@ -46,6 +63,22 @@ class GatherGradOpMLUKernel : public framework::OpKernel<T> {
auto *index = ctx.Input<Tensor>("Index");
auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
const auto index_dims = index->dims();
if (index_dims.size() == 2) {
PADDLE_ENFORCE_EQ(
index_dims[1], 1,
platform::errors::InvalidArgument(
"The last dim of index should be 1 when it is 2D, but we get %d",
index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
index_dims.size(), 1,
platform::errors::InvalidArgument(
"The index should be 1D, when it is not 2D, but we get %d",
index_dims.size()));
}
dx->mutable_data<T>(ctx.GetPlace());
MLUCnnlTensorDesc dx_desc(*dx);
......@@ -53,7 +86,9 @@ class GatherGradOpMLUKernel : public framework::OpKernel<T> {
MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &value, dx_desc.get(),
GetBasePtr(dx));
MLUCnnlTensorDesc index_desc(*index);
int index_shape_1d[1] = {static_cast<int>(index_dims[0])};
MLUCnnlTensorDesc index_desc(1, index_shape_1d,
ToCnnlDataType(index->dtype()));
MLUCnnlTensorDesc dout_desc(*dout);
const cnnlScatterRefMode_t mode = CNNL_SCATTERREF_UPDATE;
MLUCnnl::ScatterFunctor(ctx, dx_desc.get(), GetBasePtr(dx), dout_desc.get(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册