未验证 提交 76156d12 编写于 作者: Z zhoutianzi666 提交者: GitHub

[inference TRT]template GetWeightCPUData (#43993)

* template GetWeightCPUData
上级 267d3191
......@@ -390,33 +390,36 @@ nvinfer1::ITensor *TensorRTEngine::GetITensor(const std::string &name) {
return itensor_map_[name];
}
std::unordered_map<std::string, nvinfer1::ITensor *>
*TensorRTEngine::GetITensorMap() {
return &itensor_map_;
}
void TensorRTEngine::SetRuntimeBatch(size_t batch_size) {
runtime_batch_ = batch_size;
}
float *TensorRTEngine::GetWeightCPUData(const std::string &name,
template <typename T = float>
T *TensorRTEngine::GetWeightCPUData(const std::string &name,
framework::Tensor *weight_tensor) {
static int name_suffix_counter = 0;
std::string name_suffix = std::to_string(name_suffix_counter);
std::string splitter = "__";
std::string name_with_suffix = name + splitter + name_suffix;
std::unique_ptr<framework::Tensor> cpu_weight_tensor(new framework::Tensor());
platform::CPUPlace cpu_place;
PADDLE_ENFORCE_EQ(weight_map.count(name_with_suffix),
0,
platform::errors::AlreadyExists(
"The weight named %s is set into the weight map "
"twice in TRT OP converter.",
name_with_suffix));
weight_map[name_with_suffix].reset(new framework::Tensor());
weight_map[name_with_suffix]->Resize(weight_tensor->dims());
cpu_weight_tensor->Resize(weight_tensor->dims());
paddle::framework::TensorCopySync(
*weight_tensor, cpu_place, weight_map[name_with_suffix].get());
float *weight_data =
weight_map[name_with_suffix]->mutable_data<float>(cpu_place);
name_suffix_counter += 1;
*weight_tensor, cpu_place, cpu_weight_tensor.get());
T *weight_data = cpu_weight_tensor->mutable_data<T>(cpu_place);
SetWeights(name, std::move(cpu_weight_tensor));
return weight_data;
}
template float *TensorRTEngine::GetWeightCPUData(
const std::string &name, framework::Tensor *weight_tensor);
template int32_t *TensorRTEngine::GetWeightCPUData(
const std::string &name, framework::Tensor *weight_tensor);
template int64_t *TensorRTEngine::GetWeightCPUData(
const std::string &name, framework::Tensor *weight_tensor);
int TensorRTEngine::GetRuntimeBatch() { return runtime_batch_; }
nvinfer1::IPluginV2Layer *TensorRTEngine::AddPlugin(
......
......@@ -268,6 +268,7 @@ class TensorRTEngine {
void SetITensor(const std::string& name, nvinfer1::ITensor* tensor);
// Get an ITensor called name.
nvinfer1::ITensor* GetITensor(const std::string& name);
std::unordered_map<std::string, nvinfer1::ITensor*>* GetITensorMap();
nvinfer1::ICudaEngine* engine() { return infer_engine_.get(); }
nvinfer1::IExecutionContext* context() {
......@@ -405,8 +406,8 @@ class TensorRTEngine {
void SetTensorDynamicRange(nvinfer1::ITensor* tensor, float range) {
quant_dynamic_range_[tensor] = range;
}
float* GetWeightCPUData(const std::string& name,
template <typename T = float>
T* GetWeightCPUData(const std::string& name,
framework::Tensor* weight_tensor);
// A pointer to CPU memory is needed of the TRT weight.
......@@ -424,7 +425,14 @@ class TensorRTEngine {
static int suffix_counter = 0;
std::string suffix = std::to_string(suffix_counter);
std::string splitter = "__";
weight_map[w_name + splitter + suffix] = std::move(w_tensor);
std::string name_with_suffix = w_name + splitter + suffix;
PADDLE_ENFORCE_EQ(weight_map.count(name_with_suffix),
0,
platform::errors::AlreadyExists(
"The weight named %s is set into the weight map "
"twice in TRT OP converter.",
name_with_suffix));
weight_map[name_with_suffix] = std::move(w_tensor);
suffix_counter += 1;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册