quantization.cpp 2.8 KB
Newer Older
H
hanbuhe 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

H
hanbuhe 已提交
15
#include "fpga/quantization.h"
H
hanbuhe 已提交
16 17 18 19 20 21
#include <algorithm>

namespace paddle_mobile {
namespace fpga {

template <typename Dtype>
H
hanbuhe 已提交
22 23
static void chw_to_hwc(Dtype* data_in, Dtype* data_out, int64_t num,
                       int64_t channel, int64_t height, int64_t width) {
H
hanbuhe 已提交
24
  for (int n = 0; n < num; n++) {
H
hanbuhe 已提交
25
    int64_t amount_per_row = width * channel;
H
hanbuhe 已提交
26 27
    for (int c = 0; c < channel; c++) {
      for (int h = 0; h < height; h++) {
H
hanbuhe 已提交
28
        int64_t offset_height = h * amount_per_row;
H
hanbuhe 已提交
29 30 31 32 33 34 35 36 37
        for (int w = 0; w < width; w++) {
          *(data_out + offset_height + w * channel + c) = *(data_in++);
        }
      }
    }
    data_out += num;
  }
}

H
hanbuhe 已提交
38
template <typename Dtype>
H
hanbuhe 已提交
39
static Dtype find_max(Dtype* data, int64_t num) {
H
hanbuhe 已提交
40 41
  Dtype max = 0;
  for (int i = 0; i < num; ++i) {
H
hanbuhe 已提交
42 43 44
    Dtype value = data[i];
    Dtype abs = value > 0 ? value : -value;
    max = std::max(max, abs);
H
hanbuhe 已提交
45 46 47 48
  }
  return max;
}

49
// template <typename Dtype>
H
hanbuhe 已提交
50
void quantize_filter(framework::Tensor* filter) {
51
  DLOG << "quantilize_filter........" << filter->dims();
C
chonwhite 已提交
52

H
hanbuhe 已提交
53
  float scale = 0;
H
hanbuhe 已提交
54
  auto fix_range = static_cast<float>(std::pow(2, 8 - 1) - 1);
H
hanbuhe 已提交
55

H
hanbuhe 已提交
56
  auto* tmp_data = new int8_t[filter->numel()];
H
hanbuhe 已提交
57

H
hanbuhe 已提交
58 59
  // 32bit filter -> 8bit filter;
  if (filter->type() == typeid(float)) {
H
hanbuhe 已提交
60 61
    auto* float_data = filter->data<float>();
    auto max = find_max<float>(float_data, filter->numel());
H
hanbuhe 已提交
62

H
hanbuhe 已提交
63 64
    scale = (fix_range / max);
    DLOG << "scale:" << scale;
H
hanbuhe 已提交
65 66

    for (int i = 0; i < filter->numel(); ++i) {
H
hanbuhe 已提交
67
      tmp_data[i] = (int8_t)(float_data[i] * scale);
H
hanbuhe 已提交
68
    }
H
hanbuhe 已提交
69
  } else {
H
hanbuhe 已提交
70 71 72
    auto max = find_max<int8_t>(filter->data<int8_t>(), filter->numel());
    scale = (fix_range / max);
    std::memcpy(tmp_data, filter->data<int8_t>(), (size_t)filter->numel());
H
hanbuhe 已提交
73
  }
74 75 76 77 78 79 80 81 82 83 84 85 86

  if (filter->dims().size() == 4) {
    const auto batch_size = filter->dims()[0];
    const auto channel = filter->dims()[1];
    const auto height = filter->dims()[2];
    const auto width = filter->dims()[3];
    chw_to_hwc<int8_t>(tmp_data, filter->mutable_data<int8_t>(), batch_size,
                       channel, height, width);
  } else if (filter->dims().size() == 2) {
    std::memcpy(filter->mutable_data<int8_t>(), tmp_data,
                (size_t)filter->numel());
  }

H
hanbuhe 已提交
87
  delete tmp_data;
H
hanbuhe 已提交
88
  filter->SetFpgaScale(scale);
H
hanbuhe 已提交
89 90 91 92
}

}  // namespace fpga
}  // namespace paddle_mobile