activation_grad_kernel.cc 11.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/activation_grad_kernel.h"
16
#include "paddle/phi/kernels/gelu_grad_kernel.h"
17 18

#include "paddle/phi/backends/onednn/onednn_context.h"
19
#include "paddle/phi/backends/onednn/onednn_reuse.h"
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/activation_functor.h"

namespace phi {

#define DEFINE_ONEDNN_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX( \
    name, functor_class, attr)                             \
  template <typename T, typename Context>                  \
  void name##GradKernel(const Context& dev_ctx,            \
                        const DenseTensor& x,              \
                        const DenseTensor& dout,           \
                        float attr,                        \
                        DenseTensor* dx) {                 \
    functor_class<T> functor;                              \
    functor(dev_ctx, x, dout, attr, 0, dx);                \
  }

#define DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(name, functor_class) \
  template <typename T, typename Context>                                \
  void name##GradKernel(const Context& dev_ctx,                          \
                        const DenseTensor& out,                          \
                        const DenseTensor& dout,                         \
                        DenseTensor* dx) {                               \
    functor_class<T> functor;                                            \
    functor(dev_ctx, out, dout, 0, 0, dx);                               \
  }

template <typename T>
void eltwise_grad(const OneDNNContext& dev_ctx,
                  const DenseTensor& x,
                  const DenseTensor& dout,
                  float alpha,
                  float beta,
                  DenseTensor* dx,
                  dnnl::algorithm algorithm) {
57 58 59 60 61 62 63
  funcs::ActivationOneDNNHandler<T> handler(algorithm,
                                            alpha,
                                            beta,
                                            dev_ctx.GetEngine(),
                                            dev_ctx.GetPlace(),
                                            &x,
                                            &dout);
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87

  auto src_memory_p = handler.AcquireBackwardSrcMemory(&x);
  auto diff_dst_memory_p = handler.AcquireDiffDstMemory(&dout);
  auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx);
  auto activation_backward_p = handler.AcquireBackwardPrimitive();

  auto& astream = OneDNNContext::tls().get_stream();
  activation_backward_p->execute(astream,
                                 {{DNNL_ARG_SRC, *src_memory_p},
                                  {DNNL_ARG_DIFF_DST, *diff_dst_memory_p},
                                  {DNNL_ARG_DIFF_SRC, *diff_src_memory_p}});
  astream.wait();

  dx->set_mem_desc(diff_src_memory_p->get_desc());
}

template <typename T>
void eltwise_grad_use_out(const OneDNNContext& dev_ctx,
                          const DenseTensor& out,
                          const DenseTensor& dout,
                          float alpha,
                          float beta,
                          DenseTensor* dx,
                          dnnl::algorithm algorithm) {
88 89 90 91 92 93 94
  funcs::ActivationOneDNNHandler<T> handler(algorithm,
                                            alpha,
                                            beta,
                                            dev_ctx.GetEngine(),
                                            dev_ctx.GetPlace(),
                                            &out,
                                            &dout);
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111

  auto dst_memory_p = handler.AcquireBackwardSrcMemory(&out);
  auto diff_dst_memory_p = handler.AcquireDiffDstMemory(&dout);
  auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx);
  auto activation_backward_p = handler.AcquireBackwardPrimitive();

  auto& astream = OneDNNContext::tls().get_stream();
  activation_backward_p->execute(astream,
                                 {{DNNL_ARG_DST, *dst_memory_p},
                                  {DNNL_ARG_DIFF_DST, *diff_dst_memory_p},
                                  {DNNL_ARG_DIFF_SRC, *diff_src_memory_p}});
  astream.wait();

  dx->set_mem_desc(diff_src_memory_p->get_desc());
}

template <typename T, dnnl::algorithm algorithm>
112
struct OneDNNActivationGradFunc : public funcs::BaseActivationFunctor<T> {
113 114 115 116 117 118 119 120 121 122 123
  void operator()(const OneDNNContext& dev_ctx,
                  const DenseTensor& x,
                  const DenseTensor& dout,
                  float alpha,
                  float beta,
                  DenseTensor* dx) const {
    eltwise_grad<T>(dev_ctx, x, dout, alpha, beta, dx, algorithm);
  }
};

template <typename T, dnnl::algorithm algorithm>
124
struct OneDNNActivationGradUseOutFunc : public funcs::BaseActivationFunctor<T> {
125 126 127 128 129 130 131 132 133 134
  void operator()(const OneDNNContext& dev_ctx,
                  const DenseTensor& out,
                  const DenseTensor& dout,
                  float alpha,
                  float beta,
                  DenseTensor* dx) const {
    eltwise_grad_use_out<T>(dev_ctx, out, dout, alpha, beta, dx, algorithm);
  }
};

135
template <typename T>
136 137
using AbsOneDNNGradFunctor =
    OneDNNActivationGradFunc<T, dnnl::algorithm::eltwise_abs>;
138

139
template <typename T>
140 141 142
using EluOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc<
    T,
    dnnl::algorithm::eltwise_elu_use_dst_for_bwd>;
143 144

template <typename T>
145 146 147
using ExpOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc<
    T,
    dnnl::algorithm::eltwise_exp_use_dst_for_bwd>;
148 149

template <typename T>
150 151
using HardSwishOneDNNGradFunctor =
    OneDNNActivationGradFunc<T, dnnl::algorithm::eltwise_hardswish>;
152 153

