activation_op.h 18.7 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2

L
Luo Tao 已提交
3 4 5 6 7 8 9 10 11
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Q
qijun 已提交
12 13

#pragma once
D
dzhwinter 已提交
14
#include <glog/logging.h>
15

Y
Yihua Xu 已提交
16
#include <algorithm>
17
#include <cmath>
18
#include <memory>
D
dzhwinter 已提交
19 20
#include <string>
#include <unordered_set>
21 22
#include <utility>
#include <vector>
C
Clementine 已提交
23 24 25 26
#ifndef _USE_MATH_DEFINES
#define _USE_MATH_DEFINES
#endif

27
#include <type_traits>
28

Y
Yi Wang 已提交
29 30
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
31
#include "paddle/fluid/framework/tensor_util.h"
32
#include "paddle/fluid/platform/enforce.h"
33
#include "paddle/fluid/platform/float16.h"
34
#include "paddle/phi/kernels/funcs/blas/blas.h"
35 36 37 38
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif

39 40
#include "paddle/phi/kernels/funcs/activation_functor.h"

Q
qijun 已提交
41 42 43
namespace paddle {
namespace operators {

44 45
using framework::To32BitIndex;

46
using ActBwdOpFwdDeps = phi::funcs::ActBwdOpFwdDeps;
47

C
chengduo 已提交
48 49 50 51 52 53
/* The following operator can be used to process SelectedRows, because the
 * output of those operator for zero is zero too.
 */
static std::unordered_set<std::string> CanBeUsedBySelectedRows = {
    "abs", "abs_grad", "square", "square_grad", "sqrt", "sqrt_grad"};

54 55 56 57 58
inline void ExtractActivationTensor(const framework::ExecutionContext& context,
                                    const framework::Tensor** X,
                                    framework::Tensor** Out) {
  auto x_var = context.InputVar("X");
  auto out_var = context.OutputVar("Out");
59 60 61 62 63 64 65 66
  PADDLE_ENFORCE_NOT_NULL(x_var,
                          platform::errors::NotFound(
                              "Cannot get input Variable X, variable name = %s",
                              context.InputName("X")));
  PADDLE_ENFORCE_NOT_NULL(
      out_var, platform::errors::NotFound(
                   "Cannot get output Variable Out, variable name = %s",
                   context.OutputName("Out")));
H
hong 已提交
67
  if (CanBeUsedBySelectedRows.count(context.Type())) {
68 69 70 71 72 73 74 75
    *X = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_var);
    *Out = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(
        out_var);
  } else {
    *X = context.Input<framework::Tensor>("X");
    *Out = context.Output<framework::Tensor>("Out");
  }

76 77 78 79
  PADDLE_ENFORCE_NOT_NULL(*Out, platform::errors::NotFound(
                                    "Cannot get the tensor from the Variable "
                                    "Output(Out), variable name = %s",
                                    context.OutputName("Out")));
80 81
}

82
template <ActBwdOpFwdDeps kDepValue>
83 84 85 86 87 88
inline void ExtractActivationGradTensor(
    const framework::ExecutionContext& context, const framework::Tensor** X,
    const framework::Tensor** Out, const framework::Tensor** dOut,
    framework::Tensor** dX) {
  auto out_grad_var = context.InputVar(framework::GradVarName("Out"));
  auto x_grad_var = context.OutputVar(framework::GradVarName("X"));
89 90
  const framework::Variable* out_var = nullptr;

91 92
  if (static_cast<int>(kDepValue) &
      static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
93
    out_var = context.InputVar("Out");
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
    PADDLE_ENFORCE_NOT_NULL(
        out_var, platform::errors::NotFound(
                     "Cannot get input Variable Out, variable name = %s",
                     context.InputName("Out")));
  }

  PADDLE_ENFORCE_NOT_NULL(
      out_grad_var, platform::errors::NotFound(
                        "Cannot get input Variable %s, variable name = %s",
                        framework::GradVarName("Out"),
                        context.InputName(framework::GradVarName("Out"))));
  PADDLE_ENFORCE_NOT_NULL(
      x_grad_var, platform::errors::NotFound(
                      "Cannot get output Variable %s, variable name = %s",
                      framework::GradVarName("X"),
                      context.OutputName(framework::GradVarName("X"))));
110

H
hong 已提交
111
  if (CanBeUsedBySelectedRows.count(context.Type())) {
112 113 114 115
    *dOut = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(
        *out_grad_var);
    *dX = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(
        x_grad_var);
116 117 118 119 120 121 122 123

    if (out_var) {
      *Out =
          paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*out_var);
    } else {
      *Out = *dOut;  // fake out
    }

124 125 126 127
  } else {
    *Out = context.Input<framework::Tensor>("Out");
    *dOut = context.Input<framework::Tensor>(framework::GradVarName("Out"));
    *dX = context.Output<framework::Tensor>(framework::GradVarName("X"));
128 129 130 131 132 133

    if (out_var) {
      *Out = &(out_var->Get<framework::LoDTensor>());
    } else {
      *Out = *dOut;  // fake out
    }
134
  }
135

136 137 138 139 140
  PADDLE_ENFORCE_NOT_NULL(*dX,
                          platform::errors::NotFound(
                              "Cannot get the tensor from the Variable "
                              "Output(Out), variable name = %s",
                              context.OutputName(framework::GradVarName("X"))));
141

142
  if (static_cast<int>(kDepValue) & static_cast<int>(ActBwdOpFwdDeps::kDepX)) {
C
chengduo 已提交
143
    auto x_var = context.InputVar("X");
144 145 146 147
    PADDLE_ENFORCE_NOT_NULL(x_var, platform::errors::NotFound(
                                       "Cannot get the tensor from the "
                                       "Variable Input(X), variable name = %s",
                                       context.InputName("X")));
H
hong 已提交
148
    if (CanBeUsedBySelectedRows.count(context.Type())) {
149
      *X = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_var);
C
chengduo 已提交
150
    } else {
151
      *X = context.Input<framework::Tensor>("X");
C
chengduo 已提交
152
    }
153
  } else {
H
hong 已提交
154
    VLOG(10) << " Inplace activation of Op : " << context.Type();
155 156 157
    *X = *dX;
  }
}
C
chengduo 已提交
158

