lrn.cc 3.5 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/arm/math/lrn.h"
#include "lite/arm/math/funcs.h"

namespace paddle {
namespace lite {
namespace arm {
namespace math {

template <>
void compute_across_channels<float>(const float* din,
                                    float* dout,
                                    int num,
                                    int channel,
                                    int h,
                                    int w,
                                    int local_size,
                                    float alpha,
                                    float beta,
                                    float k) {
  int channel_size = h * w;
  int cnt = channel_size / 4;
  int remain = channel_size % 4;
  int pre_pad = (local_size - 1) / 2;
  int post_pad = local_size - pre_pad - 1;
  float32x4_t k_val = vdupq_n_f32(k);
  float32x4_t alpha_val = vdupq_n_f32(alpha);
  float32x4_t beta_val = vdupq_n_f32(-beta);
  for (int n = 0; n < num; ++n) {
    const float* din_ptr = din + n * channel * channel_size;
    float* dout_ptr = dout + n * channel * channel_size;
    for (int c = 0; c < channel; ++c) {
      const float* din_ch_ptr = din_ptr + c * channel_size;
      float* dout_ch_ptr = dout_ptr + c * channel_size;
      int cs = (c - pre_pad) < 0 ? 0 : (c - pre_pad);
      int ce = (c + post_pad) >= channel ? channel : (c + pre_pad + 1);
      for (int i = 0; i < cnt; ++i) {
        int idx = i * 4;
        float32x4_t sum = vdupq_n_f32(0.f);
        float32x4_t din = vld1q_f32(din_ch_ptr);
        for (int k = cs; k < ce; ++k) {
          float32x4_t v0 = vld1q_f32(&din_ptr[k * channel_size + idx]);
          sum = vmlaq_f32(sum, v0, v0);
        }
        sum = vmulq_f32(sum, alpha_val);
        sum = vaddq_f32(sum, k_val);
        float32x4_t res0 = pow_ps(sum, beta_val);
        float32x4_t res1 = vmulq_f32(din, res0);
        vst1q_f32(dout_ch_ptr, res1);
        dout_ch_ptr += 4;
        din_ch_ptr += 4;
      }
      int idx = cnt * 4;
      for (int i = 0; i < remain; ++i) {
        float sum = 0.0;
        for (int k = cs; k < ce; ++k) {
          sum +=
              din_ptr[k * channel_size + idx] * din_ptr[k * channel_size + idx];
        }
        sum = k + sum * alpha;
        dout_ch_ptr[0] = din_ch_ptr[0] * pow(sum, -beta);
        dout_ch_ptr++;
        din_ch_ptr++;
        idx++;
      }
    }
  }
}

template <>
void compute_within_channels<float>(const float* din,
                                    float* dout,
                                    int num,
                                    int channel,
                                    int h,
                                    int w,
                                    int local_size,
                                    float alpha,
                                    float beta,
                                    float k) {
  LOG(ERROR) << "unsupported method!!";
  return;
}

}  // namespace math
}  // namespace arm
}  // namespace lite
}  // namespace paddle