activation_op.h 16.6 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2

L
Luo Tao 已提交
3 4 5 6 7 8 9 10 11
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Q
qijun 已提交
12 13

#pragma once
D
dzhwinter 已提交
14
#include <glog/logging.h>
15

Y
Yihua Xu 已提交
16
#include <algorithm>
17
#include <cmath>
18
#include <memory>
D
dzhwinter 已提交
19 20
#include <string>
#include <unordered_set>
21 22
#include <utility>
#include <vector>
C
Clementine 已提交
23 24 25 26
#ifndef _USE_MATH_DEFINES
#define _USE_MATH_DEFINES
#endif

27
#include <type_traits>
28

Y
Yi Wang 已提交
29 30
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
31
#include "paddle/fluid/framework/tensor_util.h"
32
#include "paddle/fluid/platform/enforce.h"
33
#include "paddle/fluid/platform/float16.h"
34
#include "paddle/phi/kernels/funcs/blas/blas.h"
35

36 37
#include "paddle/phi/kernels/funcs/activation_functor.h"

Q
qijun 已提交
38 39 40
namespace paddle {
namespace operators {

41 42
using framework::To32BitIndex;

43
using ActBwdOpFwdDeps = phi::funcs::ActBwdOpFwdDeps;
44

C
chengduo 已提交
45 46 47 48 49 50
/* The following operator can be used to process SelectedRows, because the
 * output of those operator for zero is zero too.
 */
static std::unordered_set<std::string> CanBeUsedBySelectedRows = {
    "abs", "abs_grad", "square", "square_grad", "sqrt", "sqrt_grad"};

51
inline void ExtractActivationTensor(const framework::ExecutionContext& context,
52 53
                                    const phi::DenseTensor** X,
                                    phi::DenseTensor** Out) {
54 55
  auto x_var = context.InputVar("X");
  auto out_var = context.OutputVar("Out");
56 57 58 59 60
  PADDLE_ENFORCE_NOT_NULL(x_var,
                          platform::errors::NotFound(
                              "Cannot get input Variable X, variable name = %s",
                              context.InputName("X")));
  PADDLE_ENFORCE_NOT_NULL(
61 62 63 64
      out_var,
      platform::errors::NotFound(
          "Cannot get output Variable Out, variable name = %s",
          context.OutputName("Out")));
H
hong 已提交
65
  if (CanBeUsedBySelectedRows.count(context.Type())) {
66 67 68 69
    *X = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_var);
    *Out = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(
        out_var);
  } else {
70 71
    *X = context.Input<phi::DenseTensor>("X");
    *Out = context.Output<phi::DenseTensor>("Out");
72 73
  }

74 75 76 77 78
  PADDLE_ENFORCE_NOT_NULL(
      *Out,
      platform::errors::NotFound("Cannot get the tensor from the Variable "
                                 "Output(Out), variable name = %s",
                                 context.OutputName("Out")));
79 80
}

81
template <ActBwdOpFwdDeps kDepValue>
82
inline void ExtractActivationGradTensor(
83
    const framework::ExecutionContext& context,
84 85 86 87
    const phi::DenseTensor** X,
    const phi::DenseTensor** Out,
    const phi::DenseTensor** dOut,
    phi::DenseTensor** dX) {
88 89
  auto out_grad_var = context.InputVar(framework::GradVarName("Out"));
  auto x_grad_var = context.OutputVar(framework::GradVarName("X"));
90 91
  const framework::Variable* out_var = nullptr;

92 93
  if (static_cast<int>(kDepValue) &
      static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
94
    out_var = context.InputVar("Out");
95
    PADDLE_ENFORCE_NOT_NULL(
96 97 98 99
        out_var,
        platform::errors::NotFound(
            "Cannot get input Variable Out, variable name = %s",
            context.InputName("Out")));
100 101 102
  }

  PADDLE_ENFORCE_NOT_NULL(
103 104 105 106 107
      out_grad_var,
      platform::errors::NotFound(
          "Cannot get input Variable %s, variable name = %s",
          framework::GradVarName("Out"),
          context.InputName(framework::GradVarName("Out"))));
108
  PADDLE_ENFORCE_NOT_NULL(
109 110 111 112 113
      x_grad_var,
      platform::errors::NotFound(
          "Cannot get output Variable %s, variable name = %s",
          framework::GradVarName("X"),
          context.OutputName(framework::GradVarName("X"))));
114

H
hong 已提交
115
  if (CanBeUsedBySelectedRows.count(context.Type())) {
116 117 118 119
    *dOut = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(
        *out_grad_var);
    *dX = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(
        x_grad_var);
120 121 122 123 124 125 126 127

    if (out_var) {
      *Out =
          paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*out_var);
    } else {
      *Out = *dOut;  // fake out
    }

128
  } else {
129 130 131
    *Out = context.Input<phi::DenseTensor>("Out");
    *dOut = context.Input<phi::DenseTensor>(framework::GradVarName("Out"));
    *dX = context.Output<phi::DenseTensor>(framework::GradVarName("X"));
132 133 134 135 136 137

    if (out_var) {
      *Out = &(out_var->Get<framework::LoDTensor>());
    } else {
      *Out = *dOut;  // fake out
    }
138
  }
139

140 141 142 143 144
  PADDLE_ENFORCE_NOT_NULL(*dX,
                          platform::errors::NotFound(
                              "Cannot get the tensor from the Variable "
                              "Output(Out), variable name = %s",
                              context.OutputName(framework::GradVarName("X"))));
145

146
  if (static_cast<int>(kDepValue) & static_cast<int>(ActBwdOpFwdDeps::kDepX)) {
C
chengduo 已提交
147
    auto x_var = context.InputVar("X");
148 149 150 151 152
    PADDLE_ENFORCE_NOT_NULL(
        x_var,
        platform::errors::NotFound("Cannot get the tensor from the "
                                   "Variable Input(X), variable name = %s",
                                   context.InputName("X")));
H
hong 已提交
153
    if (CanBeUsedBySelectedRows.count(context.Type())) {
154
      *X = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_var);
C
chengduo 已提交
155
    } else {
156
      *X = context.Input<phi::DenseTensor>("X");
C
chengduo 已提交
157
    }
158
  } else {
H
hong 已提交
159
    VLOG(10) << " Inplace activation of Op : " << context.Type();
160 161 162
    *X = *dX;
  }
}
C
chengduo 已提交
163

