提交 0664e2aa 编写于 作者: H HexToString

fix bug

上级 d7ebdfa0
...@@ -84,9 +84,12 @@ const std::string getFileBySuffix( ...@@ -84,9 +84,12 @@ const std::string getFileBySuffix(
while ((dirp = readdir(dp)) != nullptr) { while ((dirp = readdir(dp)) != nullptr) {
if (dirp->d_type == DT_REG) { if (dirp->d_type == DT_REG) {
for (int idx = 0; idx < suffixVector.size(); ++idx) { for (int idx = 0; idx < suffixVector.size(); ++idx) {
if (std::string(dirp->d_name).find(suffixVector[idx]) != std::string fileName_in_Dir = static_cast<std::string>(dirp->d_name);
std::string::npos) { if (fileName_in_Dir.length() >= suffixVector[idx].length() &&
fileName = static_cast<std::string>(dirp->d_name); fileName_in_Dir.substr(
fileName_in_Dir.length() - suffixVector[idx].length(),
suffixVector[idx].length()) == suffixVector[idx]) {
fileName = fileName_in_Dir;
break; break;
} }
} }
...@@ -166,8 +169,10 @@ class PaddleInferenceEngine : public EngineCore { ...@@ -166,8 +169,10 @@ class PaddleInferenceEngine : public EngineCore {
} }
Config config; Config config;
std::vector<std::string> suffixParaVector = {".pdiparams", "__params__", "params"}; std::vector<std::string> suffixParaVector = {
std::vector<std::string> suffixModelVector = {".pdmodel", "__model__", "model"}; ".pdiparams", "__params__", "params"};
std::vector<std::string> suffixModelVector = {
".pdmodel", "__model__", "model"};
std::string paraFileName = getFileBySuffix(model_path, suffixParaVector); std::string paraFileName = getFileBySuffix(model_path, suffixParaVector);
std::string modelFileName = getFileBySuffix(model_path, suffixModelVector); std::string modelFileName = getFileBySuffix(model_path, suffixModelVector);
...@@ -273,23 +278,20 @@ class PaddleInferenceEngine : public EngineCore { ...@@ -273,23 +278,20 @@ class PaddleInferenceEngine : public EngineCore {
config.SetXpuDeviceId(gpu_id); config.SetXpuDeviceId(gpu_id);
} }
if (engine_conf.has_use_ascend_cl() && if (engine_conf.has_use_ascend_cl() && engine_conf.use_ascend_cl()) {
engine_conf.use_ascend_cl()) {
if (engine_conf.has_use_lite() && engine_conf.use_lite()) { if (engine_conf.has_use_lite() && engine_conf.use_lite()) {
// for ascend 310 // for ascend 310
FLAGS_nnadapter_device_names = "huawei_ascend_npu"; FLAGS_nnadapter_device_names = "huawei_ascend_npu";
FLAGS_nnadapter_context_properties = FLAGS_nnadapter_context_properties =
"HUAWEI_ASCEND_NPU_SELECTED_DEVICE_IDS=" + "HUAWEI_ASCEND_NPU_SELECTED_DEVICE_IDS=" + std::to_string(gpu_id);
std::to_string(gpu_id);
FLAGS_nnadapter_model_cache_dir = ""; FLAGS_nnadapter_model_cache_dir = "";
config.NNAdapter() config.NNAdapter()
.Enable() .Enable()
.SetDeviceNames({FLAGS_nnadapter_device_names}) .SetDeviceNames({FLAGS_nnadapter_device_names})
.SetContextProperties(FLAGS_nnadapter_context_properties) .SetContextProperties(FLAGS_nnadapter_context_properties)
.SetModelCacheDir(FLAGS_nnadapter_model_cache_dir); .SetModelCacheDir(FLAGS_nnadapter_model_cache_dir);
LOG(INFO) << "Enable Lite NNAdapter for Ascend," LOG(INFO) << "Enable Lite NNAdapter for Ascend,"
<< "nnadapter_device_names=" << "nnadapter_device_names=" << FLAGS_nnadapter_device_names
<< FLAGS_nnadapter_device_names
<< ",nnadapter_context_properties=" << ",nnadapter_context_properties="
<< FLAGS_nnadapter_context_properties << FLAGS_nnadapter_context_properties
<< ",nnadapter_model_cache_dir=" << ",nnadapter_model_cache_dir="
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册