159 160 161 162 163
template <typename DeviceContext, typename Functor>
class ActivationKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  using T = typename Functor::ELEMENT_TYPE;
C
chengduo 已提交
164

165 166 167 168
  void Compute(const framework::ExecutionContext& context) const override {
    const framework::Tensor* X = nullptr;
    framework::Tensor* Out = nullptr;
    ExtractActivationTensor(context, &X, &Out);
C
chengduo 已提交
169
    Out->mutable_data<T>(context.GetPlace());
170

171 172 173 174
    auto x = framework::EigenVector<T>::Flatten(
        GET_DATA_SAFELY(X, "Input", "X", "Activation"));
    auto out = framework::EigenVector<T>::Flatten(
        GET_DATA_SAFELY(Out, "Output", "Out", "Activation"));
Q
QI JUN 已提交
175 176
    auto* place =
        context.template device_context<DeviceContext>().eigen_device();
Q
qijun 已提交
177
    Functor functor;
178 179 180 181 182

    auto attrs = functor.GetAttrs();
    for (auto& attr : attrs) {
      *attr.second = context.Attr<float>(attr.first);
    }
183 184 185 186 187 188 189 190
    // use 32bit index to speed up computation
    bool use_32bit_index = out.size() < Eigen::NumTraits<int>::highest();
    bool is_gpu_place = platform::is_gpu_place(context.GetPlace());
    if (use_32bit_index && is_gpu_place) {
      functor(*place, To32BitIndex(x), To32BitIndex(out));
    } else {
      functor(*place, x, out);
    }
Q
qijun 已提交
191 192 193
  }
};