164 165 166 167 168
template <typename DeviceContext, typename Functor>
class ActivationKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  using T = typename Functor::ELEMENT_TYPE;
C
chengduo 已提交
169

170
  void Compute(const framework::ExecutionContext& context) const override {
171 172
    const phi::DenseTensor* X = nullptr;
    phi::DenseTensor* Out = nullptr;
173
    ExtractActivationTensor(context, &X, &Out);
C
chengduo 已提交
174
    Out->mutable_data<T>(context.GetPlace());
175

176 177 178 179
    auto x = framework::EigenVector<T>::Flatten(
        GET_DATA_SAFELY(X, "Input", "X", "Activation"));
    auto out = framework::EigenVector<T>::Flatten(
        GET_DATA_SAFELY(Out, "Output", "Out", "Activation"));
Q
QI JUN 已提交
180 181
    auto* place =
        context.template device_context<DeviceContext>().eigen_device();
Q
qijun 已提交
182
    Functor functor;
183 184 185 186 187

    auto attrs = functor.GetAttrs();
    for (auto& attr : attrs) {
      *attr.second = context.Attr<float>(attr.first);
    }
188 189 190 191 192 193 194 195
    // use 32bit index to speed up computation
    bool use_32bit_index = out.size() < Eigen::NumTraits<int>::highest();
    bool is_gpu_place = platform::is_gpu_place(context.GetPlace());
    if (use_32bit_index && is_gpu_place) {
      functor(*place, To32BitIndex(x), To32BitIndex(out));
    } else {
      functor(*place, x, out);
    }
Q
qijun 已提交
196 197 198
  }
};

Q
QI JUN 已提交
199
template <typename DeviceContext, typename Functor>
200 201
class ActivationGradKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
Q
qijun 已提交
202
 public:
203
  using T = typename Functor::ELEMENT_TYPE;
