composite_backward_api.h 13.1 KB
Newer Older
J
Jiabin Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
16
#include "paddle/fluid/prim/api/all.h"
17 18 19
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/ddim.h"

J
Jiabin Yang 已提交
20 21
namespace paddle {
namespace prim {
22 23 24 25 26
using Tensor = paddle::experimental::Tensor;
using IntArray =
    paddle::experimental::IntArrayBase<paddle::experimental::Tensor>;
//  This function should have as same signature as phi, which defined in
//  paddle/phi/api/backward/backward_api.h
J
Jiabin Yang 已提交
27
template <typename T>
28 29 30 31 32 33 34
void cast_grad(const Tensor& out_grad, DataType dtype, Tensor* x_grad) {
  if (x_grad) {
    auto res = cast<T>(out_grad, dtype);
    set_output<T>(res, x_grad);
  }
}
template <typename T>
J
Jiabin Yang 已提交
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
void gather_grad(const Tensor& x,
                 const Tensor& index,
                 const Tensor& out_grad,
                 const Scalar& axis,
                 bool overwrite,
                 Tensor* grad_x) {
  auto zero_tensor = full<T>(phi::vectorize(x.dims()), 0.0, x.dtype());
  std::vector<int> tmp_perm;

  // change axis to rank 0
  int axis_value = axis.to<int>();
  tmp_perm.push_back(axis_value);
  // make other ranks
  for (int i = 0; i < x.dims().size(); ++i) {
    if (i != axis_value) {
      tmp_perm.push_back(i);
    }
  }
  std::vector<int> reverse_perm(tmp_perm);
  // make origin ranks
  for (int i = 0; i < static_cast<int>(tmp_perm.size()); ++i) {
    reverse_perm[tmp_perm[i]] = i;
  }

  // transpose out_grad and zero grad to target rank.
  auto tmp_zero_x_grad = transpose<T>(zero_tensor, tmp_perm);
  auto tmp_out_grad = transpose<T>(out_grad, tmp_perm);
  // scatter grad to grad_x
  auto tmp_grad_x = scatter<T>(tmp_zero_x_grad, index, tmp_out_grad, false);
  auto tmp_grad_x_tranposed = transpose<T>(tmp_grad_x, reverse_perm);
  set_output<T>(tmp_grad_x_tranposed, grad_x);
}

J
Jiabin Yang 已提交
68 69
template <typename T>
void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) {
70
  if (!grad_x) return;
71
  auto grad_x_tmp = grad_out * (1.0 - out.pow(2.0));
72
  set_output<T>(grad_x_tmp, grad_x);
J
Jiabin Yang 已提交
73
}
74

75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
template <typename T>
void reshape_grad(const Tensor& x, const Tensor& grad_out, Tensor* grad_x) {
  if (grad_x) {
    auto grad_x_tmp = reshape<T>(grad_out, phi::vectorize(x.dims()));
    set_output<T>(grad_x_tmp, grad_x);
  }
}

template <typename T>
void transpose_grad(const Tensor& grad_out,
                    const std::vector<int>& perm,
                    Tensor* grad_x) {
  if (grad_x) {
    std::vector<int> reverse_perm(perm);
    // make origin ranks
    for (int i = 0; i < static_cast<int>(perm.size()); ++i) {
      reverse_perm[perm[i]] = i;
    }
    auto grad_x_tmp = transpose<T>(grad_out, reverse_perm);
    set_output<T>(grad_x_tmp, grad_x);
  }
}

98 99 100 101 102 103 104 105 106
template <typename T>
void subtract_grad(const Tensor& x,
                   const Tensor& y,
                   const Tensor& out_grad,
                   int axis,
                   Tensor* dx,
                   Tensor* dy) {
  if (dy) {
    auto scale_out_grad = scale<T>(out_grad, -1.0, 0.0, true);
107
    if (x.dims() != y.dims()) {
108
      // Maybe need reduce here
109 110 111 112
      phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims());
      if (!reduce_dim.size()) {
        by_pass<T>(scale_out_grad, dy);
      } else {
113 114
        auto dy_reduce_res =
            scale_out_grad.sum(phi::vectorize(reduce_dim), y.dtype(), false);
115
        auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
116
        set_output<T>(dy_tmp, dy);
117
      }
118 119 120 121 122
    } else {
      by_pass<T>(scale_out_grad, dy);
    }
  }
  if (dx) {
123
    if (y.dims() != x.dims()) {
124
      // Maybe need reduce here
125 126 127 128 129
      auto reduce_dim = get_reduce_dims(x.dims(), y.dims());
      if (!reduce_dim.size()) {
        by_pass<T>(out_grad, dx);
      } else {
        auto dx_reduce_res =
130
            out_grad.sum(phi::vectorize(reduce_dim), x.dtype(), false);
131
        auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
132
        set_output<T>(dx_tmp, dx);
133
      }
134 135 136 137 138 139 140 141 142 143 144 145 146 147
    } else {
      by_pass<T>(out_grad, dx);
    }
  }
}

