// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #ifndef _USE_MATH_DEFINES #define _USE_MATH_DEFINES #endif #include #include "paddle/fluid/prim/api/all.h" #include "paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h" #include "paddle/fluid/prim/api/generated_prim/prim_generated_api.h" #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/int_array.h" #include "paddle/phi/core/ddim.h" namespace paddle { namespace prim { using Tensor = paddle::Tensor; using IntArray = paddle::experimental::IntArrayBase; // This function should have as same signature as phi, which defined in // paddle/phi/api/backward/backward_api.h template void hardswish_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) { if (x_grad) { auto offset = full(phi::vectorize(x.dims()), 3.0, x.dtype()); auto condition = less_equal(x, offset); auto tmp1 = where(condition, out_grad * ((x / 3.0) + 0.5), out_grad); auto res = where( less_than(x, full(phi::vectorize(x.dims()), -3.0, x.dtype())), full(phi::vectorize(x.dims()), 0.0, x.dtype()), tmp1); set_output(res, x_grad); } } template void leaky_relu_grad(const Tensor& out, const Tensor& out_grad, float negative_slope, Tensor* x_grad) { if (x_grad) { auto condition = greater_than( out, full(phi::vectorize(out.dims()), 0.0, out.dtype())); auto res = where(condition, out_grad, out_grad * negative_slope); set_output(res, x_grad); } } template void silu_grad(const Tensor& x, const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { if (x_grad) { auto sigmoid = out / x; auto res = out_grad * sigmoid * (1.0 + x * (1.0 - sigmoid)); set_output(res, x_grad); } } template void relu_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { if (x_grad) { auto condition = greater_than( out, full(phi::vectorize(out.dims()), 0.0, out.dtype())); auto res = where(condition, out_grad, full(phi::vectorize(out.dims()), 0.0, out.dtype())); set_output(res, x_grad); } } template void softmax_grad(const Tensor& out, const Tensor& out_grad, int axis, Tensor* x_grad) { if (x_grad) { if (out_grad.dims().size() > 0) { if (axis >= 0) { auto new_out_grad = out_grad * out; auto tmp_x_grad = new_out_grad - out * sum(new_out_grad, {axis}, out.dtype(), true); set_output(tmp_x_grad, x_grad); } else { auto new_out_grad = out_grad * out; auto tmp_x_grad = new_out_grad - out * sum(new_out_grad, {out.dims().size() + axis}, out.dtype(), true); set_output(tmp_x_grad, x_grad); } } else { set_output( full(phi::vectorize(out_grad.dims()), 0.0, out_grad.dtype()), x_grad); } } } template void cast_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) { if (x_grad) { auto res = cast(out_grad, x.dtype()); set_output(res, x_grad); } } template void gather_grad(const Tensor& x, const Tensor& index, const Tensor& out_grad, const Scalar& axis, Tensor* grad_x) { auto zero_tensor = full(phi::vectorize(x.dims()), 0.0, x.dtype()); std::vector tmp_perm; // change axis to rank 0 int axis_value = axis.to(); tmp_perm.push_back(axis_value); // make other ranks for (int i = 0; i < x.dims().size(); ++i) { if (i != axis_value) { tmp_perm.push_back(i); } } std::vector reverse_perm(tmp_perm); // make origin ranks for (int i = 0; i < static_cast(tmp_perm.size()); ++i) { if (tmp_perm[i] >= 0) { reverse_perm[tmp_perm[i]] = i; } else { reverse_perm[tmp_perm[i] + tmp_perm.size()] = i; } } // transpose out_grad and zero grad to target rank. auto tmp_zero_x_grad = transpose(zero_tensor, tmp_perm); auto tmp_out_grad = transpose(out_grad, tmp_perm); // scatter grad to grad_x auto tmp_grad_x = scatter(tmp_zero_x_grad, index, tmp_out_grad, false); auto tmp_grad_x_tranposed = transpose(tmp_grad_x, reverse_perm); set_output(tmp_grad_x_tranposed, grad_x); } template void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) { if (!grad_x) return; auto grad_x_tmp = grad_out * (1 - out * out); set_output(grad_x_tmp, grad_x); } template void reshape_grad(const Tensor& x, const Tensor& grad_out, Tensor* grad_x) { if (grad_x) { auto grad_x_tmp = reshape(grad_out, phi::vectorize(x.dims())); set_output(grad_x_tmp, grad_x); } } template void transpose_grad(const Tensor& grad_out, const std::vector& perm, Tensor* grad_x) { if (grad_x) { std::vector reverse_perm(perm); // make origin ranks for (int i = 0; i < static_cast(perm.size()); ++i) { if (perm[i] >= 0) { reverse_perm[perm[i]] = i; } else { reverse_perm[perm[i] + perm.size()] = i; } } auto grad_x_tmp = transpose(grad_out, reverse_perm); set_output(grad_x_tmp, grad_x); } } template void subtract_grad(const Tensor& x, const Tensor& y, const Tensor& out_grad, int axis, Tensor* dx, Tensor* dy) { if (dy) { auto scale_out_grad = scale(out_grad, -1.0, 0.0, true); if (x.dims() != y.dims()) { // Maybe need reduce here phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims()); if (!reduce_dim.size()) { by_pass(scale_out_grad, dy); } else { auto dy_reduce_res = scale_out_grad.sum(phi::vectorize(reduce_dim), y.dtype(), false); auto dy_tmp = reshape(dy_reduce_res, phi::vectorize(y.dims())); set_output(dy_tmp, dy); } } else { by_pass(scale_out_grad, dy); } } if (dx) { if (y.dims() != x.dims()) { // Maybe need reduce here auto reduce_dim = get_reduce_dims(x.dims(), y.dims()); if (!reduce_dim.size()) { by_pass(out_grad, dx); } else { auto dx_reduce_res = out_grad.sum(phi::vectorize(reduce_dim), x.dtype(), false); auto dx_tmp = reshape(dx_reduce_res, phi::vectorize(x.dims())); set_output(dx_tmp, dx); } } else { by_pass(out_grad, dx); } } } template void add_grad(const Tensor& x, const Tensor& y, const Tensor& out_grad, int axis, Tensor* dx, Tensor* dy) { if (dy) { if (x.dims() != y.dims()) { // Maybe need reduce here phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims()); if (!reduce_dim.size()) { by_pass(out_grad, dy); } else { auto dy_reduce_res = out_grad.sum(phi::vectorize(reduce_dim), y.dtype(), false); auto dy_tmp = reshape(dy_reduce_res, phi::vectorize(y.dims())); set_output(dy_tmp, dy); } } else { by_pass(out_grad, dy); } } if (dx) { if (y.dims() != x.dims()) { // Maybe need reduce here auto reduce_dim = get_reduce_dims(x.dims(), y.dims()); if (!reduce_dim.size()) { by_pass(out_grad, dx); } else { auto dx_reduce_res = out_grad.sum(phi::vectorize(reduce_dim), x.dtype(), false); auto dx_tmp = reshape(dx_reduce_res, phi::vectorize(x.dims())); set_output(dx_tmp, dx); } } else { by_pass(out_grad, dx); } } } template void sum_grad(const Tensor& x, const Tensor& out_grad, const IntArray& axis, bool keepdim, bool reduce_all, Tensor* x_grad) { if (!x_grad) { return; } std::vector x_dim = phi::vectorize(x.dims()); int64_t axis_size = axis.size(); int64_t x_dim_size = x_dim.size(); reduce_all = false; if (reduce_all || axis_size == 0 || axis_size == x_dim_size) { reduce_all = true; } else { reduce_all = false; } auto x_grad_tmp = Tensor(); if (x_dim_size == 1) { x_grad_tmp = out_grad.expand(IntArray(x_dim)); } else { if (!keepdim) { auto axis_ = std::vector(); if (reduce_all) { for (int64_t i = 0; i < x_dim_size; i++) { axis_.push_back(i); } } else { axis_ = axis.GetData(); for (int64_t i = 0; i < axis_size; i++) { if (axis[i] < 0) { axis_[i] = axis[i] + x_dim_size; } } } auto out_grad_shape = get_unsqueeze_dims(out_grad, axis_); auto out_grad_ = reshape(out_grad, out_grad_shape); x_grad_tmp = out_grad_.expand(IntArray(x_dim)); } else { x_grad_tmp = out_grad.expand(IntArray(x_dim)); } } set_output(x_grad_tmp, x_grad); } template void divide_grad(const Tensor& x, const Tensor& y, const Tensor& out, const Tensor& out_grad, int axis, Tensor* dx, Tensor* dy) { if (dy) { // dy = -(x/y^2) * dout auto dy_res = -(x / y.pow(2.0)) * out_grad; if (x.dims() != y.dims()) { // Maybe need reduce here phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims()); if (!reduce_dim.size()) { set_output(dy_res, dy); } else { auto dy_reduce_res = dy_res.sum(phi::vectorize(reduce_dim), y.dtype(), false); auto dy_tmp = reshape(dy_reduce_res, phi::vectorize(y.dims())); set_output(dy_tmp, dy); } } else { set_output(dy_res, dy); } } // indicate we will compute dy if (dx) { // dx = (1/y) * dout auto one_tensor = full(phi::vectorize(y.dims()), 1.0, y.dtype()); auto dx_res = one_tensor / y * out_grad; if (y.dims() != x.dims()) { // Maybe need reduce here auto reduce_dim = get_reduce_dims(x.dims(), y.dims()); if (!reduce_dim.size()) { set_output(dx_res, dx); } else { auto dx_reduce_res = dx_res.sum(phi::vectorize(reduce_dim), x.dtype(), false); auto dx_tmp = reshape(dx_reduce_res, phi::vectorize(x.dims())); set_output(dx_tmp, dx); } } else { set_output(dx_res, dx); } } // indicate we will compute dx } template void elementwise_pow_grad(const Tensor& x, const Tensor& y, const Tensor& out_grad, Tensor* dx, Tensor* dy) { if (dy) { // dy = lnx * x^y auto lnx = log(x); auto x_pow_y = elementwise_pow(x, y); auto dy_res = lnx * x_pow_y * out_grad; if (x.dims() != y.dims()) { // Maybe need reduce here phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims()); if (!reduce_dim.size()) { set_output(dy_res, dy); } else { auto dy_reduce_res = dy_res.sum(phi::vectorize(reduce_dim), y.dtype(), false); auto dy_tmp = reshape(dy_reduce_res, phi::vectorize(y.dims())); set_output(dy_tmp, dy); } } else { set_output(dy_res, dy); } } // indicate we will compute dy if (dx) { // dx = y * x^(y-1) auto tmp_z = y - 1.0; auto x_pow_z = elementwise_pow(x, tmp_z); auto dx_res = y * x_pow_z * out_grad; if (y.dims() != x.dims()) { // Maybe need reduce here auto reduce_dim = get_reduce_dims(x.dims(), y.dims()); if (!reduce_dim.size()) { set_output(dx_res, dx); } else { auto dx_reduce_res = dx_res.sum(phi::vectorize(reduce_dim), x.dtype(), false); auto dx_tmp = reshape(dx_reduce_res, phi::vectorize(x.dims())); set_output(dx_tmp, dx); } } else { set_output(dx_res, dx); } } // indicate we will compute dx } template void sqrt_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { if (x_grad) { // This calculation is important for resnet. auto x_grad_tmp = (0.5 / out) * out_grad; set_output(x_grad_tmp, x_grad); } } template void floor_grad(const Tensor& out_grad, Tensor* x_grad) { if (x_grad) { auto zero_tensor = full(phi::vectorize(out_grad.dims()), 0.0, out_grad.dtype()); set_output(zero_tensor, x_grad); } } template void concat_grad(const std::vector& x, const Tensor& out_grad, const Scalar& axis, std::vector x_grad) { int axis_value = axis.to(); int rank = x[0].dims().size(); if (axis_value < 0) { axis_value = axis_value + rank; } axis_value = axis_value > 0 ? axis_value : 0; std::vector sections; int x_num = x.size(); for (int i = 0; i < x_num; ++i) { sections.push_back(x[i].dims()[axis_value]); } std::vector x_grad_tmp = split(out_grad, phi::IntArray(sections), axis_value); for (int i = 0; i < x_num; ++i) { set_output(x_grad_tmp.at(i), x_grad.at(i)); } } template void multiply_grad(const Tensor& x, const Tensor& y, const Tensor& out_grad, int axis, Tensor* x_grad, Tensor* y_grad) { if (x_grad) { auto x_grad_unreduce = out_grad * y; if (x_grad_unreduce.dims() != x.dims()) { auto axes = get_reduce_dims_from_out(x_grad_unreduce.dims(), x.dims()); if (!axes.size()) { set_output(x_grad_unreduce, x_grad); } else { auto x_grad_reduced = x_grad_unreduce.sum( phi::vectorize(axes), x_grad_unreduce.dtype(), false); if (x_grad_reduced.dims().size() != x.dims().size()) { x_grad_reduced = reshape(x_grad_reduced, x.shape()); } set_output(x_grad_reduced, x_grad); } } else { set_output(x_grad_unreduce, x_grad); } } if (y_grad) { auto y_grad_unreduce = out_grad * x; if (y_grad_unreduce.dims() != y.dims()) { auto axes = get_reduce_dims_from_out(y_grad_unreduce.dims(), y.dims()); if (!axes.size()) { set_output(y_grad_unreduce, y_grad); } else { auto y_grad_reduced = y_grad_unreduce.sum( phi::vectorize(axes), y_grad_unreduce.dtype(), false); if (y_grad_reduced.dims().size() != y.dims().size()) { y_grad_reduced = reshape(y_grad_reduced, y.shape()); } set_output(y_grad_reduced, y_grad); } } else { set_output(y_grad_unreduce, y_grad); } } } template void expand_grad(const Tensor& x, const Tensor& out_grad, const IntArray& shape, Tensor* x_grad) { if (x_grad) { auto out_dims = phi::make_ddim(shape.GetData()); if (out_dims != x.dims()) { auto axes = get_reduce_dims(x.dims(), out_dims); if (!axes.size()) { by_pass(out_grad, x_grad); } else { auto reduced = out_grad.sum(phi::vectorize(axes), x.dtype(), false); if (reduced.dims().size() != x.dims().size()) { reduced = reshape(reduced, x.shape()); } set_output(reduced, x_grad); } } else { by_pass(out_grad, x_grad); } } } template void log_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) { if (x_grad) { // dx = dout / x set_output(out_grad / x, x_grad); } } template void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { if (x_grad) { if (out.dtype() == phi::DataType::FLOAT16 || out.dtype() == phi::DataType::BFLOAT16) { Tensor out_promote = cast(out, phi::DataType::FLOAT32); Tensor out_grad_promote = cast(out_grad, phi::DataType::FLOAT32); set_output(cast(out_promote * out_grad_promote, out.dtype()), x_grad); } else { set_output(out_grad * out, x_grad); } } } template void sigmoid_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { if (x_grad) { set_output(out_grad * (out * (1 - out)), x_grad); } } template void abs_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) { if (x_grad) { auto abs_tmp = abs(x); auto divide_tmp = divide(x, abs_tmp); set_output(out_grad * divide_tmp, x_grad); } } template void slice_grad(const Tensor& input, const Tensor& out_grad, const std::vector& axes, const IntArray& starts, const IntArray& ends, const std::vector& infer_flags, const std::vector& decrease_axis, Tensor* input_grad) { if (input_grad) { size_t rank = input.dims().size(); auto out_dims = out_grad.dims(); std::vector origin_out_shape; auto in_dims = input.dims(); auto decrease_size = decrease_axis.size(); if (decrease_size > 0) { if (decrease_size == static_cast(in_dims.size())) { // all dims decrease out_dims = phi::make_ddim(std::vector(decrease_size, 1)); } else { origin_out_shape.resize(out_dims.size() + decrease_size, -1); for (size_t i = 0; i < decrease_size; ++i) { origin_out_shape[decrease_axis[i]] = 1; } int index = 0; for (size_t i = 0; i < origin_out_shape.size(); ++i) { if (origin_out_shape[i] == -1) { origin_out_shape[i] = out_dims[index]; ++index; } } out_dims = phi::make_ddim(origin_out_shape); } } std::vector offsets(rank, 0); std::vector extents(rank, 0); for (size_t i = 0; i < rank; ++i) { offsets[i] = 0; extents[i] = out_dims[i]; } for (size_t i = 0; i < axes.size(); ++i) { int axis = axes[i]; int64_t start = starts[i] < 0 ? (starts[i] + in_dims[axis]) : starts[i]; start = std::max(start, static_cast(0)); offsets[axis] = start; } std::vector paddings; for (size_t i = 0; i < rank; ++i) { paddings.push_back(offsets[i]); paddings.push_back((in_dims[i] - out_dims[i]) - offsets[i]); } if (decrease_size > 0 && (decrease_size != static_cast(in_dims.size()))) { auto out_tmp = pad(reshape(out_grad, origin_out_shape), paddings, 0.0); set_output(out_tmp, input_grad); } else { auto out_tmp = pad(out_grad, paddings, 0.0); set_output(out_tmp, input_grad); } } } template void group_norm_grad(const Tensor& x, const paddle::optional& scale, const paddle::optional& bias, const Tensor& y, const Tensor& mean, const Tensor& variance, const Tensor& out_grad, float epsilon, int groups, const std::string& data_layout, Tensor* x_grad, Tensor* scale_grad, Tensor* bias_grad) { // x.shape=[n,c,h,w] // y.shape=[n,c,h,w] // g_size = c/g // scale.shape=[c] // mean, var: shape=[n, g] // inv_std = rsqrt(var + epsilon) // ds = sum(dy * x, axes=(2,3)) // db = sum(dy, axes=(2,3)) // // cal d_x: // s = g / (h*w*c) // if scale: // ds_val = sum((ds * scale).reshape(n, g, g_size), axes=2) // db_val = sum((db * scale).reshape(n, g, g_size), axes=2) // p1 = (inv_std.reshape(n, g, 1)) * (scale.reshape(1, g, g_size)) // else: // ds_val = sum(ds.reshape(n, g, g_size), axes=2) // db_val = sum(db.reshape(n, g, g_size), axes=2) // p1 = (inv_std.reshape(n, g, 1)) * (ones(1, g, g_size)) // p2 = (db_val * mean - ds_val) * inv_std * inv_std * inv_std * s // p3 = -p2 * mean - db_val * inv_std * s // p1.reshape(n, g, g_size, 1) // p2.reshape(n, g, 1, 1) // p3.reshape(n, g, 1, 1) // d_x = dy.reshape(n, g, g_size, h*w) * p1 + x.reshape(n, g, g_size, h*w)* p2 // + p3 // // cal d_scale: // temp = ds.reshape(n, g, g_size) - db.reshape(n, g, g_size) * // mean.reshape(n, g, 1) // d_scale = sum(temp * inv_std.reshape(n, g, 1), axes=0).reshape(c) // // cal d_bias: // d_bias = sum(dy, axes=(0,2,3)) DataLayout data_layout_ = phi::StringToDataLayout(data_layout); if (data_layout_ != DataLayout::kNCHW) { PADDLE_THROW(phi::errors::InvalidArgument("Unsupported storage order: %s", data_layout)); } Tensor x_data = x; Tensor out_grad_data = out_grad; if (x.dtype() == phi::DataType::FLOAT16) { x_data = cast(x, phi::DataType::FLOAT32); } if (out_grad.dtype() == phi::DataType::FLOAT16) { out_grad_data = cast(out_grad, phi::DataType::FLOAT32); } std::vector x_dims = phi::vectorize(x.dims()); auto add_axis = std::vector({-1}); const int N = x_dims[0]; const int C = x_dims[1]; const int hw = x_dims[2] * x_dims[3]; const int g_num = C / groups; auto reduce_axis = IntArray(std::vector({2, 3})); auto shape_group = IntArray(std::vector({N, groups, g_num})); auto whole_group_shape = IntArray(std::vector({N, groups, g_num, hw})); auto scale_ptr = scale.get_ptr(); auto bias_ptr = bias.get_ptr(); auto inv_std = sqrt(1.0 / (variance + epsilon)); auto inv_std_mul_s = inv_std / hw / g_num; auto dtype = x_data.dtype(); auto sum_y_grad_mul_x = sum(out_grad_data * x_data, reduce_axis, dtype, false); auto sum_y_grad = sum(out_grad_data, reduce_axis, dtype, false); if (x_grad) { Tensor d1; Tensor d2; Tensor p1; if (scale_ptr) { auto scale_data = scale.get(); if (scale_data.dtype() == phi::DataType::FLOAT16) { scale_data = cast(scale_data, phi::DataType::FLOAT32); } d1 = (reshape(sum_y_grad_mul_x * scale_data, shape_group)) .sum(std::vector({2}), dtype, false); d2 = (reshape(sum_y_grad * scale_data, shape_group)) .sum(std::vector({2}), dtype, false); p1 = reshape(inv_std, std::vector({N, groups, 1})) * reshape(scale_data, std::vector({1, groups, g_num})); } else { d1 = (reshape(sum_y_grad_mul_x, shape_group)) .sum(std::vector({2}), dtype, false); d2 = (reshape(sum_y_grad, shape_group)) .sum(std::vector({2}), dtype, false); p1 = (reshape(inv_std, std::vector({N, groups, 1}))) .expand(IntArray(shape_group)); } auto p2 = (d2 * mean - d1) * (inv_std_mul_s * inv_std * inv_std); auto p3 = -p2 * mean - d2 * inv_std_mul_s; auto first_shape = get_unsqueeze_dims(p1, std::vector({3})); auto second_shape = get_unsqueeze_dims(p2, std::vector({2, 3})); p1 = reshape(p1, first_shape); p2 = reshape(p2, second_shape); p3 = reshape(p3, second_shape); auto tmp_1 = reshape(out_grad_data, whole_group_shape) * p1; auto tmp_2 = reshape(x_data, whole_group_shape) * p2 + p3; auto x_grad_data = tmp_1 + tmp_2; x_grad_data = reshape(x_grad_data, x.shape()); if (x.dtype() == phi::DataType::FLOAT16) { x_grad_data = cast(x_grad_data, x.dtype()); } set_output(x_grad_data, x_grad); } if (scale_grad) { if (scale_ptr) { auto third_shape = get_unsqueeze_dims(mean, std::vector({2})); auto tmp1 = (reshape(sum_y_grad_mul_x, shape_group) - reshape(sum_y_grad, shape_group) * reshape(mean, third_shape)) * reshape(inv_std, third_shape); auto scale_grad_tmp = reshape(tmp1.sum(std::vector({0}), dtype, false), IntArray(std::vector({C}))); set_output(scale_grad_tmp, scale_grad); } else { scale_grad = nullptr; } } if (bias_grad) { if (bias_ptr) { auto bias_grad_tmp = sum_y_grad.sum(std::vector({0}), dtype, false); set_output(bias_grad_tmp, bias_grad); } else { bias_grad = nullptr; } } } template void layer_norm_grad(const Tensor& x, const paddle::optional& scale, const paddle::optional& bias, const Tensor& mean, const Tensor& variance, const Tensor& out_grad, float epsilon, int begin_norm_axis, Tensor* x_grad, Tensor* scale_grad, Tensor* bias_grad) { auto x_dims = x.dims(); auto shape_1 = 1; // front part auto shape_2 = 1; // back part for (int i = 0; i < begin_norm_axis; ++i) { shape_1 *= x_dims[i]; } for (int i = begin_norm_axis; i < x.dims().size(); ++i) { shape_2 *= x_dims[i]; } auto scale_ptr = scale.get_ptr(); auto bias_ptr = bias.get_ptr(); auto x_cast = reshape(x, std::vector({shape_1, shape_2})); auto out_grad_cast = reshape(out_grad, std::vector({shape_1, shape_2})); auto mean_ = reshape(mean, std::vector({shape_1, 1})); auto variance_ = reshape(variance, std::vector({shape_1, 1})); Tensor scale_cast; if (scale_ptr) { scale_cast = reshape(*scale_ptr, std::vector({1, shape_2})); } // cast dtype to float32 if dtype =float16 or bfloat16 if (x.dtype() == phi::DataType::FLOAT16 || x.dtype() == phi::DataType::BFLOAT16) { x_cast = cast(x_cast, phi::DataType::FLOAT32); out_grad_cast = cast(out_grad_cast, phi::DataType::FLOAT32); if (scale_ptr) { scale_cast = cast(scale_cast, phi::DataType::FLOAT32); } } auto x_sub_mean = x_cast - mean_; // M,N auto tmp = (1.0 / (variance_ + epsilon)); // M,1 auto sqrt_var_1 = sqrt(tmp); // M,1 auto x_sub_mean_mul_sqrt_var_1 = x_sub_mean * sqrt_var_1; if (x_grad) { auto out_grad_scale = out_grad_cast; // M,N if (scale_ptr) { out_grad_scale = out_grad_cast * scale_cast; // M,N * 1,N = M,N } auto dx_end = sqrt_var_1 * out_grad_scale; auto d_mean = dx_end.sum(std::vector({1}), x_cast.dtype(), true); // M,1 auto d_std_1 = (tmp * x_sub_mean * out_grad_scale) .sum(std::vector({1}), x_cast.dtype(), true); // M,1 auto d_std = d_std_1 * x_sub_mean_mul_sqrt_var_1; // M,1 * M,N = M,N auto d_mean_d_std = (1.0 / shape_2) * (d_mean + d_std); auto x_grad_tmp = dx_end - d_mean_d_std; x_grad_tmp = reshape(x_grad_tmp, phi::vectorize(x.dims())); if (x.dtype() == phi::DataType::FLOAT16 || x.dtype() == phi::DataType::BFLOAT16) { x_grad_tmp = cast(x_grad_tmp, x.dtype()); } set_output(x_grad_tmp, x_grad); } if (scale_grad) { if (scale_ptr) { auto scale_grad_tmp = (x_sub_mean_mul_sqrt_var_1 * out_grad_cast) .sum(std::vector({0}), x_cast.dtype(), true); scale_grad_tmp = reshape(scale_grad_tmp, scale_ptr->shape()); if (scale_ptr->dtype() == phi::DataType::FLOAT16 || scale_ptr->dtype() == phi::DataType::BFLOAT16) { scale_grad_tmp = cast(scale_grad_tmp, scale_ptr->dtype()); } set_output(scale_grad_tmp, scale_grad); } else { scale_grad = nullptr; } } if (bias_grad) { if (bias_ptr) { auto bias_grad_tmp = out_grad_cast.sum(std::vector({0}), x_cast.dtype(), true); bias_grad_tmp = reshape(bias_grad_tmp, bias_ptr->shape()); if (bias_ptr->dtype() == phi::DataType::FLOAT16 || bias_ptr->dtype() == phi::DataType::BFLOAT16) { bias_grad_tmp = cast(bias_grad_tmp, bias_ptr->dtype()); } set_output(bias_grad_tmp, bias_grad); } else { bias_grad = nullptr; } } } template void cumsum_grad(const Tensor& x, const Tensor& out_grad, const Scalar& axis, bool flatten, bool exclusive, bool reverse, Tensor* x_grad) { if (x_grad) { auto grad = cumsum(out_grad, axis, flatten, exclusive, !reverse); grad = reshape(grad, x.shape()); set_output(grad, x_grad); } } template void split_grad(const std::vector& out_grad, const Scalar& axis, Tensor* x_grad) { if (x_grad) { auto grad = concat(out_grad, axis); set_output(grad, x_grad); } } template void topk_grad(const Tensor& x, const Tensor& indices, const Tensor& out_grad, const Scalar& k, const int& axis, const bool& largest, const bool& sorted, Tensor* x_grad) { if (x_grad) { auto zero_tensor = full(phi::vectorize(x.dims()), 0.0, x.dtype()); auto x_grad_tmp = put_along_axis(zero_tensor, indices, out_grad, axis); set_output(x_grad_tmp, x_grad); } } template void gather_nd_grad(const Tensor& x, const Tensor& index, const Tensor& out_grad, Tensor* x_grad) { if (x_grad) { auto zero_tensor = full(phi::vectorize(x.dims()), 0.0, x.dtype()); auto x_grad_tmp = scatter_nd_add(zero_tensor, index, out_grad); set_output(x_grad_tmp, x_grad); } } template void prod_grad(const Tensor& x, const Tensor& out, const Tensor& out_grad, const IntArray& axis, bool keep_dim, bool reduce_all, Tensor* x_grad) { if (x_grad) { std::vector x_dim = phi::vectorize(x.dims()); int64_t axis_size = axis.size(); int64_t x_dim_size = x_dim.size(); reduce_all = false; if (reduce_all || axis_size == 0 || axis_size == x_dim_size) { reduce_all = true; } else { reduce_all = false; } auto x_grad_tmp = Tensor(); auto out_tmp = Tensor(); if (x_dim_size == 1) { x_grad_tmp = out_grad.expand(IntArray(x_dim)); out_tmp = out.expand(IntArray(x_dim)); } else { if (!keep_dim) { auto axis_ = std::vector(); if (reduce_all) { for (int64_t i = 0; i < x_dim_size; i++) { axis_.push_back(i); } } else { axis_ = axis.GetData(); for (int64_t i = 0; i < axis_size; i++) { if (axis[i] < 0) { axis_[i] = axis[i] + x_dim_size; } } } auto out_grad_shape = get_unsqueeze_dims(out_grad, axis_); auto out_grad_ = reshape(out_grad, out_grad_shape); x_grad_tmp = out_grad_.expand(IntArray(x_dim)); auto out_ = reshape(out, out_grad_shape); out_tmp = out_.expand(IntArray(x_dim)); } else { x_grad_tmp = out_grad.expand(IntArray(x_dim)); out_tmp = out.expand(IntArray(x_dim)); } } auto x_grad_res = x_grad_tmp * out_tmp * (1 / x); set_output(x_grad_res, x_grad); } } template void max_grad(const Tensor& x, const Tensor& out, const Tensor& out_grad, const IntArray& axis, bool keepdim, bool reduce_all, Tensor* x_grad) { if (!x_grad) { return; } auto zero_tensor = full(phi::vectorize(x.dims()), 0.0, x.dtype()); std::vector x_dim = phi::vectorize(x.dims()); int64_t axis_size = axis.size(); int64_t x_dim_size = x_dim.size(); reduce_all = false; if (reduce_all || axis_size == 0 || axis_size == x_dim_size) { reduce_all = true; } else { reduce_all = false; } auto x_grad_tmp = Tensor(); if (x_dim_size == 0 || x_dim_size == 1 || keepdim) { auto out_grad_tmp = out_grad.expand(IntArray(x_dim)); auto out_tmp = out.expand(IntArray(x_dim)); auto mask = equal(x, out_tmp); x_grad_tmp = where(mask, out_grad_tmp, zero_tensor); } else { auto axis_ = std::vector(); if (reduce_all) { for (int64_t i = 0; i < x_dim_size; i++) { axis_.push_back(i); } } else { axis_ = axis.GetData(); for (int64_t i = 0; i < axis_size; i++) { if (axis[i] < 0) { axis_[i] = axis[i] + x_dim_size; } } } auto out_grad_shape = get_unsqueeze_dims(out_grad, axis_); auto out_grad_ = reshape(out_grad, out_grad_shape); auto out_ = reshape(out, out_grad_shape); auto out_grad_tmp = out_grad_.expand(IntArray(x_dim)); auto out_tmp = out_.expand(IntArray(x_dim)); auto mask = equal(x, out_tmp); x_grad_tmp = where(mask, out_grad_tmp, zero_tensor); } set_output(x_grad_tmp, x_grad); } template void assign_grad(const Tensor& out_grad, Tensor* x_grad) { if (x_grad) { by_pass(out_grad, x_grad); } } template void erf_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) { if (x_grad) { auto m_2_sqrt_pi = full(phi::vectorize(x.dims()), M_2_SQRTPI, x.dtype()); auto neg_one = full(phi::vectorize(x.dims()), -1.0, x.dtype()); auto neg_tmp = neg_one * x * x; auto mul_tmp = m_2_sqrt_pi * exp(neg_tmp); set_output(out_grad * mul_tmp, x_grad); } } template void maximum_grad(const Tensor& x, const Tensor& y, const Tensor& out_grad, Tensor* x_grad, Tensor* y_grad) { if (x_grad) { auto x_tmp = cast(greater_than(x, y), out_grad.dtype()); auto dx_res = out_grad * x_tmp; if (y.dims() != x.dims()) { // Maybe need reduce here auto reduce_dim = get_reduce_dims(x.dims(), y.dims()); if (!reduce_dim.size()) { set_output(dx_res, x_grad); } else { auto dx_reduce_res = dx_res.sum(phi::vectorize(reduce_dim), x.dtype(), false); auto dx_tmp = reshape(dx_reduce_res, phi::vectorize(x.dims())); set_output(dx_tmp, x_grad); } } else { set_output(dx_res, x_grad); } } if (y_grad) { auto y_tmp = cast(less_equal(x, y), out_grad.dtype()); auto dy_res = out_grad * y_tmp; if (x.dims() != y.dims()) { // Maybe need reduce here phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims()); if (!reduce_dim.size()) { set_output(dy_res, y_grad); } else { auto dy_reduce_res = dy_res.sum(phi::vectorize(reduce_dim), y.dtype(), false); auto dy_tmp = reshape(dy_reduce_res, phi::vectorize(y.dims())); set_output(dy_tmp, y_grad); } } else { set_output(dy_res, y_grad); } } } template void dropout_grad(const Tensor& mask, const Tensor& out_grad, const Scalar& p, bool is_test, const std::string& mode, Tensor* x_grad) { if (!x_grad) return; if (is_test) { if (mode == "upscale_in_train") { by_pass(out_grad, x_grad); } else { set_output(out_grad * (1.0 - p.to()), x_grad); } } else { if (mode == "upscale_in_train") { if (p.to() == 1.0f) { set_output(scale(out_grad, 0.0), x_grad); } else { set_output(scale(out_grad * cast(mask, out_grad.dtype()), 1.0 / (1.0 - p.to())), x_grad); } } else { set_output(out_grad * cast(mask, out_grad.dtype()), x_grad); } } } template void sin_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) { auto x_grad_tmp = cos(x) * out_grad; set_output(x_grad_tmp, x_grad); } template void cos_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) { auto x_grad_tmp = -sin(x) * out_grad; set_output(x_grad_tmp, x_grad); } template void scatter_grad(const Tensor& index, const Tensor& updates, const Tensor& out_grad, bool overwrite, Tensor* x_grad, Tensor* updates_grad) { if (x_grad) { auto zero_tensor = full(phi::vectorize(updates.dims()), 0.0, updates.dtype()); auto tmp_grad = scatter(out_grad, index, zero_tensor, false); set_output(tmp_grad, x_grad); } if (updates_grad) { Scalar tmp_zero = 0; auto tmp_updates_grad = gather(out_grad, index, tmp_zero); set_output(tmp_updates_grad, updates_grad); } } template void batch_norm_grad(const Tensor& x, const Tensor& scale, const Tensor& bias, const paddle::optional& mean_out, const paddle::optional& variance_out, const Tensor& saved_mean, const Tensor& saved_variance, const paddle::optional& reserve_space, const Tensor& out_grad, float momentum, float epsilon, const std::string& data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, Tensor* x_grad, Tensor* scale_grad, Tensor* bias_grad) { use_global_stats = is_test || use_global_stats; DataLayout data_layout_ = phi::StringToDataLayout(data_layout); Tensor x_data = x; Tensor out_grad_data = out_grad; if (x.dtype() == phi::DataType::FLOAT16) { x_data = cast(x, phi::DataType::FLOAT32); } if (out_grad.dtype() == phi::DataType::FLOAT16) { out_grad_data = cast(out_grad, phi::DataType::FLOAT32); } auto x_dims = x_data.dims(); const int C = (data_layout_ == DataLayout::kNCHW ? x_dims[1] : x_dims[x_dims.size() - 1]); int nume = 1; for (auto i = 0; i < x_dims.size(); i++) { nume = nume * x_dims[i]; } const int nhw = nume / C; if (x_dims.size() == 2 && data_layout_ == DataLayout::kNCHW) { data_layout_ = DataLayout::kNHWC; } auto run_var = variance_out.get(); auto run_mean = mean_out.get(); Tensor mean_data; Tensor rsqrt_var; if (use_global_stats) { auto eps = full(phi::vectorize(run_var.dims()), epsilon, run_var.dtype()); mean_data = run_mean; rsqrt_var = (run_var + eps).pow(-0.5); } else { mean_data = saved_mean; rsqrt_var = saved_variance; } // inv_var = 1 / sqrt(var + eps) // reduce_axis = [0, 2, 3] (NCHW) [0, 1, 2] (NHWC) // // d_bias = np.sum(d_y, reduce_axis) // d_scale = np.sum((X - mean) / inv_var * dy, reduce_axis) // // train mode // d_x = (1. / nhw) * scale * inv_var // *(nhw * d_y - np.sum(d_y, reduce_axis) - (X - mean) * inv_var * inv_var * // np.sum(d_y * (X - mean), reduce_axis)) // // test mode // d_x = d_y * scale * inv_var std::vector nchw_to_nhwc_dim = {0, 2, 3, 1}; std::vector nhwc_to_nchw_dim = {0, 3, 1, 2}; auto reduce_axis = IntArray(std::vector{0, 1, 2}); auto dtype = x_data.dtype(); switch (data_layout_) { case DataLayout::kNCHW: { auto nhwc_x = transpose(x_data, nchw_to_nhwc_dim); auto nhwc_out_grad = transpose(out_grad_data, nchw_to_nhwc_dim); auto nhwc_out_grad_sum = sum(nhwc_out_grad, reduce_axis, dtype, false); auto sum_dout_mul_diff = sum( nhwc_out_grad * (nhwc_x - mean_data), reduce_axis, dtype, false); if (x_grad) { if (use_global_stats) { auto nhwc_x_grad = scale * rsqrt_var * nhwc_out_grad; auto nchw_x_grad = transpose(nhwc_x_grad, nhwc_to_nchw_dim); if (x.dtype() == phi::DataType::FLOAT16) { nchw_x_grad = cast(nchw_x_grad, x.dtype()); } set_output(nchw_x_grad, x_grad); } else { auto part1 = scale * rsqrt_var; auto mean_temp1 = nhwc_out_grad_sum / nhw; auto mean_temp2 = sum_dout_mul_diff / nhw * rsqrt_var * rsqrt_var; auto part2 = nhwc_out_grad - mean_temp1 - (nhwc_x - mean_data) * mean_temp2; auto x_grad_data = part1 * part2; auto nchw_x_grad = transpose(x_grad_data, nhwc_to_nchw_dim); if (x.dtype() == phi::DataType::FLOAT16) { nchw_x_grad = cast(nchw_x_grad, x.dtype()); } set_output(nchw_x_grad, x_grad); } } if (scale_grad) { auto scale_grad_data = sum_dout_mul_diff * rsqrt_var; set_output(scale_grad_data, scale_grad); } if (bias_grad) { set_output(nhwc_out_grad_sum, bias_grad); } break; } case DataLayout::kNHWC: { if (x_grad) { auto out_grad_data_sum = sum(out_grad_data, reduce_axis, dtype, false); auto nhwc_sum_dout_mul_diff = sum( out_grad_data * (x_data - mean_data), reduce_axis, dtype, false); if (use_global_stats) { auto x_grad_data = scale * rsqrt_var * out_grad_data; if (x.dtype() == phi::DataType::FLOAT16) { x_grad_data = cast(x_grad_data, x.dtype()); } set_output(x_grad_data, x_grad); } else { auto part1 = scale * rsqrt_var; auto mean_temp1 = out_grad_data_sum / nhw; auto mean_temp2 = nhwc_sum_dout_mul_diff / nhw * rsqrt_var * rsqrt_var; auto part2 = out_grad_data - mean_temp1 - (x_data - mean_data) * mean_temp2; auto x_grad_data = part1 * part2; if (x.dtype() == phi::DataType::FLOAT16) { x_grad_data = cast(x_grad_data, x.dtype()); } set_output(x_grad_data, x_grad); } if (scale_grad) { auto scale_grad_data = nhwc_sum_dout_mul_diff * rsqrt_var; set_output(scale_grad_data, scale_grad); } if (bias_grad) { set_output(out_grad_data_sum, bias_grad); } } break; } default: PADDLE_THROW(phi::errors::InvalidArgument("Unknown storage order: %s", data_layout)); } } template void instance_norm_grad(const Tensor& x, const paddle::optional& scale, const Tensor& saved_mean, const Tensor& saved_variance, const Tensor& y_grad, float epsilon, Tensor* x_grad, Tensor* scale_grad, Tensor* bias_grad) { const int n = x.dims()[0]; const int c = x.dims()[1]; const int h = x.dims()[2]; const int w = x.dims()[3]; Tensor x_hat; Tensor std_inv; if (scale_grad || x_grad) { auto mean = reshape(saved_mean, IntArray({n, c, 1, 1})) .tile(IntArray({1, 1, h, w})); std_inv = reshape(saved_variance, IntArray({n, c, 1, 1})) .tile(IntArray({1, 1, h, w})); x_hat = (x - mean) * std_inv; } // x_grad = scale * inv_var * (y_grad - y_grad.mean(2,3) - x_hat * (y_grad * // x_hat).mean((h,w))) if (x_grad) { auto scale_t = reshape(scale.get_ptr() ? scale.get() : full(IntArray({c}), 1., x.dtype()), IntArray({1, c, 1, 1})) .tile(IntArray({n, 1, h, w})); set_output( (scale_t * std_inv) * (y_grad - y_grad.sum(IntArray({2, 3}), y_grad.dtype(), true) / (h * w) - (x_hat * ((y_grad * x_hat).sum(IntArray({2, 3}), y_grad.dtype(), true) / (h * w)))), x_grad); } // scale_grad = x_hat * y_grad.sum(n, h, w) if (scale_grad) { set_output((y_grad * x_hat).sum(IntArray({0, 2, 3})), scale_grad); } // d_bias = y_grad.sum(n, h, w) if (bias_grad) { set_output(y_grad.sum(IntArray({0, 2, 3})), bias_grad); } } template void gelu_grad(const Tensor& x, const Tensor& out_grad, bool approximate, Tensor* x_grad) { if (!x_grad) return; // Promote to fp32 when the input type is fp16 for keeping consistent with // phi kernel if (x.dtype() == phi::DataType::FLOAT16 || x.dtype() == phi::DataType::BFLOAT16) { auto promoted_x = cast(x, phi::DataType::FLOAT32); auto promoted_out_grad = cast(out_grad, phi::DataType::FLOAT32); if (approximate) { float kbeta = M_SQRT2 * M_2_SQRTPI * 0.5; float kkappa = 0.044715; auto x_sq = promoted_x * promoted_x; auto x_cube = x_sq * promoted_x; auto inner = kbeta * (promoted_x + kkappa * x_cube); auto tanh_inner = tanh(inner); auto left = scale(promoted_x, 0.5); auto right = scale(tanh_inner, 1., 1.); auto left_derivative = scale(right, 0.5); auto tanh_derivative = scale(tanh_inner * tanh_inner, -1., 1.); auto inner_derivative = kbeta * (scale(3 * kkappa * x_sq, 1., 1.)); auto right_derivative = left * tanh_derivative * inner_derivative; set_output( cast(promoted_out_grad * (left_derivative + right_derivative), x.type()), x_grad); } else { float kalpha = M_SQRT1_2; float kbeta = M_2_SQRTPI * M_SQRT1_2 * 0.5; auto cdf = scale(scale(erf(kalpha * promoted_x), 1., 1.), 0.5); auto pdf = kbeta * exp(scale(promoted_x * promoted_x, -0.5)); set_output( cast(promoted_out_grad * (cdf + promoted_x * pdf), x.type()), x_grad); } } else { // Scale only support fp32 attr in static graph mode, use elementwise_xx // when precision is over fp32. if (approximate) { auto kBeta = M_SQRT2 * M_2_SQRTPI * 0.5; auto kKappa = 0.044715; auto x_sq = x * x; auto x_cube = x_sq * x; auto inner = kBeta * (x + kKappa * x_cube); auto tanh_inner = tanh(inner); auto left = scale(x, 0.5); auto right = scale(tanh_inner, 1., 1.); auto left_derivative = scale(right, 0.5); auto tanh_derivative = scale(tanh_inner * tanh_inner, -1., 1.); auto inner_derivative = kBeta * (scale(3 * kKappa * x_sq, 1., 1.)); auto right_derivative = left * tanh_derivative * inner_derivative; set_output(out_grad * (left_derivative + right_derivative), x_grad); } else { auto kAlpha = M_SQRT1_2; auto kBeta = M_2_SQRTPI * M_SQRT1_2 * 0.5; auto cdf = scale(scale(erf(kAlpha * x), 1., 1.), 0.5); auto pdf = kBeta * exp(scale(x * x, -0.5)); set_output(out_grad * (cdf + x * pdf), x_grad); } } } template void minimum_grad(const Tensor& x, const Tensor& y, const Tensor& out_grad, Tensor* x_grad, Tensor* y_grad) { if (x_grad) { auto x_tmp = cast(less_than(x, y), out_grad.dtype()); auto dx_res = out_grad * x_tmp; if (y.dims() != x.dims()) { // Maybe need reduce here auto reduce_dim = get_reduce_dims(x.dims(), y.dims()); if (!reduce_dim.size()) { set_output(dx_res, x_grad); } else { auto dx_reduce_res = dx_res.sum(phi::vectorize(reduce_dim), x.dtype(), false); auto dx_tmp = reshape(dx_reduce_res, phi::vectorize(x.dims())); set_output(dx_tmp, x_grad); } } else { set_output(dx_res, x_grad); } } if (y_grad) { auto y_tmp = cast(greater_equal(x, y), out_grad.dtype()); auto dy_res = out_grad * y_tmp; if (x.dims() != y.dims()) { // Maybe need reduce here phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims()); if (!reduce_dim.size()) { set_output(dy_res, y_grad); } else { auto dy_reduce_res = dy_res.sum(phi::vectorize(reduce_dim), y.dtype(), false); auto dy_tmp = reshape(dy_reduce_res, phi::vectorize(y.dims())); set_output(dy_tmp, y_grad); } } else { set_output(dy_res, y_grad); } } } template void tile_grad(const Tensor& x, const Tensor& out_grad, const IntArray& repeat_times, Tensor* x_grad) { if (x_grad) { auto repeat_times_data = repeat_times.GetData(); auto out_grad_shape = phi::vectorize(out_grad.dims()); auto result = out_grad; for (int i = 0; i < static_cast(repeat_times_data.size()); i++) { int size = out_grad_shape[i] / repeat_times_data[i]; std::vector sections(repeat_times_data[i], size); auto split_arr = split(result, IntArray(sections), i); result = full(phi::vectorize(split_arr[0].dims()), 0.0, x.dtype()); for (int j = 0; j < static_cast(split_arr.size()); j++) { result = split_arr[j] + result; } } result = reshape(result, x.shape()); set_output(result, x_grad); } } template void roll_grad(const Tensor& x, const Tensor& out_grad, const IntArray& shifts, const std::vector& axis, Tensor* x_grad) { if (x_grad) { auto shifts_ = shifts.GetData(); int64_t nums = shifts_.size(); for (int64_t i = 0; i < nums; i++) { shifts_[i] = 0 - shifts_[i]; } auto x_grad_output = roll(out_grad, shifts_, axis); set_output(x_grad_output, x_grad); } } template void pad_grad(const Tensor& input, const Tensor& out_grad, const std::vector& paddings, const Scalar& pad_value, Tensor* input_grad) { if (input_grad) { size_t rank = input.dims().size(); auto out_dims = out_grad.dims(); std::vector starts(rank, 0); std::vector ends(rank, 0); std::vector axes(rank, 0); std::vector infer_flags(rank, 1); std::vector decrease_axis({}); for (size_t i = 0; i < rank; ++i) { starts[i] = static_cast(paddings[2 * i]); ends[i] = static_cast(out_dims[i] - paddings[2 * i + 1]); axes[i] = i; } auto out_tmp = slice(out_grad, axes, starts, ends, infer_flags, decrease_axis); set_output(out_tmp, input_grad); } } template void scatter_nd_add_grad(const Tensor& index, const Tensor& updates, const Tensor& out_grad, Tensor* x_grad, Tensor* updates_grad) { if (x_grad) { by_pass(out_grad, x_grad); } if (updates_grad) { // Gradient by Gather: dUpdates = dO[Ids] auto tmp_updates_grad = gather_nd(out_grad, index); set_output(tmp_updates_grad, updates_grad); } } } // namespace prim } // namespace paddle