Q
qijun 已提交
204
  void Compute(const framework::ExecutionContext& context) const override {
205 206
    const phi::DenseTensor *X, *Out, *dOut;
    phi::DenseTensor* dX = nullptr;
207
    X = Out = dOut = nullptr;
208 209
    ExtractActivationGradTensor<Functor::FwdDeps()>(
        context, &X, &Out, &dOut, &dX);
Q
qijun 已提交
210
    dX->mutable_data<T>(context.GetPlace());
211 212 213 214 215 216 217 218
    auto dout = framework::EigenVector<T>::Flatten(
        GET_DATA_SAFELY(dOut, "Input", "Out@GRAD", "ActivationGrad"));
    auto out = framework::EigenVector<T>::Flatten(
        GET_DATA_SAFELY(Out, "Input", "Out", "ActivationGrad"));
    auto dx = framework::EigenVector<T>::Flatten(
        GET_DATA_SAFELY(dX, "Input", "X@GRAD", "ActivationGrad"));
    auto x = framework::EigenVector<T>::Flatten(
        GET_DATA_SAFELY(X, "Input", "X", "ActivationGrad"));
Q
QI JUN 已提交
219 220
    auto* place =
        context.template device_context<DeviceContext>().eigen_device();
Q
qijun 已提交
221
    Functor functor;
222 223 224 225
    auto attrs = functor.GetAttrs();
    for (auto& attr : attrs) {
      *attr.second = context.Attr<float>(attr.first);
    }
226 227 228 229
    // use 32bit index to speed up computation
    bool use_32bit_index = out.size() < Eigen::NumTraits<int>::highest();
    bool is_gpu_place = platform::is_gpu_place(context.GetPlace());
    if (use_32bit_index && is_gpu_place) {
230 231 232 233
      functor(*place,
              To32BitIndex(x),
              To32BitIndex(out),
              To32BitIndex(dout),
234 235 236 237
              To32BitIndex(dx));
    } else {
      functor(*place, x, out, dout, dx);
    }
Q
qijun 已提交
238 239 240
  }
};

241 242 243 244 245 246 247 248 249
template <typename T>
struct BaseActivationFunctor {
  using ELEMENT_TYPE = T;

  using AttrPair = std::vector<std::pair<const char*, float*>>;

  AttrPair GetAttrs() { return AttrPair(); }
};

250 251 252 253 254 255
#define USE_PHI_FUNCTOR(name)                         \
  template <typename T>                               \
  using name##Functor = phi::funcs::name##Functor<T>; \
  template <typename T>                               \
  using name##GradFunctor = phi::funcs::name##GradFunctor<T>;

256 257 258 259 260 261 262 263
#define USE_PHI_DOUBLE_GRAD_FUNCTOR(name) \
  template <typename T>                   \
  using name##GradGradFunctor = phi::funcs::name##GradGradFunctor<T>;

#define USE_PHI_TRIPLE_GRAD_FUNCTOR(name) \
  template <typename T>                   \
  using name##TripleGradFunctor = phi::funcs::name##TripleGradFunctor<T>;

264 265 266 267 268 269 270 271 272 273 274
USE_PHI_FUNCTOR(Cos)
USE_PHI_FUNCTOR(Tan)
USE_PHI_FUNCTOR(Acos)
USE_PHI_FUNCTOR(Sin)
USE_PHI_FUNCTOR(Asin)
USE_PHI_FUNCTOR(Atan)
USE_PHI_FUNCTOR(Sinh)
USE_PHI_FUNCTOR(Cosh)
USE_PHI_FUNCTOR(Asinh)
USE_PHI_FUNCTOR(Acosh)
USE_PHI_FUNCTOR(Atanh)
275
USE_PHI_FUNCTOR(Tanh)
276
USE_PHI_FUNCTOR(Exp)
277 278 279 280
USE_PHI_DOUBLE_GRAD_FUNCTOR(Tanh)
USE_PHI_TRIPLE_GRAD_FUNCTOR(Tanh)
USE_PHI_FUNCTOR(BRelu)
USE_PHI_FUNCTOR(ThresholdedRelu)
281
USE_PHI_FUNCTOR(Relu6)
282 283
USE_PHI_FUNCTOR(LeakyRelu)
USE_PHI_DOUBLE_GRAD_FUNCTOR(LeakyRelu)
Y
YuanRisheng 已提交
284 285 286 287 288 289
USE_PHI_FUNCTOR(HardShrink)
USE_PHI_FUNCTOR(SoftShrink)
USE_PHI_FUNCTOR(TanhShrink)
USE_PHI_FUNCTOR(Silu)
USE_PHI_FUNCTOR(ELU)
USE_PHI_DOUBLE_GRAD_FUNCTOR(ELU)
290
USE_PHI_FUNCTOR(Softsign)
Y
YuanRisheng 已提交
291 292 293 294 295
USE_PHI_FUNCTOR(Sigmoid)
USE_PHI_DOUBLE_GRAD_FUNCTOR(Sigmoid)
USE_PHI_TRIPLE_GRAD_FUNCTOR(Sigmoid)
USE_PHI_FUNCTOR(LogSigmoid)
USE_PHI_FUNCTOR(HardSigmoid)
296 297 298 299 300
USE_PHI_FUNCTOR(Log)
USE_PHI_DOUBLE_GRAD_FUNCTOR(Log)
USE_PHI_FUNCTOR(Log2)
USE_PHI_FUNCTOR(Log10)
USE_PHI_FUNCTOR(Log1p)
Y
YuanRisheng 已提交
301 302 303
USE_PHI_FUNCTOR(Swish)
USE_PHI_FUNCTOR(HardSwish)
USE_PHI_FUNCTOR(Pow)
304 305 306 307 308 309
USE_PHI_FUNCTOR(Exp)
USE_PHI_FUNCTOR(Expm1)
USE_PHI_FUNCTOR(Mish)
USE_PHI_FUNCTOR(STanh)
USE_PHI_FUNCTOR(Reciprocal)
USE_PHI_FUNCTOR(Square)
Y
YuanRisheng 已提交
310
USE_PHI_DOUBLE_GRAD_FUNCTOR(Square)
311
USE_PHI_FUNCTOR(Sqrt)
Y
YuanRisheng 已提交
312
USE_PHI_DOUBLE_GRAD_FUNCTOR(Sqrt)
313
USE_PHI_FUNCTOR(Rsqrt)
Y
YuanRisheng 已提交
314
USE_PHI_DOUBLE_GRAD_FUNCTOR(Rsqrt)
315
USE_PHI_FUNCTOR(Softplus)
Y
YuanRisheng 已提交
316 317
USE_PHI_FUNCTOR(CELU)
USE_PHI_DOUBLE_GRAD_FUNCTOR(CELU)
Y
YuanRisheng 已提交
318 319 320