template <typename T>
void add_grad(const Tensor& x,
              const Tensor& y,
              const Tensor& out_grad,
              int axis,
              Tensor* dx,
              Tensor* dy) {
  if (dy) {
148
    if (x.dims() != y.dims()) {
149
      // Maybe need reduce here
150 151 152 153 154
      phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims());
      if (!reduce_dim.size()) {
        by_pass<T>(out_grad, dy);
      } else {
        auto dy_reduce_res =
155
            out_grad.sum(phi::vectorize(reduce_dim), y.dtype(), false);
156
        auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
157
        set_output<T>(dy_tmp, dy);
158 159
      }

160 161 162 163 164
    } else {
      by_pass<T>(out_grad, dy);
    }
  }
  if (dx) {
165
    if (y.dims() != x.dims()) {
166
      // Maybe need reduce here
167 168 169 170 171
      auto reduce_dim = get_reduce_dims(x.dims(), y.dims());
      if (!reduce_dim.size()) {
        by_pass<T>(out_grad, dx);
      } else {
        auto dx_reduce_res =
172
            out_grad.sum(phi::vectorize(reduce_dim), x.dtype(), false);
173
        auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
174
        set_output<T>(dx_tmp, dx);
175
      }
176 177 178 179 180 181
    } else {
      by_pass<T>(out_grad, dx);
    }
  }
}

182 183 184 185 186 187 188 189 190 191
template <typename T>
void sum_grad(const Tensor& x,
              const Tensor& out_grad,
              const IntArray& axis,
              bool keepdim,
              bool reduce_all,
              Tensor* x_grad) {
  if (!x_grad) {
    return;
  }
R
risemeup1 已提交
192
  std::vector<int64_t> x_dim = phi::vectorize<int64_t>(x.dims());
193 194 195 196 197 198 199 200 201
  int64_t axis_size = axis.size();
  int64_t x_dim_size = x_dim.size();
  reduce_all = false;
  if (reduce_all || axis_size == 0 || axis_size == x_dim_size) {
    reduce_all = true;
  } else {
    reduce_all = false;
  }
  auto x_grad_tmp = Tensor();
202
  if (x_dim_size == 1) {
203
    x_grad_tmp = out_grad.expand(IntArray(x_dim));
204 205 206 207 208 209 210 211 212
  } else {
    if (!keepdim) {
      auto axis_ = std::vector<int64_t>();
      if (reduce_all) {
        for (int64_t i = 1; i < x_dim_size; i++) {
          axis_.push_back(i);
        }
      } else {
        axis_ = axis.GetData();
213
      }
214
      auto out_grad_ = unsqueeze<T>(out_grad, axis_);
215
      x_grad_tmp = out_grad_.expand(IntArray(x_dim));
216
    } else {
217
      x_grad_tmp = out_grad.expand(IntArray(x_dim));
218 219 220
    }
  }

221
  set_output<T>(x_grad_tmp, x_grad);
222 223
}

