composite_backward_api.h 11.9 KB
Newer Older
J
Jiabin Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
16
#include "paddle/fluid/prim/api/all.h"
17 18 19
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/ddim.h"

J
Jiabin Yang 已提交
20 21
namespace paddle {
namespace prim {
22 23 24 25 26
using Tensor = paddle::experimental::Tensor;
using IntArray =
    paddle::experimental::IntArrayBase<paddle::experimental::Tensor>;
//  This function should have as same signature as phi, which defined in
//  paddle/phi/api/backward/backward_api.h
J
Jiabin Yang 已提交
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
template <typename T>
void gather_grad(const Tensor& x,
                 const Tensor& index,
                 const Tensor& out_grad,
                 const Scalar& axis,
                 bool overwrite,
                 Tensor* grad_x) {
  auto zero_tensor = full<T>(phi::vectorize(x.dims()), 0.0, x.dtype());
  std::vector<int> tmp_perm;

  // change axis to rank 0
  int axis_value = axis.to<int>();
  tmp_perm.push_back(axis_value);
  // make other ranks
  for (int i = 0; i < x.dims().size(); ++i) {
    if (i != axis_value) {
      tmp_perm.push_back(i);
    }
  }
  std::vector<int> reverse_perm(tmp_perm);
  // make origin ranks
  for (int i = 0; i < static_cast<int>(tmp_perm.size()); ++i) {
    reverse_perm[tmp_perm[i]] = i;
  }

  // transpose out_grad and zero grad to target rank.
  auto tmp_zero_x_grad = transpose<T>(zero_tensor, tmp_perm);
  auto tmp_out_grad = transpose<T>(out_grad, tmp_perm);
  // scatter grad to grad_x
  auto tmp_grad_x = scatter<T>(tmp_zero_x_grad, index, tmp_out_grad, false);
  auto tmp_grad_x_tranposed = transpose<T>(tmp_grad_x, reverse_perm);
  set_output<T>(tmp_grad_x_tranposed, grad_x);
}

J
Jiabin Yang 已提交
61 62
template <typename T>
void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) {
63
  if (!grad_x) return;
64
  auto grad_x_tmp = grad_out * (1.0 - out.pow(2.0));
65
  set_output<T>(grad_x_tmp, grad_x);
J
Jiabin Yang 已提交
66
}
67

68 69 70 71 72 73 74 75 76
template <typename T>
void subtract_grad(const Tensor& x,
                   const Tensor& y,
                   const Tensor& out_grad,
                   int axis,
                   Tensor* dx,
                   Tensor* dy) {
  if (dy) {
    auto scale_out_grad = scale<T>(out_grad, -1.0, 0.0, true);
77
    if (x.dims() != y.dims()) {
78
      // Maybe need reduce here
79 80 81 82
      phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims());
      if (!reduce_dim.size()) {
        by_pass<T>(scale_out_grad, dy);
      } else {
83 84
        auto dy_reduce_res =
            scale_out_grad.sum(phi::vectorize(reduce_dim), y.dtype(), false);
85
        auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
86
        set_output<T>(dy_tmp, dy);
87
      }
88 89 90 91 92
    } else {
      by_pass<T>(scale_out_grad, dy);
    }
  }
  if (dx) {
93
    if (y.dims() != x.dims()) {
94
      // Maybe need reduce here
95 96 97 98 99
      auto reduce_dim = get_reduce_dims(x.dims(), y.dims());
      if (!reduce_dim.size()) {
        by_pass<T>(out_grad, dx);
      } else {
        auto dx_reduce_res =
100
            out_grad.sum(phi::vectorize(reduce_dim), x.dtype(), false);
101
        auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
102
        set_output<T>(dx_tmp, dx);
103
      }
104 105 106 107 108 109 110 111 112 113 114 115 116 117
    } else {
      by_pass<T>(out_grad, dx);
    }
  }
}

template <typename T>
void add_grad(const Tensor& x,
              const Tensor& y,
              const Tensor& out_grad,
              int axis,
              Tensor* dx,
              Tensor* dy) {
  if (dy) {
118
    if (x.dims() != y.dims()) {
119
      // Maybe need reduce here
120 121 122 123 124
      phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims());
      if (!reduce_dim.size()) {
        by_pass<T>(out_grad, dy);
      } else {
        auto dy_reduce_res =
125
            out_grad.sum(phi::vectorize(reduce_dim), y.dtype(), false);
126
        auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
127
        set_output<T>(dy_tmp, dy);
128 129
      }

130 131 132 133 134
    } else {
      by_pass<T>(out_grad, dy);
    }
  }
  if (dx) {
135
    if (y.dims() != x.dims()) {
136
      // Maybe need reduce here
137 138 139 140 141
      auto reduce_dim = get_reduce_dims(x.dims(), y.dims());
      if (!reduce_dim.size()) {
        by_pass<T>(out_grad, dx);
      } else {
        auto dx_reduce_res =
142
            out_grad.sum(phi::vectorize(reduce_dim), x.dtype(), false);
143
        auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
144
        set_output<T>(dx_tmp, dx);
145
      }
146 147 148 149 150 151
    } else {
      by_pass<T>(out_grad, dx);
    }
  }
}