template <typename T>
using ELUGradNegativeAlphaFunctor = phi::funcs::ELUGradNegativeAlphaFunctor<T>;
321

Y
YuanRisheng 已提交
322 323 324 325 326 327 328 329 330 331 332 333
template <typename T>
using RoundFunctor = phi::funcs::RoundFunctor<T>;

template <typename T>
using FloorFunctor = phi::funcs::FloorFunctor<T>;

template <typename T>
using CeilFunctor = phi::funcs::CeilFunctor<T>;

template <typename T>
using ZeroGradFunctor = phi::funcs::ZeroGradFunctor<T>;

334
template <typename T>
335
using ELUGradNegativeAlphaFunctor = phi::funcs::ELUGradNegativeAlphaFunctor<T>;
R
ronnywang 已提交
336

Q
qijun 已提交
337
// relu(x) = max(x, 0)
338 339

template <typename T>
340 341 342
using ReluCPUFunctor = phi::funcs::ReluCPUFunctor<T>;
template <typename T>
using ReluGradFunctor = phi::funcs::ReluGradFunctor<T>;
Q
qijun 已提交
343

Q
qijun 已提交
344
template <typename T>
345
using ReluGradGradFunctor = phi::funcs::ReluGradGradFunctor<T>;
346

347 348
template <typename T>
using ReluCUDAFunctor = phi::funcs::ReluCUDAFunctor<T>;
Q
qijun 已提交
349

350 351 352 353 354 355
template <typename T>
struct SoftReluFunctor : public BaseActivationFunctor<T> {
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }
356

F
fengjiayi 已提交
357 358
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
Y
Yu Yang 已提交
359 360
    auto tmp = static_cast<T>(threshold);
    auto temp = x.cwiseMax(-tmp).cwiseMin(tmp);
F
fengjiayi 已提交
361
    out.device(d) = (static_cast<T>(1) + temp.exp()).log();
362 363 364
  }
};

365 366 367 368 369 370
template <typename T>
struct SoftReluGradFunctor : public BaseActivationFunctor<T> {
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }
371 372 373 374
  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
F
fengjiayi 已提交
375 376
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
Y
Yu Yang 已提交
377
    auto tmp = static_cast<T>(threshold);
Z
Zeng Jinle 已提交
378
    auto temp = ((out > -tmp) * (out < tmp)).template cast<T>();
F
fengjiayi 已提交
379
    dx.device(d) = dout * (static_cast<T>(1) - (-out).exp()) * temp;
380
  }
381

382 383 384
  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
385 386
};

Z
zhupengyang 已提交
387 388 389 390
template <typename DeviceContext, typename T>
class ELUGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
391 392 393 394
    auto* X = context.Input<phi::DenseTensor>("X");
    auto* Out = context.Input<phi::DenseTensor>("Out");
    auto* dOut = context.Input<phi::DenseTensor>(framework::GradVarName("Out"));
    auto* dX = context.Output<phi::DenseTensor>(framework::GradVarName("X"));