template <typename T>
154 155
using MishOneDNNGradFunctor =
    OneDNNActivationGradFunc<T, dnnl::algorithm::eltwise_mish>;
156 157

template <typename T>
158 159 160 161 162 163 164 165 166 167 168 169 170
using GeluTanhOneDNNGradFunctor =
    OneDNNActivationGradFunc<T, dnnl::algorithm::eltwise_gelu_tanh>;

template <typename T>
using GeluErfOneDNNGradFunctor =
    OneDNNActivationGradFunc<T, dnnl::algorithm::eltwise_gelu_erf>;

template <typename T>
using ReluOneDNNGradFunctor =
    OneDNNActivationGradFunc<T, dnnl::algorithm::eltwise_relu>;

template <typename T>
using Relu6OneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc<
171
    T,
172
    dnnl::algorithm::eltwise_clip_v2_use_dst_for_bwd>;
173 174

template <typename T>
175
using SigmoidOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc<
176
    T,
177
    dnnl::algorithm::eltwise_logistic_use_dst_for_bwd>;
178 179

template <typename T>
180
using SqrtOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc<
181 182 183 184
    T,
    dnnl::algorithm::eltwise_sqrt_use_dst_for_bwd>;

template <typename T>
185 186
using SwishOneDNNGradFunctor =
    OneDNNActivationGradFunc<T, dnnl::algorithm::eltwise_swish>;
187 188

template <typename T>
189
using TanhOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc<
190
    T,
191
    dnnl::algorithm::eltwise_tanh_use_dst_for_bwd>;
192

193
DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Abs, AbsOneDNNGradFunctor);
194
DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Exp, ExpOneDNNGradUseOutFunctor);
195
DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Relu, ReluOneDNNGradFunctor);
196 197 198 199
DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Sigmoid,
                                            SigmoidOneDNNGradUseOutFunctor);
DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Sqrt, SqrtOneDNNGradUseOutFunctor);
DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Tanh, TanhOneDNNGradUseOutFunctor);
200 201

DEFINE_ONEDNN_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(LeakyRelu,
202
                                                  ReluOneDNNGradFunctor,
203 204
                                                  alpha);
DEFINE_ONEDNN_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Mish,
205
                                                  MishOneDNNGradFunctor,
206 207
                                                  threshold);
DEFINE_ONEDNN_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Swish,
208
                                                  SwishOneDNNGradFunctor,
209
                                                  beta);
210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236

template <typename T, typename Context>
void EluGradKernel(const Context& dev_ctx,
                   const DenseTensor& x,
                   const DenseTensor& out,
                   const DenseTensor& dout,
                   float alpha,
                   DenseTensor* dx) {
  EluOneDNNGradUseOutFunctor<T> functor;
  functor(dev_ctx, out, dout, alpha, 0, dx);
}

template <typename T, typename Context>
void GeluGradKernel(const Context& dev_ctx,
                    const DenseTensor& x,
                    const DenseTensor& out_grad,
                    bool approximate,
                    DenseTensor* x_grad) {
  if (approximate) {
    GeluTanhOneDNNGradFunctor<T> functor;
    functor(dev_ctx, x, out_grad, 0, 0, x_grad);
  } else {
    GeluErfOneDNNGradFunctor<T> functor;
    functor(dev_ctx, x, out_grad, 0, 0, x_grad);
  }
}

237 238 239 240 241 242 243 244
template <typename T, typename Context>
void HardSwishGradKernel(const Context& dev_ctx,
                         const DenseTensor& x,
                         const DenseTensor& dout,
                         float threshold,
                         float scale,
                         float offset,
                         DenseTensor* dx) {
245
  HardSwishOneDNNGradFunctor<T> functor;
246 247 248 249
  functor(dev_ctx, x, dout, threshold, 0, dx);
}

template <typename T, typename Context>
250 251 252 253 254 255 256
void Relu6GradKernel(const Context& dev_ctx,
                     const DenseTensor& out,
                     const DenseTensor& dout,
                     float threshold,
                     DenseTensor* dx) {
  Relu6OneDNNGradUseOutFunctor<T> functor;
  functor(dev_ctx, out, dout, 0, threshold, dx);
257 258 259 260 261 262 263 264 265 266 267 268 269 270 271
}

}  // namespace phi

PD_REGISTER_KERNEL(relu_grad,
                   OneDNN,
                   ALL_LAYOUT,
                   phi::ReluGradKernel,
                   float,
                   phi::dtype::bfloat16) {}

#define PD_REGISTER_ACTIVATION_GRAD_KERNEL(name, func) \
  PD_REGISTER_KERNEL(                                  \
      name, OneDNN, ALL_LAYOUT, phi::func, float, phi::dtype::bfloat16) {}

272
PD_REGISTER_ACTIVATION_GRAD_KERNEL(abs_grad, AbsGradKernel)
273 274
PD_REGISTER_ACTIVATION_GRAD_KERNEL(elu_grad, EluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(exp_grad, ExpGradKernel)
275
PD_REGISTER_ACTIVATION_GRAD_KERNEL(gelu_grad, GeluGradKernel)
276 277 278
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_swish_grad, HardSwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_grad, LeakyReluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel)
279
PD_REGISTER_ACTIVATION_GRAD_KERNEL(relu6_grad, Relu6GradKernel)
280 281 282 283
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_grad, SigmoidGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_grad, TanhGradKernel)