group_norm_compute.cc 6.6 KB
Newer Older
C
chenjiaoAngel 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/kernels/arm/group_norm_compute.h"
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace arm {

void GroupNormCompute::PrepareForRun() {}

void GroupNormCompute::Run() {
  auto& param = this->Param<param_t>();
  const float* in = param.x->data<float>();
  const float* scale = param.scale->data<float>();
  const float* bias = param.bias->data<float>();
  float* out = param.out->mutable_data<float>();
  float* saved_mean = param.saved_mean->mutable_data<float>();
  float* saved_variance = param.saved_variance->mutable_data<float>();
  float epsilon = param.epsilon;
  int groups = param.groups;
  int channels = param.channels;
  int n = param.x->dims()[0];
  int c = param.x->dims()[1];
  int ch_per_group = channels / groups;
  int height = param.x->dims()[2];
  int width = param.x->dims()[3];
  int spatial_size = ch_per_group * height * width;
  int ngroup = n * groups;
  int cnt = spatial_size >> 4;
  int remain = spatial_size % 16;
// compute saved_mean and saved_variance
#pragma omp parallel for
  for (int n = 0; n < ngroup; ++n) {
    const float* in_p = in + n * spatial_size;
    float sum_spatial = 0.f;
    float summ_spatial = 0.f;
    float32x4_t sum0 = vdupq_n_f32(0.f);
    float32x4_t sum1 = vdupq_n_f32(0.f);
    float32x4_t sum2 = vdupq_n_f32(0.f);
    float32x4_t sum3 = vdupq_n_f32(0.f);
    float32x4_t summ0 = vdupq_n_f32(0.f);
    float32x4_t summ1 = vdupq_n_f32(0.f);
    float32x4_t summ2 = vdupq_n_f32(0.f);
    float32x4_t summ3 = vdupq_n_f32(0.f);
    for (int i = 0; i < cnt; i++) {
      float32x4_t in0 = vld1q_f32(in_p);
      float32x4_t in1 = vld1q_f32(in_p + 4);
      float32x4_t in2 = vld1q_f32(in_p + 8);
      float32x4_t in3 = vld1q_f32(in_p + 12);
      sum0 = vaddq_f32(sum0, in0);
      summ0 = vmlaq_f32(summ0, in0, in0);
      sum1 = vaddq_f32(sum1, in1);
      summ1 = vmlaq_f32(summ1, in1, in1);
      sum2 = vaddq_f32(sum2, in2);
      summ2 = vmlaq_f32(summ2, in2, in2);
      sum3 = vaddq_f32(sum3, in3);
      summ3 = vmlaq_f32(summ3, in3, in3);
      in_p += 16;
    }
    for (int i = 0; i < remain - 3; i += 4) {
C
chenjiaoAngel 已提交
77 78 79 80
      float32x4_t in0 = vld1q_f32(in_p);
      sum1 = vaddq_f32(sum1, in0);
      summ1 = vmlaq_f32(summ1, in0, in0);
      in_p += 4;
C
chenjiaoAngel 已提交
81 82 83 84 85 86 87 88
    }
    float sum = 0.0;
    float summ = 0.0;
    sum0 = vaddq_f32(sum0, sum1);
    sum2 = vaddq_f32(sum2, sum3);
    summ0 = vaddq_f32(summ0, summ1);
    summ2 = vaddq_f32(summ2, summ3);
    for (int i = 0; i < remain % 4; i++) {
C
chenjiaoAngel 已提交
89 90 91
      sum += *in_p;
      summ += (*in_p) * (*in_p);
      in_p++;
C
chenjiaoAngel 已提交
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
    }
    sum0 = vaddq_f32(sum0, sum2);
    summ0 = vaddq_f32(summ0, summ2);
    float32x2_t sum_low = vpadd_f32(vget_low_f32(sum0), vget_high_f32(sum0));
    float32x2_t sum_high = vpadd_f32(vget_low_f32(summ0), vget_high_f32(summ0));
    float32x2_t sum_mix = vpadd_f32(sum_low, sum_high);
    sum += vget_lane_f32(sum_mix, 0);
    summ += vget_lane_f32(sum_mix, 1);
    float mean = sum / spatial_size;
    // float variance = summ / spatial_size - mean * mean;
    // the flolowing code has higher precision than above comment code
    float variance = (summ - mean * mean * spatial_size) / spatial_size;
    float std = 1.f / sqrtf(variance + epsilon);
    saved_mean[n] = mean;
    saved_variance[n] = std;
  }
  int in_size = height * width;
  cnt = in_size >> 4;
  remain = in_size % 16;
// compute Group_norm result: out = scale * (in - mean) / std + bias
#pragma omp parallel for
  for (int i = 0; i < ngroup; ++i) {
    const float* in_p = in + i * spatial_size;
    float* out_p = out + i * spatial_size;
    int numc = i % groups;
    numc *= ch_per_group;
    for (int c = 0; c < ch_per_group; c++) {
      int chin = numc + c;
      const float sstd_val = scale[chin] * saved_variance[i];
      const float bias_val = bias[chin];
      const float mean_val = saved_mean[i];
      const float32x4_t vsstd = vdupq_n_f32(sstd_val);
      const float32x4_t vbias = vdupq_n_f32(bias_val);
      const float32x4_t vmean = vdupq_n_f32(mean_val);
      for (int k = 0; k < cnt; k++) {
C
chenjiaoAngel 已提交
127 128 129 130 131 132 133 134 135 136 137 138 139
        float32x4_t in0 = vld1q_f32(in_p);
        float32x4_t in1 = vld1q_f32(in_p + 4);
        float32x4_t in2 = vld1q_f32(in_p + 8);
        float32x4_t in3 = vld1q_f32(in_p + 12);
        float32x4_t submean0 = vsubq_f32(in0, vmean);
        float32x4_t submean1 = vsubq_f32(in1, vmean);
        float32x4_t submean2 = vsubq_f32(in2, vmean);
        float32x4_t submean3 = vsubq_f32(in3, vmean);
        float32x4_t out0 = vmlaq_f32(vbias, submean0, vsstd);
        float32x4_t out1 = vmlaq_f32(vbias, submean1, vsstd);
        float32x4_t out2 = vmlaq_f32(vbias, submean2, vsstd);
        float32x4_t out3 = vmlaq_f32(vbias, submean3, vsstd);
        vst1q_f32(out_p, out0);
C
chenjiaoAngel 已提交
140 141 142
        vst1q_f32(out_p + 4, out1);
        vst1q_f32(out_p + 8, out2);
        vst1q_f32(out_p + 12, out3);
C
chenjiaoAngel 已提交
143 144
        in_p += 16;
        out_p += 16;
C
chenjiaoAngel 已提交
145 146
      }
      for (int k = 0; k < remain - 3; k += 4) {
C
chenjiaoAngel 已提交
147 148 149 150 151 152
        float32x4_t in0 = vld1q_f32(in_p);
        in_p += 4;
        float32x4_t submean0 = vsubq_f32(in0, vmean);
        float32x4_t out0 = vmlaq_f32(vbias, submean0, vsstd);
        vst1q_f32(out_p, out0);
        out_p += 4;
C
chenjiaoAngel 已提交
153 154
      }
      for (int k = 0; k < remain % 4; k++) {
C
chenjiaoAngel 已提交
155 156 157
        *out_p = (*in_p - mean_val) * sstd_val + bias_val;
        in_p++;
        out_p++;
C
chenjiaoAngel 已提交
158
      }
C
chenjiaoAngel 已提交
159
    }
C
chenjiaoAngel 已提交
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
  }
}

}  // namespace arm
}  // namespace kernels
}  // namespace lite
}  // namespace paddle

REGISTER_LITE_KERNEL(group_norm,
                     kARM,
                     kFloat,
                     kNCHW,
                     paddle::lite::kernels::arm::GroupNormCompute,
                     def)
    .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
    .BindInput("Scale", {LiteType::GetTensorTy(TARGET(kARM))})
    .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
    .BindOutput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
    .BindOutput("SavedMean", {LiteType::GetTensorTy(TARGET(kARM))})
    .BindOutput("SavedVariance", {LiteType::GetTensorTy(TARGET(kARM))})
    .Finalize();