composite_backward_api.h 9.5 KB
Newer Older
J
Jiabin Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
X
xiaoguoguo626807 已提交
16
#include "paddle/fluid/prim/api/generated/prim_api/prim_api.h"
J
Jiabin Yang 已提交
17 18
#include "paddle/fluid/prim/api/manual/prim_api/prim_api.h"
#include "paddle/fluid/prim/api/manual/utils/utils.h"
19 20 21
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/ddim.h"

J
Jiabin Yang 已提交
22 23
namespace paddle {
namespace prim {
24 25 26 27 28
using Tensor = paddle::experimental::Tensor;
using IntArray =
    paddle::experimental::IntArrayBase<paddle::experimental::Tensor>;
//  This function should have as same signature as phi, which defined in
//  paddle/phi/api/backward/backward_api.h
J
Jiabin Yang 已提交
29 30
template <typename T>
void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) {
31
  if (!grad_x) return;
J
Jiabin Yang 已提交
32 33 34
  auto tmp = pow<T>(out, 2.0);
  tmp = scale<T>(tmp, -1.0, 1.0, true);
  auto grad_x_tmp = multiply<T>(grad_out, tmp);
35
  set_output<T>(grad_x_tmp, grad_x);
J
Jiabin Yang 已提交
36
}
37

38 39 40 41 42 43 44 45 46
template <typename T>
void subtract_grad(const Tensor& x,
                   const Tensor& y,
                   const Tensor& out_grad,
                   int axis,
                   Tensor* dx,
                   Tensor* dy) {
  if (dy) {
    auto scale_out_grad = scale<T>(out_grad, -1.0, 0.0, true);
47
    if (x.dims() != y.dims()) {
48
      // Maybe need reduce here
49 50 51 52 53 54 55
      phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims());
      if (!reduce_dim.size()) {
        by_pass<T>(scale_out_grad, dy);
      } else {
        auto dy_reduce_res = sum<T>(
            scale_out_grad, phi::vectorize(reduce_dim), y.dtype(), false);
        auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
56
        set_output<T>(dy_tmp, dy);
57
      }
58 59 60 61 62
    } else {
      by_pass<T>(scale_out_grad, dy);
    }
  }
  if (dx) {
63
    if (y.dims() != x.dims()) {
64
      // Maybe need reduce here
65 66 67 68 69 70 71
      auto reduce_dim = get_reduce_dims(x.dims(), y.dims());
      if (!reduce_dim.size()) {
        by_pass<T>(out_grad, dx);
      } else {
        auto dx_reduce_res =
            sum<T>(out_grad, phi::vectorize(reduce_dim), x.dtype(), false);
        auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
72
        set_output<T>(dx_tmp, dx);
73
      }
74 75 76 77 78 79 80 81 82 83 84 85 86 87
    } else {
      by_pass<T>(out_grad, dx);
    }
  }
}

template <typename T>
void add_grad(const Tensor& x,
              const Tensor& y,
              const Tensor& out_grad,
              int axis,
              Tensor* dx,
              Tensor* dy) {
  if (dy) {
88
    if (x.dims() != y.dims()) {
89
      // Maybe need reduce here
90 91 92 93 94 95 96
      phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims());
      if (!reduce_dim.size()) {
        by_pass<T>(out_grad, dy);
      } else {
        auto dy_reduce_res =
            sum<T>(out_grad, phi::vectorize(reduce_dim), y.dtype(), false);
        auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
97
        set_output<T>(dy_tmp, dy);
98 99
      }

100 101 102 103 104
    } else {
      by_pass<T>(out_grad, dy);
    }
  }
  if (dx) {
105
    if (y.dims() != x.dims()) {
106
      // Maybe need reduce here
107 108 109 110 111 112 113
      auto reduce_dim = get_reduce_dims(x.dims(), y.dims());
      if (!reduce_dim.size()) {
        by_pass<T>(out_grad, dx);
      } else {
        auto dx_reduce_res =
            sum<T>(out_grad, phi::vectorize(reduce_dim), x.dtype(), false);
        auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
114
        set_output<T>(dx_tmp, dx);
115
      }
116 117 118 119 120 121
    } else {
      by_pass<T>(out_grad, dx);
    }
  }
}

122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
template <typename T>
void sum_grad(const Tensor& x,
              const Tensor& out_grad,
              const IntArray& axis,
              bool keepdim,
              bool reduce_all,
              Tensor* x_grad) {
  if (!x_grad) {
    return;
  }
  std::vector<int> x_dim = phi::vectorize<int>(x.dims());
  int64_t axis_size = axis.size();
  int64_t x_dim_size = x_dim.size();
  reduce_all = false;
  if (reduce_all || axis_size == 0 || axis_size == x_dim_size) {
    reduce_all = true;
  } else {
    reduce_all = false;
  }
  auto x_grad_tmp = Tensor();
142 143 144 145 146 147 148 149 150 151 152
  if (x_dim_size == 1) {
    x_grad_tmp = expand<T>(out_grad, IntArray(x_dim));
  } else {
    if (!keepdim) {
      auto axis_ = std::vector<int64_t>();
      if (reduce_all) {
        for (int64_t i = 1; i < x_dim_size; i++) {
          axis_.push_back(i);
        }
      } else {
        axis_ = axis.GetData();
153
      }
154 155
      auto out_grad_ = unsqueeze<T>(out_grad, axis_);
      x_grad_tmp = expand<T>(out_grad_, IntArray(x_dim));
156
    } else {
157
      x_grad_tmp = expand<T>(out_grad, IntArray(x_dim));
158 159 160
    }
  }

161
  set_output<T>(x_grad_tmp, x_grad);
162 163
}

