activation_op.h 14.4 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2

L
Luo Tao 已提交
3 4 5 6 7 8 9 10 11
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Q
qijun 已提交
12 13

#pragma once
D
dzhwinter 已提交
14
#include <glog/logging.h>
15

Y
Yihua Xu 已提交
16
#include <algorithm>
17
#include <cmath>
18
#include <memory>
D
dzhwinter 已提交
19 20
#include <string>
#include <unordered_set>
21 22
#include <utility>
#include <vector>
C
Clementine 已提交
23 24 25 26
#ifndef _USE_MATH_DEFINES
#define _USE_MATH_DEFINES
#endif

27
#include <type_traits>
28

Y
Yi Wang 已提交
29
#include "paddle/fluid/framework/eigen.h"
C
Charles-hit 已提交
30
#include "paddle/fluid/framework/infershape_utils.h"
Y
Yi Wang 已提交
31
#include "paddle/fluid/framework/op_registry.h"
32
#include "paddle/fluid/framework/tensor_util.h"
33
#include "paddle/fluid/platform/enforce.h"
34
#include "paddle/fluid/platform/float16.h"
35
#include "paddle/phi/kernels/funcs/blas/blas.h"
36

37 38
#include "paddle/phi/kernels/funcs/activation_functor.h"

Q
qijun 已提交
39 40 41
namespace paddle {
namespace operators {

42 43
using framework::To32BitIndex;

44
using ActBwdOpFwdDeps = phi::funcs::ActBwdOpFwdDeps;
45

C
chengduo 已提交
46 47 48 49 50 51
/* The following operator can be used to process SelectedRows, because the
 * output of those operator for zero is zero too.
 */
static std::unordered_set<std::string> CanBeUsedBySelectedRows = {
    "abs", "abs_grad", "square", "square_grad", "sqrt", "sqrt_grad"};

52
inline void ExtractActivationTensor(const framework::ExecutionContext& context,
53 54
                                    const phi::DenseTensor** X,
                                    phi::DenseTensor** Out) {
55 56
  auto x_var = context.InputVar("X");
  auto out_var = context.OutputVar("Out");
57 58 59 60 61
  PADDLE_ENFORCE_NOT_NULL(x_var,
                          platform::errors::NotFound(
                              "Cannot get input Variable X, variable name = %s",
                              context.InputName("X")));
  PADDLE_ENFORCE_NOT_NULL(
62 63 64 65
      out_var,
      platform::errors::NotFound(
          "Cannot get output Variable Out, variable name = %s",
          context.OutputName("Out")));
H
hong 已提交
66
  if (CanBeUsedBySelectedRows.count(context.Type())) {
67 68 69 70
    *X = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_var);
    *Out = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(
        out_var);
  } else {
71 72
    *X = context.Input<phi::DenseTensor>("X");
    *Out = context.Output<phi::DenseTensor>("Out");
73 74
  }

75 76 77 78 79
  PADDLE_ENFORCE_NOT_NULL(
      *Out,
      platform::errors::NotFound("Cannot get the tensor from the Variable "
                                 "Output(Out), variable name = %s",
                                 context.OutputName("Out")));
80 81
}

82
template <ActBwdOpFwdDeps kDepValue>
83
inline void ExtractActivationGradTensor(
84
    const framework::ExecutionContext& context,
85 86 87 88
    const phi::DenseTensor** X,
    const phi::DenseTensor** Out,
    const phi::DenseTensor** dOut,
    phi::DenseTensor** dX) {
89 90
  auto out_grad_var = context.InputVar(framework::GradVarName("Out"));
  auto x_grad_var = context.OutputVar(framework::GradVarName("X"));
91 92
  const framework::Variable* out_var = nullptr;

93 94
  if (static_cast<int>(kDepValue) &
      static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
95
    out_var = context.InputVar("Out");
96
    PADDLE_ENFORCE_NOT_NULL(
97 98 99 100
        out_var,
        platform::errors::NotFound(
            "Cannot get input Variable Out, variable name = %s",
            context.InputName("Out")));
101 102 103
  }

  PADDLE_ENFORCE_NOT_NULL(
104 105 106 107 108
      out_grad_var,
      platform::errors::NotFound(
          "Cannot get input Variable %s, variable name = %s",
          framework::GradVarName("Out"),
          context.InputName(framework::GradVarName("Out"))));
109
  PADDLE_ENFORCE_NOT_NULL(
110 111 112 113 114
      x_grad_var,
      platform::errors::NotFound(
          "Cannot get output Variable %s, variable name = %s",
          framework::GradVarName("X"),
          context.OutputName(framework::GradVarName("X"))));
115

H
hong 已提交
116
  if (CanBeUsedBySelectedRows.count(context.Type())) {
117 118 119 120
    *dOut = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(
        *out_grad_var);
    *dX = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(
        x_grad_var);
121 122 123 124 125 126 127 128

    if (out_var) {
      *Out =
          paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*out_var);
    } else {
      *Out = *dOut;  // fake out
    }

129
  } else {
130 131 132
    *Out = context.Input<phi::DenseTensor>("Out");
    *dOut = context.Input<phi::DenseTensor>(framework::GradVarName("Out"));
    *dX = context.Output<phi::DenseTensor>(framework::GradVarName("X"));
133 134

    if (out_var) {
135
      *Out = &(out_var->Get<phi::DenseTensor>());
136 137 138
    } else {
      *Out = *dOut;  // fake out
    }
139
  }
140

141 142 143 144 145
  PADDLE_ENFORCE_NOT_NULL(*dX,
                          platform::errors::NotFound(
                              "Cannot get the tensor from the Variable "
                              "Output(Out), variable name = %s",
                              context.OutputName(framework::GradVarName("X"))));
146

147
  if (static_cast<int>(kDepValue) & static_cast<int>(ActBwdOpFwdDeps::kDepX)) {
C
chengduo 已提交
148
    auto x_var = context.InputVar("X");
149 150 151 152 153
    PADDLE_ENFORCE_NOT_NULL(
        x_var,
        platform::errors::NotFound("Cannot get the tensor from the "
                                   "Variable Input(X), variable name = %s",
                                   context.InputName("X")));
H
hong 已提交
154
    if (CanBeUsedBySelectedRows.count(context.Type())) {
155
      *X = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_var);
C
chengduo 已提交
156
    } else {
157
      *X = context.Input<phi::DenseTensor>("X");
C
chengduo 已提交
158
    }
159
  } else {
H
hong 已提交
160
    VLOG(10) << " Inplace activation of Op : " << context.Type();
161 162 163
    *X = *dX;
  }
}
C
chengduo 已提交
164

