elementwise_grad_kernel.cu 10.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
//   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/phi/kernels/elementwise_grad_kernel.h"
16

17
#include "paddle/phi/backends/gpu/gpu_context.h"
18 19 20
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/common/float16.h"
21 22 23
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
24
#include "paddle/phi/kernels/gpu/elementwise_grad.h"
25
#include "paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h"
26

27
namespace phi {
28 29 30 31 32 33 34 35 36 37 38

template <typename T>
void AddGradFunc(const GPUContext& dev_ctx,
                 const DenseTensor& x,
                 const DenseTensor& y,
                 const DenseTensor& out,
                 const DenseTensor& dout,
                 DenseTensor* dx,
                 DenseTensor* dy,
                 int axis = -1) {
  if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) {
39
    ElementwiseAddGrad<T>(dev_ctx, x, y, out, dout, dx, dy);
40
  } else {
41
    DefaultElementwiseAddGrad<T>(dev_ctx, x, y, out, dout, dx, dy, axis);
42 43 44 45 46 47 48 49 50 51 52
  }
}

template <typename T, typename Context>
void AddGradKernel(const Context& dev_ctx,
                   const DenseTensor& x,
                   const DenseTensor& y,
                   const DenseTensor& dout,
                   int axis,
                   DenseTensor* dx,
                   DenseTensor* dy) {
53
  phi::AddGradImpl<T>(dev_ctx, x, y, dout, axis, dx, dy, AddGradFunc<T>);
54 55 56 57 58 59 60 61 62 63
}

template <typename T, typename Context>
void AddDoubleGradKernel(const Context& dev_ctx,
                         const DenseTensor& y,
                         paddle::optional<const DenseTensor&> ddx,
                         paddle::optional<const DenseTensor&> ddy,
                         const DenseTensor& dout,
                         int axis,
                         DenseTensor* ddout) {
64
  phi::AddDoubleGradImpl<T>(dev_ctx, y, ddx, ddy, dout, axis, ddout);
65 66 67 68 69 70 71 72 73 74
}

template <typename T, typename Context>
void AddTripleGradKernel(const Context& dev_ctx,
                         const DenseTensor& ddx,
                         const DenseTensor& ddy,
                         const DenseTensor& d_ddout,
                         int axis,
                         DenseTensor* d_ddx,
                         DenseTensor* d_ddy) {
75
  phi::AddGradImpl<T>(
76 77 78
      dev_ctx, ddx, ddy, d_ddout, axis, d_ddx, d_ddy, AddGradFunc<T>);
}

79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
template <typename T, typename Context>
void SubtractGradKernel(const Context& dev_ctx,
                        const DenseTensor& x,
                        const DenseTensor& y,
                        const DenseTensor& dout,
                        int axis,
                        DenseTensor* dx,
                        DenseTensor* dy) {
  // skip out
  auto* out = &dout;
  if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) {
    elementwise_sub_grad<T>(dev_ctx, x, y, *out, dout, dx, dy);
  } else {
    default_elementwise_sub_grad<T>(dev_ctx, x, y, *out, dout, dx, dy, axis);
  }
}

template <typename T, typename Context>
void SubtractDoubleGradKernel(const Context& dev_ctx,
                              const DenseTensor& y,
                              paddle::optional<const DenseTensor&> ddx,
                              paddle::optional<const DenseTensor&> ddy,
                              const DenseTensor& dout,
                              int axis,
                              DenseTensor* ddout) {
104
  phi::SubtractDoubleGradImpl<T>(dev_ctx, y, ddx, ddy, dout, axis, ddout);
105 106
}

107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
template <typename T, typename Context>
void DivideGradKernel(const Context& dev_ctx,
                      const DenseTensor& x,
                      const DenseTensor& y,
                      const DenseTensor& out,
                      const DenseTensor& dout,
                      int axis,
                      DenseTensor* dx,
                      DenseTensor* dy) {
  const auto place = dev_ctx.GetPlace();
  if (dx != nullptr && dy != nullptr) {
    std::vector<const DenseTensor*> ins = {&dout, &out, &y};
    GetGradXAndYOut<ElementwiseType::kTernary, T>(
        dev_ctx,
        place,
        axis,
        ins,
        dout,
        dx,
        dy,
        funcs::DivGradXYFunctor<T, T>());
  } else if (dx != nullptr && dy == nullptr) {
    std::vector<const DenseTensor*> ins = {&dout, &y};
    GetGradXOrYOut<ElementwiseType::kBinary, T>(
        dev_ctx, place, axis, ins, dout, dx, funcs::DivGradXFunctor<T>());
  } else if (dy != nullptr && dx == nullptr) {
    std::vector<const DenseTensor*> ins = {&dout, &out, &y};
    GetGradXOrYOut<ElementwiseType::kTernary, T>(
        dev_ctx, place, axis, ins, dout, dy, funcs::DivGradYFunctor<T>());
  }
}

