未验证 提交 2e2f92a5 编写于 作者: P Pei Yang 提交者: GitHub

fix trt weight bug (#21231)

added splitter "__" between weight name and suffix number to avoid conflicts.
上级 29b63f0a
......@@ -211,7 +211,8 @@ float *TensorRTEngine::GetWeightCPUData(const std::string &name,
const std::vector<float> &scale) {
static int name_suffix_counter = 0;
std::string name_suffix = std::to_string(name_suffix_counter);
std::string name_with_suffix = name + name_suffix;
std::string splitter = "__";
std::string name_with_suffix = name + splitter + name_suffix;
auto w_dims = weight_tensor->dims();
platform::CPUPlace cpu_place;
PADDLE_ENFORCE_EQ(
......
......@@ -159,7 +159,8 @@ class TensorRTEngine {
std::unique_ptr<framework::Tensor> w_tensor) {
static int suffix_counter = 0;
std::string suffix = std::to_string(suffix_counter);
weight_map[w_name + suffix] = std::move(w_tensor);
std::string splitter = "__";
weight_map[w_name + splitter + suffix] = std::move(w_tensor);
suffix_counter += 1;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册