165 166 167 168 169
template <typename DeviceContext, typename Functor>
class ActivationKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  using T = typename Functor::ELEMENT_TYPE;
C
chengduo 已提交
170

171
  void Compute(const framework::ExecutionContext& context) const override {
172 173
    const phi::DenseTensor* X = nullptr;
    phi::DenseTensor* Out = nullptr;
174
    ExtractActivationTensor(context, &X, &Out);
C
chengduo 已提交
175
    Out->mutable_data<T>(context.GetPlace());
176

177 178 179 180
    auto x = framework::EigenVector<T>::Flatten(
        GET_DATA_SAFELY(X, "Input", "X", "Activation"));
    auto out = framework::EigenVector<T>::Flatten(
        GET_DATA_SAFELY(Out, "Output", "Out", "Activation"));
Q
QI JUN 已提交
181 182
    auto* place =
        context.template device_context<DeviceContext>().eigen_device();
Q
qijun 已提交
183
    Functor functor;
184 185 186 187 188

    auto attrs = functor.GetAttrs();
    for (auto& attr : attrs) {
      *attr.second = context.Attr<float>(attr.first);
    }
189 190 191 192 193 194 195 196
    // use 32bit index to speed up computation
    bool use_32bit_index = out.size() < Eigen::NumTraits<int>::highest();
    bool is_gpu_place = platform::is_gpu_place(context.GetPlace());
    if (use_32bit_index && is_gpu_place) {
      functor(*place, To32BitIndex(x), To32BitIndex(out));
    } else {
      functor(*place, x, out);
    }
Q
qijun 已提交
197 198 199
  }
};

Q
QI JUN 已提交
200
template <typename DeviceContext, typename Functor>
201 202
class ActivationGradKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
Q
qijun 已提交
203
 public:
204
  using T = typename Functor::ELEMENT_TYPE;
Q
qijun 已提交
205
  void Compute(const framework::ExecutionContext& context) const override {
206 207
    const phi::DenseTensor *X, *Out, *dOut;
    phi::DenseTensor* dX = nullptr;
208
    X = Out = dOut = nullptr;
209 210
    ExtractActivationGradTensor<Functor::FwdDeps()>(
        context, &X, &Out, &dOut, &dX);
Q
qijun 已提交
211
    dX->mutable_data<T>(context.GetPlace());
212 213 214 215 216 217 218 219
    auto dout = framework::EigenVector<T>::Flatten(
        GET_DATA_SAFELY(dOut, "Input", "Out@GRAD", "ActivationGrad"));
    auto out = framework::EigenVector<T>::Flatten(
        GET_DATA_SAFELY(Out, "Input", "Out", "ActivationGrad"));
    auto dx = framework::EigenVector<T>::Flatten(
        GET_DATA_SAFELY(dX, "Input", "X@GRAD", "ActivationGrad"));
    auto x = framework::EigenVector<T>::Flatten(
        GET_DATA_SAFELY(X, "Input", "X", "ActivationGrad"));
Q
QI JUN 已提交
220 221
    auto* place =
        context.template device_context<DeviceContext>().eigen_device();
Q
qijun 已提交
222
    Functor functor;
223 224 225 226
    auto attrs = functor.GetAttrs();
    for (auto& attr : attrs) {
      *attr.second = context.Attr<float>(attr.first);
    }
227 228 229 230
    // use 32bit index to speed up computation
    bool use_32bit_index = out.size() < Eigen::NumTraits<int>::highest();
    bool is_gpu_place = platform::is_gpu_place(context.GetPlace());
    if (use_32bit_index && is_gpu_place) {
231 232 233 234
      functor(*place,
              To32BitIndex(x),
              To32BitIndex(out),
              To32BitIndex(dout),
235 236 237 238
              To32BitIndex(dx));
    } else {
      functor(*place, x, out, dout, dx);
    }
Q
qijun 已提交
239 240 241
  }
};

