activation_op.h 18.7 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2

L
Luo Tao 已提交
3 4 5 6 7 8 9 10 11
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Q
qijun 已提交
12 13

#pragma once
D
dzhwinter 已提交
14
#include <glog/logging.h>
Y
Yihua Xu 已提交
15
#include <algorithm>
16
#include <memory>
D
dzhwinter 已提交
17 18
#include <string>
#include <unordered_set>
19 20
#include <utility>
#include <vector>
21

C
Clementine 已提交
22 23 24 25 26
#include <cmath>
#ifndef _USE_MATH_DEFINES
#define _USE_MATH_DEFINES
#endif

27
#include <type_traits>
Y
Yi Wang 已提交
28 29
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
30
#include "paddle/fluid/framework/tensor_util.h"
31
#include "paddle/fluid/platform/enforce.h"
32
#include "paddle/fluid/platform/float16.h"
33
#include "paddle/phi/kernels/funcs/blas/blas.h"
34 35 36 37
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif

38 39
#include "paddle/phi/kernels/funcs/activation_functor.h"

Q
qijun 已提交
40 41 42
namespace paddle {
namespace operators {

43 44
using framework::To32BitIndex;

45
using ActBwdOpFwdDeps = phi::funcs::ActBwdOpFwdDeps;
46

C
chengduo 已提交
47 48 49 50 51 52
/* The following operator can be used to process SelectedRows, because the
 * output of those operator for zero is zero too.
 */
static std::unordered_set<std::string> CanBeUsedBySelectedRows = {
    "abs", "abs_grad", "square", "square_grad", "sqrt", "sqrt_grad"};

53 54 55 56 57
inline void ExtractActivationTensor(const framework::ExecutionContext& context,
                                    const framework::Tensor** X,
                                    framework::Tensor** Out) {
  auto x_var = context.InputVar("X");
  auto out_var = context.OutputVar("Out");
58 59 60 61 62 63 64 65
  PADDLE_ENFORCE_NOT_NULL(x_var,
                          platform::errors::NotFound(
                              "Cannot get input Variable X, variable name = %s",
                              context.InputName("X")));
  PADDLE_ENFORCE_NOT_NULL(
      out_var, platform::errors::NotFound(
                   "Cannot get output Variable Out, variable name = %s",
                   context.OutputName("Out")));
H
hong 已提交
66
  if (CanBeUsedBySelectedRows.count(context.Type())) {
67 68 69 70 71 72 73 74
    *X = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_var);
    *Out = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(
        out_var);
  } else {
    *X = context.Input<framework::Tensor>("X");
    *Out = context.Output<framework::Tensor>("Out");
  }

75 76 77 78
  PADDLE_ENFORCE_NOT_NULL(*Out, platform::errors::NotFound(
                                    "Cannot get the tensor from the Variable "
                                    "Output(Out), variable name = %s",
                                    context.OutputName("Out")));
79 80
}

81
template <ActBwdOpFwdDeps kDepValue>
82 83 84 85 86 87
inline void ExtractActivationGradTensor(
    const framework::ExecutionContext& context, const framework::Tensor** X,
    const framework::Tensor** Out, const framework::Tensor** dOut,
    framework::Tensor** dX) {
  auto out_grad_var = context.InputVar(framework::GradVarName("Out"));
  auto x_grad_var = context.OutputVar(framework::GradVarName("X"));
88 89
  const framework::Variable* out_var = nullptr;

90 91
  if (static_cast<int>(kDepValue) &
      static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
92
    out_var = context.InputVar("Out");
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
    PADDLE_ENFORCE_NOT_NULL(
        out_var, platform::errors::NotFound(
                     "Cannot get input Variable Out, variable name = %s",
                     context.InputName("Out")));
  }

  PADDLE_ENFORCE_NOT_NULL(
      out_grad_var, platform::errors::NotFound(
                        "Cannot get input Variable %s, variable name = %s",
                        framework::GradVarName("Out"),
                        context.InputName(framework::GradVarName("Out"))));
  PADDLE_ENFORCE_NOT_NULL(
      x_grad_var, platform::errors::NotFound(
                      "Cannot get output Variable %s, variable name = %s",
                      framework::GradVarName("X"),
                      context.OutputName(framework::GradVarName("X"))));
109

H
hong 已提交
110
  if (CanBeUsedBySelectedRows.count(context.Type())) {
111 112 113 114
    *dOut = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(
        *out_grad_var);
    *dX = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(
        x_grad_var);
115 116 117 118 119 120 121 122

    if (out_var) {
      *Out =
          paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*out_var);
    } else {
      *Out = *dOut;  // fake out
    }

123 124 125 126
  } else {
    *Out = context.Input<framework::Tensor>("Out");
    *dOut = context.Input<framework::Tensor>(framework::GradVarName("Out"));
    *dX = context.Output<framework::Tensor>(framework::GradVarName("X"));
127 128 129 130 131 132

    if (out_var) {
      *Out = &(out_var->Get<framework::LoDTensor>());
    } else {
      *Out = *dOut;  // fake out
    }
133
  }
134

135 136 137 138 139
  PADDLE_ENFORCE_NOT_NULL(*dX,
                          platform::errors::NotFound(
                              "Cannot get the tensor from the Variable "
                              "Output(Out), variable name = %s",
                              context.OutputName(framework::GradVarName("X"))));
140

141
  if (static_cast<int>(kDepValue) & static_cast<int>(ActBwdOpFwdDeps::kDepX)) {
C
chengduo 已提交
142
    auto x_var = context.InputVar("X");
143 144 145 146
    PADDLE_ENFORCE_NOT_NULL(x_var, platform::errors::NotFound(
                                       "Cannot get the tensor from the "
                                       "Variable Input(X), variable name = %s",
                                       context.InputName("X")));
H
hong 已提交
147
    if (CanBeUsedBySelectedRows.count(context.Type())) {
148
      *X = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_var);
C
chengduo 已提交
149
    } else {
150
      *X = context.Input<framework::Tensor>("X");
C
chengduo 已提交
151
    }
152
  } else {
H
hong 已提交
153
    VLOG(10) << " Inplace activation of Op : " << context.Type();
154 155 156
    *X = *dX;
  }
}
C
chengduo 已提交
157

