diff --git a/core/predictor/common/constant.h b/core/predictor/common/constant.h index dd4c5733c410864f3bf8449891bd90e1aec457b1..ec0c5b4ee292f840b6fd8638b891f9d341463dd5 100644 --- a/core/predictor/common/constant.h +++ b/core/predictor/common/constant.h @@ -43,8 +43,6 @@ DECLARE_bool(enable_model_toolkit); DECLARE_string(enable_protocol_list); DECLARE_bool(enable_cube); DECLARE_bool(enable_general_model); -DECLARE_string(precision); -DECLARE_bool(use_calib); // STATIC Variables extern const char* START_OP_NAME; diff --git a/core/predictor/common/utils.h b/core/predictor/common/utils.h index 6108389bbbfa5609fc12ee39381906e50ad0b847..7fee442d43a2e9344150647ddba16e91c24a3044 100644 --- a/core/predictor/common/utils.h +++ b/core/predictor/common/utils.h @@ -29,21 +29,22 @@ namespace butil = base; #endif enum class Precision { - kFloat32 = 0, ///< fp32 - kInt8, ///< int8 - kHalf, ///< fp16 - kBfloat16, ///< bf16 + kUnk = -1, // unknown type + kFloat32 = 0, // fp32 + kInt8, // int8 + kHalf, // fp16 + kBfloat16 // bf16 }; -string PrecisionTypeString(const Precision data_type) { +std::string PrecisionTypeString(const Precision data_type) { switch (data_type) { - case 0: + case Precision::kFloat32: return "kFloat32"; - case 1: + case Precision::kInt8: return "kInt8"; - case 2: + case Precision::kHalf: return "kHalf"; - case 3: + case Precision::kBfloat16: return "kBloat16"; default: return "unUnk"; @@ -59,20 +60,6 @@ std::string ToLower(const std::string& data) { return result; } -Precision GetPrecision(const std::string& precision_data) { - std::string precision_type = ToLower(precision_data); - if (precision_type == "fp32") { - return Precision::kFloat32; - } else if (precision_type == "int8") { - return Precison::kInt8; - } else if (precision_type == "fp16") { - return Precision::kHalf; - } else if (precision_type == "bf16") { - return Precision::kBfloat16; - } - return "unknow type"; -} - class TimerFlow { public: static const int MAX_SIZE = 1024; diff --git a/paddle_inference/paddle/include/paddle_engine.h b/paddle_inference/paddle/include/paddle_engine.h index 10bcdbfd7a075016649a3b986be2d777f3fbe3a8..a2be5257aeedb984a9d30b9946c707fcf3ff824d 100644 --- a/paddle_inference/paddle/include/paddle_engine.h +++ b/paddle_inference/paddle/include/paddle_engine.h @@ -37,10 +37,24 @@ using paddle_infer::Tensor; using paddle_infer::CreatePredictor; DECLARE_int32(gpuid); +DECLARE_string(precision); +DECLARE_bool(use_calib); static const int max_batch = 32; static const int min_subgraph_size = 3; -static predictor::Precision precision_type; +static PrecisionType precision_type; + +PrecisionType GetPrecision(const std::string& precision_data) { + std::string precision_type = predictor::ToLower(precision_data); + if (precision_type == "fp32") { + return PrecisionType::kFloat32; + } else if (precision_type == "int8") { + return PrecisionType::kInt8; + } else if (precision_type == "fp16") { + return PrecisionType::kHalf; + } + return PrecisionType::kFloat32; +} // Engine Base class PaddleEngineBase { @@ -138,7 +152,7 @@ class PaddleInferenceEngine : public PaddleEngineBase { // 2000MB GPU memory config.EnableUseGpu(2000, FLAGS_gpuid); } - precision_type = predictor::GetPrecision(FLAGS_precision); + precision_type = GetPrecision(FLAGS_precision); if (engine_conf.has_use_trt() && engine_conf.use_trt()) { if (!engine_conf.has_use_gpu() || !engine_conf.use_gpu()) { @@ -149,7 +163,7 @@ class PaddleInferenceEngine : public PaddleEngineBase { min_subgraph_size, precision_type, false, - use_calib); + FLAGS_use_calib); LOG(INFO) << "create TensorRT predictor"; } @@ -160,9 +174,9 @@ class PaddleInferenceEngine : public PaddleEngineBase { if ((!engine_conf.has_use_lite() && !engine_conf.has_use_gpu()) || (engine_conf.has_use_lite() && !engine_conf.use_lite() && engine_conf.has_use_gpu() && !engine_conf.use_gpu())) { - if (precision_type == Precision::kInt8) { + if (precision_type == PrecisionType::kInt8) { config.EnableMkldnnQuantizer(); - } else if (precision_type == Precision::kHalf) { + } else if (precision_type == PrecisionType::kHalf) { config.EnableMkldnnBfloat16(); } }