Q
QI JUN 已提交
194
template <typename DeviceContext, typename Functor>
195 196
class ActivationGradKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
Q
qijun 已提交
197
 public:
198
  using T = typename Functor::ELEMENT_TYPE;
Q
qijun 已提交
199
  void Compute(const framework::ExecutionContext& context) const override {
200 201 202
    const framework::Tensor *X, *Out, *dOut;
    framework::Tensor* dX = nullptr;
    X = Out = dOut = nullptr;
203 204
    ExtractActivationGradTensor<Functor::FwdDeps()>(context, &X, &Out, &dOut,
                                                    &dX);
Q
qijun 已提交
205
    dX->mutable_data<T>(context.GetPlace());
206 207 208 209 210 211 212 213
    auto dout = framework::EigenVector<T>::Flatten(
        GET_DATA_SAFELY(dOut, "Input", "Out@GRAD", "ActivationGrad"));
    auto out = framework::EigenVector<T>::Flatten(
        GET_DATA_SAFELY(Out, "Input", "Out", "ActivationGrad"));
    auto dx = framework::EigenVector<T>::Flatten(
        GET_DATA_SAFELY(dX, "Input", "X@GRAD", "ActivationGrad"));
    auto x = framework::EigenVector<T>::Flatten(
        GET_DATA_SAFELY(X, "Input", "X", "ActivationGrad"));
Q
QI JUN 已提交
214 215
    auto* place =
        context.template device_context<DeviceContext>().eigen_device();
Q
qijun 已提交
216
    Functor functor;
217 218 219 220
    auto attrs = functor.GetAttrs();
    for (auto& attr : attrs) {
      *attr.second = context.Attr<float>(attr.first);
    }
221 222 223 224 225 226 227 228 229
    // use 32bit index to speed up computation
    bool use_32bit_index = out.size() < Eigen::NumTraits<int>::highest();
    bool is_gpu_place = platform::is_gpu_place(context.GetPlace());
    if (use_32bit_index && is_gpu_place) {
      functor(*place, To32BitIndex(x), To32BitIndex(out), To32BitIndex(dout),
              To32BitIndex(dx));
    } else {
      functor(*place, x, out, dout, dx);
    }
Q
qijun 已提交
230 231 232
  }
};

233 234 235 236 237 238 239 240 241
template <typename T>
struct BaseActivationFunctor {
  using ELEMENT_TYPE = T;

  using AttrPair = std::vector<std::pair<const char*, float*>>;

  AttrPair GetAttrs() { return AttrPair(); }
};

242 243 244 245 246 247
#define USE_PHI_FUNCTOR(name)                         \
  template <typename T>                               \
  using name##Functor = phi::funcs::name##Functor<T>; \
  template <typename T>                               \
  using name##GradFunctor = phi::funcs::name##GradFunctor<T>;

248 249 250 251 252 253 254 255
#define USE_PHI_DOUBLE_GRAD_FUNCTOR(name) \
  template <typename T>                   \
  using name##GradGradFunctor = phi::funcs::name##GradGradFunctor<T>;

#define USE_PHI_TRIPLE_GRAD_FUNCTOR(name) \
  template <typename T>                   \
  using name##TripleGradFunctor = phi::funcs::name##TripleGradFunctor<T>;

