elementwise_grad_kernel.cu 8.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
//   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/phi/kernels/elementwise_grad_kernel.h"
16

17
#include "paddle/phi/backends/gpu/gpu_context.h"
18 19 20
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/common/float16.h"
21 22 23
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
24
#include "paddle/phi/kernels/gpu/elementwise_grad.h"
25
#include "paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h"
26

27
namespace phi {
28 29 30 31 32 33 34 35 36 37 38

template <typename T>
void AddGradFunc(const GPUContext& dev_ctx,
                 const DenseTensor& x,
                 const DenseTensor& y,
                 const DenseTensor& out,
                 const DenseTensor& dout,
                 DenseTensor* dx,
                 DenseTensor* dy,
                 int axis = -1) {
  if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) {
39
    ElementwiseAddGrad<T>(dev_ctx, x, y, out, dout, dx, dy);
40
  } else {
41
    DefaultElementwiseAddGrad<T>(dev_ctx, x, y, out, dout, dx, dy, axis);
42 43 44 45 46 47 48 49 50 51 52
  }
}

template <typename T, typename Context>
void AddGradKernel(const Context& dev_ctx,
                   const DenseTensor& x,
                   const DenseTensor& y,
                   const DenseTensor& dout,
                   int axis,
                   DenseTensor* dx,
                   DenseTensor* dy) {
53
  phi::AddGradImpl<T>(dev_ctx, x, y, dout, axis, dx, dy, AddGradFunc<T>);
54 55 56 57 58 59 60 61 62 63
}

template <typename T, typename Context>
void AddDoubleGradKernel(const Context& dev_ctx,
                         const DenseTensor& y,
                         paddle::optional<const DenseTensor&> ddx,
                         paddle::optional<const DenseTensor&> ddy,
                         const DenseTensor& dout,
                         int axis,
                         DenseTensor* ddout) {
64
  phi::AddDoubleGradImpl<T>(dev_ctx, y, ddx, ddy, dout, axis, ddout);
65 66 67 68 69 70 71 72 73 74
}

template <typename T, typename Context>
void AddTripleGradKernel(const Context& dev_ctx,
                         const DenseTensor& ddx,
                         const DenseTensor& ddy,
                         const DenseTensor& d_ddout,
                         int axis,
                         DenseTensor* d_ddx,
                         DenseTensor* d_ddy) {
75
  phi::AddGradImpl<T>(
76 77 78
      dev_ctx, ddx, ddy, d_ddout, axis, d_ddx, d_ddy, AddGradFunc<T>);
}

79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
template <typename T, typename Context>
void SubtractGradKernel(const Context& dev_ctx,
                        const DenseTensor& x,
                        const DenseTensor& y,
                        const DenseTensor& dout,
                        int axis,
                        DenseTensor* dx,
                        DenseTensor* dy) {
  // skip out
  auto* out = &dout;
  if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) {
    elementwise_sub_grad<T>(dev_ctx, x, y, *out, dout, dx, dy);
  } else {
    default_elementwise_sub_grad<T>(dev_ctx, x, y, *out, dout, dx, dy, axis);
  }
}

template <typename T, typename Context>
void SubtractDoubleGradKernel(const Context& dev_ctx,
                              const DenseTensor& y,
                              paddle::optional<const DenseTensor&> ddx,
                              paddle::optional<const DenseTensor&> ddy,
                              const DenseTensor& dout,
                              int axis,
                              DenseTensor* ddout) {
104
  phi::SubtractDoubleGradImpl<T>(dev_ctx, y, ddx, ddy, dout, axis, ddout);
105 106
}

