diff --git a/paddle/fluid/operators/gather_op_npu.cc b/paddle/fluid/operators/gather_op_npu.cc index 2d7b5b93ad65163e99c4f235e03503364b4b764b..cf0d9cda34231ce1a439aeea82df74c0a54323a5 100644 --- a/paddle/fluid/operators/gather_op_npu.cc +++ b/paddle/fluid/operators/gather_op_npu.cc @@ -51,10 +51,13 @@ class GatherGradOpNPUKernel : public framework::OpKernel { auto *dx = ctx.Output(framework::GradVarName("X")); // step1: Unsqueeze index + framework::Tensor tmp_tensor(index->type()); const auto index_dims = index->dims(); if (index_dims.size() == 1) { - framework::Tensor tmp_index = UnsqueezeTo(*index, 2); - index = &tmp_index; + tmp_tensor.ShareDataWith(*index); + std::vector new_dim = {index_dims[0], 1}; + tmp_tensor.Resize(framework::make_ddim(new_dim)); + index = &tmp_tensor; } auto stream = diff --git a/paddle/fluid/operators/gather_op_npu_test.cc b/paddle/fluid/operators/gather_op_npu_test.cc index 4cd46da6f26f8d922a8e0b806ed622f36204f52a..de067e45585d91ce0efa2269909f9a1052a895ac 100644 --- a/paddle/fluid/operators/gather_op_npu_test.cc +++ b/paddle/fluid/operators/gather_op_npu_test.cc @@ -109,17 +109,17 @@ void CompareGrad(f::Scope* scope, const p::DeviceContext& ctx, auto dout = scope->Var("DOut"); auto tensor_dout = dout->GetMutable(); - std::vector init_index = {0, 1, 2, 0}; + std::vector init_index = {0, 1}; paddle::framework::TensorFromVector(init_index, ctx, tensor_index); - tensor_index->Resize(paddle::framework::make_ddim({2, 2})); + tensor_index->Resize(paddle::framework::make_ddim({2})); std::vector init_x = {1.0, 1.0, 1.0, 1.0, 1.0, 1.0}; TensorFromVector(init_x, ctx, tensor_x); tensor_x->Resize(paddle::framework::make_ddim({3, 2})); - std::vector init_dout = {5.0, 10.0}; + std::vector init_dout = {5.0, 10.0, 2.0, 3.0}; TensorFromVector(init_dout, ctx, tensor_dout); - tensor_dout->Resize(paddle::framework::make_ddim({2})); + tensor_dout->Resize(paddle::framework::make_ddim({2, 2})); ctx.Wait(); @@ -143,7 +143,7 @@ void CompareGrad(f::Scope* scope, const p::DeviceContext& ctx, uint32_t expected_size = 3 * 2; EXPECT_EQ((uint32_t)dx_vec.size(), expected_size); - std::vector expected_dx_vec = {0.0, 5.0, 0.0, 0.0, 10.0, 0.0}; + std::vector expected_dx_vec = {5.0, 10.0, 2.0, 3.0, 0.0, 0.0}; for (uint32_t i = 0; i < dx_vec.size(); i++) { VLOG(3) << "dx_vec[i]=" << dx_vec[i]; EXPECT_EQ(dx_vec[i], expected_dx_vec[i]);