From d73eb38c3e518f06bfac8385bc17474d1b3d88d2 Mon Sep 17 00:00:00 2001 From: Aganlengzi Date: Fri, 6 May 2022 20:04:58 +0800 Subject: [PATCH] [NPU] support model PPO (#42484) --- .../elementwise/elementwise_div_op_npu.cc | 47 ++++++++++++++----- 1 file changed, 36 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op_npu.cc b/paddle/fluid/operators/elementwise/elementwise_div_op_npu.cc index fb9dbc7fc8..ea0d13f406 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op_npu.cc +++ b/paddle/fluid/operators/elementwise/elementwise_div_op_npu.cc @@ -118,20 +118,45 @@ class ElementwiseDivGradNPUKernel : public framework::OpKernel { if (dy) { dy->mutable_data(place); - Tensor neg_out(y->type()); - neg_out.mutable_data(y->dims(), place); + Tensor neg_out(out->type()); + neg_out.mutable_data(out->dims(), place); const auto& runner_neg_out = NpuOpRunner("Neg", {*out}, {neg_out}, {}); runner_neg_out.Run(stream); - Tensor y_grad_w(y->type()); - y_grad_w.mutable_data(y->dims(), place); - const auto& runner_y_grad_w = - NpuOpRunner("Div", {neg_out, *y}, {y_grad_w}, {}); - runner_y_grad_w.Run(stream); - - const auto& runner_y_grad = - NpuOpRunner("Mul", {y_grad_w, *dout}, {*dy}, {}); - runner_y_grad.Run(stream); + Tensor tmp_mul(out->type()); + tmp_mul.mutable_data(out->dims(), place); + const auto& runner_mul = + NpuOpRunner("Mul", {neg_out, *dout}, {tmp_mul}, {}); + runner_mul.Run(stream); + + if (dy->dims() != dout->dims()) { + Tensor reduced_tmp_mul(y->type()); + reduced_tmp_mul.mutable_data(y->dims(), place); + + std::vector axes; + int64_t diff = dout->dims().size() - dy->dims().size(); + for (int64_t i = 0; i < dout->dims().size(); ++i) { + if (i < diff) { + axes.push_back(i); + continue; + } + if (dout->dims()[i] > dy->dims()[i - diff]) { + axes.push_back(i); + } + } + const auto& runner_reduce = + NpuOpRunner("ReduceSumD", {tmp_mul}, {reduced_tmp_mul}, + {{"axes", axes}, {"keep_dims", false}}); + runner_reduce.Run(stream); + + const auto& runner_y_grad = + NpuOpRunner("Div", {reduced_tmp_mul, *y}, {*dy}, {}); + runner_y_grad.Run(stream); + } else { + const auto& runner_y_grad = + NpuOpRunner("Div", {tmp_mul, *y}, {*dy}, {}); + runner_y_grad.Run(stream); + } } } }; -- GitLab