Z
zhupengyang 已提交
395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420
    const float alpha = context.Attr<float>("alpha");
    dX->mutable_data<T>(context.GetPlace());

    auto x = framework::EigenVector<T>::Flatten(
        GET_DATA_SAFELY(X, "Input", "X", "elu_grad"));
    auto out = framework::EigenVector<T>::Flatten(
        GET_DATA_SAFELY(Out, "Input", "Out", "elu_grad"));
    auto dout = framework::EigenVector<T>::Flatten(
        GET_DATA_SAFELY(dOut, "Input", "dOut", "elu_grad"));
    auto dx = framework::EigenVector<T>::Flatten(
        GET_DATA_SAFELY(dX, "Output", "dX", "elu_grad"));
    auto* place =
        context.template device_context<DeviceContext>().eigen_device();

    if (alpha > 0) {
      ELUGradFunctor<T> functor;
      functor.alpha = alpha;
      functor(*place, x, out, dout, dx);
    } else {
      ELUGradNegativeAlphaFunctor<T> functor;
      functor.alpha = alpha;
      functor(*place, x, out, dout, dx);
    }
  }
};

Z
Zhong Hui 已提交
421 422 423
template <typename T>
struct AbsGradGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device>
424
  void operator()(const Device& dev,
425 426 427 428 429 430
                  const phi::DenseTensor* X,
                  const phi::DenseTensor* Out,
                  const phi::DenseTensor* ddX,
                  phi::DenseTensor* ddOut,
                  phi::DenseTensor* dOut,
                  phi::DenseTensor* dX) const {
Z
Zhong Hui 已提交
431 432 433 434 435 436 437 438 439 440 441
    auto* d = dev.eigen_device();
    auto ddx = framework::EigenVector<T>::Flatten(
        GET_DATA_SAFELY(ddX, "Input", "DDX", "AbsGradGrad"));
    auto x = framework::EigenVector<T>::Flatten(
        GET_DATA_SAFELY(X, "Input", "X", "AbsGradGrad"));
    if (ddOut) {
      auto ddout = framework::EigenVector<T>::Flatten(
          GET_DATA_SAFELY(ddOut, "Output", "DDOut", "AbsGradGrad"));
      ddout.device(*d) = ddx * x.sign();
    }
  }
442
  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
443 444
};

445 446
// TODO(dengkaipeng): double gradient calculation for Square/Sqrt need
// DOut(dy) as input(not output), tensor extraction is different from
447
// others. Impliment extraction kernel separately here.
448
inline void ExtractDoubleGradTensorWithInputDOut(
449
    const framework::ExecutionContext& ctx,
450 451 452 453 454
    const phi::DenseTensor** X,
    const phi::DenseTensor** ddX,
    phi::DenseTensor** dX,
    const phi::DenseTensor** dOut,
    phi::DenseTensor** ddOut) {
455 456 457
  // extract ddX(output), ddOut(input)
  auto ddx_var = ctx.InputVar("DDX");
  auto ddo_var = ctx.OutputVar("DDOut");
458
  PADDLE_ENFORCE_NOT_NULL(
459 460 461 462
      ddx_var,
      platform::errors::NotFound(
          "Cannot get input Variable Out, variable name = %s",
          ctx.InputName("DDX")));
463
  *ddX = ctx.Input<phi::DenseTensor>("DDX");
464
  if (ddo_var) {
465
    *ddOut = ctx.Output<phi::DenseTensor>("DDOut");
466
  }
467 468 469 470 471
  PADDLE_ENFORCE_NOT_NULL(
      ddX,
      platform::errors::NotFound(
          "Cannot get the tensor from the Variable DDX, variable name = %s",
          ctx.OutputName("DDX")));
472 473 474

  // extract x(input), dx(output)
  auto x_var = ctx.InputVar("X");
475
  PADDLE_ENFORCE_NOT_NULL(
476 477 478 479
      x_var,
      platform::errors::NotFound(
          "Cannot get input Variable Out, variable name = %s",
          ctx.InputName("X")));
480
  auto dx_var = ctx.OutputVar("DX");
481
  *X = ctx.Input<phi::DenseTensor>("X");
482
  if (dx_var) {
483
    *dX = ctx.Output<phi::DenseTensor>("DX");
484 485 486 487 488
  }

  // extract dOut(input)
  auto dout_var = ctx.InputVar("DOut");
  if (dout_var) {
489
    *dOut = ctx.Input<phi::DenseTensor>("DOut");
490 491 492
  }
}

Q
qijun 已提交
493 494
}  // namespace operators
}  // namespace paddle
495

496 497
#define FOR_EACH_ACTIVATION_OP(__macro) \
  __macro(soft_relu, SoftRelu, SoftReluFunctor, SoftReluGradFunctor);