Y
YuanRisheng 已提交
139 140 141 142 143 144 145 146 147 148 149 150
template <typename T, typename Context>
void MultiplyGradKernel(const Context& dev_ctx,
                        const DenseTensor& x,
                        const DenseTensor& y,
                        const DenseTensor& dout,
                        int axis,
                        DenseTensor* dx,
                        DenseTensor* dy) {
  funcs::ElementwiseGradPreProcess(dout, dx);
  ElementwiseMulGrad<T>(dev_ctx, x, y, dout, dx, dy, axis);
}

151
}  // namespace phi
152

153
PD_REGISTER_KERNEL(add_grad,
154 155
                   GPU,
                   ALL_LAYOUT,
156
                   phi::AddGradKernel,
157 158 159 160
                   float,
                   double,
                   int,
                   int64_t,
161
                   phi::dtype::float16,
162
                   phi::dtype::bfloat16,
163 164
                   phi::dtype::complex<float>,
                   phi::dtype::complex<double>) {}
165

166
PD_REGISTER_KERNEL(add_double_grad,
167 168
                   GPU,
                   ALL_LAYOUT,
169
                   phi::AddDoubleGradKernel,
170 171 172 173
                   float,
                   double,
                   int,
                   int64_t,
174
                   phi::dtype::float16,
175
                   phi::dtype::bfloat16,
176 177
                   phi::dtype::complex<float>,
                   phi::dtype::complex<double>) {}
178

179
PD_REGISTER_KERNEL(add_triple_grad,
180 181
                   GPU,
                   ALL_LAYOUT,
182
                   phi::AddTripleGradKernel,
183 184 185 186
                   float,
                   double,
                   int,
                   int64_t,
187
                   phi::dtype::float16,
188
                   phi::dtype::bfloat16,
189 190
                   phi::dtype::complex<float>,
                   phi::dtype::complex<double>) {}
191

192
PD_REGISTER_KERNEL(subtract_grad,
193 194
                   GPU,
                   ALL_LAYOUT,
195
                   phi::SubtractGradKernel,
196 197 198 199
                   float,
                   double,
                   int,
                   int64_t,
200
                   phi::dtype::float16,
201
                   phi::dtype::bfloat16,
202 203
                   phi::dtype::complex<float>,
                   phi::dtype::complex<double>) {}
204

205
PD_REGISTER_KERNEL(subtract_double_grad,
206 207
                   GPU,
                   ALL_LAYOUT,
208
                   phi::SubtractDoubleGradKernel,
209 210 211 212
                   float,
                   double,
                   int,
                   int64_t,
213
                   phi::dtype::float16,
214
                   phi::dtype::bfloat16,
215 216
                   phi::dtype::complex<float>,
                   phi::dtype::complex<double>) {}
217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242

PD_REGISTER_KERNEL(divide_grad,
                   GPU,
                   ALL_LAYOUT,
                   phi::DivideGradKernel,
                   float,
                   phi::dtype::float16,
                   phi::dtype::bfloat16,
                   double,
                   int,
                   int64_t,
                   phi::dtype::complex<float>,
                   phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(divide_double_grad,
                   GPU,
                   ALL_LAYOUT,
                   phi::DivideDoubleGradKernel,
                   float,
                   phi::dtype::float16,
                   phi::dtype::bfloat16,
                   double,
                   int,
                   int64_t,
                   phi::dtype::complex<float>,
                   phi::dtype::complex<double>) {}
Y
YuanRisheng 已提交
243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284

PD_REGISTER_KERNEL(multiply_grad,
                   GPU,
                   ALL_LAYOUT,
                   phi::MultiplyGradKernel,
                   float,
                   phi::dtype::float16,
                   double,
                   int,
                   int64_t,
                   bool,
                   phi::dtype::bfloat16,
                   phi::dtype::complex<float>,
                   phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(multiply_double_grad,
                   GPU,
                   ALL_LAYOUT,
                   phi::MultiplyDoubleGradKernel,
                   float,
                   phi::dtype::float16,
                   double,
                   int,
                   int64_t,
                   bool,
                   phi::dtype::bfloat16,
                   phi::dtype::complex<float>,
                   phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(multiply_triple_grad,
                   GPU,
                   ALL_LAYOUT,
                   phi::MultiplyTripleGradKernel,
                   float,
                   phi::dtype::float16,
                   double,
                   int,
                   int64_t,
                   bool,
                   phi::dtype::bfloat16,
                   phi::dtype::complex<float>,
                   phi::dtype::complex<double>) {}
285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301
PD_REGISTER_KERNEL(elementwise_fmax_grad,
                   GPU,
                   ALL_LAYOUT,
                   phi::ElementwiseFMaxGradKernel,
                   float,
                   double,
                   int,
                   int64_t) {}

PD_REGISTER_KERNEL(elementwise_fmin_grad,
                   GPU,
                   ALL_LAYOUT,
                   phi::ElementwiseFMinGradKernel,
                   float,
                   double,
                   int,
                   int64_t) {}