未验证 提交 76156d12 编写于 作者: Z zhoutianzi666 提交者: GitHub

[inference TRT]template GetWeightCPUData (#43993)

* template GetWeightCPUData
上级 267d3191
...@@ -390,33 +390,36 @@ nvinfer1::ITensor *TensorRTEngine::GetITensor(const std::string &name) { ...@@ -390,33 +390,36 @@ nvinfer1::ITensor *TensorRTEngine::GetITensor(const std::string &name) {
return itensor_map_[name]; return itensor_map_[name];
} }
std::unordered_map<std::string, nvinfer1::ITensor *>
*TensorRTEngine::GetITensorMap() {
return &itensor_map_;
}
void TensorRTEngine::SetRuntimeBatch(size_t batch_size) { void TensorRTEngine::SetRuntimeBatch(size_t batch_size) {
runtime_batch_ = batch_size; runtime_batch_ = batch_size;
} }
float *TensorRTEngine::GetWeightCPUData(const std::string &name, template <typename T = float>
T *TensorRTEngine::GetWeightCPUData(const std::string &name,
framework::Tensor *weight_tensor) { framework::Tensor *weight_tensor) {
static int name_suffix_counter = 0; std::unique_ptr<framework::Tensor> cpu_weight_tensor(new framework::Tensor());
std::string name_suffix = std::to_string(name_suffix_counter);
std::string splitter = "__";
std::string name_with_suffix = name + splitter + name_suffix;
platform::CPUPlace cpu_place; platform::CPUPlace cpu_place;
PADDLE_ENFORCE_EQ(weight_map.count(name_with_suffix), cpu_weight_tensor->Resize(weight_tensor->dims());
0,
platform::errors::AlreadyExists(
"The weight named %s is set into the weight map "
"twice in TRT OP converter.",
name_with_suffix));
weight_map[name_with_suffix].reset(new framework::Tensor());
weight_map[name_with_suffix]->Resize(weight_tensor->dims());
paddle::framework::TensorCopySync( paddle::framework::TensorCopySync(
*weight_tensor, cpu_place, weight_map[name_with_suffix].get()); *weight_tensor, cpu_place, cpu_weight_tensor.get());
float *weight_data = T *weight_data = cpu_weight_tensor->mutable_data<T>(cpu_place);
weight_map[name_with_suffix]->mutable_data<float>(cpu_place); SetWeights(name, std::move(cpu_weight_tensor));
name_suffix_counter += 1;
return weight_data; return weight_data;
} }
template float *TensorRTEngine::GetWeightCPUData(
const std::string &name, framework::Tensor *weight_tensor);
template int32_t *TensorRTEngine::GetWeightCPUData(
const std::string &name, framework::Tensor *weight_tensor);
template int64_t *TensorRTEngine::GetWeightCPUData(
const std::string &name, framework::Tensor *weight_tensor);
int TensorRTEngine::GetRuntimeBatch() { return runtime_batch_; } int TensorRTEngine::GetRuntimeBatch() { return runtime_batch_; }
nvinfer1::IPluginV2Layer *TensorRTEngine::AddPlugin( nvinfer1::IPluginV2Layer *TensorRTEngine::AddPlugin(
......
...@@ -268,6 +268,7 @@ class TensorRTEngine { ...@@ -268,6 +268,7 @@ class TensorRTEngine {
void SetITensor(const std::string& name, nvinfer1::ITensor* tensor); void SetITensor(const std::string& name, nvinfer1::ITensor* tensor);
// Get an ITensor called name. // Get an ITensor called name.
nvinfer1::ITensor* GetITensor(const std::string& name); nvinfer1::ITensor* GetITensor(const std::string& name);
std::unordered_map<std::string, nvinfer1::ITensor*>* GetITensorMap();
nvinfer1::ICudaEngine* engine() { return infer_engine_.get(); } nvinfer1::ICudaEngine* engine() { return infer_engine_.get(); }
nvinfer1::IExecutionContext* context() { nvinfer1::IExecutionContext* context() {
...@@ -405,8 +406,8 @@ class TensorRTEngine { ...@@ -405,8 +406,8 @@ class TensorRTEngine {
void SetTensorDynamicRange(nvinfer1::ITensor* tensor, float range) { void SetTensorDynamicRange(nvinfer1::ITensor* tensor, float range) {
quant_dynamic_range_[tensor] = range; quant_dynamic_range_[tensor] = range;
} }
template <typename T = float>
float* GetWeightCPUData(const std::string& name, T* GetWeightCPUData(const std::string& name,
framework::Tensor* weight_tensor); framework::Tensor* weight_tensor);
// A pointer to CPU memory is needed of the TRT weight. // A pointer to CPU memory is needed of the TRT weight.
...@@ -424,7 +425,14 @@ class TensorRTEngine { ...@@ -424,7 +425,14 @@ class TensorRTEngine {
static int suffix_counter = 0; static int suffix_counter = 0;
std::string suffix = std::to_string(suffix_counter); std::string suffix = std::to_string(suffix_counter);
std::string splitter = "__"; std::string splitter = "__";
weight_map[w_name + splitter + suffix] = std::move(w_tensor); std::string name_with_suffix = w_name + splitter + suffix;
PADDLE_ENFORCE_EQ(weight_map.count(name_with_suffix),
0,
platform::errors::AlreadyExists(
"The weight named %s is set into the weight map "
"twice in TRT OP converter.",
name_with_suffix));
weight_map[name_with_suffix] = std::move(w_tensor);
suffix_counter += 1; suffix_counter += 1;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册