prelu_kernel.cpp 3.5 KB
Newer Older
T
Tian 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#ifdef PRELU_OP

#include "operators/kernel/prelu_kernel.h"
#include <operators/math/transform.h>
19 20 21
#if __ARM_NEON
#include <arm_neon.h>
#endif
T
Tian 已提交
22 23

namespace paddle_mobile {
I
itminner 已提交
24
namespace operators {
T
Tian 已提交
25

I
itminner 已提交
26 27 28 29
template <typename T>
struct PReluFunctor {
  explicit PReluFunctor(float slope) { this->slope_ = slope; }
  inline T operator()(T in) const { return in > 0 ? in : in * slope_; }
T
Tian 已提交
30

I
itminner 已提交
31 32
  float slope_ = 0.0f;
};
T
Tian 已提交
33 34 35 36

/*
 * @b 特化到具体平台的实现, param 从 op 层传入
 * */
I
itminner 已提交
37
template <>
N
nhzlx 已提交
38
void PReluKernel<CPU, float>::Compute(const PReluParam<CPU> &param) const {
39 40
  auto *x = param.InputX();
  auto *alpha = param.InputAlpha();
I
itminner 已提交
41
  auto *out = param.Out();
42
  std::string mode = param.Mode();
43
  auto *x_ptr = x->data<float>();
44
  auto *o_ptr = out->mutable_data<float>();
45
  auto *alpha_ptr = alpha->data<float>();
46 47
  int numel = x->numel();
  auto dim = x->dims();
48 49
  int k = dim[0] * dim[1];
  int n = dim[2] * dim[3];
50 51 52
  int index = 0;
  int i = 0;
  int temp = 0;
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
#if __ARM_NEON
  #pragma omp parallel for
  for (int i = 0; i < k; i++) {
    float32x4_t zero = vdupq_n_f32(0.0);
    float32x4_t cv;
    float32x4_t cv1;
    float32x4_t cv2;
    float32x4_t pv;
    for (int j = 0; (j + 3) < n; j += 4) {
      const float *in = x_ptr + i * n + j;
      float *out = o_ptr + i * n + j;
      cv = vld1q_f32(in);
      cv1 = vmaxq_f32(cv, zero);
      cv2 = vminq_f32(cv, zero);
      if (mode == "channel") {
        cv2 = vmulq_n_f32(cv2, alpha_ptr[i]);
      } else if (mode == "element") {
        pv = vld1q_f32(alpha_ptr + i * n + j);
        cv2 = vmulq_f32(cv2, pv);
      } else {
        cv2 = vmulq_n_f32(cv2, alpha_ptr[0]);
      }
      cv = vaddq_f32(cv1, cv2);
      vst1q_f32(out, cv);
    }
    int j;
    for (j = 0; (j + 3) < n; j += 4) {
    }
    for (int m = j; m < n; m++) {
      if (mode == "channel") {
        o_ptr[i * n + m] = x_ptr[i * n + m] > 0
                               ? x_ptr[i * n + m]
                               : alpha_ptr[i] * x_ptr[i * n + m];
      } else if (mode == "element") {
        o_ptr[i * n + m] = x_ptr[i * n + m] > 0
                               ? x_ptr[i * n + m]
                               : alpha_ptr[i * n + m] * x_ptr[i * n + m];
      } else {
        o_ptr[i * n + m] = x_ptr[i * n + m] > 0
                               ? x_ptr[i * n + m]
                               : alpha_ptr[0] * x_ptr[i * n + m];
      }
    }
  }

#else
99
  if (mode == "channel") {
100
    temp = numel / (dim[0] * dim[1]);
101
#pragma omp parallel for
102 103 104 105 106
    for (i = 0; i < numel; i++) {
      index = (i / temp) % dim[1];
      o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[index] * x_ptr[i];
    }
  } else if (mode == "element") {
107
#pragma omp parallel for
108 109 110 111
    for (i = 0; i < numel; i++) {
      o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[i] * x_ptr[i];
    }
  } else {
112
#pragma omp parallel for
113 114
    for (i = 0; i < numel; i++) {
      o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[0] * x_ptr[i];
I
itminner 已提交
115 116
    }
  }
117
#endif
I
itminner 已提交
118 119
}
}  // namespace operators
T
Tian 已提交
120 121
}  // namespace paddle_mobile

I
itminner 已提交
122
#endif