224 225 226 227 228 229 230 231 232 233
template <typename T>
void divide_grad(const Tensor& x,
                 const Tensor& y,
                 const Tensor& out,
                 const Tensor& out_grad,
                 int axis,
                 Tensor* dx,
                 Tensor* dy) {
  if (dy) {
    // dy = -(x/y^2) * dout
234
    auto dy_res = -(x / y.pow(2.0)) * out_grad;
235
    if (x.dims() != y.dims()) {
236
      // Maybe need reduce here
237 238
      phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims());
      if (!reduce_dim.size()) {
239
        set_output<T>(dy_res, dy);
240 241
      } else {
        auto dy_reduce_res =
242
            dy_res.sum(phi::vectorize(reduce_dim), y.dtype(), false);
243
        auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
244
        set_output<T>(dy_tmp, dy);
245
      }
246
    } else {
247
      set_output<T>(dy_res, dy);
248 249 250 251
    }
  }  // indicate we will compute dy
  if (dx) {
    // dx = (1/y) * dout
252
    auto one_tensor = full<T>(phi::vectorize(y.dims()), 1.0, y.dtype());
253
    auto dx_res = one_tensor / y * out_grad;
254
    if (y.dims() != x.dims()) {
255
      // Maybe need reduce here
256 257
      auto reduce_dim = get_reduce_dims(x.dims(), y.dims());
      if (!reduce_dim.size()) {
258
        set_output<T>(dx_res, dx);
259 260
      } else {
        auto dx_reduce_res =
261
            dx_res.sum(phi::vectorize(reduce_dim), x.dtype(), false);
262
        auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
263
        set_output<T>(dx_tmp, dx);
264 265
      }

266
    } else {
267
      set_output<T>(dx_res, dx);
268 269 270
    }
  }  // indicate we will compute dx
}
271 272 273 274

template <typename T>
void sqrt_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
  if (x_grad) {
275
    auto x_grad_tmp = out_grad * 0.5 / out;
276
    set_output<T>(x_grad_tmp, x_grad);
277 278
  }
}
279 280 281 282 283 284 285 286 287

template <typename T>
void multiply_grad(const Tensor& x,
                   const Tensor& y,
                   const Tensor& out_grad,
                   int axis,
                   Tensor* x_grad,
                   Tensor* y_grad) {
  if (x_grad) {
288
    auto x_grad_unreduce = out_grad * y;
289 290
    if (x_grad_unreduce.dims() != x.dims()) {
      auto axes = get_reduce_dims_from_out(x_grad_unreduce.dims(), x.dims());
291
      if (!axes.size()) {
292
        set_output<T>(x_grad_unreduce, x_grad);
293
      } else {
294 295
        auto x_grad_reduced = x_grad_unreduce.sum(
            phi::vectorize(axes), x_grad_unreduce.dtype(), false);
296 297 298
        if (x_grad_reduced.dims().size() != x.dims().size()) {
          x_grad_reduced = reshape<T>(x_grad_reduced, x.shape());
        }
299
        set_output<T>(x_grad_reduced, x_grad);
300 301
      }
    } else {
302
      set_output<T>(x_grad_unreduce, x_grad);
303 304 305
    }
  }
  if (y_grad) {
306
    auto y_grad_unreduce = out_grad * x;
307 308
    if (y_grad_unreduce.dims() != y.dims()) {
      auto axes = get_reduce_dims_from_out(y_grad_unreduce.dims(), y.dims());
309
      if (!axes.size()) {
310
        set_output<T>(y_grad_unreduce, y_grad);
311
      } else {
312 313
        auto y_grad_reduced = y_grad_unreduce.sum(
            phi::vectorize(axes), y_grad_unreduce.dtype(), false);
314 315 316
        if (y_grad_reduced.dims().size() != y.dims().size()) {
          y_grad_reduced = reshape<T>(y_grad_reduced, y.shape());
        }
317
        set_output<T>(y_grad_reduced, y_grad);
318 319
      }
    } else {
320
      set_output<T>(y_grad_unreduce, y_grad);
321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336
    }
  }
}

