未验证 提交 62a23aec 编写于 作者: T TeslaZhao 提交者: GitHub

Merge pull request #1130 from zhangjun/low-precision

low-precision support
...@@ -13,8 +13,10 @@ ...@@ -13,8 +13,10 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <string> #include <algorithm>
#include <cctype>
#include <fstream> #include <fstream>
#include <string>
#include "core/predictor/common/inner_common.h" #include "core/predictor/common/inner_common.h"
#include "core/predictor/common/macros.h" #include "core/predictor/common/macros.h"
...@@ -26,6 +28,38 @@ namespace predictor { ...@@ -26,6 +28,38 @@ namespace predictor {
namespace butil = base; namespace butil = base;
#endif #endif
enum class Precision {
kUnk = -1, // unknown type
kFloat32 = 0, // fp32
kInt8, // int8
kHalf, // fp16
kBfloat16, // bf16
};
static std::string PrecisionTypeString(const Precision data_type) {
switch (data_type) {
case Precision::kFloat32:
return "kFloat32";
case Precision::kInt8:
return "kInt8";
case Precision::kHalf:
return "kHalf";
case Precision::kBfloat16:
return "kBloat16";
default:
return "unUnk";
}
}
static std::string ToLower(const std::string& data) {
std::string result = data;
std::transform(
result.begin(), result.end(), result.begin(), [](unsigned char c) {
return tolower(c);
});
return result;
}
class TimerFlow { class TimerFlow {
public: public:
static const int MAX_SIZE = 1024; static const int MAX_SIZE = 1024;
......
...@@ -37,9 +37,24 @@ using paddle_infer::Tensor; ...@@ -37,9 +37,24 @@ using paddle_infer::Tensor;
using paddle_infer::CreatePredictor; using paddle_infer::CreatePredictor;
DECLARE_int32(gpuid); DECLARE_int32(gpuid);
DECLARE_string(precision);
DECLARE_bool(use_calib);
static const int max_batch = 32; static const int max_batch = 32;
static const int min_subgraph_size = 3; static const int min_subgraph_size = 3;
static PrecisionType precision_type;
PrecisionType GetPrecision(const std::string& precision_data) {
std::string precision_type = predictor::ToLower(precision_data);
if (precision_type == "fp32") {
return PrecisionType::kFloat32;
} else if (precision_type == "int8") {
return PrecisionType::kInt8;
} else if (precision_type == "fp16") {
return PrecisionType::kHalf;
}
return PrecisionType::kFloat32;
}
// Engine Base // Engine Base
class PaddleEngineBase { class PaddleEngineBase {
...@@ -137,6 +152,7 @@ class PaddleInferenceEngine : public PaddleEngineBase { ...@@ -137,6 +152,7 @@ class PaddleInferenceEngine : public PaddleEngineBase {
// 2000MB GPU memory // 2000MB GPU memory
config.EnableUseGpu(2000, FLAGS_gpuid); config.EnableUseGpu(2000, FLAGS_gpuid);
} }
precision_type = GetPrecision(FLAGS_precision);
if (engine_conf.has_use_trt() && engine_conf.use_trt()) { if (engine_conf.has_use_trt() && engine_conf.use_trt()) {
if (!engine_conf.has_use_gpu() || !engine_conf.use_gpu()) { if (!engine_conf.has_use_gpu() || !engine_conf.use_gpu()) {
...@@ -145,14 +161,24 @@ class PaddleInferenceEngine : public PaddleEngineBase { ...@@ -145,14 +161,24 @@ class PaddleInferenceEngine : public PaddleEngineBase {
config.EnableTensorRtEngine(1 << 20, config.EnableTensorRtEngine(1 << 20,
max_batch, max_batch,
min_subgraph_size, min_subgraph_size,
Config::Precision::kFloat32, precision_type,
false, false,
false); FLAGS_use_calib);
LOG(INFO) << "create TensorRT predictor"; LOG(INFO) << "create TensorRT predictor";
} }
if (engine_conf.has_use_lite() && engine_conf.use_lite()) { if (engine_conf.has_use_lite() && engine_conf.use_lite()) {
config.EnableLiteEngine(PrecisionType::kFloat32, true); config.EnableLiteEngine(precision_type, true);
}
if ((!engine_conf.has_use_lite() && !engine_conf.has_use_gpu()) ||
(engine_conf.has_use_lite() && !engine_conf.use_lite() &&
engine_conf.has_use_gpu() && !engine_conf.use_gpu())) {
if (precision_type == PrecisionType::kInt8) {
config.EnableMkldnnQuantizer();
} else if (precision_type == PrecisionType::kHalf) {
config.EnableMkldnnBfloat16();
}
} }
if (engine_conf.has_use_xpu() && engine_conf.use_xpu()) { if (engine_conf.has_use_xpu() && engine_conf.use_xpu()) {
...@@ -171,7 +197,6 @@ class PaddleInferenceEngine : public PaddleEngineBase { ...@@ -171,7 +197,6 @@ class PaddleInferenceEngine : public PaddleEngineBase {
config.EnableMemoryOptim(); config.EnableMemoryOptim();
} }
predictor::AutoLock lock(predictor::GlobalCreateMutex::instance()); predictor::AutoLock lock(predictor::GlobalCreateMutex::instance());
_predictor = CreatePredictor(config); _predictor = CreatePredictor(config);
if (NULL == _predictor.get()) { if (NULL == _predictor.get()) {
......
...@@ -20,6 +20,8 @@ namespace paddle_serving { ...@@ -20,6 +20,8 @@ namespace paddle_serving {
namespace inference { namespace inference {
DEFINE_int32(gpuid, 0, "GPU device id to use"); DEFINE_int32(gpuid, 0, "GPU device id to use");
DEFINE_string(precision, "fp32", "precision to deploy, default is fp32");
DEFINE_bool(use_calib, false, "calibration mode, default is false");
REGIST_FACTORY_OBJECT_IMPL_WITH_NAME( REGIST_FACTORY_OBJECT_IMPL_WITH_NAME(
::baidu::paddle_serving::predictor::FluidInferEngine<PaddleInferenceEngine>, ::baidu::paddle_serving::predictor::FluidInferEngine<PaddleInferenceEngine>,
......
...@@ -27,6 +27,12 @@ logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s") ...@@ -27,6 +27,12 @@ logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger("LocalPredictor") logger = logging.getLogger("LocalPredictor")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
precision_map = {
'int8': paddle_infer.PrecisionType.Int8,
'fp32': paddle_infer.PrecisionType.Float32,
'fp16': paddle_infer.PrecisionType.Half,
}
class LocalPredictor(object): class LocalPredictor(object):
""" """
...@@ -56,6 +62,8 @@ class LocalPredictor(object): ...@@ -56,6 +62,8 @@ class LocalPredictor(object):
use_trt=False, use_trt=False,
use_lite=False, use_lite=False,
use_xpu=False, use_xpu=False,
precision="fp32",
use_calib=False,
use_feed_fetch_ops=False): use_feed_fetch_ops=False):
""" """
Load model configs and create the paddle predictor by Paddle Inference API. Load model configs and create the paddle predictor by Paddle Inference API.
...@@ -71,6 +79,8 @@ class LocalPredictor(object): ...@@ -71,6 +79,8 @@ class LocalPredictor(object):
use_trt: use nvidia TensorRT optimization, False default use_trt: use nvidia TensorRT optimization, False default
use_lite: use Paddle-Lite Engint, False default use_lite: use Paddle-Lite Engint, False default
use_xpu: run predict on Baidu Kunlun, False default use_xpu: run predict on Baidu Kunlun, False default
precision: precision mode, "fp32" default
use_calib: use TensorRT calibration, False default
use_feed_fetch_ops: use feed/fetch ops, False default. use_feed_fetch_ops: use feed/fetch ops, False default.
""" """
client_config = "{}/serving_server_conf.prototxt".format(model_path) client_config = "{}/serving_server_conf.prototxt".format(model_path)
...@@ -88,9 +98,11 @@ class LocalPredictor(object): ...@@ -88,9 +98,11 @@ class LocalPredictor(object):
logger.info( logger.info(
"LocalPredictor load_model_config params: model_path:{}, use_gpu:{},\ "LocalPredictor load_model_config params: model_path:{}, use_gpu:{},\
gpu_id:{}, use_profile:{}, thread_num:{}, mem_optim:{}, ir_optim:{},\ gpu_id:{}, use_profile:{}, thread_num:{}, mem_optim:{}, ir_optim:{},\
use_trt:{}, use_lite:{}, use_xpu: {}, use_feed_fetch_ops:{}".format( use_trt:{}, use_lite:{}, use_xpu: {}, precision: {}, use_calib: {},\
model_path, use_gpu, gpu_id, use_profile, thread_num, mem_optim, use_feed_fetch_ops:{}"
ir_optim, use_trt, use_lite, use_xpu, use_feed_fetch_ops)) .format(model_path, use_gpu, gpu_id, use_profile, thread_num,
mem_optim, ir_optim, use_trt, use_lite, use_xpu, precision,
use_calib, use_feed_fetch_ops))
self.feed_names_ = [var.alias_name for var in model_conf.feed_var] self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var] self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
...@@ -106,6 +118,9 @@ class LocalPredictor(object): ...@@ -106,6 +118,9 @@ class LocalPredictor(object):
self.fetch_names_to_idx_[var.alias_name] = i self.fetch_names_to_idx_[var.alias_name] = i
self.fetch_names_to_type_[var.alias_name] = var.fetch_type self.fetch_names_to_type_[var.alias_name] = var.fetch_type
precision_type = paddle_infer.PrecisionType.Float32
if precision.lower() in precision_map:
precision_type = precision_map[precision.lower()]
if use_profile: if use_profile:
config.enable_profile() config.enable_profile()
if mem_optim: if mem_optim:
...@@ -121,6 +136,7 @@ class LocalPredictor(object): ...@@ -121,6 +136,7 @@ class LocalPredictor(object):
config.enable_use_gpu(100, gpu_id) config.enable_use_gpu(100, gpu_id)
if use_trt: if use_trt:
config.enable_tensorrt_engine( config.enable_tensorrt_engine(
precision_mode=precision_type,
workspace_size=1 << 20, workspace_size=1 << 20,
max_batch_size=32, max_batch_size=32,
min_subgraph_size=3, min_subgraph_size=3,
...@@ -129,7 +145,7 @@ class LocalPredictor(object): ...@@ -129,7 +145,7 @@ class LocalPredictor(object):
if use_lite: if use_lite:
config.enable_lite_engine( config.enable_lite_engine(
precision_mode=paddle_infer.PrecisionType.Float32, precision_mode=precision_type,
zero_copy=True, zero_copy=True,
passes_filter=[], passes_filter=[],
ops_filter=[]) ops_filter=[])
...@@ -138,6 +154,11 @@ class LocalPredictor(object): ...@@ -138,6 +154,11 @@ class LocalPredictor(object):
# 2MB l3 cache # 2MB l3 cache
config.enable_xpu(8 * 1024 * 1024) config.enable_xpu(8 * 1024 * 1024)
if not use_gpu and not use_lite:
if precision_type == paddle_infer.PrecisionType.Int8:
config.enable_quantizer()
if precision.lower() == "bf16":
config.enable_mkldnn_bfloat16()
self.predictor = paddle_infer.create_predictor(config) self.predictor = paddle_infer.create_predictor(config)
def predict(self, feed=None, fetch=None, batch=False, log_id=0): def predict(self, feed=None, fetch=None, batch=False, log_id=0):
......
...@@ -51,6 +51,16 @@ def serve_args(): ...@@ -51,6 +51,16 @@ def serve_args():
"--name", type=str, default="None", help="Default service name") "--name", type=str, default="None", help="Default service name")
parser.add_argument( parser.add_argument(
"--use_mkl", default=False, action="store_true", help="Use MKL") "--use_mkl", default=False, action="store_true", help="Use MKL")
parser.add_argument(
"--precision",
type=str,
default="fp32",
help="precision mode(fp32, int8, fp16, bf16)")
parser.add_argument(
"--use_calib",
default=False,
action="store_true",
help="Use TensorRT Calibration")
parser.add_argument( parser.add_argument(
"--mem_optim_off", "--mem_optim_off",
default=False, default=False,
...@@ -109,7 +119,7 @@ def start_standard_model(serving_port): # pylint: disable=doc-string-missing ...@@ -109,7 +119,7 @@ def start_standard_model(serving_port): # pylint: disable=doc-string-missing
if model == "": if model == "":
print("You must specify your serving model") print("You must specify your serving model")
exit(-1) exit(-1)
for single_model_config in args.model: for single_model_config in args.model:
if os.path.isdir(single_model_config): if os.path.isdir(single_model_config):
pass pass
...@@ -131,11 +141,10 @@ def start_standard_model(serving_port): # pylint: disable=doc-string-missing ...@@ -131,11 +141,10 @@ def start_standard_model(serving_port): # pylint: disable=doc-string-missing
infer_op_name = "general_detection" infer_op_name = "general_detection"
general_infer_op = op_maker.create(infer_op_name) general_infer_op = op_maker.create(infer_op_name)
op_seq_maker.add_op(general_infer_op) op_seq_maker.add_op(general_infer_op)
general_response_op = op_maker.create('general_response') general_response_op = op_maker.create('general_response')
op_seq_maker.add_op(general_response_op) op_seq_maker.add_op(general_response_op)
server = None server = None
if use_multilang: if use_multilang:
server = serving.MultiLangServer() server = serving.MultiLangServer()
...@@ -148,6 +157,8 @@ def start_standard_model(serving_port): # pylint: disable=doc-string-missing ...@@ -148,6 +157,8 @@ def start_standard_model(serving_port): # pylint: disable=doc-string-missing
server.use_mkl(use_mkl) server.use_mkl(use_mkl)
server.set_max_body_size(max_body_size) server.set_max_body_size(max_body_size)
server.set_port(port) server.set_port(port)
server.set_precision(args.precision)
server.set_use_calib(args.use_calib)
server.use_encryption_model(use_encryption_model) server.use_encryption_model(use_encryption_model)
if args.product_name != None: if args.product_name != None:
server.set_product_name(args.product_name) server.set_product_name(args.product_name)
...@@ -199,7 +210,7 @@ def start_gpu_card_model(index, gpuid, port, args): # pylint: disable=doc-strin ...@@ -199,7 +210,7 @@ def start_gpu_card_model(index, gpuid, port, args): # pylint: disable=doc-strin
infer_op_name = "general_infer" infer_op_name = "general_infer"
general_infer_op = op_maker.create(infer_op_name) general_infer_op = op_maker.create(infer_op_name)
op_seq_maker.add_op(general_infer_op) op_seq_maker.add_op(general_infer_op)
general_response_op = op_maker.create('general_response') general_response_op = op_maker.create('general_response')
op_seq_maker.add_op(general_response_op) op_seq_maker.add_op(general_response_op)
...@@ -210,6 +221,8 @@ def start_gpu_card_model(index, gpuid, port, args): # pylint: disable=doc-strin ...@@ -210,6 +221,8 @@ def start_gpu_card_model(index, gpuid, port, args): # pylint: disable=doc-strin
server.set_op_sequence(op_seq_maker.get_op_sequence()) server.set_op_sequence(op_seq_maker.get_op_sequence())
server.set_num_threads(thread_num) server.set_num_threads(thread_num)
server.use_mkl(use_mkl) server.use_mkl(use_mkl)
server.set_precision(args.precision)
server.set_use_calib(args.use_calib)
server.set_memory_optimize(mem_optim) server.set_memory_optimize(mem_optim)
server.set_ir_optimize(ir_optim) server.set_ir_optimize(ir_optim)
server.set_max_body_size(max_body_size) server.set_max_body_size(max_body_size)
...@@ -297,7 +310,8 @@ class MainService(BaseHTTPRequestHandler): ...@@ -297,7 +310,8 @@ class MainService(BaseHTTPRequestHandler):
key = base64.b64decode(post_data["key"].encode()) key = base64.b64decode(post_data["key"].encode())
for single_model_config in args.model: for single_model_config in args.model:
if os.path.isfile(single_model_config): if os.path.isfile(single_model_config):
raise ValueError("The input of --model should be a dir not file.") raise ValueError(
"The input of --model should be a dir not file.")
with open(single_model_config + "/key", "wb") as f: with open(single_model_config + "/key", "wb") as f:
f.write(key) f.write(key)
return True return True
...@@ -309,7 +323,8 @@ class MainService(BaseHTTPRequestHandler): ...@@ -309,7 +323,8 @@ class MainService(BaseHTTPRequestHandler):
key = base64.b64decode(post_data["key"].encode()) key = base64.b64decode(post_data["key"].encode())
for single_model_config in args.model: for single_model_config in args.model:
if os.path.isfile(single_model_config): if os.path.isfile(single_model_config):
raise ValueError("The input of --model should be a dir not file.") raise ValueError(
"The input of --model should be a dir not file.")
with open(single_model_config + "/key", "rb") as f: with open(single_model_config + "/key", "rb") as f:
cur_key = f.read() cur_key = f.read()
if key != cur_key: if key != cur_key:
...@@ -394,7 +409,9 @@ if __name__ == "__main__": ...@@ -394,7 +409,9 @@ if __name__ == "__main__":
device=args.device, device=args.device,
use_lite=args.use_lite, use_lite=args.use_lite,
use_xpu=args.use_xpu, use_xpu=args.use_xpu,
ir_optim=args.ir_optim) ir_optim=args.ir_optim,
precision=args.precision,
use_calib=args.use_calib)
web_service.run_rpc_service() web_service.run_rpc_service()
app_instance = Flask(__name__) app_instance = Flask(__name__)
......
...@@ -71,6 +71,8 @@ class Server(object): ...@@ -71,6 +71,8 @@ class Server(object):
self.max_concurrency = 0 self.max_concurrency = 0
self.num_threads = 2 self.num_threads = 2
self.port = 8080 self.port = 8080
self.precision = "fp32"
self.use_calib = False
self.reload_interval_s = 10 self.reload_interval_s = 10
self.max_body_size = 64 * 1024 * 1024 self.max_body_size = 64 * 1024 * 1024
self.module_path = os.path.dirname(paddle_serving_server.__file__) self.module_path = os.path.dirname(paddle_serving_server.__file__)
...@@ -113,6 +115,12 @@ class Server(object): ...@@ -113,6 +115,12 @@ class Server(object):
def set_port(self, port): def set_port(self, port):
self.port = port self.port = port
def set_precision(self, precision="fp32"):
self.precision = precision
def set_use_calib(self, use_calib=False):
self.use_calib = use_calib
def set_reload_interval(self, interval): def set_reload_interval(self, interval):
self.reload_interval_s = interval self.reload_interval_s = interval
...@@ -186,6 +194,10 @@ class Server(object): ...@@ -186,6 +194,10 @@ class Server(object):
engine.use_trt = self.use_trt engine.use_trt = self.use_trt
engine.use_lite = self.use_lite engine.use_lite = self.use_lite
engine.use_xpu = self.use_xpu engine.use_xpu = self.use_xpu
engine.use_gpu = False
if self.device == "gpu":
engine.use_gpu = True
if os.path.exists('{}/__params__'.format(model_config_path)): if os.path.exists('{}/__params__'.format(model_config_path)):
engine.combined_model = True engine.combined_model = True
else: else:
...@@ -472,6 +484,8 @@ class Server(object): ...@@ -472,6 +484,8 @@ class Server(object):
"-max_concurrency {} " \ "-max_concurrency {} " \
"-num_threads {} " \ "-num_threads {} " \
"-port {} " \ "-port {} " \
"-precision {} " \
"-use_calib {} " \
"-reload_interval_s {} " \ "-reload_interval_s {} " \
"-resource_path {} " \ "-resource_path {} " \
"-resource_file {} " \ "-resource_file {} " \
...@@ -485,6 +499,8 @@ class Server(object): ...@@ -485,6 +499,8 @@ class Server(object):
self.max_concurrency, self.max_concurrency,
self.num_threads, self.num_threads,
self.port, self.port,
self.precision,
self.use_calib,
self.reload_interval_s, self.reload_interval_s,
self.workdir, self.workdir,
self.resource_fn, self.resource_fn,
...@@ -500,6 +516,8 @@ class Server(object): ...@@ -500,6 +516,8 @@ class Server(object):
"-max_concurrency {} " \ "-max_concurrency {} " \
"-num_threads {} " \ "-num_threads {} " \
"-port {} " \ "-port {} " \
"-precision {} " \
"-use_calib {} " \
"-reload_interval_s {} " \ "-reload_interval_s {} " \
"-resource_path {} " \ "-resource_path {} " \
"-resource_file {} " \ "-resource_file {} " \
...@@ -514,6 +532,8 @@ class Server(object): ...@@ -514,6 +532,8 @@ class Server(object):
self.max_concurrency, self.max_concurrency,
self.num_threads, self.num_threads,
self.port, self.port,
self.precision,
self.use_calib,
self.reload_interval_s, self.reload_interval_s,
self.workdir, self.workdir,
self.resource_fn, self.resource_fn,
...@@ -562,6 +582,12 @@ class MultiLangServer(object): ...@@ -562,6 +582,12 @@ class MultiLangServer(object):
def set_port(self, port): def set_port(self, port):
self.gport_ = port self.gport_ = port
def set_precision(self, precision="fp32"):
self.precision = precision
def set_use_calib(self, use_calib=False):
self.use_calib = use_calib
def set_reload_interval(self, interval): def set_reload_interval(self, interval):
self.bserver_.set_reload_interval(interval) self.bserver_.set_reload_interval(interval)
......
...@@ -27,6 +27,7 @@ import os ...@@ -27,6 +27,7 @@ import os
from paddle_serving_server import pipeline from paddle_serving_server import pipeline
from paddle_serving_server.pipeline import Op from paddle_serving_server.pipeline import Op
def port_is_available(port): def port_is_available(port):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
sock.settimeout(2) sock.settimeout(2)
...@@ -36,6 +37,7 @@ def port_is_available(port): ...@@ -36,6 +37,7 @@ def port_is_available(port):
else: else:
return False return False
class WebService(object): class WebService(object):
def __init__(self, name="default_service"): def __init__(self, name="default_service"):
self.name = name self.name = name
...@@ -63,7 +65,9 @@ class WebService(object): ...@@ -63,7 +65,9 @@ class WebService(object):
def run_service(self): def run_service(self):
self._server.run_server() self._server.run_server()
def load_model_config(self, server_config_dir_paths, client_config_path=None): def load_model_config(self,
server_config_dir_paths,
client_config_path=None):
if isinstance(server_config_dir_paths, str): if isinstance(server_config_dir_paths, str):
server_config_dir_paths = [server_config_dir_paths] server_config_dir_paths = [server_config_dir_paths]
elif isinstance(server_config_dir_paths, list): elif isinstance(server_config_dir_paths, list):
...@@ -73,14 +77,16 @@ class WebService(object): ...@@ -73,14 +77,16 @@ class WebService(object):
if os.path.isdir(single_model_config): if os.path.isdir(single_model_config):
pass pass
elif os.path.isfile(single_model_config): elif os.path.isfile(single_model_config):
raise ValueError("The input of --model should be a dir not file.") raise ValueError(
"The input of --model should be a dir not file.")
self.server_config_dir_paths = server_config_dir_paths self.server_config_dir_paths = server_config_dir_paths
from .proto import general_model_config_pb2 as m_config from .proto import general_model_config_pb2 as m_config
import google.protobuf.text_format import google.protobuf.text_format
file_path_list = [] file_path_list = []
for single_model_config in self.server_config_dir_paths: for single_model_config in self.server_config_dir_paths:
file_path_list.append( "{}/serving_server_conf.prototxt".format(single_model_config) ) file_path_list.append("{}/serving_server_conf.prototxt".format(
single_model_config))
model_conf = m_config.GeneralModelConfig() model_conf = m_config.GeneralModelConfig()
f = open(file_path_list[0], 'r') f = open(file_path_list[0], 'r')
model_conf = google.protobuf.text_format.Merge( model_conf = google.protobuf.text_format.Merge(
...@@ -109,7 +115,9 @@ class WebService(object): ...@@ -109,7 +115,9 @@ class WebService(object):
mem_optim=True, mem_optim=True,
use_lite=False, use_lite=False,
use_xpu=False, use_xpu=False,
ir_optim=False): ir_optim=False,
precision="fp32",
use_calib=False):
device = "gpu" device = "gpu"
if gpuid == -1: if gpuid == -1:
if use_lite: if use_lite:
...@@ -130,7 +138,7 @@ class WebService(object): ...@@ -130,7 +138,7 @@ class WebService(object):
infer_op_name = "general_infer" infer_op_name = "general_infer"
general_infer_op = op_maker.create(infer_op_name) general_infer_op = op_maker.create(infer_op_name)
op_seq_maker.add_op(general_infer_op) op_seq_maker.add_op(general_infer_op)
general_response_op = op_maker.create('general_response') general_response_op = op_maker.create('general_response')
op_seq_maker.add_op(general_response_op) op_seq_maker.add_op(general_response_op)
...@@ -140,13 +148,16 @@ class WebService(object): ...@@ -140,13 +148,16 @@ class WebService(object):
server.set_memory_optimize(mem_optim) server.set_memory_optimize(mem_optim)
server.set_ir_optimize(ir_optim) server.set_ir_optimize(ir_optim)
server.set_device(device) server.set_device(device)
server.set_precision(precision)
server.set_use_calib(use_calib)
if use_lite: if use_lite:
server.set_lite() server.set_lite()
if use_xpu: if use_xpu:
server.set_xpu() server.set_xpu()
server.load_model_config(self.server_config_dir_paths)#brpc Server support server_config_dir_paths server.load_model_config(self.server_config_dir_paths
) #brpc Server support server_config_dir_paths
if gpuid >= 0: if gpuid >= 0:
server.set_gpuid(gpuid) server.set_gpuid(gpuid)
server.prepare_server(workdir=workdir, port=port, device=device) server.prepare_server(workdir=workdir, port=port, device=device)
...@@ -159,6 +170,8 @@ class WebService(object): ...@@ -159,6 +170,8 @@ class WebService(object):
workdir="", workdir="",
port=9393, port=9393,
device="gpu", device="gpu",
precision="fp32",
use_calib=False,
use_lite=False, use_lite=False,
use_xpu=False, use_xpu=False,
ir_optim=False, ir_optim=False,
...@@ -188,7 +201,9 @@ class WebService(object): ...@@ -188,7 +201,9 @@ class WebService(object):
mem_optim=mem_optim, mem_optim=mem_optim,
use_lite=use_lite, use_lite=use_lite,
use_xpu=use_xpu, use_xpu=use_xpu,
ir_optim=ir_optim)) ir_optim=ir_optim,
precision=precision,
use_calib=use_calib))
else: else:
for i, gpuid in enumerate(self.gpus): for i, gpuid in enumerate(self.gpus):
self.rpc_service_list.append( self.rpc_service_list.append(
...@@ -200,7 +215,9 @@ class WebService(object): ...@@ -200,7 +215,9 @@ class WebService(object):
mem_optim=mem_optim, mem_optim=mem_optim,
use_lite=use_lite, use_lite=use_lite,
use_xpu=use_xpu, use_xpu=use_xpu,
ir_optim=ir_optim)) ir_optim=ir_optim,
precision=precision,
use_calib=use_calib))
def _launch_web_service(self): def _launch_web_service(self):
gpu_num = len(self.gpus) gpu_num = len(self.gpus)
...@@ -297,9 +314,13 @@ class WebService(object): ...@@ -297,9 +314,13 @@ class WebService(object):
# default self.gpus = [0]. # default self.gpus = [0].
if len(self.gpus) == 0: if len(self.gpus) == 0:
self.gpus.append(0) self.gpus.append(0)
self.client.load_model_config(self.server_config_dir_paths[0], use_gpu=True, gpu_id=self.gpus[0]) self.client.load_model_config(
self.server_config_dir_paths[0],
use_gpu=True,
gpu_id=self.gpus[0])
else: else:
self.client.load_model_config(self.server_config_dir_paths[0], use_gpu=False) self.client.load_model_config(
self.server_config_dir_paths[0], use_gpu=False)
def run_web_service(self): def run_web_service(self):
print("This API will be deprecated later. Please do not use it") print("This API will be deprecated later. Please do not use it")
......
...@@ -238,6 +238,8 @@ class PipelineServer(object): ...@@ -238,6 +238,8 @@ class PipelineServer(object):
"devices": "", "devices": "",
"mem_optim": True, "mem_optim": True,
"ir_optim": False, "ir_optim": False,
"precision": "fp32",
"use_calib": False,
}, },
} }
for op in self._used_op: for op in self._used_op:
...@@ -394,6 +396,8 @@ class ServerYamlConfChecker(object): ...@@ -394,6 +396,8 @@ class ServerYamlConfChecker(object):
"devices": "", "devices": "",
"mem_optim": True, "mem_optim": True,
"ir_optim": False, "ir_optim": False,
"precision": "fp32",
"use_calib": False,
} }
conf_type = { conf_type = {
"model_config": str, "model_config": str,
...@@ -403,6 +407,8 @@ class ServerYamlConfChecker(object): ...@@ -403,6 +407,8 @@ class ServerYamlConfChecker(object):
"devices": str, "devices": str,
"mem_optim": bool, "mem_optim": bool,
"ir_optim": bool, "ir_optim": bool,
"precision": str,
"use_calib": bool,
} }
conf_qualification = {"thread_num": (">=", 1), } conf_qualification = {"thread_num": (">=", 1), }
ServerYamlConfChecker.check_conf(conf, default_conf, conf_type, ServerYamlConfChecker.check_conf(conf, default_conf, conf_type,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册