提交 f9a49d85 编写于 作者: M MRXLT 提交者: GitHub

Merge pull request #196 from MRXLT/general-server-v1

bug fix && add timeline for batch predict
...@@ -145,7 +145,7 @@ std::vector<std::vector<float>> PredictorClient::predict( ...@@ -145,7 +145,7 @@ std::vector<std::vector<float>> PredictorClient::predict(
int64_t preprocess_start = timeline.TimeStampUS(); int64_t preprocess_start = timeline.TimeStampUS();
// we save infer_us at fetch_result[fetch_name.size()] // we save infer_us at fetch_result[fetch_name.size()]
fetch_result.resize(fetch_name.size() + 1); fetch_result.resize(fetch_name.size());
_api.thrd_clear(); _api.thrd_clear();
_predictor = _api.fetch_predictor("general_model"); _predictor = _api.fetch_predictor("general_model");
...@@ -276,7 +276,11 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict( ...@@ -276,7 +276,11 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
if (fetch_name.size() == 0) { if (fetch_name.size() == 0) {
return fetch_result_batch; return fetch_result_batch;
} }
fetch_result_batch.resize(batch_size + 1);
Timer timeline;
int64_t preprocess_start = timeline.TimeStampUS();
fetch_result_batch.resize(batch_size);
int fetch_name_num = fetch_name.size(); int fetch_name_num = fetch_name.size();
for (int bi = 0; bi < batch_size; bi++) { for (int bi = 0; bi < batch_size; bi++) {
fetch_result_batch[bi].resize(fetch_name_num); fetch_result_batch[bi].resize(fetch_name_num);
...@@ -349,13 +353,30 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict( ...@@ -349,13 +353,30 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
<< "itn feed value prepared"; << "itn feed value prepared";
} }
int64_t preprocess_end = timeline.TimeStampUS();
int64_t client_infer_start = timeline.TimeStampUS();
Response res; Response res;
int64_t client_infer_end = 0;
int64_t postprocess_start = 0;
int64_t postprocess_end = 0;
if (FLAGS_profile_client) {
if (FLAGS_profile_server) {
req.set_profile_server(true);
}
}
res.Clear(); res.Clear();
if (_predictor->inference(&req, &res) != 0) { if (_predictor->inference(&req, &res) != 0) {
LOG(ERROR) << "failed call predictor with req: " << req.ShortDebugString(); LOG(ERROR) << "failed call predictor with req: " << req.ShortDebugString();
exit(-1); exit(-1);
} else { } else {
client_infer_end = timeline.TimeStampUS();
postprocess_start = client_infer_end;
for (int bi = 0; bi < batch_size; bi++) { for (int bi = 0; bi < batch_size; bi++) {
for (auto &name : fetch_name) { for (auto &name : fetch_name) {
int idx = _fetch_name_to_idx[name]; int idx = _fetch_name_to_idx[name];
...@@ -372,8 +393,30 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict( ...@@ -372,8 +393,30 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
} }
} }
} }
postprocess_end = timeline.TimeStampUS();
} }
if (FLAGS_profile_client) {
std::ostringstream oss;
oss << "PROFILE\t"
<< "prepro_0:" << preprocess_start << " "
<< "prepro_1:" << preprocess_end << " "
<< "client_infer_0:" << client_infer_start << " "
<< "client_infer_1:" << client_infer_end << " ";
if (FLAGS_profile_server) {
int op_num = res.profile_time_size() / 2;
for (int i = 0; i < op_num; ++i) {
oss << "op" << i << "_0:" << res.profile_time(i * 2) << " ";
oss << "op" << i << "_1:" << res.profile_time(i * 2 + 1) << " ";
}
}
oss << "postpro_0:" << postprocess_start << " ";
oss << "postpro_1:" << postprocess_end;
fprintf(stderr, "%s\n", oss.str().c_str());
}
return fetch_result_batch; return fetch_result_batch;
} }
......
...@@ -89,8 +89,8 @@ class Client(object): ...@@ -89,8 +89,8 @@ class Client(object):
self.client_handle_ = PredictorClient() self.client_handle_ = PredictorClient()
self.client_handle_.init(path) self.client_handle_.init(path)
read_env_flags = ["profile_client", "profile_server"] read_env_flags = ["profile_client", "profile_server"]
self.client_handle_.init_gflags([sys.argv[0]] + self.client_handle_.init_gflags([sys.argv[
["--tryfromenv=" + ",".join(read_env_flags)]) 0]] + ["--tryfromenv=" + ",".join(read_env_flags)])
self.feed_names_ = [var.alias_name for var in model_conf.feed_var] self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var] self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
self.feed_shapes_ = [var.shape for var in model_conf.feed_var] self.feed_shapes_ = [var.shape for var in model_conf.feed_var]
...@@ -183,18 +183,13 @@ class Client(object): ...@@ -183,18 +183,13 @@ class Client(object):
fetch_names) fetch_names)
result_map_batch = [] result_map_batch = []
for result in result_batch[:-1]: for result in result_batch:
result_map = {} result_map = {}
for i, name in enumerate(fetch_names): for i, name in enumerate(fetch_names):
result_map[name] = result[i] result_map[name] = result[i]
result_map_batch.append(result_map) result_map_batch.append(result_map)
infer_time = result_batch[-1][0][0] return result_map_batch
if profile:
return result_map_batch, infer_time
else:
return result_map_batch
def release(self): def release(self):
self.client_handle_.destroy_predictor() self.client_handle_.destroy_predictor()
...@@ -17,8 +17,10 @@ from .proto import server_configure_pb2 as server_sdk ...@@ -17,8 +17,10 @@ from .proto import server_configure_pb2 as server_sdk
from .proto import general_model_config_pb2 as m_config from .proto import general_model_config_pb2 as m_config
import google.protobuf.text_format import google.protobuf.text_format
import tarfile import tarfile
import socket
import paddle_serving_server as paddle_serving_server import paddle_serving_server as paddle_serving_server
from version import serving_server_version from version import serving_server_version
from contextlib import closing
class OpMaker(object): class OpMaker(object):
...@@ -86,7 +88,6 @@ class Server(object): ...@@ -86,7 +88,6 @@ class Server(object):
self.reload_interval_s = 10 self.reload_interval_s = 10
self.module_path = os.path.dirname(paddle_serving_server.__file__) self.module_path = os.path.dirname(paddle_serving_server.__file__)
self.cur_path = os.getcwd() self.cur_path = os.getcwd()
self.vlog_level = 0
self.use_local_bin = False self.use_local_bin = False
def set_max_concurrency(self, concurrency): def set_max_concurrency(self, concurrency):
...@@ -227,6 +228,8 @@ class Server(object): ...@@ -227,6 +228,8 @@ class Server(object):
os.system("mkdir {}".format(workdir)) os.system("mkdir {}".format(workdir))
os.system("touch {}/fluid_time_file".format(workdir)) os.system("touch {}/fluid_time_file".format(workdir))
if not self.check_port(port):
raise SystemExit("Prot {} is already used".format(port))
self._prepare_resource(workdir) self._prepare_resource(workdir)
self._prepare_engine(self.model_config_path, device) self._prepare_engine(self.model_config_path, device)
self._prepare_infer_service(port) self._prepare_infer_service(port)
...@@ -242,6 +245,15 @@ class Server(object): ...@@ -242,6 +245,15 @@ class Server(object):
self._write_pb_str(resource_fn, self.resource_conf) self._write_pb_str(resource_fn, self.resource_conf)
self._write_pb_str(model_toolkit_fn, self.model_toolkit_conf) self._write_pb_str(model_toolkit_fn, self.model_toolkit_conf)
def check_port(self, port):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
sock.settimeout(2)
result = sock.connect_ex(('127.0.0.1', port))
if result != 0:
return True
else:
return False
def run_server(self): def run_server(self):
# just run server with system command # just run server with system command
# currently we do not load cube # currently we do not load cube
......
...@@ -17,8 +17,10 @@ from .proto import server_configure_pb2 as server_sdk ...@@ -17,8 +17,10 @@ from .proto import server_configure_pb2 as server_sdk
from .proto import general_model_config_pb2 as m_config from .proto import general_model_config_pb2 as m_config
import google.protobuf.text_format import google.protobuf.text_format
import tarfile import tarfile
import socket
import paddle_serving_server_gpu as paddle_serving_server import paddle_serving_server_gpu as paddle_serving_server
from version import serving_server_version from version import serving_server_version
from contextlib import closing
class OpMaker(object): class OpMaker(object):
...@@ -86,8 +88,8 @@ class Server(object): ...@@ -86,8 +88,8 @@ class Server(object):
self.reload_interval_s = 10 self.reload_interval_s = 10
self.module_path = os.path.dirname(paddle_serving_server.__file__) self.module_path = os.path.dirname(paddle_serving_server.__file__)
self.cur_path = os.getcwd() self.cur_path = os.getcwd()
self.vlog_level = 0
self.use_local_bin = False self.use_local_bin = False
self.gpuid = 0
def set_max_concurrency(self, concurrency): def set_max_concurrency(self, concurrency):
self.max_concurrency = concurrency self.max_concurrency = concurrency
...@@ -98,9 +100,6 @@ class Server(object): ...@@ -98,9 +100,6 @@ class Server(object):
def set_port(self, port): def set_port(self, port):
self.port = port self.port = port
def set_vlog_level(self, vlog_level):
slef.vlog_level = vlog_level
def set_reload_interval(self, interval): def set_reload_interval(self, interval):
self.reload_interval_s = interval self.reload_interval_s = interval
...@@ -214,6 +213,9 @@ class Server(object): ...@@ -214,6 +213,9 @@ class Server(object):
os.system("mkdir {}".format(workdir)) os.system("mkdir {}".format(workdir))
os.system("touch {}/fluid_time_file".format(workdir)) os.system("touch {}/fluid_time_file".format(workdir))
if not self.check_port(port):
raise SystemExit("Prot {} is already used".format(port))
self._prepare_resource(workdir) self._prepare_resource(workdir)
self._prepare_engine(self.model_config_path, device) self._prepare_engine(self.model_config_path, device)
self._prepare_infer_service(port) self._prepare_infer_service(port)
...@@ -229,6 +231,15 @@ class Server(object): ...@@ -229,6 +231,15 @@ class Server(object):
self._write_pb_str(resource_fn, self.resource_conf) self._write_pb_str(resource_fn, self.resource_conf)
self._write_pb_str(model_toolkit_fn, self.model_toolkit_conf) self._write_pb_str(model_toolkit_fn, self.model_toolkit_conf)
def check_port(self, port):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
sock.settimeout(2)
result = sock.connect_ex(('127.0.0.1', port))
if result != 0:
return True
else:
return False
def run_server(self): def run_server(self):
# just run server with system command # just run server with system command
# currently we do not load cube # currently we do not load cube
...@@ -247,8 +258,7 @@ class Server(object): ...@@ -247,8 +258,7 @@ class Server(object):
"-workflow_path {} " \ "-workflow_path {} " \
"-workflow_file {} " \ "-workflow_file {} " \
"-bthread_concurrency {} " \ "-bthread_concurrency {} " \
"-gpuid {} " \ "-gpuid {} ".format(
"-v {} ".format(
self.bin_path, self.bin_path,
self.workdir, self.workdir,
self.infer_service_fn, self.infer_service_fn,
...@@ -261,6 +271,7 @@ class Server(object): ...@@ -261,6 +271,7 @@ class Server(object):
self.workdir, self.workdir,
self.workflow_fn, self.workflow_fn,
self.num_threads, self.num_threads,
self.gpuid, self.gpuid,)
self.vlog_level) print("Going to Run Comand")
print(command)
os.system(command) os.system(command)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册