152 153 154 155 156 157 158 159 160 161
template <typename T>
void sum_grad(const Tensor& x,
              const Tensor& out_grad,
              const IntArray& axis,
              bool keepdim,
              bool reduce_all,
              Tensor* x_grad) {
  if (!x_grad) {
    return;
  }
R
risemeup1 已提交
162
  std::vector<int64_t> x_dim = phi::vectorize<int64_t>(x.dims());
163 164 165 166 167 168 169 170 171
  int64_t axis_size = axis.size();
  int64_t x_dim_size = x_dim.size();
  reduce_all = false;
  if (reduce_all || axis_size == 0 || axis_size == x_dim_size) {
    reduce_all = true;
  } else {
    reduce_all = false;
  }
  auto x_grad_tmp = Tensor();
172
  if (x_dim_size == 1) {
173
    x_grad_tmp = out_grad.expand(IntArray(x_dim));
174 175 176 177 178 179 180 181 182
  } else {
    if (!keepdim) {
      auto axis_ = std::vector<int64_t>();
      if (reduce_all) {
        for (int64_t i = 1; i < x_dim_size; i++) {
          axis_.push_back(i);
        }
      } else {
        axis_ = axis.GetData();
183
      }
184
      auto out_grad_ = unsqueeze<T>(out_grad, axis_);
185
      x_grad_tmp = out_grad_.expand(IntArray(x_dim));
186
    } else {
187
      x_grad_tmp = out_grad.expand(IntArray(x_dim));
188 189 190
    }
  }

191
  set_output<T>(x_grad_tmp, x_grad);
192 193
}

194 195 196 197 198 199 200 201 202 203
template <typename T>
void divide_grad(const Tensor& x,
                 const Tensor& y,
                 const Tensor& out,
                 const Tensor& out_grad,
                 int axis,
                 Tensor* dx,
                 Tensor* dy) {
  if (dy) {
    // dy = -(x/y^2) * dout
204
    auto dy_res = -(x / y.pow(2.0)) * out_grad;
205
    if (x.dims() != y.dims()) {
206
      // Maybe need reduce here
207 208
      phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims());
      if (!reduce_dim.size()) {
209
        set_output<T>(dy_res, dy);
210 211
      } else {
        auto dy_reduce_res =
212
            dy_res.sum(phi::vectorize(reduce_dim), y.dtype(), false);
213
        auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
214
        set_output<T>(dy_tmp, dy);
215
      }
216
    } else {
217
      set_output<T>(dy_res, dy);
218 219 220 221
    }
  }  // indicate we will compute dy
  if (dx) {
    // dx = (1/y) * dout
222
    auto one_tensor = full<T>(phi::vectorize(y.dims()), 1.0, y.dtype());
223
    auto dx_res = one_tensor / y * out_grad;
224
    if (y.dims() != x.dims()) {
225
      // Maybe need reduce here
226 227
      auto reduce_dim = get_reduce_dims(x.dims(), y.dims());
      if (!reduce_dim.size()) {
228
        set_output<T>(dx_res, dx);
229 230
      } else {
        auto dx_reduce_res =
231
            dx_res.sum(phi::vectorize(reduce_dim), x.dtype(), false);
232
        auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
233
        set_output<T>(dx_tmp, dx);
234 235
      }

236
    } else {
237
      set_output<T>(dx_res, dx);
238 239 240
    }
  }  // indicate we will compute dx
}
241 242 243 244

template <typename T>
void sqrt_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
  if (x_grad) {
245
    auto x_grad_tmp = out_grad * 0.5 / out;
246
    set_output<T>(x_grad_tmp, x_grad);
247 248
  }
}
249 250 251 252 253 254 255 256 257

