activation_compute.h 7.4 KB
Newer Older
L
lhl960107 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once

16
#include <algorithm>
L
lhl960107 已提交
17 18
#include <utility>
#include <vector>
19 20 21 22 23 24

#include <cmath>
#ifndef _USE_MATH_DEFINES
#define _USE_MATH_DEFINES
#endif

25
#include "lite/backends/x86/math/blas.h"
L
lhl960107 已提交
26 27 28 29
#include "lite/core/kernel.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
#include "lite/fluid/eigen.h"
30
#include "lite/operators/op_params.h"
L
lhl960107 已提交
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74

namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {

enum ActBwdOpFwdDeps {
  kNoDeps = 0x00,  // Do not need any forward input/output
  kDepX = 0x01,    // Only need forward input X
  kDepOut = 0x02,  // Only need forward output Out

  // Never add kDepXOut, because Out can be always calculated
  // by forward input X in backward part.
  // FIXME(zjl): but in MKLDNN abs, X and Out are all needed...
  // Developers should not rely on this enum value!
  kDepXOut = 0x03
};

template <typename T>
struct BaseActivationFunctor {
  using ELEMENT_TYPE = T;

  using AttrPair = std::vector<std::pair<const char*, float*>>;

  AttrPair GetAttrs() { return AttrPair(); }

  /* NOTE(*): Output reuse X memory if X is not dependented by its Gradient.
     For example, sigmoid op's gradient didn't involve x, so its output can
     reuse
     input memory. But abs op's gradient use x, it can not be inplaced.
     gradient did use x.
   */
  bool Inplace() const { return false; }
};

template <typename Functor>
bool Activate(const lite::Tensor* X, lite::Tensor* Out) {
  using T = typename Functor::ELEMENT_TYPE;
  auto place = lite::fluid::EigenDeviceType<TARGET(kX86)>();
  CHECK_OR_FALSE(X)
  CHECK_OR_FALSE(Out)
  auto x = lite::fluid::EigenVector<T>::Flatten(*X);
  auto out = lite::fluid::EigenVector<T>::Flatten(*Out);
  Functor()(place, x, out);
75
  return true;
L
lhl960107 已提交
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
}

// square(x) = x^2
template <typename T>
struct SquareFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.square();
  }
};

template <typename T>
class SquareCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
 public:
  using param_t = operators::ActivationParam;

  void Run() override {
    auto& param = *param_.get_mutable<operators::ActivationParam>();

    param.Out->template mutable_data<T>();
    Activate<SquareFunctor<T>>(param.X, param.Out);
  }

  virtual ~SquareCompute() = default;
};

// relu(x) = max(x, 0)
template <typename T>
struct ReluFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.cwiseMax(static_cast<T>(0));
  }
};

template <typename T>
class ReluCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
 public:
  using param_t = operators::ActivationParam;

  void Run() override {
    auto& param = *param_.get_mutable<operators::ActivationParam>();

    param.Out->template mutable_data<T>();
    Activate<ReluFunctor<T>>(param.X, param.Out);
  }

  virtual ~ReluCompute() = default;
};

126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
template <typename T>
struct LeakyReluFunctor {
  float alpha;
  explicit LeakyReluFunctor(float alpha_) : alpha(alpha_) {}

  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.cwiseMax(static_cast<T>(alpha) * x);
  }
};

template <typename T>
class LeakyReluCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
 public:
  using param_t = operators::ActivationParam;

  void Run() override {
    auto& param = *param_.get_mutable<operators::ActivationParam>();

    param.Out->template mutable_data<T>();
    auto X = param.X;
    auto Out = param.Out;
    auto place = lite::fluid::EigenDeviceType<TARGET(kX86)>();
    CHECK(X);
    CHECK(Out);
    auto x = lite::fluid::EigenVector<T>::Flatten(*X);
    auto out = lite::fluid::EigenVector<T>::Flatten(*Out);
    LeakyReluFunctor<T> functor(param.Leaky_relu_alpha);
    functor(place, x, out);
  }

  virtual ~LeakyReluCompute() = default;
};

160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229
// tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
template <typename T>
struct TanhFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
    out.device(d) = x.tanh();
  }
};

template <typename T>
class TanhCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
 public:
  using param_t = operators::ActivationParam;

  void Run() override {
    auto& param = *param_.get_mutable<operators::ActivationParam>();

    param.Out->template mutable_data<T>();
    Activate<TanhFunctor<T>>(param.X, param.Out);
  }

  virtual ~TanhCompute() = default;
};

// gelu(x) = 0.5 * x *  (1 + erf(x / sqrt(2)))
template <typename T>
struct GeluFunctor : public BaseActivationFunctor<T> {
  template <typename Device, typename X, typename Out>
  void operator()(Device d, X x, Out out) const {
// Because the execute or device context can not be deliver here, it keep the
// marco for NVCC.
#if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \
    !defined(__OSX__) && !defined(PADDLE_WITH_CUDA)
    auto x_data = x.data();
    auto out_data = out.data();
    int n = std::min(x.size(), out.size());

    std::memset(out_data, 0, n * sizeof(T));
    paddle::lite::x86::math::CBlas<T>::AXPY(
        n, static_cast<T>(M_SQRT1_2), x_data, 1, out_data, 1);
    paddle::lite::x86::math::CBlas<T>::VMERF(n, out_data, out_data, VML_LA);
    for (int i = 0; i < n; i++) {
      out_data[i] += static_cast<T>(1);
    }
    paddle::lite::x86::math::CBlas<T>::VMUL(n, x_data, out_data, out_data);
    for (int i = 0; i < n; i++) {
      out_data[i] *= static_cast<T>(0.5);
    }
#else
    auto temp = (x * static_cast<T>(M_SQRT1_2)).erf();
    out.device(d) = x * static_cast<T>(0.5) * (static_cast<T>(1) + temp);
#endif
  }
};

template <typename T>
class GeluCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
 public:
  using param_t = operators::ActivationParam;

  void Run() override {
    auto& param = *param_.get_mutable<operators::ActivationParam>();

    param.Out->template mutable_data<T>();
    Activate<GeluFunctor<T>>(param.X, param.Out);
  }

  virtual ~GeluCompute() = default;
};

230 231 232 233 234 235 236 237 238 239
// softsign(x) = x / (1 + |x|)
template <typename T>
class SoftsignCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
 public:
  using param_t = operators::ActivationParam;

  void Run() override {
    // auto& context = ctx_->As<X86Context>();
    auto& param = *param_.get_mutable<operators::ActivationParam>();

H
huzhiqiang 已提交
240 241
    const T* x_data = param.X->template data<T>();
    T* out_data = param.Out->template mutable_data<T>();
242 243 244 245
    size_t x_size = param.X->numel();
    for (size_t i = 0; i < x_size; i++) {
      out_data[i] = x_data[i] / (static_cast<T>(1) + std::abs(x_data[i]));
    }
246 247 248 249 250
  }

  virtual ~SoftsignCompute() = default;
};

L
lhl960107 已提交
251 252 253 254
}  // namespace x86
}  // namespace kernels
}  // namespace lite
}  // namespace paddle