template <typename T>
void expand_grad(const Tensor& x,
                 const Tensor& out_grad,
                 const IntArray& shape,
                 Tensor* x_grad) {
  if (x_grad) {
    auto out_dims = phi::make_ddim(shape.GetData());
    if (out_dims != x.dims()) {
      auto axes = get_reduce_dims(x.dims(), out_dims);
      if (!axes.size()) {
        by_pass<T>(out_grad, x_grad);
      } else {
337
        auto reduced = out_grad.sum(phi::vectorize(axes), x.dtype(), false);
338 339 340
        if (reduced.dims().size() != x.dims().size()) {
          reduced = reshape<T>(reduced, x.shape());
        }
341
        set_output<T>(reduced, x_grad);
342 343 344 345 346 347 348 349 350 351
      }
    } else {
      by_pass<T>(out_grad, x_grad);
    }
  }
}

template <typename T>
void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
  if (x_grad) {
352
    set_output<T>(out_grad * out, x_grad);
353 354 355
  }
}

X
xiaoguoguo626807 已提交
356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416
template <typename T>
void slice_grad(const Tensor& input,
                const Tensor& out_grad,
                const std::vector<int64_t>& axes,
                const IntArray& starts,
                const IntArray& ends,
                const std::vector<int64_t>& infer_flags,
                const std::vector<int64_t>& decrease_axis,
                Tensor* input_grad) {
  if (input_grad) {
    size_t rank = input.dims().size();
    auto out_dims = out_grad.dims();
    auto in_dims = input.dims();

    auto decrease_size = decrease_axis.size();
    if (decrease_size > 0) {
      if (decrease_size == static_cast<size_t>(in_dims.size())) {
        // all dims decrease
        out_dims = phi::make_ddim(std::vector<int>(decrease_size, 1));
      } else {
        std::vector<int> origin_out_shape(out_dims.size() + decrease_size, -1);
        for (size_t i = 0; i < decrease_size; ++i) {
          origin_out_shape[decrease_axis[i]] = 1;
        }

        int index = 0;
        for (size_t i = 0; i < origin_out_shape.size(); ++i) {
          if (origin_out_shape[i] == -1) {
            origin_out_shape[i] = out_dims[index];
            ++index;
          }
        }
        out_dims = phi::make_ddim(origin_out_shape);
      }
    }

    std::vector<int> offsets(rank, 0);
    std::vector<int> extents(rank, 0);
    for (size_t i = 0; i < rank; ++i) {
      offsets[i] = 0;
      extents[i] = out_dims[i];
    }

    for (size_t i = 0; i < axes.size(); ++i) {
      int axis = axes[i];
      int64_t start = starts[i] < 0 ? (starts[i] + in_dims[axis]) : starts[i];
      start = std::max(start, static_cast<int64_t>(0));
      offsets[axis] = start;
    }

    std::vector<int> paddings;
    for (size_t i = 0; i < rank; ++i) {
      paddings.push_back(offsets[i]);
      paddings.push_back((in_dims[i] - out_dims[i]) - offsets[i]);
    }

    auto out_tmp = pad<T>(out_grad, paddings, 0.0);
    set_output<T>(out_tmp, input_grad);
  }
}

G
GGBond8488 已提交
417 418 419 420 421 422 423 424 425 426 427 428 429 430 431
template <typename T>
void cumsum_grad(const Tensor& x,
                 const Tensor& out_grad,
                 const Scalar& axis,
                 bool flatten,
                 bool exclusive,
                 bool reverse,
                 Tensor* x_grad) {
  if (x_grad) {
    auto grad = cumsum<T>(out_grad, axis, flatten, exclusive, !reverse);
    grad = reshape<T>(grad, x.shape());
    set_output<T>(grad, x_grad);
  }
}

J
Jiabin Yang 已提交
432 433
}  // namespace prim
}  // namespace paddle