提交 a5113877 编写于 作者: Z zhangjun

fix

上级 c2928447
......@@ -43,8 +43,6 @@ DECLARE_bool(enable_model_toolkit);
DECLARE_string(enable_protocol_list);
DECLARE_bool(enable_cube);
DECLARE_bool(enable_general_model);
DECLARE_string(precision);
DECLARE_bool(use_calib);
// STATIC Variables
extern const char* START_OP_NAME;
......
......@@ -29,21 +29,22 @@ namespace butil = base;
#endif
enum class Precision {
kFloat32 = 0, ///< fp32
kInt8, ///< int8
kHalf, ///< fp16
kBfloat16, ///< bf16
kUnk = -1, // unknown type
kFloat32 = 0, // fp32
kInt8, // int8
kHalf, // fp16
kBfloat16 // bf16
};
string PrecisionTypeString(const Precision data_type) {
std::string PrecisionTypeString(const Precision data_type) {
switch (data_type) {
case 0:
case Precision::kFloat32:
return "kFloat32";
case 1:
case Precision::kInt8:
return "kInt8";
case 2:
case Precision::kHalf:
return "kHalf";
case 3:
case Precision::kBfloat16:
return "kBloat16";
default:
return "unUnk";
......@@ -59,20 +60,6 @@ std::string ToLower(const std::string& data) {
return result;
}
Precision GetPrecision(const std::string& precision_data) {
std::string precision_type = ToLower(precision_data);
if (precision_type == "fp32") {
return Precision::kFloat32;
} else if (precision_type == "int8") {
return Precison::kInt8;
} else if (precision_type == "fp16") {
return Precision::kHalf;
} else if (precision_type == "bf16") {
return Precision::kBfloat16;
}
return "unknow type";
}
class TimerFlow {
public:
static const int MAX_SIZE = 1024;
......
......@@ -37,10 +37,24 @@ using paddle_infer::Tensor;
using paddle_infer::CreatePredictor;
DECLARE_int32(gpuid);
DECLARE_string(precision);
DECLARE_bool(use_calib);
static const int max_batch = 32;
static const int min_subgraph_size = 3;
static predictor::Precision precision_type;
static PrecisionType precision_type;
PrecisionType GetPrecision(const std::string& precision_data) {
std::string precision_type = predictor::ToLower(precision_data);
if (precision_type == "fp32") {
return PrecisionType::kFloat32;
} else if (precision_type == "int8") {
return PrecisionType::kInt8;
} else if (precision_type == "fp16") {
return PrecisionType::kHalf;
}
return PrecisionType::kFloat32;
}
// Engine Base
class PaddleEngineBase {
......@@ -138,7 +152,7 @@ class PaddleInferenceEngine : public PaddleEngineBase {
// 2000MB GPU memory
config.EnableUseGpu(2000, FLAGS_gpuid);
}
precision_type = predictor::GetPrecision(FLAGS_precision);
precision_type = GetPrecision(FLAGS_precision);
if (engine_conf.has_use_trt() && engine_conf.use_trt()) {
if (!engine_conf.has_use_gpu() || !engine_conf.use_gpu()) {
......@@ -149,7 +163,7 @@ class PaddleInferenceEngine : public PaddleEngineBase {
min_subgraph_size,
precision_type,
false,
use_calib);
FLAGS_use_calib);
LOG(INFO) << "create TensorRT predictor";
}
......@@ -160,9 +174,9 @@ class PaddleInferenceEngine : public PaddleEngineBase {
if ((!engine_conf.has_use_lite() && !engine_conf.has_use_gpu()) ||
(engine_conf.has_use_lite() && !engine_conf.use_lite() &&
engine_conf.has_use_gpu() && !engine_conf.use_gpu())) {
if (precision_type == Precision::kInt8) {
if (precision_type == PrecisionType::kInt8) {
config.EnableMkldnnQuantizer();
} else if (precision_type == Precision::kHalf) {
} else if (precision_type == PrecisionType::kHalf) {
config.EnableMkldnnBfloat16();
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册