quantization.cpp 2.6 KB
Newer Older
H
hanbuhe 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

H
hanbuhe 已提交
15
#include "fpga/quantization.h"
H
hanbuhe 已提交
16 17 18 19 20 21
#include <algorithm>

namespace paddle_mobile {
namespace fpga {

template <typename Dtype>
H
hanbuhe 已提交
22 23
static void chw_to_hwc(Dtype* data_in, Dtype* data_out, int64_t num,
                       int64_t channel, int64_t height, int64_t width) {
H
hanbuhe 已提交
24
  for (int n = 0; n < num; n++) {
H
hanbuhe 已提交
25
    int64_t amount_per_row = width * channel;
H
hanbuhe 已提交
26 27
    for (int c = 0; c < channel; c++) {
      for (int h = 0; h < height; h++) {
H
hanbuhe 已提交
28
        int64_t offset_height = h * amount_per_row;
H
hanbuhe 已提交
29 30 31 32 33 34 35 36 37
        for (int w = 0; w < width; w++) {
          *(data_out + offset_height + w * channel + c) = *(data_in++);
        }
      }
    }
    data_out += num;
  }
}

H
hanbuhe 已提交
38
template <typename Dtype>
H
hanbuhe 已提交
39
static Dtype find_max(Dtype* data, int64_t num) {
H
hanbuhe 已提交
40 41
  Dtype max = 0;
  for (int i = 0; i < num; ++i) {
H
hanbuhe 已提交
42 43 44
    Dtype value = data[i];
    Dtype abs = value > 0 ? value : -value;
    max = std::max(max, abs);
H
hanbuhe 已提交
45 46 47 48
  }
  return max;
}

49
// template <typename Dtype>
H
hanbuhe 已提交
50
void quantize_filter(framework::Tensor *filter) {
C
chonwhite 已提交
51 52
  DLOG << "quantilize_filter........";

H
hanbuhe 已提交
53
  float scale = 0;
H
hanbuhe 已提交
54
  auto fix_range = static_cast<float>(std::pow(2, 8 - 1) - 1);
H
hanbuhe 已提交
55

H
hanbuhe 已提交
56 57 58 59
  const auto batch_size = filter->dims()[0];
  const auto channel = filter->dims()[1];
  const auto height = filter->dims()[2];
  const auto width = filter->dims()[3];
H
hanbuhe 已提交
60

H
hanbuhe 已提交
61
  auto* tmp_data = new int8_t[filter->numel()];
H
hanbuhe 已提交
62

H
hanbuhe 已提交
63 64
  // 32bit filter -> 8bit filter;
  if (filter->type() == typeid(float)) {
H
hanbuhe 已提交
65 66
    auto* float_data = filter->data<float>();
    auto max = find_max<float>(float_data, filter->numel());
H
hanbuhe 已提交
67

H
hanbuhe 已提交
68 69
    scale = (fix_range / max);
    DLOG << "scale:" << scale;
H
hanbuhe 已提交
70 71

    for (int i = 0; i < filter->numel(); ++i) {
H
hanbuhe 已提交
72
      tmp_data[i] = (int8_t)(float_data[i] * scale);
H
hanbuhe 已提交
73
    }
H
hanbuhe 已提交
74
  } else {
H
hanbuhe 已提交
75 76 77
    auto max = find_max<int8_t>(filter->data<int8_t>(), filter->numel());
    scale = (fix_range / max);
    std::memcpy(tmp_data, filter->data<int8_t>(), (size_t)filter->numel());
H
hanbuhe 已提交
78
  }
H
hanbuhe 已提交
79
  // NCHW -> NHWC;
H
hanbuhe 已提交
80 81
  chw_to_hwc<int8_t>(tmp_data, filter->mutable_data<int8_t>(), batch_size,
                     channel, height, width);
H
hanbuhe 已提交
82
  delete tmp_data;
H
hanbuhe 已提交
83
  filter->SetFpgaScale(scale);
H
hanbuhe 已提交
84 85 86 87
}

}  // namespace fpga
}  // namespace paddle_mobile