提交 1af674b7 编写于 作者: Z zhangjun

make case non-sensitive

上级 15033071
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <algorithm>
#include <cctype>
#include <fstream> #include <fstream>
#include <string> #include <string>
#include "core/predictor/common/inner_common.h" #include "core/predictor/common/inner_common.h"
...@@ -48,14 +50,24 @@ string PrecisionTypeString(const Precision data_type) { ...@@ -48,14 +50,24 @@ string PrecisionTypeString(const Precision data_type) {
} }
} }
std::string ToLower(const std::string& data) {
std::string result = data;
std::transform(
result.begin(), result.end(), result.begin(), [](unsigned char c) {
return tolower(c);
});
return result;
}
Precision GetPrecision(const std::string& precision_data) { Precision GetPrecision(const std::string& precision_data) {
if (precision_data == "fp32") { std::string precision_type = ToLower(precision_data);
if (precision_type == "fp32") {
return Precision::kFloat32; return Precision::kFloat32;
} else if (precision_data == "int8") { } else if (precision_type == "int8") {
return Precison::kInt8; return Precison::kInt8;
} else if (precision_data == "fp16") { } else if (precision_type == "fp16") {
return Precision::kHalf; return Precision::kHalf;
} else if (precision_data == "bf16") { } else if (precision_type == "bf16") {
return Precision::kBfloat16; return Precision::kBfloat16;
} }
return "unknow type"; return "unknow type";
......
...@@ -119,8 +119,8 @@ class LocalPredictor(object): ...@@ -119,8 +119,8 @@ class LocalPredictor(object):
self.fetch_names_to_type_[var.alias_name] = var.fetch_type self.fetch_names_to_type_[var.alias_name] = var.fetch_type
precision_type = paddle_infer.PrecisionType.Float32 precision_type = paddle_infer.PrecisionType.Float32
if precision in precision_map: if precision.lower() in precision_map:
precision_type = precision_map[precision] precision_type = precision_map[precision.lower()]
if use_profile: if use_profile:
config.enable_profile() config.enable_profile()
if mem_optim: if mem_optim:
...@@ -157,7 +157,7 @@ class LocalPredictor(object): ...@@ -157,7 +157,7 @@ class LocalPredictor(object):
if not use_gpu and not use_lite: if not use_gpu and not use_lite:
if precision_type == paddle_infer.PrecisionType.Int8: if precision_type == paddle_infer.PrecisionType.Int8:
config.enable_quantizer() config.enable_quantizer()
if precision == "bf16": if precision.lower() == "bf16":
config.enable_mkldnn_bfloat16() config.enable_mkldnn_bfloat16()
self.predictor = paddle_infer.create_predictor(config) self.predictor = paddle_infer.create_predictor(config)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册