diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index 82c51311a03d5fbffdc563d24d91e69f9643fa5f..9fe8f67e6a6573a31883b7816d4beec79b621057 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -390,33 +390,36 @@ nvinfer1::ITensor *TensorRTEngine::GetITensor(const std::string &name) { return itensor_map_[name]; } +std::unordered_map + *TensorRTEngine::GetITensorMap() { + return &itensor_map_; +} + void TensorRTEngine::SetRuntimeBatch(size_t batch_size) { runtime_batch_ = batch_size; } -float *TensorRTEngine::GetWeightCPUData(const std::string &name, - framework::Tensor *weight_tensor) { - static int name_suffix_counter = 0; - std::string name_suffix = std::to_string(name_suffix_counter); - std::string splitter = "__"; - std::string name_with_suffix = name + splitter + name_suffix; +template +T *TensorRTEngine::GetWeightCPUData(const std::string &name, + framework::Tensor *weight_tensor) { + std::unique_ptr cpu_weight_tensor(new framework::Tensor()); platform::CPUPlace cpu_place; - PADDLE_ENFORCE_EQ(weight_map.count(name_with_suffix), - 0, - platform::errors::AlreadyExists( - "The weight named %s is set into the weight map " - "twice in TRT OP converter.", - name_with_suffix)); - weight_map[name_with_suffix].reset(new framework::Tensor()); - weight_map[name_with_suffix]->Resize(weight_tensor->dims()); + cpu_weight_tensor->Resize(weight_tensor->dims()); paddle::framework::TensorCopySync( - *weight_tensor, cpu_place, weight_map[name_with_suffix].get()); - float *weight_data = - weight_map[name_with_suffix]->mutable_data(cpu_place); - name_suffix_counter += 1; + *weight_tensor, cpu_place, cpu_weight_tensor.get()); + T *weight_data = cpu_weight_tensor->mutable_data(cpu_place); + SetWeights(name, std::move(cpu_weight_tensor)); return weight_data; } +template float *TensorRTEngine::GetWeightCPUData( + const std::string &name, framework::Tensor *weight_tensor); +template int32_t *TensorRTEngine::GetWeightCPUData( + const std::string &name, framework::Tensor *weight_tensor); + +template int64_t *TensorRTEngine::GetWeightCPUData( + const std::string &name, framework::Tensor *weight_tensor); + int TensorRTEngine::GetRuntimeBatch() { return runtime_batch_; } nvinfer1::IPluginV2Layer *TensorRTEngine::AddPlugin( diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index 8d28d1c05ea141d020b6fa105bd38f99e49dc540..c75f7dd17cb95edaab4050d23ed4abc13becc46f 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -268,6 +268,7 @@ class TensorRTEngine { void SetITensor(const std::string& name, nvinfer1::ITensor* tensor); // Get an ITensor called name. nvinfer1::ITensor* GetITensor(const std::string& name); + std::unordered_map* GetITensorMap(); nvinfer1::ICudaEngine* engine() { return infer_engine_.get(); } nvinfer1::IExecutionContext* context() { @@ -405,9 +406,9 @@ class TensorRTEngine { void SetTensorDynamicRange(nvinfer1::ITensor* tensor, float range) { quant_dynamic_range_[tensor] = range; } - - float* GetWeightCPUData(const std::string& name, - framework::Tensor* weight_tensor); + template + T* GetWeightCPUData(const std::string& name, + framework::Tensor* weight_tensor); // A pointer to CPU memory is needed of the TRT weight. // Before TRT runs, fluid loads weight into GPU storage. @@ -424,7 +425,14 @@ class TensorRTEngine { static int suffix_counter = 0; std::string suffix = std::to_string(suffix_counter); std::string splitter = "__"; - weight_map[w_name + splitter + suffix] = std::move(w_tensor); + std::string name_with_suffix = w_name + splitter + suffix; + PADDLE_ENFORCE_EQ(weight_map.count(name_with_suffix), + 0, + platform::errors::AlreadyExists( + "The weight named %s is set into the weight map " + "twice in TRT OP converter.", + name_with_suffix)); + weight_map[name_with_suffix] = std::move(w_tensor); suffix_counter += 1; }