158 159 160 161 162
template <typename DeviceContext, typename Functor>
class ActivationKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  using T = typename Functor::ELEMENT_TYPE;
C
chengduo 已提交
163

164 165 166 167
  void Compute(const framework::ExecutionContext& context) const override {
    const framework::Tensor* X = nullptr;
    framework::Tensor* Out = nullptr;
    ExtractActivationTensor(context, &X, &Out);
C
chengduo 已提交
168
    Out->mutable_data<T>(context.GetPlace());
169

170 171 172 173
    auto x = framework::EigenVector<T>::Flatten(
        GET_DATA_SAFELY(X, "Input", "X", "Activation"));
    auto out = framework::EigenVector<T>::Flatten(
        GET_DATA_SAFELY(Out, "Output", "Out", "Activation"));
Q
QI JUN 已提交
174 175
    auto* place =
        context.template device_context<DeviceContext>().eigen_device();
Q
qijun 已提交
176
    Functor functor;
177 178 179 180 181

    auto attrs = functor.GetAttrs();
    for (auto& attr : attrs) {
      *attr.second = context.Attr<float>(attr.first);
    }
182 183 184 185 186 187 188 189
    // use 32bit index to speed up computation
    bool use_32bit_index = out.size() < Eigen::NumTraits<int>::highest();
    bool is_gpu_place = platform::is_gpu_place(context.GetPlace());
    if (use_32bit_index && is_gpu_place) {
      functor(*place, To32BitIndex(x), To32BitIndex(out));
    } else {
      functor(*place, x, out);
    }
Q
qijun 已提交
190 191 192
  }
};

