// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "paddle/fluid/prim/api/all.h" #include "paddle/fluid/prim/api/generated_prim/prim_generated_api.h" #include "paddle/phi/common/int_array.h" #include "paddle/phi/core/ddim.h" namespace paddle { namespace prim { using Tensor = paddle::Tensor; using IntArray = paddle::experimental::IntArrayBase; // This function should have as same signature as phi, which defined in // paddle/phi/api/backward/backward_api.h template void cast_grad(const Tensor& out_grad, DataType dtype, Tensor* x_grad) { if (x_grad) { auto res = cast(out_grad, dtype); set_output(res, x_grad); } } template void gather_grad(const Tensor& x, const Tensor& index, const Tensor& out_grad, const Scalar& axis, bool overwrite, Tensor* grad_x) { auto zero_tensor = full(phi::vectorize(x.dims()), 0.0, x.dtype()); std::vector tmp_perm; // change axis to rank 0 int axis_value = axis.to(); tmp_perm.push_back(axis_value); // make other ranks for (int i = 0; i < x.dims().size(); ++i) { if (i != axis_value) { tmp_perm.push_back(i); } } std::vector reverse_perm(tmp_perm); // make origin ranks for (int i = 0; i < static_cast(tmp_perm.size()); ++i) { reverse_perm[tmp_perm[i]] = i; } // transpose out_grad and zero grad to target rank. auto tmp_zero_x_grad = transpose(zero_tensor, tmp_perm); auto tmp_out_grad = transpose(out_grad, tmp_perm); // scatter grad to grad_x auto tmp_grad_x = scatter(tmp_zero_x_grad, index, tmp_out_grad, false); auto tmp_grad_x_tranposed = transpose(tmp_grad_x, reverse_perm); set_output(tmp_grad_x_tranposed, grad_x); } template void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) { if (!grad_x) return; auto grad_x_tmp = grad_out * (1 - out * out); set_output(grad_x_tmp, grad_x); } template void reshape_grad(const Tensor& x, const Tensor& grad_out, Tensor* grad_x) { if (grad_x) { auto grad_x_tmp = reshape(grad_out, phi::vectorize(x.dims())); set_output(grad_x_tmp, grad_x); } } template void transpose_grad(const Tensor& grad_out, const std::vector& perm, Tensor* grad_x) { if (grad_x) { std::vector reverse_perm(perm); // make origin ranks for (int i = 0; i < static_cast(perm.size()); ++i) { if (perm[i] >= 0) { reverse_perm[perm[i]] = i; } else { reverse_perm[perm[i] + perm.size()] = i; } } auto grad_x_tmp = transpose(grad_out, reverse_perm); set_output(grad_x_tmp, grad_x); } } template void subtract_grad(const Tensor& x, const Tensor& y, const Tensor& out_grad, int axis, Tensor* dx, Tensor* dy) { if (dy) { auto scale_out_grad = scale(out_grad, -1.0, 0.0, true); if (x.dims() != y.dims()) { // Maybe need reduce here phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims()); if (!reduce_dim.size()) { by_pass(scale_out_grad, dy); } else { auto dy_reduce_res = scale_out_grad.sum(phi::vectorize(reduce_dim), y.dtype(), false); auto dy_tmp = reshape(dy_reduce_res, phi::vectorize(y.dims())); set_output(dy_tmp, dy); } } else { by_pass(scale_out_grad, dy); } } if (dx) { if (y.dims() != x.dims()) { // Maybe need reduce here auto reduce_dim = get_reduce_dims(x.dims(), y.dims()); if (!reduce_dim.size()) { by_pass(out_grad, dx); } else { auto dx_reduce_res = out_grad.sum(phi::vectorize(reduce_dim), x.dtype(), false); auto dx_tmp = reshape(dx_reduce_res, phi::vectorize(x.dims())); set_output(dx_tmp, dx); } } else { by_pass(out_grad, dx); } } } template void add_grad(const Tensor& x, const Tensor& y, const Tensor& out_grad, int axis, Tensor* dx, Tensor* dy) { if (dy) { if (x.dims() != y.dims()) { // Maybe need reduce here phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims()); if (!reduce_dim.size()) { by_pass(out_grad, dy); } else { auto dy_reduce_res = out_grad.sum(phi::vectorize(reduce_dim), y.dtype(), false); auto dy_tmp = reshape(dy_reduce_res, phi::vectorize(y.dims())); set_output(dy_tmp, dy); } } else { by_pass(out_grad, dy); } } if (dx) { if (y.dims() != x.dims()) { // Maybe need reduce here auto reduce_dim = get_reduce_dims(x.dims(), y.dims()); if (!reduce_dim.size()) { by_pass(out_grad, dx); } else { auto dx_reduce_res = out_grad.sum(phi::vectorize(reduce_dim), x.dtype(), false); auto dx_tmp = reshape(dx_reduce_res, phi::vectorize(x.dims())); set_output(dx_tmp, dx); } } else { by_pass(out_grad, dx); } } } template void sum_grad(const Tensor& x, const Tensor& out_grad, const IntArray& axis, bool keepdim, bool reduce_all, Tensor* x_grad) { if (!x_grad) { return; } std::vector x_dim = phi::vectorize(x.dims()); int64_t axis_size = axis.size(); int64_t x_dim_size = x_dim.size(); reduce_all = false; if (reduce_all || axis_size == 0 || axis_size == x_dim_size) { reduce_all = true; } else { reduce_all = false; } auto x_grad_tmp = Tensor(); if (x_dim_size == 1) { x_grad_tmp = out_grad.expand(IntArray(x_dim)); } else { if (!keepdim) { auto axis_ = std::vector(); if (reduce_all) { for (int64_t i = 1; i < x_dim_size; i++) { axis_.push_back(i); } } else { axis_ = axis.GetData(); for (int64_t i = 0; i < axis_size; i++) { if (axis[i] < 0) { axis_[i] = axis[i] + x_dim_size; } } } auto out_grad_ = unsqueeze(out_grad, axis_); x_grad_tmp = out_grad_.expand(IntArray(x_dim)); } else { x_grad_tmp = out_grad.expand(IntArray(x_dim)); } } set_output(x_grad_tmp, x_grad); } template void divide_grad(const Tensor& x, const Tensor& y, const Tensor& out, const Tensor& out_grad, int axis, Tensor* dx, Tensor* dy) { if (dy) { // dy = -(x/y^2) * dout auto dy_res = -(x / y.pow(2.0)) * out_grad; if (x.dims() != y.dims()) { // Maybe need reduce here phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims()); if (!reduce_dim.size()) { set_output(dy_res, dy); } else { auto dy_reduce_res = dy_res.sum(phi::vectorize(reduce_dim), y.dtype(), false); auto dy_tmp = reshape(dy_reduce_res, phi::vectorize(y.dims())); set_output(dy_tmp, dy); } } else { set_output(dy_res, dy); } } // indicate we will compute dy if (dx) { // dx = (1/y) * dout auto one_tensor = full(phi::vectorize(y.dims()), 1.0, y.dtype()); auto dx_res = one_tensor / y * out_grad; if (y.dims() != x.dims()) { // Maybe need reduce here auto reduce_dim = get_reduce_dims(x.dims(), y.dims()); if (!reduce_dim.size()) { set_output(dx_res, dx); } else { auto dx_reduce_res = dx_res.sum(phi::vectorize(reduce_dim), x.dtype(), false); auto dx_tmp = reshape(dx_reduce_res, phi::vectorize(x.dims())); set_output(dx_tmp, dx); } } else { set_output(dx_res, dx); } } // indicate we will compute dx } template void elementwise_pow_grad(const Tensor& x, const Tensor& y, const Tensor& out_grad, int axis, Tensor* dx, Tensor* dy) { if (dy) { // dy = lnx * x^y auto lnx = log(x); auto x_pow_y = elementwise_pow(x, y); auto dy_res = lnx * x_pow_y; if (x.dims() != y.dims()) { // Maybe need reduce here phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims()); if (!reduce_dim.size()) { set_output(dy_res, dy); } else { auto dy_reduce_res = dy_res.sum(phi::vectorize(reduce_dim), y.dtype(), false); auto dy_tmp = reshape(dy_reduce_res, phi::vectorize(y.dims())); set_output(dy_tmp, dy); } } else { set_output(dy_res, dy); } } // indicate we will compute dy if (dx) { // dx = y * x^(y-1) auto tmp_z = y - 1.0; auto x_pow_z = elementwise_pow(x, tmp_z); auto dx_res = y * x_pow_z; if (y.dims() != x.dims()) { // Maybe need reduce here auto reduce_dim = get_reduce_dims(x.dims(), y.dims()); if (!reduce_dim.size()) { set_output(dx_res, dx); } else { auto dx_reduce_res = dx_res.sum(phi::vectorize(reduce_dim), x.dtype(), false); auto dx_tmp = reshape(dx_reduce_res, phi::vectorize(x.dims())); set_output(dx_tmp, dx); } } else { set_output(dx_res, dx); } } // indicate we will compute dx } template void sqrt_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { if (x_grad) { // This calculation is important for resnet. auto x_grad_tmp = (0.5 / out) * out_grad; set_output(x_grad_tmp, x_grad); } } template void floor_grad(const Tensor& out_grad, Tensor* x_grad) { if (x_grad) { auto zero_tensor = full(phi::vectorize(out_grad.dims()), 0.0, out_grad.dtype()); set_output(zero_tensor, x_grad); } } template void concat_grad(const std::vector& x, const Tensor& out_grad, const Scalar& axis, std::vector x_grad) { int axis_value = axis.to(); int rank = x[0].dims().size(); if (axis_value < 0) { axis_value = axis_value + rank; } axis_value = axis_value > 0 ? axis_value : 0; std::vector sections; int x_num = x.size(); for (int i = 0; i < x_num; ++i) { sections.push_back(x[i].dims()[axis_value]); } std::vector x_grad_tmp = split(out_grad, phi::IntArray(sections), axis); for (int i = 0; i < x_num; ++i) { set_output(x_grad_tmp.at(i), x_grad.at(i)); } } template void multiply_grad(const Tensor& x, const Tensor& y, const Tensor& out_grad, int axis, Tensor* x_grad, Tensor* y_grad) { if (x_grad) { auto x_grad_unreduce = out_grad * y; if (x_grad_unreduce.dims() != x.dims()) { auto axes = get_reduce_dims_from_out(x_grad_unreduce.dims(), x.dims()); if (!axes.size()) { set_output(x_grad_unreduce, x_grad); } else { auto x_grad_reduced = x_grad_unreduce.sum( phi::vectorize(axes), x_grad_unreduce.dtype(), false); if (x_grad_reduced.dims().size() != x.dims().size()) { x_grad_reduced = reshape(x_grad_reduced, x.shape()); } set_output(x_grad_reduced, x_grad); } } else { set_output(x_grad_unreduce, x_grad); } } if (y_grad) { auto y_grad_unreduce = out_grad * x; if (y_grad_unreduce.dims() != y.dims()) { auto axes = get_reduce_dims_from_out(y_grad_unreduce.dims(), y.dims()); if (!axes.size()) { set_output(y_grad_unreduce, y_grad); } else { auto y_grad_reduced = y_grad_unreduce.sum( phi::vectorize(axes), y_grad_unreduce.dtype(), false); if (y_grad_reduced.dims().size() != y.dims().size()) { y_grad_reduced = reshape(y_grad_reduced, y.shape()); } set_output(y_grad_reduced, y_grad); } } else { set_output(y_grad_unreduce, y_grad); } } } template void expand_grad(const Tensor& x, const Tensor& out_grad, const IntArray& shape, Tensor* x_grad) { if (x_grad) { auto out_dims = phi::make_ddim(shape.GetData()); if (out_dims != x.dims()) { auto axes = get_reduce_dims(x.dims(), out_dims); if (!axes.size()) { by_pass(out_grad, x_grad); } else { auto reduced = out_grad.sum(phi::vectorize(axes), x.dtype(), false); if (reduced.dims().size() != x.dims().size()) { reduced = reshape(reduced, x.shape()); } set_output(reduced, x_grad); } } else { by_pass(out_grad, x_grad); } } } template void log_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) { if (x_grad) { // dx = dout / x set_output(out_grad / x, x_grad); } } template void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { if (x_grad) { set_output(out_grad * out, x_grad); } } template void abs_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) { if (x_grad) { auto abs_tmp = abs(x); auto divide_tmp = divide(x, abs_tmp); set_output(out_grad * divide_tmp, x_grad); } } template void matmul_double_grad(const Tensor& x, const Tensor& y, const Tensor& grad_out, const paddle::optional& grad_x_grad, const paddle::optional& grad_y_grad, bool transpose_x, bool transpose_y, Tensor* x_grad, Tensor* y_grad, Tensor* grad_out_grad) { // Get dims from the input x, y, output_grad std::vector x_dims = vectorize(x.dims()); std::vector y_dims = vectorize(y.dims()); std::vector grad_out_dims = vectorize(grad_out.dims()); int x_ndim = x_dims.size(); int y_ndim = y_dims.size(); int dout_ndim = grad_out_dims.size(); // prepare dims for x_ndim <= 1 || y_ndim <= 1 Tensor x_help, y_help, xg_help, yg_help, out_help; if (x_ndim == 1 && y_ndim == 1) { transpose_x = false; transpose_y = false; x_help = reshape(x, IntArray(std::vector({1, x_dims[0]}))); y_help = reshape(y, IntArray(std::vector({y_dims[0], 1}))); if (grad_x_grad) { xg_help = reshape(grad_x_grad.get(), IntArray(std::vector({1, x_dims[0]}))); } if (grad_y_grad) { yg_help = reshape(grad_y_grad.get(), IntArray(std::vector({y_dims[0], 1}))); } out_help = reshape(grad_out, IntArray(std::vector({1, 1}))); } else if (x_ndim == 1) { transpose_x = false; x_help = reshape(x, IntArray(std::vector({1, x_dims[0]}))); y_help = y; if (grad_x_grad) { xg_help = reshape(grad_x_grad.get(), IntArray(std::vector({1, x_dims[0]}))); } if (grad_y_grad) { yg_help = grad_y_grad.get(); } auto tmp_grad_out_dims = grad_out_dims; tmp_grad_out_dims.insert(tmp_grad_out_dims.begin(), 1); out_help = reshape(grad_out, IntArray(tmp_grad_out_dims)); } else if (y_ndim == 1) { transpose_y = false; x_help = x; y_help = reshape(y, IntArray(std::vector({y_dims[0], 1}))); if (grad_x_grad) { xg_help = grad_x_grad.get(); } if (grad_y_grad) { yg_help = reshape(grad_y_grad.get(), IntArray(std::vector({y_dims[0], 1}))); } auto tmp_grad_out_dims = grad_out_dims; tmp_grad_out_dims.push_back(1); out_help = reshape(grad_out, IntArray(tmp_grad_out_dims)); } else { x_help = x; y_help = y; if (grad_x_grad) { xg_help = grad_x_grad.get(); } if (grad_y_grad) { yg_help = grad_y_grad.get(); } out_help = grad_out; } bool is_broadcast = true; if (x_ndim <= 2 && y_ndim <= 2) { is_broadcast = false; } else if (x_ndim != y_ndim) { is_broadcast = true; } else { is_broadcast = !std::equal( x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2, y_dims.cbegin()); } Tensor dx, dy, ddout_1, ddout_2, ddout; if (!grad_x_grad && !grad_y_grad) { x_grad = nullptr; y_grad = nullptr; grad_out_grad = nullptr; return; } else if (!grad_x_grad) { y_grad = nullptr; if (!transpose_x && !transpose_y) { if (x_grad) { dx = matmul(out_help, yg_help, false, true); } if (grad_out_grad) { ddout = matmul(x_help, yg_help, false, false); } } else if (!transpose_x && transpose_y) { if (x_grad) { dx = matmul(out_help, yg_help, false, false); } if (grad_out_grad) { ddout = matmul(x_help, yg_help, false, true); } } else if (transpose_x && !transpose_y) { if (x_grad) { dx = matmul(yg_help, out_help, false, true); } if (grad_out_grad) { ddout = matmul(x_help, yg_help, true, false); } } else { if (x_grad) { dx = matmul(yg_help, out_help, true, true); } if (grad_out_grad) { ddout = matmul(x_help, yg_help, true, true); } } } else if (!grad_y_grad) { x_grad = nullptr; if (!transpose_x && !transpose_y) { if (y_grad) { dy = matmul(xg_help, out_help, true, false); } if (grad_out_grad) { ddout = matmul(xg_help, y_help, false, false); } } else if (!transpose_x && transpose_y) { if (y_grad) { dy = matmul(out_help, xg_help, true, false); } if (grad_out_grad) { ddout = matmul(xg_help, y_help, false, true); } } else if (transpose_x && !transpose_y) { if (y_grad) { dy = matmul(xg_help, out_help, false, false); } if (grad_out_grad) { ddout = matmul(xg_help, y_help, true, false); } } else { if (y_grad) { dy = matmul(out_help, xg_help, true, true); } if (grad_out_grad) { ddout = matmul(xg_help, y_help, true, true); } } } else { if (!transpose_x && !transpose_y) { if (x_grad) { dx = matmul(out_help, yg_help, false, true); } if (y_grad) { dy = matmul(xg_help, out_help, true, false); } if (grad_out_grad) { ddout_1 = matmul(x_help, yg_help, false, false); ddout_2 = matmul(xg_help, y_help, false, false); ddout = add(ddout_1, ddout_2); } } else if (!transpose_x && transpose_y) { if (x_grad) { dx = matmul(out_help, yg_help, false, false); } if (y_grad) { dy = matmul(out_help, xg_help, true, false); } if (grad_out_grad) { ddout_1 = matmul(x_help, yg_help, false, true); ddout_2 = matmul(xg_help, y_help, false, true); ddout = add(ddout_1, ddout_2); } } else if (transpose_x && !transpose_y) { if (x_grad) { dx = matmul(yg_help, out_help, false, true); } if (y_grad) { dy = matmul(xg_help, out_help, false, false); } if (grad_out_grad) { ddout_1 = matmul(x_help, yg_help, true, false); ddout_2 = matmul(xg_help, y_help, true, false); ddout = add(ddout_1, ddout_2); } } else { if (x_grad) { dx = matmul(yg_help, out_help, true, true); } if (y_grad) { dy = matmul(out_help, xg_help, true, true); } if (grad_out_grad) { ddout_1 = matmul(x_help, yg_help, true, true); ddout_2 = matmul(xg_help, y_help, true, true); ddout = add(ddout_1, ddout_2); } } } if (is_broadcast) { // Case3: broadcast. It need cost much time to reduce sum for the // broadcast and wastes the memory. // So we should avoid the case in reality. VLOG(3) << "It need cost much time to reduce sum for the broadcast and " "wastes the memory. So we should avoid the case in reality"; // Reduce sum to get grad by ReduceSum if (x_grad) { auto tx_dims = x_dims; auto tx_ndim = x_ndim; auto tdout_ndim = dout_ndim; if (x_ndim == 1) { tx_dims = std::vector({1, x_dims[0]}); tx_ndim = x_ndim + 1; tdout_ndim = dout_ndim + 1; } auto x_grad_reduce_dims = get_reduce_dims(dx, tdout_ndim, tx_ndim, &tx_dims); if (!x_grad_reduce_dims.empty()) { dx = sum(dx, IntArray(x_grad_reduce_dims), dy.dtype(), true); } reshape(dx, IntArray(tx_dims)); } if (y_grad) { auto ty_dims = y_dims; auto ty_ndim = y_ndim; auto tdout_ndim = dout_ndim; if (y_ndim == 1) { ty_dims = std::vector({y_dims[0], 1}); ty_ndim = y_ndim + 1; tdout_ndim = dout_ndim + 1; } auto y_grad_reduce_dims = get_reduce_dims(dy, tdout_ndim, ty_ndim, &ty_dims); if (!y_grad_reduce_dims.empty()) { dy = sum(dy, IntArray(y_grad_reduce_dims), dy.dtype(), true); } reshape(dy, IntArray(ty_dims)); } } // recover the original dim of output (delete 1) std::vector dx_dims = dx.initialized() ? vectorize(dx.dims()) : std::vector({}); std::vector dy_dims = dy.initialized() ? vectorize(dy.dims()) : std::vector({}); std::vector ddout_dims = ddout.initialized() ? vectorize(ddout.dims()) : std::vector({}); if (x_ndim == 1 && y_ndim == 1) { if (dx.initialized() && dx_dims[0] == 1) { dx = reshape(dx, IntArray(x_dims)); } if (dy.initialized() && dy_dims.back() == 1) { dy = reshape(dy, IntArray(y_dims)); } if (ddout.initialized() && ddout_dims == std::vector({1, 1})) { ddout = reshape(ddout, IntArray(std::vector({1}))); } } else if (x_ndim == 1) { if (dx.initialized() && dx_dims[0] == 1) { dx = reshape(dx, IntArray(x_dims)); } if (ddout.initialized() && ddout_dims[0] == 1) { ddout = reshape(ddout, IntArray(std::vector( {ddout_dims.cbegin() + 1, ddout_dims.cend()}))); } } else if (y_ndim == 1) { if (dy.initialized() && dy_dims.back() == 1) { dy = reshape(dy, IntArray(y_dims)); } if (ddout.initialized() && ddout_dims.back() == 1) { ddout = reshape(ddout, IntArray(std::vector( {ddout_dims.cbegin(), ddout_dims.cbegin() + ddout_dims.size() - 1}))); } } if (x_grad) { set_output(dx, x_grad); } if (y_grad) { set_output(dy, y_grad); } if (grad_out_grad) { set_output(ddout, grad_out_grad); } } template void slice_grad(const Tensor& input, const Tensor& out_grad, const std::vector& axes, const IntArray& starts, const IntArray& ends, const std::vector& infer_flags, const std::vector& decrease_axis, Tensor* input_grad) { if (input_grad) { size_t rank = input.dims().size(); auto out_dims = out_grad.dims(); std::vector origin_out_shape; auto in_dims = input.dims(); auto decrease_size = decrease_axis.size(); if (decrease_size > 0) { if (decrease_size == static_cast(in_dims.size())) { // all dims decrease out_dims = phi::make_ddim(std::vector(decrease_size, 1)); } else { origin_out_shape.resize(out_dims.size() + decrease_size, -1); for (size_t i = 0; i < decrease_size; ++i) { origin_out_shape[decrease_axis[i]] = 1; } int index = 0; for (size_t i = 0; i < origin_out_shape.size(); ++i) { if (origin_out_shape[i] == -1) { origin_out_shape[i] = out_dims[index]; ++index; } } out_dims = phi::make_ddim(origin_out_shape); } } std::vector offsets(rank, 0); std::vector extents(rank, 0); for (size_t i = 0; i < rank; ++i) { offsets[i] = 0; extents[i] = out_dims[i]; } for (size_t i = 0; i < axes.size(); ++i) { int axis = axes[i]; int64_t start = starts[i] < 0 ? (starts[i] + in_dims[axis]) : starts[i]; start = std::max(start, static_cast(0)); offsets[axis] = start; } std::vector paddings; for (size_t i = 0; i < rank; ++i) { paddings.push_back(offsets[i]); paddings.push_back((in_dims[i] - out_dims[i]) - offsets[i]); } if (decrease_size > 0 && (decrease_size != static_cast(in_dims.size()))) { auto out_tmp = pad(reshape(out_grad, origin_out_shape), paddings, 0.0); set_output(out_tmp, input_grad); } else { auto out_tmp = pad(out_grad, paddings, 0.0); set_output(out_tmp, input_grad); } } } template void cumsum_grad(const Tensor& x, const Tensor& out_grad, const Scalar& axis, bool flatten, bool exclusive, bool reverse, Tensor* x_grad) { if (x_grad) { auto grad = cumsum(out_grad, axis, flatten, exclusive, !reverse); grad = reshape(grad, x.shape()); set_output(grad, x_grad); } } template void topk_grad(const Tensor& x, const Tensor& indices, const Tensor& out_grad, const Scalar& k, const int& axis, const bool& largest, const bool& sorted, Tensor* x_grad) { if (x_grad) { auto zero_tensor = full(phi::vectorize(x.dims()), 0.0, x.dtype()); auto x_grad_tmp = put_along_axis(zero_tensor, indices, out_grad, axis); set_output(x_grad_tmp, x_grad); } } template void gather_nd_grad(const Tensor& x, const Tensor& index, const Tensor& out_grad, Tensor* x_grad) { if (x_grad) { auto zero_tensor = full(phi::vectorize(x.dims()), 0.0, x.dtype()); auto x_grad_tmp = scatter_nd_add(zero_tensor, index, out_grad); set_output(x_grad_tmp, x_grad); } } } // namespace prim } // namespace paddle