242 243 244 245 246 247 248 249 250
template <typename T>
struct BaseActivationFunctor {
  using ELEMENT_TYPE = T;

  using AttrPair = std::vector<std::pair<const char*, float*>>;

  AttrPair GetAttrs() { return AttrPair(); }
};

251 252 253 254 255 256
#define USE_PHI_FUNCTOR(name)                         \
  template <typename T>                               \
  using name##Functor = phi::funcs::name##Functor<T>; \
  template <typename T>                               \
  using name##GradFunctor = phi::funcs::name##GradFunctor<T>;

257 258 259 260 261 262 263 264
#define USE_PHI_DOUBLE_GRAD_FUNCTOR(name) \
  template <typename T>                   \
  using name##GradGradFunctor = phi::funcs::name##GradGradFunctor<T>;

#define USE_PHI_TRIPLE_GRAD_FUNCTOR(name) \
  template <typename T>                   \
  using name##TripleGradFunctor = phi::funcs::name##TripleGradFunctor<T>;

265 266 267 268 269
template <typename T>
using BReluFunctor = phi::funcs::HardTanhFunctor<T>;
template <typename T>
using BReluGradFunctor = phi::funcs::HardTanhGradFunctor<T>;

270
USE_PHI_FUNCTOR(Tanh)
271
USE_PHI_FUNCTOR(Relu6)
272 273
USE_PHI_FUNCTOR(LeakyRelu)
USE_PHI_DOUBLE_GRAD_FUNCTOR(LeakyRelu)
Y
YuanRisheng 已提交
274 275
USE_PHI_FUNCTOR(HardShrink)
USE_PHI_FUNCTOR(ELU)
Y
YuanRisheng 已提交
276 277
USE_PHI_FUNCTOR(Sigmoid)
USE_PHI_FUNCTOR(HardSigmoid)
Y
YuanRisheng 已提交
278 279 280
USE_PHI_FUNCTOR(Swish)
USE_PHI_FUNCTOR(HardSwish)
USE_PHI_FUNCTOR(Pow)
281 282
USE_PHI_FUNCTOR(Mish)
USE_PHI_FUNCTOR(STanh)
Y
YuanRisheng 已提交
283 284 285

template <typename T>
using ELUGradNegativeAlphaFunctor = phi::funcs::ELUGradNegativeAlphaFunctor<T>;
286

Y
YuanRisheng 已提交
287 288 289 290 291 292 293 294 295 296 297 298
template <typename T>
using RoundFunctor = phi::funcs::RoundFunctor<T>;

template <typename T>
using FloorFunctor = phi::funcs::FloorFunctor<T>;

template <typename T>
using CeilFunctor = phi::funcs::CeilFunctor<T>;

template <typename T>
using ZeroGradFunctor = phi::funcs::ZeroGradFunctor<T>;

299
template <typename T>
300
using ELUGradNegativeAlphaFunctor = phi::funcs::ELUGradNegativeAlphaFunctor<T>;
R
ronnywang 已提交
301

Q
qijun 已提交
302
// relu(x) = max(x, 0)
303 304

template <typename T>
305 306 307
using ReluCPUFunctor = phi::funcs::ReluCPUFunctor<T>;
template <typename T>
using ReluGradFunctor = phi::funcs::ReluGradFunctor<T>;
Q
qijun 已提交
308

Q
qijun 已提交
309
template <typename T>
310
using ReluGradGradFunctor = phi::funcs::ReluGradGradFunctor<T>;
311

312 313
template <typename T>
using ReluCUDAFunctor = phi::funcs::ReluCUDAFunctor<T>;
Q
qijun 已提交
314

315 316 317 318 319 320
template <typename T>
struct SoftReluFunctor : public BaseActivationFunctor<T> {
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }
321