Q
QI JUN 已提交
193
template <typename DeviceContext, typename Functor>
194 195
class ActivationGradKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
Q
qijun 已提交
196
 public:
197
  using T = typename Functor::ELEMENT_TYPE;
Q
qijun 已提交
198
  void Compute(const framework::ExecutionContext& context) const override {
199 200 201
    const framework::Tensor *X, *Out, *dOut;
    framework::Tensor* dX = nullptr;
    X = Out = dOut = nullptr;
202 203
    ExtractActivationGradTensor<Functor::FwdDeps()>(context, &X, &Out, &dOut,
                                                    &dX);
Q
qijun 已提交
204
    dX->mutable_data<T>(context.GetPlace());
205 206 207 208 209 210 211 212
    auto dout = framework::EigenVector<T>::Flatten(
        GET_DATA_SAFELY(dOut, "Input", "Out@GRAD", "ActivationGrad"));
    auto out = framework::EigenVector<T>::Flatten(
        GET_DATA_SAFELY(Out, "Input", "Out", "ActivationGrad"));
    auto dx = framework::EigenVector<T>::Flatten(
        GET_DATA_SAFELY(dX, "Input", "X@GRAD", "ActivationGrad"));
    auto x = framework::EigenVector<T>::Flatten(
        GET_DATA_SAFELY(X, "Input", "X", "ActivationGrad"));
Q
QI JUN 已提交
213 214
    auto* place =
        context.template device_context<DeviceContext>().eigen_device();
Q
qijun 已提交
215
    Functor functor;
216 217 218 219
    auto attrs = functor.GetAttrs();
    for (auto& attr : attrs) {
      *attr.second = context.Attr<float>(attr.first);
    }
220 221 222 223 224 225 226 227 228
    // use 32bit index to speed up computation
    bool use_32bit_index = out.size() < Eigen::NumTraits<int>::highest();
    bool is_gpu_place = platform::is_gpu_place(context.GetPlace());
    if (use_32bit_index && is_gpu_place) {
      functor(*place, To32BitIndex(x), To32BitIndex(out), To32BitIndex(dout),
              To32BitIndex(dx));
    } else {
      functor(*place, x, out, dout, dx);
    }
Q
qijun 已提交
229 230 231
  }
};

232 233 234 235 236 237 238 239 240
template <typename T>
struct BaseActivationFunctor {
  using ELEMENT_TYPE = T;

  using AttrPair = std::vector<std::pair<const char*, float*>>;

  AttrPair GetAttrs() { return AttrPair(); }
};

241 242 243 244 245 246
#define USE_PHI_FUNCTOR(name)                         \
  template <typename T>                               \
  using name##Functor = phi::funcs::name##Functor<T>; \
  template <typename T>                               \
  using name##GradFunctor = phi::funcs::name##GradFunctor<T>;

247 248 249 250 251 252 253 254
#define USE_PHI_DOUBLE_GRAD_FUNCTOR(name) \
  template <typename T>                   \
  using name##GradGradFunctor = phi::funcs::name##GradGradFunctor<T>;

#define USE_PHI_TRIPLE_GRAD_FUNCTOR(name) \
  template <typename T>                   \
  using name##TripleGradFunctor = phi::funcs::name##TripleGradFunctor<T>;