164 165 166 167 168 169 170 171 172 173 174 175 176 177
template <typename T>
void divide_grad(const Tensor& x,
                 const Tensor& y,
                 const Tensor& out,
                 const Tensor& out_grad,
                 int axis,
                 Tensor* dx,
                 Tensor* dy) {
  if (dy) {
    // dy = -(x/y^2) * dout
    auto tmp0 = pow<T>(y, 2.0);
    auto tmp1 = divide<T>(x, tmp0);
    auto tmp2 = scale<T>(tmp1, -1.0, 0.0, true);
    auto dy_res = multiply<T>(tmp2, out_grad);
178
    if (x.dims() != y.dims()) {
179
      // Maybe need reduce here
180 181
      phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims());
      if (!reduce_dim.size()) {
182
        set_output<T>(dy_res, dy);
183 184 185 186
      } else {
        auto dy_reduce_res =
            sum<T>(dy_res, phi::vectorize(reduce_dim), y.dtype(), false);
        auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
187
        set_output<T>(dy_tmp, dy);
188
      }
189
    } else {
190
      set_output<T>(dy_res, dy);
191 192 193 194
    }
  }  // indicate we will compute dy
  if (dx) {
    // dx = (1/y) * dout
195
    auto one_tensor = full<T>(phi::vectorize(y.dims()), 1.0, y.dtype());
196 197
    auto tmp0 = divide<T>(one_tensor, y);
    auto dx_res = multiply<T>(tmp0, out_grad);
198
    if (y.dims() != x.dims()) {
199
      // Maybe need reduce here
200 201
      auto reduce_dim = get_reduce_dims(x.dims(), y.dims());
      if (!reduce_dim.size()) {
202
        set_output<T>(dx_res, dx);
203 204 205 206
      } else {
        auto dx_reduce_res =
            sum<T>(dx_res, phi::vectorize(reduce_dim), x.dtype(), false);
        auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
207
        set_output<T>(dx_tmp, dx);
208 209
      }

210
    } else {
211
      set_output<T>(dx_res, dx);
212 213 214
    }
  }  // indicate we will compute dx
}
215 216 217 218 219 220 221

template <typename T>
void sqrt_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
  if (x_grad) {
    auto div_x = full<T>(phi::vectorize(out.dims()), 0.5);
    auto tmp = divide<T>(div_x, out);
    auto x_grad_tmp = multiply<T>(out_grad, tmp);
222
    set_output<T>(x_grad_tmp, x_grad);
223 224
  }
}
225 226 227 228 229 230 231 232 233 234 235 236 237

template <typename T>
void multiply_grad(const Tensor& x,
                   const Tensor& y,
                   const Tensor& out_grad,
                   int axis,
                   Tensor* x_grad,
                   Tensor* y_grad) {
  if (x_grad) {
    auto x_grad_unreduce = multiply<T>(out_grad, y);
    if (x.dims() != y.dims()) {
      auto axes = get_reduce_dims(x.dims(), y.dims());
      if (!axes.size()) {
238
        set_output<T>(x_grad_unreduce, x_grad);
239 240 241 242 243 244 245 246
      } else {
        auto x_grad_reduced = sum<T>(x_grad_unreduce,
                                     phi::vectorize(axes),
                                     x_grad_unreduce.dtype(),
                                     false);
        if (x_grad_reduced.dims().size() != x.dims().size()) {
          x_grad_reduced = reshape<T>(x_grad_reduced, x.shape());
        }
247
        set_output<T>(x_grad_reduced, x_grad);
248 249
      }
    } else {
250
      set_output<T>(x_grad_unreduce, x_grad);
251 252 253 254 255 256 257
    }
  }
  if (y_grad) {
    auto y_grad_unreduce = multiply<T>(out_grad, x);
    if (y.dims() != x.dims()) {
      auto axes = get_reduce_dims(y.dims(), x.dims());
      if (!axes.size()) {
258
        set_output<T>(y_grad_unreduce, y_grad);
259 260 261 262 263 264 265 266
      } else {
        auto y_grad_reduced = sum<T>(y_grad_unreduce,
                                     phi::vectorize(axes),
                                     y_grad_unreduce.dtype(),
                                     false);
        if (y_grad_reduced.dims().size() != y.dims().size()) {
          y_grad_reduced = reshape<T>(y_grad_reduced, y.shape());
        }
267
        set_output<T>(y_grad_reduced, y_grad);
268 269
      }
    } else {
270
      set_output<T>(y_grad_unreduce, y_grad);
271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290
    }
  }
}

template <typename T>
void expand_grad(const Tensor& x,
                 const Tensor& out_grad,
                 const IntArray& shape,
                 Tensor* x_grad) {
  if (x_grad) {
    auto out_dims = phi::make_ddim(shape.GetData());
    if (out_dims != x.dims()) {
      auto axes = get_reduce_dims(x.dims(), out_dims);
      if (!axes.size()) {
        by_pass<T>(out_grad, x_grad);
      } else {
        auto reduced = sum<T>(out_grad, phi::vectorize(axes), x.dtype(), false);
        if (reduced.dims().size() != x.dims().size()) {
          reduced = reshape<T>(reduced, x.shape());
        }
291
        set_output<T>(reduced, x_grad);
292 293 294 295 296 297 298 299 300 301
      }
    } else {
      by_pass<T>(out_grad, x_grad);
    }
  }
}

template <typename T>
void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
  if (x_grad) {
302
    set_output<T>(multiply<T>(out_grad, out), x_grad);
303 304 305
  }
}

J
Jiabin Yang 已提交
306 307
}  // namespace prim
}  // namespace paddle