diff --git a/paddle/fluid/operators/matmul_v2_op_npu.cc b/paddle/fluid/operators/matmul_v2_op_npu.cc index 353eab5bc5264146399d6d6549a32cd085902688..d3022056a47ded99e63aa05c1aca8e9b31ccc3fe 100644 --- a/paddle/fluid/operators/matmul_v2_op_npu.cc +++ b/paddle/fluid/operators/matmul_v2_op_npu.cc @@ -135,21 +135,8 @@ class MatMulV2GradNPUKernel : public framework::OpKernel { } if (dy) { dy->mutable_data(ctx.GetPlace()); - framework::Tensor dout_; - TensorCopySync(*dout, ctx.GetPlace(), &dout_); - std::vector vec_dim = framework::vectorize(dout_.dims()); - std::vector vec_dim_v{vec_dim[0] * vec_dim[1], vec_dim[2]}; - dout_.Resize(framework::make_ddim(vec_dim_v)); - - framework::Tensor x_; - TensorCopySync(*x, ctx.GetPlace(), &x_); - std::vector vec_dim_x = framework::vectorize(x_.dims()); - std::vector vec_dim_x_v{vec_dim_x[0] * vec_dim_x[1], - vec_dim_x[2]}; - x_.Resize(framework::make_ddim(vec_dim_x_v)); - auto runner_dy = - NpuOpRunner("MatMul", {x_, dout_}, {*dy}, - {{"transpose_x1", true}, {"transpose_x2", false}}); + auto runner_dy = NpuOpRunner("BatchMatMul", {*x, *dout}, {*dy}, + {{"adj_x1", true}, {"adj_x2", false}}); runner_dy.Run(stream); } }