255 256 257 258 259 260 261 262 263 264 265
USE_PHI_FUNCTOR(Cos)
USE_PHI_FUNCTOR(Tan)
USE_PHI_FUNCTOR(Acos)
USE_PHI_FUNCTOR(Sin)
USE_PHI_FUNCTOR(Asin)
USE_PHI_FUNCTOR(Atan)
USE_PHI_FUNCTOR(Sinh)
USE_PHI_FUNCTOR(Cosh)
USE_PHI_FUNCTOR(Asinh)
USE_PHI_FUNCTOR(Acosh)
USE_PHI_FUNCTOR(Atanh)
266
USE_PHI_FUNCTOR(Tanh)
267
USE_PHI_FUNCTOR(Exp)
268 269 270 271 272 273
USE_PHI_DOUBLE_GRAD_FUNCTOR(Tanh)
USE_PHI_TRIPLE_GRAD_FUNCTOR(Tanh)
USE_PHI_FUNCTOR(BRelu)
USE_PHI_FUNCTOR(ThresholdedRelu)
USE_PHI_FUNCTOR(LeakyRelu)
USE_PHI_DOUBLE_GRAD_FUNCTOR(LeakyRelu)
Y
YuanRisheng 已提交
274 275 276 277 278 279
USE_PHI_FUNCTOR(HardShrink)
USE_PHI_FUNCTOR(SoftShrink)
USE_PHI_FUNCTOR(TanhShrink)
USE_PHI_FUNCTOR(Silu)
USE_PHI_FUNCTOR(ELU)
USE_PHI_DOUBLE_GRAD_FUNCTOR(ELU)
Y
YuanRisheng 已提交
280 281 282 283 284
USE_PHI_FUNCTOR(Sigmoid)
USE_PHI_DOUBLE_GRAD_FUNCTOR(Sigmoid)
USE_PHI_TRIPLE_GRAD_FUNCTOR(Sigmoid)
USE_PHI_FUNCTOR(LogSigmoid)
USE_PHI_FUNCTOR(HardSigmoid)
285 286 287 288 289
USE_PHI_FUNCTOR(Log)
USE_PHI_DOUBLE_GRAD_FUNCTOR(Log)
USE_PHI_FUNCTOR(Log2)
USE_PHI_FUNCTOR(Log10)
USE_PHI_FUNCTOR(Log1p)
Y
YuanRisheng 已提交
290 291 292
USE_PHI_FUNCTOR(Swish)
USE_PHI_FUNCTOR(HardSwish)
USE_PHI_FUNCTOR(Pow)
293 294 295 296 297 298
USE_PHI_FUNCTOR(Exp)
USE_PHI_FUNCTOR(Expm1)
USE_PHI_FUNCTOR(Mish)
USE_PHI_FUNCTOR(STanh)
USE_PHI_FUNCTOR(Reciprocal)
USE_PHI_FUNCTOR(Square)
Y
YuanRisheng 已提交
299
USE_PHI_DOUBLE_GRAD_FUNCTOR(Square)
300
USE_PHI_FUNCTOR(Sqrt)
Y
YuanRisheng 已提交
301
USE_PHI_DOUBLE_GRAD_FUNCTOR(Sqrt)
302
USE_PHI_FUNCTOR(Rsqrt)
Y
YuanRisheng 已提交
303
USE_PHI_DOUBLE_GRAD_FUNCTOR(Rsqrt)
304
USE_PHI_FUNCTOR(Softplus)
Y
YuanRisheng 已提交
305 306
USE_PHI_FUNCTOR(CELU)
USE_PHI_DOUBLE_GRAD_FUNCTOR(CELU)
Y
YuanRisheng 已提交
307 308 309

template <typename T>
using ELUGradNegativeAlphaFunctor = phi::funcs::ELUGradNegativeAlphaFunctor<T>;
310

Y
YuanRisheng 已提交
311 312 313 314 315 316 317 318 319 320 321 322
template <typename T>
using RoundFunctor = phi::funcs::RoundFunctor<T>;

template <typename T>
using FloorFunctor = phi::funcs::FloorFunctor<T>;

template <typename T>
using CeilFunctor = phi::funcs::CeilFunctor<T>;

template <typename T>
using ZeroGradFunctor = phi::funcs::ZeroGradFunctor<T>;

323
template <typename T>
324
using ELUGradNegativeAlphaFunctor = phi::funcs::ELUGradNegativeAlphaFunctor<T>;
R
ronnywang 已提交
325

Q
qijun 已提交
326
// relu(x) = max(x, 0)
327 328

