From 76156d12625037ee836aaeeb389f3edb4a5b6a5b Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <39978853+zhoutianzi666@users.noreply.github.com> Date: Fri, 1 Jul 2022 13:18:35 +0800 Subject: [PATCH] [inference TRT]template GetWeightCPUData (#43993) * template GetWeightCPUData --- paddle/fluid/inference/tensorrt/engine.cc | 39 ++++++++++++----------- paddle/fluid/inference/tensorrt/engine.h | 16 +++++++--- 2 files changed, 33 insertions(+), 22 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index 82c51311a03..9fe8f67e6a6 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -390,33 +390,36 @@ nvinfer1::ITensor *TensorRTEngine::GetITensor(const std::string &name) { return itensor_map_[name]; } +std::unordered_map + *TensorRTEngine::GetITensorMap() { + return &itensor_map_; +} + void TensorRTEngine::SetRuntimeBatch(size_t batch_size) { runtime_batch_ = batch_size; } -float *TensorRTEngine::GetWeightCPUData(const std::string &name, - framework::Tensor *weight_tensor) { - static int name_suffix_counter = 0; - std::string name_suffix = std::to_string(name_suffix_counter); - std::string splitter = "__"; - std::string name_with_suffix = name + splitter + name_suffix; +template +T *TensorRTEngine::GetWeightCPUData(const std::string &name, + framework::Tensor *weight_tensor) { + std::unique_ptr cpu_weight_tensor(new framework::Tensor()); platform::CPUPlace cpu_place; - PADDLE_ENFORCE_EQ(weight_map.count(name_with_suffix), - 0, - platform::errors::AlreadyExists( - "The weight named %s is set into the weight map " - "twice in TRT OP converter.", - name_with_suffix)); - weight_map[name_with_suffix].reset(new framework::Tensor()); - weight_map[name_with_suffix]->Resize(weight_tensor->dims()); + cpu_weight_tensor->Resize(weight_tensor->dims()); paddle::framework::TensorCopySync( - *weight_tensor, cpu_place, weight_map[name_with_suffix].get()); - float *weight_data = - weight_map[name_with_suffix]->mutable_data(cpu_place); - name_suffix_counter += 1; + *weight_tensor, cpu_place, cpu_weight_tensor.get()); + T *weight_data = cpu_weight_tensor->mutable_data(cpu_place); + SetWeights(name, std::move(cpu_weight_tensor)); return weight_data; } +template float *TensorRTEngine::GetWeightCPUData( + const std::string &name, framework::Tensor *weight_tensor); +template int32_t *TensorRTEngine::GetWeightCPUData( + const std::string &name, framework::Tensor *weight_tensor); + +template int64_t *TensorRTEngine::GetWeightCPUData( + const std::string &name, framework::Tensor *weight_tensor); + int TensorRTEngine::GetRuntimeBatch() { return runtime_batch_; } nvinfer1::IPluginV2Layer *TensorRTEngine::AddPlugin( diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index 8d28d1c05ea..c75f7dd17cb 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -268,6 +268,7 @@ class TensorRTEngine { void SetITensor(const std::string& name, nvinfer1::ITensor* tensor); // Get an ITensor called name. nvinfer1::ITensor* GetITensor(const std::string& name); + std::unordered_map* GetITensorMap(); nvinfer1::ICudaEngine* engine() { return infer_engine_.get(); } nvinfer1::IExecutionContext* context() { @@ -405,9 +406,9 @@ class TensorRTEngine { void SetTensorDynamicRange(nvinfer1::ITensor* tensor, float range) { quant_dynamic_range_[tensor] = range; } - - float* GetWeightCPUData(const std::string& name, - framework::Tensor* weight_tensor); + template + T* GetWeightCPUData(const std::string& name, + framework::Tensor* weight_tensor); // A pointer to CPU memory is needed of the TRT weight. // Before TRT runs, fluid loads weight into GPU storage. @@ -424,7 +425,14 @@ class TensorRTEngine { static int suffix_counter = 0; std::string suffix = std::to_string(suffix_counter); std::string splitter = "__"; - weight_map[w_name + splitter + suffix] = std::move(w_tensor); + std::string name_with_suffix = w_name + splitter + suffix; + PADDLE_ENFORCE_EQ(weight_map.count(name_with_suffix), + 0, + platform::errors::AlreadyExists( + "The weight named %s is set into the weight map " + "twice in TRT OP converter.", + name_with_suffix)); + weight_map[name_with_suffix] = std::move(w_tensor); suffix_counter += 1; } -- GitLab