From 190cf44f6762b33cf5b24d833bc2d24989fc433b Mon Sep 17 00:00:00 2001 From: fwenguang <95677191+fwenguang@users.noreply.github.com> Date: Thu, 12 May 2022 17:11:45 +0800 Subject: [PATCH] [MLU] fix cnnl error when index is 2D (#42669) --- paddle/fluid/operators/gather_op_mlu.cc | 39 +++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/gather_op_mlu.cc b/paddle/fluid/operators/gather_op_mlu.cc index 220d045952..cf35e051ed 100644 --- a/paddle/fluid/operators/gather_op_mlu.cc +++ b/paddle/fluid/operators/gather_op_mlu.cc @@ -27,11 +27,28 @@ class GatherOpMLUKernel : public framework::OpKernel { auto *index = ctx.Input("Index"); auto axis = ctx.Attr("axis"); + const auto index_dims = index->dims(); + if (index_dims.size() == 2) { + PADDLE_ENFORCE_EQ( + index_dims[1], 1, + platform::errors::InvalidArgument( + "The last dim of index should be 1 when it is 2D, but we get %d", + index_dims[1])); + } else { + PADDLE_ENFORCE_EQ( + index_dims.size(), 1, + platform::errors::InvalidArgument( + "The index should be 1D, when it is not 2D, but we get %d", + index_dims.size())); + } + auto *out = ctx.Output("Out"); out->mutable_data(ctx.GetPlace()); MLUCnnlTensorDesc x_desc(*x); - MLUCnnlTensorDesc index_desc(*index); + int index_shape_1d[1] = {static_cast(index_dims[0])}; + MLUCnnlTensorDesc index_desc(1, index_shape_1d, + ToCnnlDataType(index->dtype())); MLUCnnlTensorDesc out_desc(*out); MLUCnnl::GatherFunctor(ctx, axis, 0 /*batch_dims*/, x_desc.get(), GetBasePtr(x), index_desc.get(), GetBasePtr(index), @@ -46,6 +63,22 @@ class GatherGradOpMLUKernel : public framework::OpKernel { auto *index = ctx.Input("Index"); auto *dout = ctx.Input(framework::GradVarName("Out")); auto *dx = ctx.Output(framework::GradVarName("X")); + + const auto index_dims = index->dims(); + if (index_dims.size() == 2) { + PADDLE_ENFORCE_EQ( + index_dims[1], 1, + platform::errors::InvalidArgument( + "The last dim of index should be 1 when it is 2D, but we get %d", + index_dims[1])); + } else { + PADDLE_ENFORCE_EQ( + index_dims.size(), 1, + platform::errors::InvalidArgument( + "The index should be 1D, when it is not 2D, but we get %d", + index_dims.size())); + } + dx->mutable_data(ctx.GetPlace()); MLUCnnlTensorDesc dx_desc(*dx); @@ -53,7 +86,9 @@ class GatherGradOpMLUKernel : public framework::OpKernel { MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &value, dx_desc.get(), GetBasePtr(dx)); - MLUCnnlTensorDesc index_desc(*index); + int index_shape_1d[1] = {static_cast(index_dims[0])}; + MLUCnnlTensorDesc index_desc(1, index_shape_1d, + ToCnnlDataType(index->dtype())); MLUCnnlTensorDesc dout_desc(*dout); const cnnlScatterRefMode_t mode = CNNL_SCATTERREF_UPDATE; MLUCnnl::ScatterFunctor(ctx, dx_desc.get(), GetBasePtr(dx), dout_desc.get(), -- GitLab