template <typename T>
329 330 331
using ReluCPUFunctor = phi::funcs::ReluCPUFunctor<T>;
template <typename T>
using ReluGradFunctor = phi::funcs::ReluGradFunctor<T>;
Q
qijun 已提交
332

Q
qijun 已提交
333
template <typename T>
334
using ReluGradGradFunctor = phi::funcs::ReluGradGradFunctor<T>;
335

336 337
template <typename T>
using ReluCUDAFunctor = phi::funcs::ReluCUDAFunctor<T>;
Q
qijun 已提交
338

339 340 341 342 343 344 345 346 347
// relu6(x) = min(max(0, x), 6)
template <typename T>
struct Relu6Functor : public BaseActivationFunctor<T> {
  float threshold;

  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }

F
fengjiayi 已提交
348 349 350
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) =
Y
Yu Yang 已提交
351
        x.cwiseMax(static_cast<T>(0)).cwiseMin(static_cast<T>(threshold));
352 353 354 355 356 357 358 359 360
  }
};

template <typename T>
struct Relu6GradFunctor : public BaseActivationFunctor<T> {
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }
F
fengjiayi 已提交
361 362 363
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
D
dzhwinter 已提交
364 365 366 367
    dx.device(d) =
        dout *
        ((out > static_cast<T>(0)) * (out < static_cast<T>(threshold)))
            .template cast<T>();
368
  }
369

370 371 372
  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
373 374
};

375 376 377 378 379 380
template <typename T>
struct SoftReluFunctor : public BaseActivationFunctor<T> {
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }
381

F
fengjiayi 已提交
382 383
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
Y
Yu Yang 已提交
384 385
    auto tmp = static_cast<T>(threshold);
    auto temp = x.cwiseMax(-tmp).cwiseMin(tmp);
F
fengjiayi 已提交
386
    out.device(d) = (static_cast<T>(1) + temp.exp()).log();
387 388 389
  }
};

390 391 392 393 394 395
template <typename T>
struct SoftReluGradFunctor : public BaseActivationFunctor<T> {
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }
F
fengjiayi 已提交
396 397 398
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
Y
Yu Yang 已提交
399
    auto tmp = static_cast<T>(threshold);
Z
Zeng Jinle 已提交
400
    auto temp = ((out > -tmp) * (out < tmp)).template cast<T>();
F
fengjiayi 已提交
401
    dx.device(d) = dout * (static_cast<T>(1) - (-out).exp()) * temp;
402
  }
403

404 405 406
  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
407 408
};

Z
zhupengyang 已提交
409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443
template <typename DeviceContext, typename T>
class ELUGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto* X = context.Input<framework::Tensor>("X");
    auto* Out = context.Input<framework::Tensor>("Out");
    auto* dOut =
        context.Input<framework::Tensor>(framework::GradVarName("Out"));
    auto* dX = context.Output<framework::Tensor>(framework::GradVarName("X"));
    const float alpha = context.Attr<float>("alpha");
    dX->mutable_data<T>(context.GetPlace());

    auto x = framework::EigenVector<T>::Flatten(
        GET_DATA_SAFELY(X, "Input", "X", "elu_grad"));
    auto out = framework::EigenVector<T>::Flatten(
        GET_DATA_SAFELY(Out, "Input", "Out", "elu_grad"));
    auto dout = framework::EigenVector<T>::Flatten(
        GET_DATA_SAFELY(dOut, "Input", "dOut", "elu_grad"));
    auto dx = framework::EigenVector<T>::Flatten(
        GET_DATA_SAFELY(dX, "Output", "dX", "elu_grad"));
    auto* place =
        context.template device_context<DeviceContext>().eigen_device();

    if (alpha > 0) {
      ELUGradFunctor<T> functor;
      functor.alpha = alpha;
      functor(*place, x, out, dout, dx);
    } else {
      ELUGradNegativeAlphaFunctor<T> functor;
      functor.alpha = alpha;
      functor(*place, x, out, dout, dx);
    }
  }
};