256 257 258 259 260 261 262 263 264 265 266
USE_PHI_FUNCTOR(Cos)
USE_PHI_FUNCTOR(Tan)
USE_PHI_FUNCTOR(Acos)
USE_PHI_FUNCTOR(Sin)
USE_PHI_FUNCTOR(Asin)
USE_PHI_FUNCTOR(Atan)
USE_PHI_FUNCTOR(Sinh)
USE_PHI_FUNCTOR(Cosh)
USE_PHI_FUNCTOR(Asinh)
USE_PHI_FUNCTOR(Acosh)
USE_PHI_FUNCTOR(Atanh)
267
USE_PHI_FUNCTOR(Tanh)
268
USE_PHI_FUNCTOR(Exp)
269 270 271 272 273 274
USE_PHI_DOUBLE_GRAD_FUNCTOR(Tanh)
USE_PHI_TRIPLE_GRAD_FUNCTOR(Tanh)
USE_PHI_FUNCTOR(BRelu)
USE_PHI_FUNCTOR(ThresholdedRelu)
USE_PHI_FUNCTOR(LeakyRelu)
USE_PHI_DOUBLE_GRAD_FUNCTOR(LeakyRelu)
Y
YuanRisheng 已提交
275 276 277 278 279 280
USE_PHI_FUNCTOR(HardShrink)
USE_PHI_FUNCTOR(SoftShrink)
USE_PHI_FUNCTOR(TanhShrink)
USE_PHI_FUNCTOR(Silu)
USE_PHI_FUNCTOR(ELU)
USE_PHI_DOUBLE_GRAD_FUNCTOR(ELU)
Y
YuanRisheng 已提交
281 282 283 284 285
USE_PHI_FUNCTOR(Sigmoid)
USE_PHI_DOUBLE_GRAD_FUNCTOR(Sigmoid)
USE_PHI_TRIPLE_GRAD_FUNCTOR(Sigmoid)
USE_PHI_FUNCTOR(LogSigmoid)
USE_PHI_FUNCTOR(HardSigmoid)
286 287 288 289 290
USE_PHI_FUNCTOR(Log)
USE_PHI_DOUBLE_GRAD_FUNCTOR(Log)
USE_PHI_FUNCTOR(Log2)
USE_PHI_FUNCTOR(Log10)
USE_PHI_FUNCTOR(Log1p)
Y
YuanRisheng 已提交
291 292 293
USE_PHI_FUNCTOR(Swish)
USE_PHI_FUNCTOR(HardSwish)
USE_PHI_FUNCTOR(Pow)
294 295 296 297 298 299
USE_PHI_FUNCTOR(Exp)
USE_PHI_FUNCTOR(Expm1)
USE_PHI_FUNCTOR(Mish)
USE_PHI_FUNCTOR(STanh)
USE_PHI_FUNCTOR(Reciprocal)
USE_PHI_FUNCTOR(Square)
Y
YuanRisheng 已提交
300
USE_PHI_DOUBLE_GRAD_FUNCTOR(Square)
301
USE_PHI_FUNCTOR(Sqrt)
Y
YuanRisheng 已提交
302
USE_PHI_DOUBLE_GRAD_FUNCTOR(Sqrt)
303
USE_PHI_FUNCTOR(Rsqrt)
Y
YuanRisheng 已提交
304
USE_PHI_DOUBLE_GRAD_FUNCTOR(Rsqrt)
305
USE_PHI_FUNCTOR(Softplus)
Y
YuanRisheng 已提交
306 307
USE_PHI_FUNCTOR(CELU)
USE_PHI_DOUBLE_GRAD_FUNCTOR(CELU)
Y
YuanRisheng 已提交
308 309 310

template <typename T>
using ELUGradNegativeAlphaFunctor = phi::funcs::ELUGradNegativeAlphaFunctor<T>;
311

Y
YuanRisheng 已提交
312 313 314 315 316 317 318 319 320 321 322 323
template <typename T>
using RoundFunctor = phi::funcs::RoundFunctor<T>;

template <typename T>
using FloorFunctor = phi::funcs::FloorFunctor<T>;

template <typename T>
using CeilFunctor = phi::funcs::CeilFunctor<T>;

template <typename T>
using ZeroGradFunctor = phi::funcs::ZeroGradFunctor<T>;

324
template <typename T>
325
using ELUGradNegativeAlphaFunctor = phi::funcs::ELUGradNegativeAlphaFunctor<T>;
R
ronnywang 已提交
326

Q
qijun 已提交
327
// relu(x) = max(x, 0)
328 329

