提交 f9a49d85 编写于 作者: M MRXLT 提交者: GitHub

Merge pull request #196 from MRXLT/general-server-v1

bug fix && add timeline for batch predict
......@@ -145,7 +145,7 @@ std::vector<std::vector<float>> PredictorClient::predict(
int64_t preprocess_start = timeline.TimeStampUS();
// we save infer_us at fetch_result[fetch_name.size()]
fetch_result.resize(fetch_name.size() + 1);
fetch_result.resize(fetch_name.size());
_api.thrd_clear();
_predictor = _api.fetch_predictor("general_model");
......@@ -276,7 +276,11 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
if (fetch_name.size() == 0) {
return fetch_result_batch;
}
fetch_result_batch.resize(batch_size + 1);
Timer timeline;
int64_t preprocess_start = timeline.TimeStampUS();
fetch_result_batch.resize(batch_size);
int fetch_name_num = fetch_name.size();
for (int bi = 0; bi < batch_size; bi++) {
fetch_result_batch[bi].resize(fetch_name_num);
......@@ -349,13 +353,30 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
<< "itn feed value prepared";
}
int64_t preprocess_end = timeline.TimeStampUS();
int64_t client_infer_start = timeline.TimeStampUS();
Response res;
int64_t client_infer_end = 0;
int64_t postprocess_start = 0;
int64_t postprocess_end = 0;
if (FLAGS_profile_client) {
if (FLAGS_profile_server) {
req.set_profile_server(true);
}
}
res.Clear();
if (_predictor->inference(&req, &res) != 0) {
LOG(ERROR) << "failed call predictor with req: " << req.ShortDebugString();
exit(-1);
} else {
client_infer_end = timeline.TimeStampUS();
postprocess_start = client_infer_end;
for (int bi = 0; bi < batch_size; bi++) {
for (auto &name : fetch_name) {
int idx = _fetch_name_to_idx[name];
......@@ -372,8 +393,30 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
}
}
}
postprocess_end = timeline.TimeStampUS();
}
if (FLAGS_profile_client) {
std::ostringstream oss;
oss << "PROFILE\t"
<< "prepro_0:" << preprocess_start << " "
<< "prepro_1:" << preprocess_end << " "
<< "client_infer_0:" << client_infer_start << " "
<< "client_infer_1:" << client_infer_end << " ";
if (FLAGS_profile_server) {
int op_num = res.profile_time_size() / 2;
for (int i = 0; i < op_num; ++i) {
oss << "op" << i << "_0:" << res.profile_time(i * 2) << " ";
oss << "op" << i << "_1:" << res.profile_time(i * 2 + 1) << " ";
}
}
oss << "postpro_0:" << postprocess_start << " ";
oss << "postpro_1:" << postprocess_end;
fprintf(stderr, "%s\n", oss.str().c_str());
}
return fetch_result_batch;
}
......
......@@ -89,8 +89,8 @@ class Client(object):
self.client_handle_ = PredictorClient()
self.client_handle_.init(path)
read_env_flags = ["profile_client", "profile_server"]
self.client_handle_.init_gflags([sys.argv[0]] +
["--tryfromenv=" + ",".join(read_env_flags)])
self.client_handle_.init_gflags([sys.argv[
0]] + ["--tryfromenv=" + ",".join(read_env_flags)])
self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
self.feed_shapes_ = [var.shape for var in model_conf.feed_var]
......@@ -183,17 +183,12 @@ class Client(object):
fetch_names)
result_map_batch = []
for result in result_batch[:-1]:
for result in result_batch:
result_map = {}
for i, name in enumerate(fetch_names):
result_map[name] = result[i]
result_map_batch.append(result_map)
infer_time = result_batch[-1][0][0]
if profile:
return result_map_batch, infer_time
else:
return result_map_batch
def release(self):
......
......@@ -17,8 +17,10 @@ from .proto import server_configure_pb2 as server_sdk
from .proto import general_model_config_pb2 as m_config
import google.protobuf.text_format
import tarfile
import socket
import paddle_serving_server as paddle_serving_server
from version import serving_server_version
from contextlib import closing
class OpMaker(object):
......@@ -86,7 +88,6 @@ class Server(object):
self.reload_interval_s = 10
self.module_path = os.path.dirname(paddle_serving_server.__file__)
self.cur_path = os.getcwd()
self.vlog_level = 0
self.use_local_bin = False
def set_max_concurrency(self, concurrency):
......@@ -227,6 +228,8 @@ class Server(object):
os.system("mkdir {}".format(workdir))
os.system("touch {}/fluid_time_file".format(workdir))
if not self.check_port(port):
raise SystemExit("Prot {} is already used".format(port))
self._prepare_resource(workdir)
self._prepare_engine(self.model_config_path, device)
self._prepare_infer_service(port)
......@@ -242,6 +245,15 @@ class Server(object):
self._write_pb_str(resource_fn, self.resource_conf)
self._write_pb_str(model_toolkit_fn, self.model_toolkit_conf)
def check_port(self, port):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
sock.settimeout(2)
result = sock.connect_ex(('127.0.0.1', port))
if result != 0:
return True
else:
return False
def run_server(self):
# just run server with system command
# currently we do not load cube
......
......@@ -17,8 +17,10 @@ from .proto import server_configure_pb2 as server_sdk
from .proto import general_model_config_pb2 as m_config
import google.protobuf.text_format
import tarfile
import socket
import paddle_serving_server_gpu as paddle_serving_server
from version import serving_server_version
from contextlib import closing
class OpMaker(object):
......@@ -86,8 +88,8 @@ class Server(object):
self.reload_interval_s = 10
self.module_path = os.path.dirname(paddle_serving_server.__file__)
self.cur_path = os.getcwd()
self.vlog_level = 0
self.use_local_bin = False
self.gpuid = 0
def set_max_concurrency(self, concurrency):
self.max_concurrency = concurrency
......@@ -98,9 +100,6 @@ class Server(object):
def set_port(self, port):
self.port = port
def set_vlog_level(self, vlog_level):
slef.vlog_level = vlog_level
def set_reload_interval(self, interval):
self.reload_interval_s = interval
......@@ -214,6 +213,9 @@ class Server(object):
os.system("mkdir {}".format(workdir))
os.system("touch {}/fluid_time_file".format(workdir))
if not self.check_port(port):
raise SystemExit("Prot {} is already used".format(port))
self._prepare_resource(workdir)
self._prepare_engine(self.model_config_path, device)
self._prepare_infer_service(port)
......@@ -229,6 +231,15 @@ class Server(object):
self._write_pb_str(resource_fn, self.resource_conf)
self._write_pb_str(model_toolkit_fn, self.model_toolkit_conf)
def check_port(self, port):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
sock.settimeout(2)
result = sock.connect_ex(('127.0.0.1', port))
if result != 0:
return True
else:
return False
def run_server(self):
# just run server with system command
# currently we do not load cube
......@@ -247,8 +258,7 @@ class Server(object):
"-workflow_path {} " \
"-workflow_file {} " \
"-bthread_concurrency {} " \
"-gpuid {} " \
"-v {} ".format(
"-gpuid {} ".format(
self.bin_path,
self.workdir,
self.infer_service_fn,
......@@ -261,6 +271,7 @@ class Server(object):
self.workdir,
self.workflow_fn,
self.num_threads,
self.gpuid,
self.vlog_level)
self.gpuid,)
print("Going to Run Comand")
print(command)
os.system(command)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册