Z
Zhong Hui 已提交
444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461
template <typename T>
struct AbsGradGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device>
  void operator()(const Device& dev, const framework::Tensor* X,
                  const framework::Tensor* Out, const framework::Tensor* ddX,
                  framework::Tensor* ddOut, framework::Tensor* dOut,
                  framework::Tensor* dX) const {
    auto* d = dev.eigen_device();
    auto ddx = framework::EigenVector<T>::Flatten(
        GET_DATA_SAFELY(ddX, "Input", "DDX", "AbsGradGrad"));
    auto x = framework::EigenVector<T>::Flatten(
        GET_DATA_SAFELY(X, "Input", "X", "AbsGradGrad"));
    if (ddOut) {
      auto ddout = framework::EigenVector<T>::Flatten(
          GET_DATA_SAFELY(ddOut, "Output", "DDOut", "AbsGradGrad"));
      ddout.device(*d) = ddx * x.sign();
    }
  }
462
  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
463 464
};

465 466
// TODO(dengkaipeng): double gradient calculation for Square/Sqrt need
// DOut(dy) as input(not output), tensor extraction is different from
467
// others. Impliment extraction kernel separately here.
468 469 470 471 472 473 474
inline void ExtractDoubleGradTensorWithInputDOut(
    const framework::ExecutionContext& ctx, const framework::Tensor** X,
    const framework::Tensor** ddX, framework::Tensor** dX,
    const framework::Tensor** dOut, framework::Tensor** ddOut) {
  // extract ddX(output), ddOut(input)
  auto ddx_var = ctx.InputVar("DDX");
  auto ddo_var = ctx.OutputVar("DDOut");
475 476 477 478
  PADDLE_ENFORCE_NOT_NULL(
      ddx_var, platform::errors::NotFound(
                   "Cannot get input Variable Out, variable name = %s",
                   ctx.InputName("DDX")));
479 480 481 482
  *ddX = ctx.Input<framework::Tensor>("DDX");
  if (ddo_var) {
    *ddOut = ctx.Output<framework::Tensor>("DDOut");
  }
483 484 485 486 487
  PADDLE_ENFORCE_NOT_NULL(
      ddX,
      platform::errors::NotFound(
          "Cannot get the tensor from the Variable DDX, variable name = %s",
          ctx.OutputName("DDX")));
488 489 490

  // extract x(input), dx(output)
  auto x_var = ctx.InputVar("X");
491 492
  PADDLE_ENFORCE_NOT_NULL(
      x_var, platform::errors::NotFound(
493
                 "Cannot get input Variable Out, variable name = %s",
494
                 ctx.InputName("X")));
495 496 497 498 499 500 501 502 503 504 505 506 507
  auto dx_var = ctx.OutputVar("DX");
  *X = ctx.Input<framework::Tensor>("X");
  if (dx_var) {
    *dX = ctx.Output<framework::Tensor>("DX");
  }

  // extract dOut(input)
  auto dout_var = ctx.InputVar("DOut");
  if (dout_var) {
    *dOut = ctx.Input<framework::Tensor>("DOut");
  }
}

508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530
template <typename T>
struct SoftsignFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x / (static_cast<T>(1) + x.abs());
  }
};

// d(softsign(x))/dx = 1 / (1 + |x|)^2
// Taken from https://en.wikipedia.org/wiki/Activation_function

template <typename T>
struct SoftsignGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out, typename dOut,
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
    dx.device(d) =
        dout * (static_cast<T>(1) / (static_cast<T>(1) + x.abs()).square());
  }

  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

Q
qijun 已提交
531 532
}  // namespace operators
}  // namespace paddle
533

534 535 536 537
#define FOR_EACH_ACTIVATION_OP(__macro)                               \
  __macro(soft_relu, SoftRelu, SoftReluFunctor, SoftReluGradFunctor); \
  __macro(softsign, Softsign, SoftsignFunctor, SoftsignGradFunctor);  \
  __macro(relu6, Relu6, Relu6Functor, Relu6GradFunctor);