trt_int8_calibrator.cc 5.3 KB
Newer Older
N
nhzlx 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h"
16

N
nhzlx 已提交
17
#include "glog/logging.h"
18
#include "paddle/fluid/platform/enforce.h"
N
nhzlx 已提交
19 20 21 22 23 24

namespace paddle {
namespace inference {
namespace tensorrt {

// set the batch size before constructing the thread to execute engine
25
int TRTInt8Calibrator::getBatchSize() const TRT_NOEXCEPT { return batch_size_; }
N
nhzlx 已提交
26 27

TRTInt8Calibrator::TRTInt8Calibrator(
28 29 30 31
    const std::unordered_map<std::string, size_t>& buffers,
    int batch_size,
    std::string engine_name,
    const platform::Place place)
N
nhzlx 已提交
32
    : batch_size_(batch_size), engine_name_(engine_name) {
N
nhzlx 已提交
33 34
  int i = 0;
  VLOG(4) << "Init a new calibrator: " << engine_name_;
35
  for (const auto& it : buffers) {
36
    phi::DenseTensor temp_tensor;
N
nhzlx 已提交
37 38 39
    std::string input_name = it.first;
    int data_size = it.second;
    int num_ele = data_size / sizeof(int16_t);
40
    framework::DDim data_shape = phi::make_ddim({num_ele});
N
nhzlx 已提交
41 42 43
    temp_tensor.Resize(data_shape);
    data_tensors_.push_back(temp_tensor);
    data_buffers_[input_name] = std::pair<void*, size_t>(
44 45
        static_cast<void*>(temp_tensor.mutable_data<int16_t>(place)),
        data_size);
N
nhzlx 已提交
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
    i += 1;
  }
}

TRTInt8Calibrator::TRTInt8Calibrator(const std::string& calib_data)
    : batch_size_(0),
      calib_running_(false),
      data_is_set_(false),
      done_(true),
      calibration_table_(calib_data) {}

void TRTInt8Calibrator::waitAndSetDone() {
  std::unique_lock<std::mutex> lk(mut_);
  while ((calib_running_ || data_is_set_) && !done_) cond_.wait(lk);
  if (!done_) {
    done_ = true;
    cond_.notify_all();
  }
}

N
nhzlx 已提交
66 67
// There might be more than one input for trt subgraph,
// So, we use a map to store input information.
N
nhzlx 已提交
68 69 70 71
bool TRTInt8Calibrator::setBatch(
    const std::unordered_map<std::string, void*>& data) {
  VLOG(3) << "set batch: " << engine_name_;
  std::unique_lock<std::mutex> lk(mut_);
N
nhzlx 已提交
72 73 74 75
  //  There is a producer and a consumer. The producer set the batch data and
  //  the consumer get the batch data. The size of the data pool is one.
  //  So, the producer has to wait for the consumer to finish processing before
  //  they can set the data.
N
nhzlx 已提交
76
  while ((calib_running_ || data_is_set_) && (!done_)) cond_.wait(lk);
N
nhzlx 已提交
77 78
  // The done_ is set to true using waitAndSetDone, When all calibration data
  // are processed.
N
nhzlx 已提交
79 80 81
  if (done_) return false;

  // Sets the batch.
N
nhzlx 已提交
82
  for (const auto& it : data) {
N
nhzlx 已提交
83 84
    auto dataptr = data_buffers_.find(it.first);
    if (dataptr == data_buffers_.end()) {
85 86
      PADDLE_THROW(platform::errors::Fatal(
          "%s input name '%s' does not match with the buffer names.",
87 88
          engine_name_,
          it.first));
N
nhzlx 已提交
89 90
    }
    const auto& d = dataptr->second;
91
    PADDLE_ENFORCE_GPU_SUCCESS(
92
        cudaMemcpy(d.first, it.second, d.second, cudaMemcpyDeviceToDevice));
N
nhzlx 已提交
93 94 95 96 97 98 99
  }

  data_is_set_ = true;
  cond_.notify_all();
  return true;
}

100 101
bool TRTInt8Calibrator::getBatch(void** bindings,
                                 const char** names,
102
                                 int num_bindings) TRT_NOEXCEPT {
N
nhzlx 已提交
103 104
  VLOG(4) << "get batch: " << engine_name_;
  std::unique_lock<std::mutex> lk(mut_);
N
nhzlx 已提交
105 106
  // The consumer has just finished processing a data.
  // The producer can set the data again.
N
nhzlx 已提交
107 108 109
  calib_running_ = false;
  cond_.notify_all();

N
nhzlx 已提交
110
  // As long as there is data in the pool, the consumer can get it.
N
nhzlx 已提交
111 112 113 114 115 116 117
  while (!data_is_set_ && !done_) cond_.wait(lk);
  if (done_) return false;

  // Gets the batch
  for (int i = 0; i < num_bindings; i++) {
    auto it = data_buffers_.find(names[i]);
    if (it == data_buffers_.end()) {
G
Galaxy1458 已提交
118 119 120 121 122 123 124 125
      try {
        PADDLE_THROW(platform::errors::Fatal(
            "Calibration engine asked for unknown tensor "
            "name '%s' at position %d.",
            names[i],
            i));
      } catch (std::exception& e) {
      }
N
nhzlx 已提交
126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
    }
    bindings[i] = it->second.first;
  }

  data_is_set_ = false;
  calib_running_ = true;
  VLOG(4) << "get batch done: " << engine_name_;
  return true;
}

void TRTInt8Calibrator::setDone() {
  std::unique_lock<std::mutex> lk(mut_);
  done_ = true;
  cond_.notify_all();
}

142 143
const void* TRTInt8Calibrator::readCalibrationCache(size_t& length)
    TRT_NOEXCEPT {
N
nhzlx 已提交
144 145 146 147 148 149
  if (calibration_table_.empty()) return nullptr;
  length = calibration_table_.size();
  return calibration_table_.data();
}

void TRTInt8Calibrator::writeCalibrationCache(const void* ptr,
150
                                              std::size_t length) TRT_NOEXCEPT {
N
nhzlx 已提交
151 152 153 154 155 156 157 158 159 160 161
  calibration_table_ = std::string((const char*)ptr, length);
  VLOG(4) << "Got calibration data for " << engine_name_ << " " << ptr
          << " length=" << length;
}
TRTInt8Calibrator::~TRTInt8Calibrator() {
  VLOG(4) << "Destroying calibrator for " << engine_name_;
}

}  // namespace tensorrt
}  // namespace inference
}  // namespace paddle