107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
template <typename T, typename Context>
void DivideGradKernel(const Context& dev_ctx,
                      const DenseTensor& x,
                      const DenseTensor& y,
                      const DenseTensor& out,
                      const DenseTensor& dout,
                      int axis,
                      DenseTensor* dx,
                      DenseTensor* dy) {
  const auto place = dev_ctx.GetPlace();
  if (dx != nullptr && dy != nullptr) {
    std::vector<const DenseTensor*> ins = {&dout, &out, &y};
    GetGradXAndYOut<ElementwiseType::kTernary, T>(
        dev_ctx,
        place,
        axis,
        ins,
        dout,
        dx,
        dy,
        funcs::DivGradXYFunctor<T, T>());
  } else if (dx != nullptr && dy == nullptr) {
    std::vector<const DenseTensor*> ins = {&dout, &y};
    GetGradXOrYOut<ElementwiseType::kBinary, T>(
        dev_ctx, place, axis, ins, dout, dx, funcs::DivGradXFunctor<T>());
  } else if (dy != nullptr && dx == nullptr) {
    std::vector<const DenseTensor*> ins = {&dout, &out, &y};
    GetGradXOrYOut<ElementwiseType::kTernary, T>(
        dev_ctx, place, axis, ins, dout, dy, funcs::DivGradYFunctor<T>());
  }
}

139
}  // namespace phi
140

141
PD_REGISTER_KERNEL(add_grad,
142 143
                   GPU,
                   ALL_LAYOUT,
144
                   phi::AddGradKernel,
145 146 147 148
                   float,
                   double,
                   int,
                   int64_t,
149
                   phi::dtype::float16,
150
                   phi::dtype::bfloat16,
151 152
                   phi::dtype::complex<float>,
                   phi::dtype::complex<double>) {}
153

154
PD_REGISTER_KERNEL(add_double_grad,
155 156
                   GPU,
                   ALL_LAYOUT,
157
                   phi::AddDoubleGradKernel,
158 159 160 161
                   float,
                   double,
                   int,
                   int64_t,
162
                   phi::dtype::float16,
163
                   phi::dtype::bfloat16,
164 165
                   phi::dtype::complex<float>,
                   phi::dtype::complex<double>) {}
166

167
PD_REGISTER_KERNEL(add_triple_grad,
168 169
                   GPU,
                   ALL_LAYOUT,
170
                   phi::AddTripleGradKernel,
171 172 173 174
                   float,
                   double,
                   int,
                   int64_t,
175
                   phi::dtype::float16,
176
                   phi::dtype::bfloat16,
177 178
                   phi::dtype::complex<float>,
                   phi::dtype::complex<double>) {}
179

180
PD_REGISTER_KERNEL(subtract_grad,
181 182
                   GPU,
                   ALL_LAYOUT,
183
                   phi::SubtractGradKernel,
184 185 186 187
                   float,
                   double,
                   int,
                   int64_t,
188
                   phi::dtype::float16,
189
                   phi::dtype::bfloat16,
190 191
                   phi::dtype::complex<float>,
                   phi::dtype::complex<double>) {}
192

193
PD_REGISTER_KERNEL(subtract_double_grad,
194 195
                   GPU,
                   ALL_LAYOUT,
196
                   phi::SubtractDoubleGradKernel,
197 198 199 200
                   float,
                   double,
                   int,
                   int64_t,
201
                   phi::dtype::float16,
202
                   phi::dtype::bfloat16,
203 204
                   phi::dtype::complex<float>,
                   phi::dtype::complex<double>) {}
205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230

PD_REGISTER_KERNEL(divide_grad,
                   GPU,
                   ALL_LAYOUT,
                   phi::DivideGradKernel,
                   float,
                   phi::dtype::float16,
                   phi::dtype::bfloat16,
                   double,
                   int,
                   int64_t,
                   phi::dtype::complex<float>,
                   phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(divide_double_grad,
                   GPU,
                   ALL_LAYOUT,
                   phi::DivideDoubleGradKernel,
                   float,
                   phi::dtype::float16,
                   phi::dtype::bfloat16,
                   double,
                   int,
                   int64_t,
                   phi::dtype::complex<float>,
                   phi::dtype::complex<double>) {}