template <typename T>
void multiply_grad(const Tensor& x,
                   const Tensor& y,
                   const Tensor& out_grad,
                   int axis,
                   Tensor* x_grad,
                   Tensor* y_grad) {
  if (x_grad) {
258
    auto x_grad_unreduce = out_grad * y;
259 260
    if (x_grad_unreduce.dims() != x.dims()) {
      auto axes = get_reduce_dims_from_out(x_grad_unreduce.dims(), x.dims());
261
      if (!axes.size()) {
262
        set_output<T>(x_grad_unreduce, x_grad);
263
      } else {
264 265
        auto x_grad_reduced = x_grad_unreduce.sum(
            phi::vectorize(axes), x_grad_unreduce.dtype(), false);
266 267 268
        if (x_grad_reduced.dims().size() != x.dims().size()) {
          x_grad_reduced = reshape<T>(x_grad_reduced, x.shape());
        }
269
        set_output<T>(x_grad_reduced, x_grad);
270 271
      }
    } else {
272
      set_output<T>(x_grad_unreduce, x_grad);
273 274 275
    }
  }
  if (y_grad) {
276
    auto y_grad_unreduce = out_grad * x;
277 278
    if (y_grad_unreduce.dims() != y.dims()) {
      auto axes = get_reduce_dims_from_out(y_grad_unreduce.dims(), y.dims());
279
      if (!axes.size()) {
280
        set_output<T>(y_grad_unreduce, y_grad);
281
      } else {
282 283
        auto y_grad_reduced = y_grad_unreduce.sum(
            phi::vectorize(axes), y_grad_unreduce.dtype(), false);
284 285 286
        if (y_grad_reduced.dims().size() != y.dims().size()) {
          y_grad_reduced = reshape<T>(y_grad_reduced, y.shape());
        }
287
        set_output<T>(y_grad_reduced, y_grad);
288 289
      }
    } else {
290
      set_output<T>(y_grad_unreduce, y_grad);
291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306
    }
  }
}

template <typename T>
void expand_grad(const Tensor& x,
                 const Tensor& out_grad,
                 const IntArray& shape,
                 Tensor* x_grad) {
  if (x_grad) {
    auto out_dims = phi::make_ddim(shape.GetData());
    if (out_dims != x.dims()) {
      auto axes = get_reduce_dims(x.dims(), out_dims);
      if (!axes.size()) {
        by_pass<T>(out_grad, x_grad);
      } else {
307
        auto reduced = out_grad.sum(phi::vectorize(axes), x.dtype(), false);
308 309 310
        if (reduced.dims().size() != x.dims().size()) {
          reduced = reshape<T>(reduced, x.shape());
        }
311
        set_output<T>(reduced, x_grad);
312 313 314 315 316 317 318 319 320 321
      }
    } else {
      by_pass<T>(out_grad, x_grad);
    }
  }
}

template <typename T>
void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
  if (x_grad) {
322
    set_output<T>(out_grad * out, x_grad);
323 324 325
  }
}

X
xiaoguoguo626807 已提交
326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386
template <typename T>
void slice_grad(const Tensor& input,
                const Tensor& out_grad,
                const std::vector<int64_t>& axes,
                const IntArray& starts,
                const IntArray& ends,
                const std::vector<int64_t>& infer_flags,
                const std::vector<int64_t>& decrease_axis,
                Tensor* input_grad) {
  if (input_grad) {
    size_t rank = input.dims().size();
    auto out_dims = out_grad.dims();
    auto in_dims = input.dims();

    auto decrease_size = decrease_axis.size();
    if (decrease_size > 0) {
      if (decrease_size == static_cast<size_t>(in_dims.size())) {
        // all dims decrease
        out_dims = phi::make_ddim(std::vector<int>(decrease_size, 1));
      } else {
        std::vector<int> origin_out_shape(out_dims.size() + decrease_size, -1);
        for (size_t i = 0; i < decrease_size; ++i) {
          origin_out_shape[decrease_axis[i]] = 1;
        }

        int index = 0;
        for (size_t i = 0; i < origin_out_shape.size(); ++i) {
          if (origin_out_shape[i] == -1) {
            origin_out_shape[i] = out_dims[index];
            ++index;
          }
        }
        out_dims = phi::make_ddim(origin_out_shape);
      }
    }

    std::vector<int> offsets(rank, 0);
    std::vector<int> extents(rank, 0);
    for (size_t i = 0; i < rank; ++i) {
      offsets[i] = 0;
      extents[i] = out_dims[i];
    }

    for (size_t i = 0; i < axes.size(); ++i) {
      int axis = axes[i];
      int64_t start = starts[i] < 0 ? (starts[i] + in_dims[axis]) : starts[i];
      start = std::max(start, static_cast<int64_t>(0));
      offsets[axis] = start;
    }

    std::vector<int> paddings;
    for (size_t i = 0; i < rank; ++i) {
      paddings.push_back(offsets[i]);
      paddings.push_back((in_dims[i] - out_dims[i]) - offsets[i]);
    }

    auto out_tmp = pad<T>(out_grad, paddings, 0.0);
    set_output<T>(out_tmp, input_grad);
  }
}

J
Jiabin Yang 已提交
387 388
}  // namespace prim
}  // namespace paddle