// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "paddle/fluid/prim/api/all.h" #include "paddle/phi/common/int_array.h" #include "paddle/phi/core/ddim.h" namespace paddle { namespace prim { using Tensor = paddle::experimental::Tensor; using IntArray = paddle::experimental::IntArrayBase; // This function should have as same signature as phi, which defined in // paddle/phi/api/backward/backward_api.h template void gather_grad(const Tensor& x, const Tensor& index, const Tensor& out_grad, const Scalar& axis, bool overwrite, Tensor* grad_x) { auto zero_tensor = full(phi::vectorize(x.dims()), 0.0, x.dtype()); std::vector tmp_perm; // change axis to rank 0 int axis_value = axis.to(); tmp_perm.push_back(axis_value); // make other ranks for (int i = 0; i < x.dims().size(); ++i) { if (i != axis_value) { tmp_perm.push_back(i); } } std::vector reverse_perm(tmp_perm); // make origin ranks for (int i = 0; i < static_cast(tmp_perm.size()); ++i) { reverse_perm[tmp_perm[i]] = i; } // transpose out_grad and zero grad to target rank. auto tmp_zero_x_grad = transpose(zero_tensor, tmp_perm); auto tmp_out_grad = transpose(out_grad, tmp_perm); // scatter grad to grad_x auto tmp_grad_x = scatter(tmp_zero_x_grad, index, tmp_out_grad, false); auto tmp_grad_x_tranposed = transpose(tmp_grad_x, reverse_perm); set_output(tmp_grad_x_tranposed, grad_x); } template void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) { if (!grad_x) return; auto grad_x_tmp = grad_out * (1.0 - out.pow(2.0)); set_output(grad_x_tmp, grad_x); } template void subtract_grad(const Tensor& x, const Tensor& y, const Tensor& out_grad, int axis, Tensor* dx, Tensor* dy) { if (dy) { auto scale_out_grad = scale(out_grad, -1.0, 0.0, true); if (x.dims() != y.dims()) { // Maybe need reduce here phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims()); if (!reduce_dim.size()) { by_pass(scale_out_grad, dy); } else { auto dy_reduce_res = scale_out_grad.sum(phi::vectorize(reduce_dim), y.dtype(), false); auto dy_tmp = reshape(dy_reduce_res, phi::vectorize(y.dims())); set_output(dy_tmp, dy); } } else { by_pass(scale_out_grad, dy); } } if (dx) { if (y.dims() != x.dims()) { // Maybe need reduce here auto reduce_dim = get_reduce_dims(x.dims(), y.dims()); if (!reduce_dim.size()) { by_pass(out_grad, dx); } else { auto dx_reduce_res = out_grad.sum(phi::vectorize(reduce_dim), x.dtype(), false); auto dx_tmp = reshape(dx_reduce_res, phi::vectorize(x.dims())); set_output(dx_tmp, dx); } } else { by_pass(out_grad, dx); } } } template void add_grad(const Tensor& x, const Tensor& y, const Tensor& out_grad, int axis, Tensor* dx, Tensor* dy) { if (dy) { if (x.dims() != y.dims()) { // Maybe need reduce here phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims()); if (!reduce_dim.size()) { by_pass(out_grad, dy); } else { auto dy_reduce_res = out_grad.sum(phi::vectorize(reduce_dim), y.dtype(), false); auto dy_tmp = reshape(dy_reduce_res, phi::vectorize(y.dims())); set_output(dy_tmp, dy); } } else { by_pass(out_grad, dy); } } if (dx) { if (y.dims() != x.dims()) { // Maybe need reduce here auto reduce_dim = get_reduce_dims(x.dims(), y.dims()); if (!reduce_dim.size()) { by_pass(out_grad, dx); } else { auto dx_reduce_res = out_grad.sum(phi::vectorize(reduce_dim), x.dtype(), false); auto dx_tmp = reshape(dx_reduce_res, phi::vectorize(x.dims())); set_output(dx_tmp, dx); } } else { by_pass(out_grad, dx); } } } template void sum_grad(const Tensor& x, const Tensor& out_grad, const IntArray& axis, bool keepdim, bool reduce_all, Tensor* x_grad) { if (!x_grad) { return; } std::vector x_dim = phi::vectorize(x.dims()); int64_t axis_size = axis.size(); int64_t x_dim_size = x_dim.size(); reduce_all = false; if (reduce_all || axis_size == 0 || axis_size == x_dim_size) { reduce_all = true; } else { reduce_all = false; } auto x_grad_tmp = Tensor(); if (x_dim_size == 1) { x_grad_tmp = out_grad.expand(IntArray(x_dim)); } else { if (!keepdim) { auto axis_ = std::vector(); if (reduce_all) { for (int64_t i = 1; i < x_dim_size; i++) { axis_.push_back(i); } } else { axis_ = axis.GetData(); } auto out_grad_ = unsqueeze(out_grad, axis_); x_grad_tmp = out_grad_.expand(IntArray(x_dim)); } else { x_grad_tmp = out_grad.expand(IntArray(x_dim)); } } set_output(x_grad_tmp, x_grad); } template void divide_grad(const Tensor& x, const Tensor& y, const Tensor& out, const Tensor& out_grad, int axis, Tensor* dx, Tensor* dy) { if (dy) { // dy = -(x/y^2) * dout auto dy_res = -(x / y.pow(2.0)) * out_grad; if (x.dims() != y.dims()) { // Maybe need reduce here phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims()); if (!reduce_dim.size()) { set_output(dy_res, dy); } else { auto dy_reduce_res = dy_res.sum(phi::vectorize(reduce_dim), y.dtype(), false); auto dy_tmp = reshape(dy_reduce_res, phi::vectorize(y.dims())); set_output(dy_tmp, dy); } } else { set_output(dy_res, dy); } } // indicate we will compute dy if (dx) { // dx = (1/y) * dout auto one_tensor = full(phi::vectorize(y.dims()), 1.0, y.dtype()); auto dx_res = one_tensor / y * out_grad; if (y.dims() != x.dims()) { // Maybe need reduce here auto reduce_dim = get_reduce_dims(x.dims(), y.dims()); if (!reduce_dim.size()) { set_output(dx_res, dx); } else { auto dx_reduce_res = dx_res.sum(phi::vectorize(reduce_dim), x.dtype(), false); auto dx_tmp = reshape(dx_reduce_res, phi::vectorize(x.dims())); set_output(dx_tmp, dx); } } else { set_output(dx_res, dx); } } // indicate we will compute dx } template void sqrt_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { if (x_grad) { auto x_grad_tmp = out_grad * 0.5 / out; set_output(x_grad_tmp, x_grad); } } template void multiply_grad(const Tensor& x, const Tensor& y, const Tensor& out_grad, int axis, Tensor* x_grad, Tensor* y_grad) { if (x_grad) { auto x_grad_unreduce = out_grad * y; if (x_grad_unreduce.dims() != x.dims()) { auto axes = get_reduce_dims_from_out(x_grad_unreduce.dims(), x.dims()); if (!axes.size()) { set_output(x_grad_unreduce, x_grad); } else { auto x_grad_reduced = x_grad_unreduce.sum( phi::vectorize(axes), x_grad_unreduce.dtype(), false); if (x_grad_reduced.dims().size() != x.dims().size()) { x_grad_reduced = reshape(x_grad_reduced, x.shape()); } set_output(x_grad_reduced, x_grad); } } else { set_output(x_grad_unreduce, x_grad); } } if (y_grad) { auto y_grad_unreduce = out_grad * x; if (y_grad_unreduce.dims() != y.dims()) { auto axes = get_reduce_dims_from_out(y_grad_unreduce.dims(), y.dims()); if (!axes.size()) { set_output(y_grad_unreduce, y_grad); } else { auto y_grad_reduced = y_grad_unreduce.sum( phi::vectorize(axes), y_grad_unreduce.dtype(), false); if (y_grad_reduced.dims().size() != y.dims().size()) { y_grad_reduced = reshape(y_grad_reduced, y.shape()); } set_output(y_grad_reduced, y_grad); } } else { set_output(y_grad_unreduce, y_grad); } } } template void expand_grad(const Tensor& x, const Tensor& out_grad, const IntArray& shape, Tensor* x_grad) { if (x_grad) { auto out_dims = phi::make_ddim(shape.GetData()); if (out_dims != x.dims()) { auto axes = get_reduce_dims(x.dims(), out_dims); if (!axes.size()) { by_pass(out_grad, x_grad); } else { auto reduced = out_grad.sum(phi::vectorize(axes), x.dtype(), false); if (reduced.dims().size() != x.dims().size()) { reduced = reshape(reduced, x.shape()); } set_output(reduced, x_grad); } } else { by_pass(out_grad, x_grad); } } } template void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { if (x_grad) { set_output(out_grad * out, x_grad); } } template void slice_grad(const Tensor& input, const Tensor& out_grad, const std::vector& axes, const IntArray& starts, const IntArray& ends, const std::vector& infer_flags, const std::vector& decrease_axis, Tensor* input_grad) { if (input_grad) { size_t rank = input.dims().size(); auto out_dims = out_grad.dims(); auto in_dims = input.dims(); auto decrease_size = decrease_axis.size(); if (decrease_size > 0) { if (decrease_size == static_cast(in_dims.size())) { // all dims decrease out_dims = phi::make_ddim(std::vector(decrease_size, 1)); } else { std::vector origin_out_shape(out_dims.size() + decrease_size, -1); for (size_t i = 0; i < decrease_size; ++i) { origin_out_shape[decrease_axis[i]] = 1; } int index = 0; for (size_t i = 0; i < origin_out_shape.size(); ++i) { if (origin_out_shape[i] == -1) { origin_out_shape[i] = out_dims[index]; ++index; } } out_dims = phi::make_ddim(origin_out_shape); } } std::vector offsets(rank, 0); std::vector extents(rank, 0); for (size_t i = 0; i < rank; ++i) { offsets[i] = 0; extents[i] = out_dims[i]; } for (size_t i = 0; i < axes.size(); ++i) { int axis = axes[i]; int64_t start = starts[i] < 0 ? (starts[i] + in_dims[axis]) : starts[i]; start = std::max(start, static_cast(0)); offsets[axis] = start; } std::vector paddings; for (size_t i = 0; i < rank; ++i) { paddings.push_back(offsets[i]); paddings.push_back((in_dims[i] - out_dims[i]) - offsets[i]); } auto out_tmp = pad(out_grad, paddings, 0.0); set_output(out_tmp, input_grad); } } } // namespace prim } // namespace paddle