trt_int8_calibrator.cc 5.1 KB
Newer Older
N
nhzlx 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h"
#include "glog/logging.h"

namespace paddle {
namespace inference {
namespace tensorrt {

// set the batch size before constructing the thread to execute engine
int TRTInt8Calibrator::getBatchSize() const { return batch_size_; }

TRTInt8Calibrator::TRTInt8Calibrator(
    const std::unordered_map<std::string, size_t>& buffers, int batch_size,
    std::string engine_name, const platform::Place place)
N
nhzlx 已提交
28
    : batch_size_(batch_size), engine_name_(engine_name) {
N
nhzlx 已提交
29 30 31 32 33 34 35 36 37 38 39
  int i = 0;
  VLOG(4) << "Init a new calibrator: " << engine_name_;
  for (const auto it : buffers) {
    framework::Tensor temp_tensor;
    std::string input_name = it.first;
    int data_size = it.second;
    int num_ele = data_size / sizeof(int16_t);
    framework::DDim data_shape = framework::make_ddim({num_ele});
    temp_tensor.Resize(data_shape);
    data_tensors_.push_back(temp_tensor);
    data_buffers_[input_name] = std::pair<void*, size_t>(
40 41
        static_cast<void*>(temp_tensor.mutable_data<int16_t>(place)),
        data_size);
N
nhzlx 已提交
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
    i += 1;
  }
}

TRTInt8Calibrator::TRTInt8Calibrator(const std::string& calib_data)
    : batch_size_(0),
      calib_running_(false),
      data_is_set_(false),
      done_(true),
      calibration_table_(calib_data) {}

void TRTInt8Calibrator::waitAndSetDone() {
  std::unique_lock<std::mutex> lk(mut_);
  while ((calib_running_ || data_is_set_) && !done_) cond_.wait(lk);
  if (!done_) {
    done_ = true;
    cond_.notify_all();
  }
}

N
nhzlx 已提交
62 63
// There might be more than one input for trt subgraph,
// So, we use a map to store input information.
N
nhzlx 已提交
64 65 66 67
bool TRTInt8Calibrator::setBatch(
    const std::unordered_map<std::string, void*>& data) {
  VLOG(3) << "set batch: " << engine_name_;
  std::unique_lock<std::mutex> lk(mut_);
N
nhzlx 已提交
68 69 70 71
  //  There is a producer and a consumer. The producer set the batch data and
  //  the consumer get the batch data. The size of the data pool is one.
  //  So, the producer has to wait for the consumer to finish processing before
  //  they can set the data.
N
nhzlx 已提交
72
  while ((calib_running_ || data_is_set_) && (!done_)) cond_.wait(lk);
N
nhzlx 已提交
73 74
  // The done_ is set to true using waitAndSetDone, When all calibration data
  // are processed.
N
nhzlx 已提交
75 76 77
  if (done_) return false;

  // Sets the batch.
N
nhzlx 已提交
78
  for (const auto& it : data) {
N
nhzlx 已提交
79 80
    auto dataptr = data_buffers_.find(it.first);
    if (dataptr == data_buffers_.end()) {
81 82 83
      PADDLE_THROW(platform::errors::Fatal(
          "%s input name '%s' does not match with the buffer names.",
          engine_name_, it.first));
N
nhzlx 已提交
84 85
    }
    const auto& d = dataptr->second;
86 87
    PADDLE_ENFORCE_CUDA_SUCCESS(
        cudaMemcpy(d.first, it.second, d.second, cudaMemcpyDeviceToDevice));
N
nhzlx 已提交
88 89 90 91 92 93 94 95 96 97 98
  }

  data_is_set_ = true;
  cond_.notify_all();
  return true;
}

bool TRTInt8Calibrator::getBatch(void** bindings, const char** names,
                                 int num_bindings) {
  VLOG(4) << "get batch: " << engine_name_;
  std::unique_lock<std::mutex> lk(mut_);
N
nhzlx 已提交
99 100
  // The consumer has just finished processing a data.
  // The producer can set the data again.
N
nhzlx 已提交
101 102 103
  calib_running_ = false;
  cond_.notify_all();

N
nhzlx 已提交
104
  // As long as there is data in the pool, the consumer can get it.
N
nhzlx 已提交
105 106 107 108 109 110 111
  while (!data_is_set_ && !done_) cond_.wait(lk);
  if (done_) return false;

  // Gets the batch
  for (int i = 0; i < num_bindings; i++) {
    auto it = data_buffers_.find(names[i]);
    if (it == data_buffers_.end()) {
112 113 114 115
      PADDLE_THROW(
          platform::errors::Fatal("Calibration engine asked for unknown tensor "
                                  "name '%s' at position %d.",
                                  names[i], i));
N
nhzlx 已提交
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
    }
    bindings[i] = it->second.first;
  }

  data_is_set_ = false;
  calib_running_ = true;
  VLOG(4) << "get batch done: " << engine_name_;
  return true;
}

void TRTInt8Calibrator::setDone() {
  std::unique_lock<std::mutex> lk(mut_);
  done_ = true;
  cond_.notify_all();
}

N
nhzlx 已提交
132
const void* TRTInt8Calibrator::readCalibrationCache(size_t& length) {
N
nhzlx 已提交
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
  if (calibration_table_.empty()) return nullptr;
  length = calibration_table_.size();
  return calibration_table_.data();
}

void TRTInt8Calibrator::writeCalibrationCache(const void* ptr,
                                              std::size_t length) {
  calibration_table_ = std::string((const char*)ptr, length);
  VLOG(4) << "Got calibration data for " << engine_name_ << " " << ptr
          << " length=" << length;
}
TRTInt8Calibrator::~TRTInt8Calibrator() {
  VLOG(4) << "Destroying calibrator for " << engine_name_;
}

}  // namespace tensorrt
}  // namespace inference
}  // namespace paddle