template <typename T>
330 331 332
using ReluCPUFunctor = phi::funcs::ReluCPUFunctor<T>;
template <typename T>
using ReluGradFunctor = phi::funcs::ReluGradFunctor<T>;
Q
qijun 已提交
333

Q
qijun 已提交
334
template <typename T>
335
using ReluGradGradFunctor = phi::funcs::ReluGradGradFunctor<T>;
336

337 338
template <typename T>
using ReluCUDAFunctor = phi::funcs::ReluCUDAFunctor<T>;
Q
qijun 已提交
339

340 341 342 343 344 345 346 347 348
// relu6(x) = min(max(0, x), 6)
template <typename T>
struct Relu6Functor : public BaseActivationFunctor<T> {
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

F
fengjiayi 已提交
349 350 351
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) =
Y
Yu Yang 已提交
352
        x.cwiseMax(static_cast<T>(0)).cwiseMin(static_cast<T>(threshold));
353 354 355 356 357 358 359 360 361
  }
};

template <typename T>
struct Relu6GradFunctor : public BaseActivationFunctor<T> {
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }
F
fengjiayi 已提交
362 363 364
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
D
dzhwinter 已提交
365
    dx.device(d) =
366 367
        dout * ((out > static_cast<T>(0)) * (out < static_cast<T>(threshold)))
                   .template cast<T>();
368
  }
369

370 371 372
  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
373 374
};

375 376 377 378 379 380
template <typename T>
struct SoftReluFunctor : public BaseActivationFunctor<T> {
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }
381

F
fengjiayi 已提交
382 383
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
Y
Yu Yang 已提交
384 385
    auto tmp = static_cast<T>(threshold);
    auto temp = x.cwiseMax(-tmp).cwiseMin(tmp);
F
fengjiayi 已提交
386
    out.device(d) = (static_cast<T>(1) + temp.exp()).log();
387 388 389
  }
};

390 391 392 393 394 395
template <typename T>
struct SoftReluGradFunctor : public BaseActivationFunctor<T> {
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }
F
fengjiayi 已提交
396 397 398
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
Y
Yu Yang 已提交
399
    auto tmp = static_cast<T>(threshold);
Z
Zeng Jinle 已提交
400
    auto temp = ((out > -tmp) * (out < tmp)).template cast<T>();
F
fengjiayi 已提交
401
    dx.device(d) = dout * (static_cast<T>(1) - (-out).exp()) * temp;
402
  }
403

404 405 406
  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
407 408
};

Z
zhupengyang 已提交
409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443
template <typename DeviceContext, typename T>
class ELUGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto* X = context.Input<framework::Tensor>("X");
    auto* Out = context.Input<framework::Tensor>("Out");
    auto* dOut =
        context.Input<framework::Tensor>(framework::GradVarName("Out"));
    auto* dX = context.Output<framework::Tensor>(framework::GradVarName("X"));
    const float alpha = context.Attr<float>("alpha");
    dX->mutable_data<T>(context.GetPlace());

    auto x = framework::EigenVector<T>::Flatten(
        GET_DATA_SAFELY(X, "Input", "X", "elu_grad"));
    auto out = framework::EigenVector<T>::Flatten(
        GET_DATA_SAFELY(Out, "Input", "Out", "elu_grad"));
    auto dout = framework::EigenVector<T>::Flatten(
        GET_DATA_SAFELY(dOut, "Input", "dOut", "elu_grad"));
    auto dx = framework::EigenVector<T>::Flatten(
        GET_DATA_SAFELY(dX, "Output", "dX", "elu_grad"));
    auto* place =
        context.template device_context<DeviceContext>().eigen_device();

    if (alpha > 0) {
      ELUGradFunctor<T> functor;
      functor.alpha = alpha;
      functor(*place, x, out, dout, dx);
    } else {
      ELUGradNegativeAlphaFunctor<T> functor;
      functor.alpha = alpha;
      functor(*place, x, out, dout, dx);
    }
  }
};

