提交 a5113877 编写于 作者: Z zhangjun

fix

上级 c2928447
...@@ -43,8 +43,6 @@ DECLARE_bool(enable_model_toolkit); ...@@ -43,8 +43,6 @@ DECLARE_bool(enable_model_toolkit);
DECLARE_string(enable_protocol_list); DECLARE_string(enable_protocol_list);
DECLARE_bool(enable_cube); DECLARE_bool(enable_cube);
DECLARE_bool(enable_general_model); DECLARE_bool(enable_general_model);
DECLARE_string(precision);
DECLARE_bool(use_calib);
// STATIC Variables // STATIC Variables
extern const char* START_OP_NAME; extern const char* START_OP_NAME;
......
...@@ -29,21 +29,22 @@ namespace butil = base; ...@@ -29,21 +29,22 @@ namespace butil = base;
#endif #endif
enum class Precision { enum class Precision {
kFloat32 = 0, ///< fp32 kUnk = -1, // unknown type
kInt8, ///< int8 kFloat32 = 0, // fp32
kHalf, ///< fp16 kInt8, // int8
kBfloat16, ///< bf16 kHalf, // fp16
kBfloat16 // bf16
}; };
string PrecisionTypeString(const Precision data_type) { std::string PrecisionTypeString(const Precision data_type) {
switch (data_type) { switch (data_type) {
case 0: case Precision::kFloat32:
return "kFloat32"; return "kFloat32";
case 1: case Precision::kInt8:
return "kInt8"; return "kInt8";
case 2: case Precision::kHalf:
return "kHalf"; return "kHalf";
case 3: case Precision::kBfloat16:
return "kBloat16"; return "kBloat16";
default: default:
return "unUnk"; return "unUnk";
...@@ -59,20 +60,6 @@ std::string ToLower(const std::string& data) { ...@@ -59,20 +60,6 @@ std::string ToLower(const std::string& data) {
return result; return result;
} }
Precision GetPrecision(const std::string& precision_data) {
std::string precision_type = ToLower(precision_data);
if (precision_type == "fp32") {
return Precision::kFloat32;
} else if (precision_type == "int8") {
return Precison::kInt8;
} else if (precision_type == "fp16") {
return Precision::kHalf;
} else if (precision_type == "bf16") {
return Precision::kBfloat16;
}
return "unknow type";
}
class TimerFlow { class TimerFlow {
public: public:
static const int MAX_SIZE = 1024; static const int MAX_SIZE = 1024;
......
...@@ -37,10 +37,24 @@ using paddle_infer::Tensor; ...@@ -37,10 +37,24 @@ using paddle_infer::Tensor;
using paddle_infer::CreatePredictor; using paddle_infer::CreatePredictor;
DECLARE_int32(gpuid); DECLARE_int32(gpuid);
DECLARE_string(precision);
DECLARE_bool(use_calib);
static const int max_batch = 32; static const int max_batch = 32;
static const int min_subgraph_size = 3; static const int min_subgraph_size = 3;
static predictor::Precision precision_type; static PrecisionType precision_type;
PrecisionType GetPrecision(const std::string& precision_data) {
std::string precision_type = predictor::ToLower(precision_data);
if (precision_type == "fp32") {
return PrecisionType::kFloat32;
} else if (precision_type == "int8") {
return PrecisionType::kInt8;
} else if (precision_type == "fp16") {
return PrecisionType::kHalf;
}
return PrecisionType::kFloat32;
}
// Engine Base // Engine Base
class PaddleEngineBase { class PaddleEngineBase {
...@@ -138,7 +152,7 @@ class PaddleInferenceEngine : public PaddleEngineBase { ...@@ -138,7 +152,7 @@ class PaddleInferenceEngine : public PaddleEngineBase {
// 2000MB GPU memory // 2000MB GPU memory
config.EnableUseGpu(2000, FLAGS_gpuid); config.EnableUseGpu(2000, FLAGS_gpuid);
} }
precision_type = predictor::GetPrecision(FLAGS_precision); precision_type = GetPrecision(FLAGS_precision);
if (engine_conf.has_use_trt() && engine_conf.use_trt()) { if (engine_conf.has_use_trt() && engine_conf.use_trt()) {
if (!engine_conf.has_use_gpu() || !engine_conf.use_gpu()) { if (!engine_conf.has_use_gpu() || !engine_conf.use_gpu()) {
...@@ -149,7 +163,7 @@ class PaddleInferenceEngine : public PaddleEngineBase { ...@@ -149,7 +163,7 @@ class PaddleInferenceEngine : public PaddleEngineBase {
min_subgraph_size, min_subgraph_size,
precision_type, precision_type,
false, false,
use_calib); FLAGS_use_calib);
LOG(INFO) << "create TensorRT predictor"; LOG(INFO) << "create TensorRT predictor";
} }
...@@ -160,9 +174,9 @@ class PaddleInferenceEngine : public PaddleEngineBase { ...@@ -160,9 +174,9 @@ class PaddleInferenceEngine : public PaddleEngineBase {
if ((!engine_conf.has_use_lite() && !engine_conf.has_use_gpu()) || if ((!engine_conf.has_use_lite() && !engine_conf.has_use_gpu()) ||
(engine_conf.has_use_lite() && !engine_conf.use_lite() && (engine_conf.has_use_lite() && !engine_conf.use_lite() &&
engine_conf.has_use_gpu() && !engine_conf.use_gpu())) { engine_conf.has_use_gpu() && !engine_conf.use_gpu())) {
if (precision_type == Precision::kInt8) { if (precision_type == PrecisionType::kInt8) {
config.EnableMkldnnQuantizer(); config.EnableMkldnnQuantizer();
} else if (precision_type == Precision::kHalf) { } else if (precision_type == PrecisionType::kHalf) {
config.EnableMkldnnBfloat16(); config.EnableMkldnnBfloat16();
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册