F
fengjiayi 已提交
322 323
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
Y
Yu Yang 已提交
324 325
    auto tmp = static_cast<T>(threshold);
    auto temp = x.cwiseMax(-tmp).cwiseMin(tmp);
F
fengjiayi 已提交
326
    out.device(d) = (static_cast<T>(1) + temp.exp()).log();
327 328 329
  }
};

330 331 332 333 334 335
template <typename T>
struct SoftReluGradFunctor : public BaseActivationFunctor<T> {
  float threshold;
  typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
    return {{"threshold", &threshold}};
  }
336 337 338 339
  template <typename Device,
            typename X,
            typename Out,
            typename dOut,
F
fengjiayi 已提交
340 341
            typename dX>
  void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
Y
Yu Yang 已提交
342
    auto tmp = static_cast<T>(threshold);
Z
Zeng Jinle 已提交
343
    auto temp = ((out > -tmp) * (out < tmp)).template cast<T>();
F
fengjiayi 已提交
344
    dx.device(d) = dout * (static_cast<T>(1) - (-out).exp()) * temp;
345
  }
346

347 348 349
  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
350 351
};

Z
Zhong Hui 已提交
352 353 354
template <typename T>
struct AbsGradGradFunctor : public BaseActivationFunctor<T> {
  template <typename Device>
355
  void operator()(const Device& dev,
356 357 358 359 360 361
                  const phi::DenseTensor* X,
                  const phi::DenseTensor* Out,
                  const phi::DenseTensor* ddX,
                  phi::DenseTensor* ddOut,
                  phi::DenseTensor* dOut,
                  phi::DenseTensor* dX) const {
Z
Zhong Hui 已提交
362 363 364 365 366 367 368 369 370 371 372
    auto* d = dev.eigen_device();
    auto ddx = framework::EigenVector<T>::Flatten(
        GET_DATA_SAFELY(ddX, "Input", "DDX", "AbsGradGrad"));
    auto x = framework::EigenVector<T>::Flatten(
        GET_DATA_SAFELY(X, "Input", "X", "AbsGradGrad"));
    if (ddOut) {
      auto ddout = framework::EigenVector<T>::Flatten(
          GET_DATA_SAFELY(ddOut, "Output", "DDOut", "AbsGradGrad"));
      ddout.device(*d) = ddx * x.sign();
    }
  }
373
  static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
374 375
};

376 377
// TODO(dengkaipeng): double gradient calculation for Square/Sqrt need
// DOut(dy) as input(not output), tensor extraction is different from
378
// others. Impliment extraction kernel separately here.
379
inline void ExtractDoubleGradTensorWithInputDOut(
380
    const framework::ExecutionContext& ctx,
381 382 383 384 385
    const phi::DenseTensor** X,
    const phi::DenseTensor** ddX,
    phi::DenseTensor** dX,
    const phi::DenseTensor** dOut,
    phi::DenseTensor** ddOut) {
386 387 388
  // extract ddX(output), ddOut(input)
  auto ddx_var = ctx.InputVar("DDX");
  auto ddo_var = ctx.OutputVar("DDOut");
389
  PADDLE_ENFORCE_NOT_NULL(
390 391 392 393
      ddx_var,
      platform::errors::NotFound(
          "Cannot get input Variable Out, variable name = %s",
          ctx.InputName("DDX")));
394
  *ddX = ctx.Input<phi::DenseTensor>("DDX");
395
  if (ddo_var) {
396
    *ddOut = ctx.Output<phi::DenseTensor>("DDOut");
397
  }
398 399 400 401 402
  PADDLE_ENFORCE_NOT_NULL(
      ddX,
      platform::errors::NotFound(
          "Cannot get the tensor from the Variable DDX, variable name = %s",
          ctx.OutputName("DDX")));
403 404 405

  // extract x(input), dx(output)
  auto x_var = ctx.InputVar("X");
406
  PADDLE_ENFORCE_NOT_NULL(
407 408 409 410
      x_var,
      platform::errors::NotFound(
          "Cannot get input Variable Out, variable name = %s",
          ctx.InputName("X")));
411
  auto dx_var = ctx.OutputVar("DX");
412
  *X = ctx.Input<phi::DenseTensor>("X");
413
  if (dx_var) {
414
    *dX = ctx.Output<phi::DenseTensor>("DX");
415 416 417 418 419
  }

  // extract dOut(input)
  auto dout_var = ctx.InputVar("DOut");
  if (dout_var) {
420
    *dOut = ctx.Input<phi::DenseTensor>("DOut");
421 422 423
  }
}

Q
qijun 已提交
424 425
}  // namespace operators
}  // namespace paddle
426

427 428
#define FOR_EACH_ACTIVATION_OP(__macro) \
  __macro(soft_relu, SoftRelu, SoftReluFunctor, SoftReluGradFunctor);