Z
Zhong Hui 已提交
444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461
template <typename T>
struct AbsGradGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device>
  void operator()(const Device& dev, const framework::Tensor* X,
                  const framework::Tensor* Out, const framework::Tensor* ddX,
                  framework::Tensor* ddOut, framework::Tensor* dOut,
                  framework::Tensor* dX) const {
    auto* d = dev.eigen_device();
    auto ddx = framework::EigenVector<T>::Flatten(
        GET_DATA_SAFELY(ddX, "Input", "DDX", "AbsGradGrad"));
    auto x = framework::EigenVector<T>::Flatten(
        GET_DATA_SAFELY(X, "Input", "X", "AbsGradGrad"));
    if (ddOut) {
      auto ddout = framework::EigenVector<T>::Flatten(
          GET_DATA_SAFELY(ddOut, "Output", "DDOut", "AbsGradGrad"));
      ddout.device(*d) = ddx * x.sign();
    }
  }
462
  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
463 464
};

465 466
// TODO(dengkaipeng): double gradient calculation for Square/Sqrt need
// DOut(dy) as input(not output), tensor extraction is different from
467
// others. Impliment extraction kernel separately here.
468 469 470 471 472 473 474
inline void ExtractDoubleGradTensorWithInputDOut(
    const framework::ExecutionContext& ctx, const framework::Tensor** X,
    const framework::Tensor** ddX, framework::Tensor** dX,
    const framework::Tensor** dOut, framework::Tensor** ddOut) {
  // extract ddX(output), ddOut(input)
  auto ddx_var = ctx.InputVar("DDX");
  auto ddo_var = ctx.OutputVar("DDOut");
475 476 477 478
  PADDLE_ENFORCE_NOT_NULL(
      ddx_var, platform::errors::NotFound(
                   "Cannot get input Variable Out, variable name = %s",
                   ctx.InputName("DDX")));
479 480 481 482
  *ddX = ctx.Input<framework::Tensor>("DDX");
  if (ddo_var) {
    *ddOut = ctx.Output<framework::Tensor>("DDOut");
  }
483 484 485 486 487
  PADDLE_ENFORCE_NOT_NULL(
      ddX,
      platform::errors::NotFound(
          "Cannot get the tensor from the Variable DDX, variable name = %s",
          ctx.OutputName("DDX")));
488 489 490

  // extract x(input), dx(output)
  auto x_var = ctx.InputVar("X");
491 492
  PADDLE_ENFORCE_NOT_NULL(
      x_var, platform::errors::NotFound(
493
                 "Cannot get input Variable Out, variable name = %s",
494
                 ctx.InputName("X")));
495 496 497 498 499 500 501 502 503 504 505 506 507
  auto dx_var = ctx.OutputVar("DX");
  *X = ctx.Input<framework::Tensor>("X");
  if (dx_var) {
    *dX = ctx.Output<framework::Tensor>("DX");
  }

  // extract dOut(input)
  auto dout_var = ctx.InputVar("DOut");
  if (dout_var) {
    *dOut = ctx.Input<framework::Tensor>("DOut");
  }
}

508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530
template <typename T>
struct SoftsignFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x / (static_cast<T>(1) + x.abs());
  }
};

// d(softsign(x))/dx = 1 / (1 + |x|)^2
// Taken from https://en.wikipedia.org/wiki/Activation_function

template <typename T>
struct SoftsignGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) =
        dout * (static_cast<T>(1) / (static_cast<T>(1) + x.abs()).square());
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

Q
qijun 已提交
531 532
}  // namespace operators
}  // namespace paddle
533

534 535 536 537
#define FOR_EACH_ACTIVATION_OP(__macro)                               \
  __macro(soft_relu, SoftRelu, SoftReluFunctor, SoftReluGradFunctor); \
  __macro(softsign, Softsign, SoftsignFunctor, SoftsignGradFunctor);  \
  __macro(relu6, Relu6, Relu6Functor, Relu6GradFunctor);