提交 e10ca77a 编写于 作者: B barrierye

fix code conflict

...@@ -108,7 +108,6 @@ void PredictorClient::set_predictor_conf(const std::string &conf_path, ...@@ -108,7 +108,6 @@ void PredictorClient::set_predictor_conf(const std::string &conf_path,
_predictor_path = conf_path; _predictor_path = conf_path;
_predictor_conf = conf_file; _predictor_conf = conf_file;
} }
int PredictorClient::destroy_predictor() { int PredictorClient::destroy_predictor() {
_api.thrd_finalize(); _api.thrd_finalize();
_api.destroy(); _api.destroy();
...@@ -160,6 +159,7 @@ int PredictorClient::batch_predict( ...@@ -160,6 +159,7 @@ int PredictorClient::batch_predict(
VLOG(2) << "fetch general model predictor done."; VLOG(2) << "fetch general model predictor done.";
VLOG(2) << "float feed name size: " << float_feed_name.size(); VLOG(2) << "float feed name size: " << float_feed_name.size();
VLOG(2) << "int feed name size: " << int_feed_name.size(); VLOG(2) << "int feed name size: " << int_feed_name.size();
VLOG(2) << "max body size : " << brpc::fLU64::FLAGS_max_body_size;
Request req; Request req;
for (auto &name : fetch_name) { for (auto &name : fetch_name) {
req.add_fetch_var_names(name); req.add_fetch_var_names(name);
...@@ -179,12 +179,16 @@ int PredictorClient::batch_predict( ...@@ -179,12 +179,16 @@ int PredictorClient::batch_predict(
tensor_vec.push_back(inst->add_tensor_array()); tensor_vec.push_back(inst->add_tensor_array());
} }
VLOG(2) << "batch [" << bi << "] int_feed_name and float_feed_name" VLOG(2) << "batch [" << bi << "] int_feed_name and float_feed_name "
<< "prepared"; << "prepared";
int vec_idx = 0; int vec_idx = 0;
VLOG(2) << "tensor_vec size " << tensor_vec.size() << " float shape "
<< float_shape.size();
for (auto &name : float_feed_name) { for (auto &name : float_feed_name) {
int idx = _feed_name_to_idx[name]; int idx = _feed_name_to_idx[name];
Tensor *tensor = tensor_vec[idx]; Tensor *tensor = tensor_vec[idx];
VLOG(2) << "prepare float feed " << name << " shape size "
<< float_shape[vec_idx].size();
for (uint32_t j = 0; j < float_shape[vec_idx].size(); ++j) { for (uint32_t j = 0; j < float_shape[vec_idx].size(); ++j) {
tensor->add_shape(float_shape[vec_idx][j]); tensor->add_shape(float_shape[vec_idx][j]);
} }
...@@ -202,6 +206,8 @@ int PredictorClient::batch_predict( ...@@ -202,6 +206,8 @@ int PredictorClient::batch_predict(
for (auto &name : int_feed_name) { for (auto &name : int_feed_name) {
int idx = _feed_name_to_idx[name]; int idx = _feed_name_to_idx[name];
Tensor *tensor = tensor_vec[idx]; Tensor *tensor = tensor_vec[idx];
VLOG(2) << "prepare int feed " << name << " shape size "
<< int_shape[vec_idx].size();
for (uint32_t j = 0; j < int_shape[vec_idx].size(); ++j) { for (uint32_t j = 0; j < int_shape[vec_idx].size(); ++j) {
tensor->add_shape(int_shape[vec_idx][j]); tensor->add_shape(int_shape[vec_idx][j]);
} }
...@@ -250,8 +256,11 @@ int PredictorClient::batch_predict( ...@@ -250,8 +256,11 @@ int PredictorClient::batch_predict(
model.set_engine_name(output.engine_name()); model.set_engine_name(output.engine_name());
for (auto &name : fetch_name) { for (auto &name : fetch_name) {
int idx = _fetch_name_to_idx[name]; // int idx = _fetch_name_to_idx[name];
int idx = 0;
int shape_size = output.insts(0).tensor_array(idx).shape_size(); int shape_size = output.insts(0).tensor_array(idx).shape_size();
VLOG(2) << "fetch var " << name << " index " << idx << " shape size "
<< shape_size;
model._shape_map[name].resize(shape_size); model._shape_map[name].resize(shape_size);
for (int i = 0; i < shape_size; ++i) { for (int i = 0; i < shape_size; ++i) {
model._shape_map[name][i] = model._shape_map[name][i] =
...@@ -264,11 +273,14 @@ int PredictorClient::batch_predict( ...@@ -264,11 +273,14 @@ int PredictorClient::batch_predict(
model._lod_map[name][i] = output.insts(0).tensor_array(idx).lod(i); model._lod_map[name][i] = output.insts(0).tensor_array(idx).lod(i);
} }
} }
idx += 1;
} }
for (auto &name : fetch_name) { for (auto &name : fetch_name) {
int idx = _fetch_name_to_idx[name]; // int idx = _fetch_name_to_idx[name];
int idx = 0;
if (_fetch_name_to_type[name] == 0) { if (_fetch_name_to_type[name] == 0) {
VLOG(2) << "ferch var " << name << "type int";
model._int64_value_map[name].resize( model._int64_value_map[name].resize(
output.insts(0).tensor_array(idx).int64_data_size()); output.insts(0).tensor_array(idx).int64_data_size());
int size = output.insts(0).tensor_array(idx).int64_data_size(); int size = output.insts(0).tensor_array(idx).int64_data_size();
...@@ -277,6 +289,7 @@ int PredictorClient::batch_predict( ...@@ -277,6 +289,7 @@ int PredictorClient::batch_predict(
output.insts(0).tensor_array(idx).int64_data(i); output.insts(0).tensor_array(idx).int64_data(i);
} }
} else { } else {
VLOG(2) << "fetch var " << name << "type float";
model._float_value_map[name].resize( model._float_value_map[name].resize(
output.insts(0).tensor_array(idx).float_data_size()); output.insts(0).tensor_array(idx).float_data_size());
int size = output.insts(0).tensor_array(idx).float_data_size(); int size = output.insts(0).tensor_array(idx).float_data_size();
...@@ -285,6 +298,7 @@ int PredictorClient::batch_predict( ...@@ -285,6 +298,7 @@ int PredictorClient::batch_predict(
output.insts(0).tensor_array(idx).float_data(i); output.insts(0).tensor_array(idx).float_data(i);
} }
} }
idx += 1;
} }
predict_res_batch.add_model_res(std::move(model)); predict_res_batch.add_model_res(std::move(model));
} }
......
...@@ -58,6 +58,8 @@ int GeneralResponseOp::inference() { ...@@ -58,6 +58,8 @@ int GeneralResponseOp::inference() {
std::shared_ptr<PaddleGeneralModelConfig> model_config = std::shared_ptr<PaddleGeneralModelConfig> model_config =
resource.get_general_model_config(); resource.get_general_model_config();
VLOG(2) << "max body size : " << brpc::fLU64::FLAGS_max_body_size;
std::vector<int> fetch_index; std::vector<int> fetch_index;
fetch_index.resize(req->fetch_var_names_size()); fetch_index.resize(req->fetch_var_names_size());
for (int i = 0; i < req->fetch_var_names_size(); ++i) { for (int i = 0; i < req->fetch_var_names_size(); ++i) {
......
...@@ -111,7 +111,9 @@ class Client(object): ...@@ -111,7 +111,9 @@ class Client(object):
self.result_handle_ = PredictorRes() self.result_handle_ = PredictorRes()
self.client_handle_ = PredictorClient() self.client_handle_ = PredictorClient()
self.client_handle_.init(path) self.client_handle_.init(path)
read_env_flags = ["profile_client", "profile_server"] if "FLAGS_max_body_size" not in os.environ:
os.environ["FLAGS_max_body_size"] = str(512 * 1024 * 1024)
read_env_flags = ["profile_client", "profile_server", "max_body_size"]
self.client_handle_.init_gflags([sys.argv[ self.client_handle_.init_gflags([sys.argv[
0]] + ["--tryfromenv=" + ",".join(read_env_flags)]) 0]] + ["--tryfromenv=" + ",".join(read_env_flags)])
self.feed_names_ = [var.alias_name for var in model_conf.feed_var] self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
...@@ -223,8 +225,6 @@ class Client(object): ...@@ -223,8 +225,6 @@ class Client(object):
for i, feed_i in enumerate(feed_batch): for i, feed_i in enumerate(feed_batch):
int_slot = [] int_slot = []
float_slot = [] float_slot = []
int_shape = []
float_shape = []
for key in feed_i: for key in feed_i:
if key not in self.feed_names_: if key not in self.feed_names_:
raise ValueError("Wrong feed name: {}.".format(key)) raise ValueError("Wrong feed name: {}.".format(key))
......
...@@ -139,6 +139,7 @@ class Server(object): ...@@ -139,6 +139,7 @@ class Server(object):
self.num_threads = 4 self.num_threads = 4
self.port = 8080 self.port = 8080
self.reload_interval_s = 10 self.reload_interval_s = 10
self.max_body_size = 64 * 1024 * 1024
self.module_path = os.path.dirname(paddle_serving_server.__file__) self.module_path = os.path.dirname(paddle_serving_server.__file__)
self.cur_path = os.getcwd() self.cur_path = os.getcwd()
self.use_local_bin = False self.use_local_bin = False
...@@ -151,6 +152,14 @@ class Server(object): ...@@ -151,6 +152,14 @@ class Server(object):
def set_num_threads(self, threads): def set_num_threads(self, threads):
self.num_threads = threads self.num_threads = threads
def set_max_body_size(self, body_size):
if body_size >= self.max_body_size:
self.max_body_size = body_size
else:
print(
"max_body_size is less than default value, will use default value in service."
)
def set_port(self, port): def set_port(self, port):
self.port = port self.port = port
...@@ -383,7 +392,8 @@ class Server(object): ...@@ -383,7 +392,8 @@ class Server(object):
"-resource_file {} " \ "-resource_file {} " \
"-workflow_path {} " \ "-workflow_path {} " \
"-workflow_file {} " \ "-workflow_file {} " \
"-bthread_concurrency {} ".format( "-bthread_concurrency {} " \
"-max_body_size {} ".format(
self.bin_path, self.bin_path,
self.workdir, self.workdir,
self.infer_service_fn, self.infer_service_fn,
...@@ -395,7 +405,8 @@ class Server(object): ...@@ -395,7 +405,8 @@ class Server(object):
self.resource_fn, self.resource_fn,
self.workdir, self.workdir,
self.workflow_fn, self.workflow_fn,
self.num_threads) self.num_threads,
self.max_body_size)
print("Going to Run Command") print("Going to Run Command")
print(command) print(command)
os.system(command) os.system(command)
...@@ -91,6 +91,7 @@ class Monitor(object): ...@@ -91,6 +91,7 @@ class Monitor(object):
model_name)) model_name))
return model_name return model_name
tar_model_path = os.path.join(local_tmp_path, model_name) tar_model_path = os.path.join(local_tmp_path, model_name)
_LOGGER.info("try to unpack remote file({})".format(tar_model_path))
if not tarfile.is_tarfile(tar_model_path): if not tarfile.is_tarfile(tar_model_path):
raise Exception('not a tar packaged file type. {}'.format( raise Exception('not a tar packaged file type. {}'.format(
self._check_param_help('remote_model_name', model_name))) self._check_param_help('remote_model_name', model_name)))
...@@ -105,10 +106,11 @@ class Monitor(object): ...@@ -105,10 +106,11 @@ class Monitor(object):
self._check_param_help('local_tmp_path', local_tmp_path))) self._check_param_help('local_tmp_path', local_tmp_path)))
finally: finally:
os.remove(tar_model_path) os.remove(tar_model_path)
_LOGGER.debug('remove packed file({}).'.format(model_name)) _LOGGER.debug('remove packed file({}).'.format(tar_model_path))
_LOGGER.info('using unpacked filename: {}.'.format( _LOGGER.info('using unpacked filename: {}.'.format(
unpacked_filename)) unpacked_filename))
if not os.path.exists(unpacked_filename): if not os.path.exists(
os.path.join(local_tmp_path, unpacked_filename)):
raise Exception('file not exist. {}'.format( raise Exception('file not exist. {}'.format(
self._check_param_help('unpacked_filename', self._check_param_help('unpacked_filename',
unpacked_filename))) unpacked_filename)))
...@@ -124,13 +126,14 @@ class Monitor(object): ...@@ -124,13 +126,14 @@ class Monitor(object):
'_local_tmp_path', '_interval' '_local_tmp_path', '_interval'
] ]
self._print_params(params) self._print_params(params)
if not os.path.exists(self._local_tmp_path): local_tmp_path = os.path.join(self._local_path, self._local_tmp_path)
_LOGGER.info('mkdir: {}'.format(self._local_tmp_path)) _LOGGER.info('local_tmp_path: {}'.format(local_tmp_path))
os.makedirs(self._local_tmp_path) if not os.path.exists(local_tmp_path):
_LOGGER.info('mkdir: {}'.format(local_tmp_path))
os.makedirs(local_tmp_path)
while True: while True:
[flag, timestamp] = self._exist_remote_file( [flag, timestamp] = self._exist_remote_file(
self._remote_path, self._remote_donefile_name, self._remote_path, self._remote_donefile_name, local_tmp_path)
self._local_tmp_path)
if flag: if flag:
if self._remote_donefile_timestamp is None or \ if self._remote_donefile_timestamp is None or \
timestamp != self._remote_donefile_timestamp: timestamp != self._remote_donefile_timestamp:
...@@ -139,15 +142,15 @@ class Monitor(object): ...@@ -139,15 +142,15 @@ class Monitor(object):
self._remote_donefile_timestamp = timestamp self._remote_donefile_timestamp = timestamp
self._pull_remote_dir(self._remote_path, self._pull_remote_dir(self._remote_path,
self._remote_model_name, self._remote_model_name,
self._local_tmp_path) local_tmp_path)
_LOGGER.info('pull remote model({}).'.format( _LOGGER.info('pull remote model({}).'.format(
self._remote_model_name)) self._remote_model_name))
unpacked_filename = self._decompress_model_file( unpacked_filename = self._decompress_model_file(
self._local_tmp_path, self._remote_model_name, local_tmp_path, self._remote_model_name,
self._unpacked_filename) self._unpacked_filename)
self._update_local_model( self._update_local_model(local_tmp_path, unpacked_filename,
self._local_tmp_path, unpacked_filename, self._local_path,
self._local_path, self._local_model_name) self._local_model_name)
_LOGGER.info('update local model({}).'.format( _LOGGER.info('update local model({}).'.format(
self._local_model_name)) self._local_model_name))
self._update_local_donefile(self._local_path, self._update_local_donefile(self._local_path,
...@@ -220,7 +223,12 @@ class HadoopMonitor(Monitor): ...@@ -220,7 +223,12 @@ class HadoopMonitor(Monitor):
local_dirpath = os.path.join(local_tmp_path, dirname) local_dirpath = os.path.join(local_tmp_path, dirname)
if os.path.exists(local_dirpath): if os.path.exists(local_dirpath):
_LOGGER.info('remove old temporary model file({}).'.format(dirname)) _LOGGER.info('remove old temporary model file({}).'.format(dirname))
shutil.rmtree(local_dirpath) if self._unpacked_filename is None:
# the remote file is model folder.
shutil.rmtree(local_dirpath)
else:
# the remote file is a packed model file
os.remove(local_dirpath)
remote_dirpath = os.path.join(remote_path, dirname) remote_dirpath = os.path.join(remote_path, dirname)
cmd = '{} -get {} {} 2>/dev/null'.format(self._cmd_prefix, cmd = '{} -get {} {} 2>/dev/null'.format(self._cmd_prefix,
remote_dirpath, local_dirpath) remote_dirpath, local_dirpath)
...@@ -301,8 +309,8 @@ class FTPMonitor(Monitor): ...@@ -301,8 +309,8 @@ class FTPMonitor(Monitor):
os.path.join(remote_path, remote_dirname), name, os.path.join(remote_path, remote_dirname), name,
os.path.join(local_tmp_path, remote_dirname), overwrite) os.path.join(local_tmp_path, remote_dirname), overwrite)
else: else:
self._download_remote_file(remote_dirname, name, self._download_remote_file(remote_dirpath, name,
local_tmp_path, overwrite) local_dirpath, overwrite)
except ftplib.error_perm: except ftplib.error_perm:
_LOGGER.debug('{} is file.'.format(remote_dirname)) _LOGGER.debug('{} is file.'.format(remote_dirname))
self._download_remote_file(remote_path, remote_dirname, self._download_remote_file(remote_path, remote_dirname,
...@@ -325,17 +333,17 @@ class GeneralMonitor(Monitor): ...@@ -325,17 +333,17 @@ class GeneralMonitor(Monitor):
def _get_local_file_timestamp(self, filename): def _get_local_file_timestamp(self, filename):
return os.path.getmtime(filename) return os.path.getmtime(filename)
def _exist_remote_file(self, path, filename, local_tmp_path): def _exist_remote_file(self, remote_path, filename, local_tmp_path):
remote_filepath = os.path.join(path, filename) remote_filepath = os.path.join(remote_path, filename)
url = '{}/{}'.format(self._general_host, remote_filepath) url = '{}/{}'.format(self._general_host, remote_filepath)
_LOGGER.debug('remote file url: {}'.format(url)) _LOGGER.debug('remote file url: {}'.format(url))
cmd = 'wget -N -P {} {} &>/dev/null'.format(local_tmp_path, url) # only for check donefile, which is not a folder.
cmd = 'wget -nd -N -P {} {} &>/dev/null'.format(local_tmp_path, url)
_LOGGER.debug('wget cmd: {}'.format(cmd)) _LOGGER.debug('wget cmd: {}'.format(cmd))
if os.system(cmd) != 0: if os.system(cmd) != 0:
_LOGGER.debug('remote file({}) not exist.'.format(filename)) _LOGGER.debug('remote file({}) not exist.'.format(remote_filepath))
return [False, None] return [False, None]
else: else:
_LOGGER.debug('download remote file({}).'.format(filename))
timestamp = self._get_local_file_timestamp( timestamp = self._get_local_file_timestamp(
os.path.join(local_tmp_path, filename)) os.path.join(local_tmp_path, filename))
return [True, timestamp] return [True, timestamp]
...@@ -344,7 +352,13 @@ class GeneralMonitor(Monitor): ...@@ -344,7 +352,13 @@ class GeneralMonitor(Monitor):
remote_dirpath = os.path.join(remote_path, dirname) remote_dirpath = os.path.join(remote_path, dirname)
url = '{}/{}'.format(self._general_host, remote_dirpath) url = '{}/{}'.format(self._general_host, remote_dirpath)
_LOGGER.debug('remote file url: {}'.format(url)) _LOGGER.debug('remote file url: {}'.format(url))
cmd = 'wget -nH -r -P {} {} &>/dev/null'.format(local_tmp_path, url) if self._unpacked_filename is None:
# the remote file is model folder.
cmd = 'wget -nH -r -P {} {} &>/dev/null'.format(
os.path.join(local_tmp_path, dirname), url)
else:
# the remote file is a packed model file
cmd = 'wget -nd -N -P {} {} &>/dev/null'.format(local_tmp_path, url)
_LOGGER.debug('wget cmd: {}'.format(cmd)) _LOGGER.debug('wget cmd: {}'.format(cmd))
if os.system(cmd) != 0: if os.system(cmd) != 0:
raise Exception('pull remote dir failed. {}'.format( raise Exception('pull remote dir failed. {}'.format(
...@@ -352,7 +366,11 @@ class GeneralMonitor(Monitor): ...@@ -352,7 +366,11 @@ class GeneralMonitor(Monitor):
def parse_args(): def parse_args():
''' parse args. ''' """ parse args.
Returns:
parser.parse_args().
"""
parser = argparse.ArgumentParser(description="Monitor") parser = argparse.ArgumentParser(description="Monitor")
parser.add_argument( parser.add_argument(
"--type", type=str, default='general', help="Type of remote server") "--type", type=str, default='general', help="Type of remote server")
......
...@@ -41,6 +41,11 @@ def parse_args(): # pylint: disable=doc-string-missing ...@@ -41,6 +41,11 @@ def parse_args(): # pylint: disable=doc-string-missing
"--device", type=str, default="cpu", help="Type of device") "--device", type=str, default="cpu", help="Type of device")
parser.add_argument( parser.add_argument(
"--mem_optim", type=bool, default=False, help="Memory optimize") "--mem_optim", type=bool, default=False, help="Memory optimize")
parser.add_argument(
"--max_body_size",
type=int,
default=512 * 1024 * 1024,
help="Limit sizes of messages")
return parser.parse_args() return parser.parse_args()
...@@ -52,6 +57,7 @@ def start_standard_model(): # pylint: disable=doc-string-missing ...@@ -52,6 +57,7 @@ def start_standard_model(): # pylint: disable=doc-string-missing
workdir = args.workdir workdir = args.workdir
device = args.device device = args.device
mem_optim = args.mem_optim mem_optim = args.mem_optim
max_body_size = args.max_body_size
if model == "": if model == "":
print("You must specify your serving model") print("You must specify your serving model")
...@@ -72,6 +78,7 @@ def start_standard_model(): # pylint: disable=doc-string-missing ...@@ -72,6 +78,7 @@ def start_standard_model(): # pylint: disable=doc-string-missing
server.set_op_sequence(op_seq_maker.get_op_sequence()) server.set_op_sequence(op_seq_maker.get_op_sequence())
server.set_num_threads(thread_num) server.set_num_threads(thread_num)
server.set_memory_optimize(mem_optim) server.set_memory_optimize(mem_optim)
server.set_max_body_size(max_body_size)
server.load_model_config(model) server.load_model_config(model)
server.prepare_server(workdir=workdir, port=port, device=device) server.prepare_server(workdir=workdir, port=port, device=device)
......
...@@ -47,6 +47,11 @@ def serve_args(): ...@@ -47,6 +47,11 @@ def serve_args():
"--name", type=str, default="None", help="Default service name") "--name", type=str, default="None", help="Default service name")
parser.add_argument( parser.add_argument(
"--mem_optim", type=bool, default=False, help="Memory optimize") "--mem_optim", type=bool, default=False, help="Memory optimize")
parser.add_argument(
"--max_body_size",
type=int,
default=512 * 1024 * 1024,
help="Limit sizes of messages")
return parser.parse_args() return parser.parse_args()
...@@ -163,6 +168,7 @@ class Server(object): ...@@ -163,6 +168,7 @@ class Server(object):
self.num_threads = 4 self.num_threads = 4
self.port = 8080 self.port = 8080
self.reload_interval_s = 10 self.reload_interval_s = 10
self.max_body_size = 64 * 1024 * 1024
self.module_path = os.path.dirname(paddle_serving_server.__file__) self.module_path = os.path.dirname(paddle_serving_server.__file__)
self.cur_path = os.getcwd() self.cur_path = os.getcwd()
self.check_cuda() self.check_cuda()
...@@ -176,6 +182,14 @@ class Server(object): ...@@ -176,6 +182,14 @@ class Server(object):
def set_num_threads(self, threads): def set_num_threads(self, threads):
self.num_threads = threads self.num_threads = threads
def set_max_body_size(self, body_size):
if body_size >= self.max_body_size:
self.max_body_size = body_size
else:
print(
"max_body_size is less than default value, will use default value in service."
)
def set_port(self, port): def set_port(self, port):
self.port = port self.port = port
...@@ -414,7 +428,8 @@ class Server(object): ...@@ -414,7 +428,8 @@ class Server(object):
"-workflow_path {} " \ "-workflow_path {} " \
"-workflow_file {} " \ "-workflow_file {} " \
"-bthread_concurrency {} " \ "-bthread_concurrency {} " \
"-gpuid {} ".format( "-gpuid {} " \
"-max_body_size {} ".format(
self.bin_path, self.bin_path,
self.workdir, self.workdir,
self.infer_service_fn, self.infer_service_fn,
...@@ -427,7 +442,8 @@ class Server(object): ...@@ -427,7 +442,8 @@ class Server(object):
self.workdir, self.workdir,
self.workflow_fn, self.workflow_fn,
self.num_threads, self.num_threads,
self.gpuid,) self.gpuid,
self.max_body_size)
print("Going to Run Comand") print("Going to Run Comand")
print(command) print(command)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
Start monitor with one line command
Example:
python -m paddle_serving_server.monitor
"""
import os
import time
import argparse
import commands
import datetime
import shutil
import tarfile
import logging
_LOGGER = logging.getLogger(__name__)
class Monitor(object):
'''
Monitor base class. It is used to monitor the remote model, pull and update the local model.
'''
def __init__(self, interval):
self._remote_path = None
self._remote_model_name = None
self._remote_donefile_name = None
self._local_path = None
self._local_model_name = None
self._local_timestamp_file = None
self._interval = interval
self._remote_donefile_timestamp = None
self._local_tmp_path = None
self._unpacked_filename = None
def set_remote_path(self, remote_path):
self._remote_path = remote_path
def set_remote_model_name(self, model_name):
self._remote_model_name = model_name
def set_remote_donefile_name(self, donefile_name):
self._remote_donefile_name = donefile_name
def set_local_path(self, local_path):
self._local_path = local_path
def set_local_model_name(self, model_name):
self._local_model_name = model_name
def set_local_timestamp_file(self, timestamp_file):
self._local_timestamp_file = timestamp_file
def set_local_tmp_path(self, tmp_path):
self._local_tmp_path = tmp_path
def set_unpacked_filename(self, unpacked_filename):
self._unpacked_filename = unpacked_filename
def _check_param_help(self, param_name, param_value):
return "Please check the {}({}) parameter.".format(param_name,
param_value)
def _check_params(self, params):
for param in params:
if getattr(self, param, None) is None:
raise Exception('{} not set.'.format(param))
def _print_params(self, params_name):
self._check_params(params_name)
for name in params_name:
_LOGGER.info('{}: {}'.format(name, getattr(self, name)))
def _decompress_model_file(self, local_tmp_path, model_name,
unpacked_filename):
if unpacked_filename is None:
_LOGGER.debug('remote file({}) is already unpacked.'.format(
model_name))
return model_name
tar_model_path = os.path.join(local_tmp_path, model_name)
_LOGGER.info("try to unpack remote file({})".format(tar_model_path))
if not tarfile.is_tarfile(tar_model_path):
raise Exception('not a tar packaged file type. {}'.format(
self._check_param_help('remote_model_name', model_name)))
try:
_LOGGER.info('unpack remote file({}).'.format(model_name))
tar = tarfile.open(tar_model_path)
tar.extractall(local_tmp_path)
tar.close()
except:
raise Exception(
'Decompressing failed, maybe no disk space left. {}'.foemat(
self._check_param_help('local_tmp_path', local_tmp_path)))
finally:
os.remove(tar_model_path)
_LOGGER.debug('remove packed file({}).'.format(tar_model_path))
_LOGGER.info('using unpacked filename: {}.'.format(
unpacked_filename))
if not os.path.exists(
os.path.join(local_tmp_path, unpacked_filename)):
raise Exception('file not exist. {}'.format(
self._check_param_help('unpacked_filename',
unpacked_filename)))
return unpacked_filename
def run(self):
'''
Monitor the remote model by polling and update the local model.
'''
params = [
'_remote_path', '_remote_model_name', '_remote_donefile_name',
'_local_model_name', '_local_path', '_local_timestamp_file',
'_local_tmp_path', '_interval'
]
self._print_params(params)
local_tmp_path = os.path.join(self._local_path, self._local_tmp_path)
_LOGGER.info('local_tmp_path: {}'.format(local_tmp_path))
if not os.path.exists(local_tmp_path):
_LOGGER.info('mkdir: {}'.format(local_tmp_path))
os.makedirs(local_tmp_path)
while True:
[flag, timestamp] = self._exist_remote_file(
self._remote_path, self._remote_donefile_name, local_tmp_path)
if flag:
if self._remote_donefile_timestamp is None or \
timestamp != self._remote_donefile_timestamp:
_LOGGER.info('doneilfe({}) changed.'.format(
self._remote_donefile_name))
self._remote_donefile_timestamp = timestamp
self._pull_remote_dir(self._remote_path,
self._remote_model_name,
local_tmp_path)
_LOGGER.info('pull remote model({}).'.format(
self._remote_model_name))
unpacked_filename = self._decompress_model_file(
local_tmp_path, self._remote_model_name,
self._unpacked_filename)
self._update_local_model(local_tmp_path, unpacked_filename,
self._local_path,
self._local_model_name)
_LOGGER.info('update local model({}).'.format(
self._local_model_name))
self._update_local_donefile(self._local_path,
self._local_model_name,
self._local_timestamp_file)
_LOGGER.info('update model timestamp({}).'.format(
self._local_timestamp_file))
else:
_LOGGER.info('remote({}) has no donefile.'.format(
self._remote_path))
_LOGGER.info('sleep {}s.'.format(self._interval))
time.sleep(self._interval)
def _exist_remote_file(self, path, filename, local_tmp_path):
raise Exception('This function must be inherited.')
def _pull_remote_dir(self, remote_path, dirname, local_tmp_path):
raise Exception('This function must be inherited.')
def _update_local_model(self, local_tmp_path, remote_model_name, local_path,
local_model_name):
tmp_model_path = os.path.join(local_tmp_path, remote_model_name)
local_model_path = os.path.join(local_path, local_model_name)
cmd = 'cp -r {}/* {}'.format(tmp_model_path, local_model_path)
_LOGGER.debug('update model cmd: {}'.format(cmd))
if os.system(cmd) != 0:
raise Exception('update local model failed.')
def _update_local_donefile(self, local_path, local_model_name,
local_timestamp_file):
donefile_path = os.path.join(local_path, local_model_name,
local_timestamp_file)
cmd = 'touch {}'.format(donefile_path)
_LOGGER.debug('update timestamp cmd: {}'.format(cmd))
if os.system(cmd) != 0:
raise Exception('update local donefile failed.')
class HadoopMonitor(Monitor):
''' Monitor HDFS or AFS by Hadoop-client. '''
def __init__(self, hadoop_bin, fs_name='', fs_ugi='', interval=10):
super(HadoopMonitor, self).__init__(interval)
self._hadoop_bin = hadoop_bin
self._fs_name = fs_name
self._fs_ugi = fs_ugi
self._print_params(['_hadoop_bin', '_fs_name', '_fs_ugi'])
self._cmd_prefix = '{} fs '.format(self._hadoop_bin)
if self._fs_name:
self._cmd_prefix += '-D fs.default.name={} '.format(self._fs_name)
if self._fs_ugi:
self._cmd_prefix += '-D hadoop.job.ugi={} '.format(self._fs_ugi)
_LOGGER.info('Hadoop prefix cmd: {}'.format(self._cmd_prefix))
def _exist_remote_file(self, path, filename, local_tmp_path):
remote_filepath = os.path.join(path, filename)
cmd = '{} -ls {} 2>/dev/null'.format(self._cmd_prefix, remote_filepath)
_LOGGER.debug('check cmd: {}'.format(cmd))
[status, output] = commands.getstatusoutput(cmd)
_LOGGER.debug('resp: {}'.format(output))
if status == 0:
[_, _, _, _, _, mdate, mtime, _] = output.split('\n')[-1].split()
timestr = mdate + mtime
return [True, timestr]
else:
return [False, None]
def _pull_remote_dir(self, remote_path, dirname, local_tmp_path):
# remove old file before pull remote dir
local_dirpath = os.path.join(local_tmp_path, dirname)
if os.path.exists(local_dirpath):
_LOGGER.info('remove old temporary model file({}).'.format(dirname))
if self._unpacked_filename is None:
# the remote file is model folder.
shutil.rmtree(local_dirpath)
else:
# the remote file is a packed model file
os.remove(local_dirpath)
remote_dirpath = os.path.join(remote_path, dirname)
cmd = '{} -get {} {} 2>/dev/null'.format(self._cmd_prefix,
remote_dirpath, local_dirpath)
_LOGGER.debug('pull cmd: {}'.format(cmd))
if os.system(cmd) != 0:
raise Exception('pull remote dir failed. {}'.format(
self._check_param_help('remote_model_name', dirname)))
class FTPMonitor(Monitor):
''' FTP Monitor. '''
def __init__(self, host, port, username="", password="", interval=10):
super(FTPMonitor, self).__init__(interval)
import ftplib
self._ftp = ftplib.FTP()
self._ftp_host = host
self._ftp_port = port
self._ftp_username = username
self._ftp_password = password
self._ftp.connect(self._ftp_host, self._ftp_port)
self._ftp.login(self._ftp_username, self._ftp_password)
self._print_params(
['_ftp_host', '_ftp_port', '_ftp_username', '_ftp_password'])
def _exist_remote_file(self, path, filename, local_tmp_path):
import ftplib
try:
_LOGGER.debug('cwd: {}'.format(path))
self._ftp.cwd(path)
timestamp = self._ftp.voidcmd('MDTM {}'.format(filename))[4:].strip(
)
return [True, timestamp]
except ftplib.error_perm:
_LOGGER.debug('remote file({}) not exist.'.format(filename))
return [False, None]
def _download_remote_file(self,
remote_path,
remote_filename,
local_tmp_path,
overwrite=True):
local_fullpath = os.path.join(local_tmp_path, remote_filename)
if not overwrite and os.path.isfile(fullpath):
return
else:
with open(local_fullpath, 'wb') as f:
_LOGGER.debug('cwd: {}'.format(remote_path))
self._ftp.cwd(remote_path)
_LOGGER.debug('download remote file({})'.format(
remote_filename))
self._ftp.retrbinary('RETR {}'.format(remote_filename), f.write)
def _download_remote_files(self,
remote_path,
remote_dirname,
local_tmp_path,
overwrite=True):
import ftplib
remote_dirpath = os.path.join(remote_path, remote_dirname)
# Check whether remote_dirpath is a file or a folder
try:
_LOGGER.debug('cwd: {}'.format(remote_dirpath))
self._ftp.cwd(remote_dirpath)
_LOGGER.debug('{} is folder.'.format(remote_dirname))
local_dirpath = os.path.join(local_tmp_path, remote_dirname)
if not os.path.exists(local_dirpath):
_LOGGER.info('mkdir: {}'.format(local_dirpath))
os.mkdir(local_dirpath)
output = []
self._ftp.dir(output.append)
for line in output:
[attr, _, _, _, _, _, _, _, name] = line.split()
if attr[0] == 'd':
self._download_remote_files(
os.path.join(remote_path, remote_dirname), name,
os.path.join(local_tmp_path, remote_dirname), overwrite)
else:
self._download_remote_file(remote_dirpath, name,
local_dirpath, overwrite)
except ftplib.error_perm:
_LOGGER.debug('{} is file.'.format(remote_dirname))
self._download_remote_file(remote_path, remote_dirname,
local_tmp_path, overwrite)
return
def _pull_remote_dir(self, remote_path, dirname, local_tmp_path):
self._download_remote_files(
remote_path, dirname, local_tmp_path, overwrite=True)
class GeneralMonitor(Monitor):
''' General Monitor. '''
def __init__(self, host, interval=10):
super(GeneralMonitor, self).__init__(interval)
self._general_host = host
self._print_params(['_general_host'])
def _get_local_file_timestamp(self, filename):
return os.path.getmtime(filename)
def _exist_remote_file(self, remote_path, filename, local_tmp_path):
remote_filepath = os.path.join(remote_path, filename)
url = '{}/{}'.format(self._general_host, remote_filepath)
_LOGGER.debug('remote file url: {}'.format(url))
# only for check donefile, which is not a folder.
cmd = 'wget -nd -N -P {} {} &>/dev/null'.format(local_tmp_path, url)
_LOGGER.debug('wget cmd: {}'.format(cmd))
if os.system(cmd) != 0:
_LOGGER.debug('remote file({}) not exist.'.format(remote_filepath))
return [False, None]
else:
timestamp = self._get_local_file_timestamp(
os.path.join(local_tmp_path, filename))
return [True, timestamp]
def _pull_remote_dir(self, remote_path, dirname, local_tmp_path):
remote_dirpath = os.path.join(remote_path, dirname)
url = '{}/{}'.format(self._general_host, remote_dirpath)
_LOGGER.debug('remote file url: {}'.format(url))
if self._unpacked_filename is None:
# the remote file is model folder.
cmd = 'wget -nH -r -P {} {} &>/dev/null'.format(
os.path.join(local_tmp_path, dirname), url)
else:
# the remote file is a packed model file
cmd = 'wget -nd -N -P {} {} &>/dev/null'.format(local_tmp_path, url)
_LOGGER.debug('wget cmd: {}'.format(cmd))
if os.system(cmd) != 0:
raise Exception('pull remote dir failed. {}'.format(
self._check_param_help('remote_model_name', dirname)))
def parse_args():
""" parse args.
Returns:
parser.parse_args().
"""
parser = argparse.ArgumentParser(description="Monitor")
parser.add_argument(
"--type", type=str, default='general', help="Type of remote server")
parser.add_argument(
"--remote_path",
type=str,
required=True,
help="The base path for the remote")
parser.add_argument(
"--remote_model_name",
type=str,
required=True,
help="The model name to be pulled from the remote")
parser.add_argument(
"--remote_donefile_name",
type=str,
required=True,
help="The donefile name that marks the completion of the remote model update"
)
parser.add_argument(
"--local_path", type=str, required=True, help="Local work path")
parser.add_argument(
"--local_model_name", type=str, required=True, help="Local model name")
parser.add_argument(
"--local_timestamp_file",
type=str,
default='fluid_time_file',
help="The timestamp file used locally for hot loading, The file is considered to be placed in the `local_path/local_model_name` folder."
)
parser.add_argument(
"--local_tmp_path",
type=str,
default='_serving_monitor_tmp',
help="The path of the folder where temporary files are stored locally. If it does not exist, it will be created automatically"
)
parser.add_argument(
"--unpacked_filename",
type=str,
default=None,
help="If the model of the remote production is a packaged file, the unpacked file name should be set. Currently, only tar packaging format is supported."
)
parser.add_argument(
"--interval",
type=int,
default=10,
help="The polling interval in seconds")
parser.add_argument(
"--debug", action='store_true', help="If set, output more details")
parser.set_defaults(debug=False)
# general monitor
parser.add_argument("--general_host", type=str, help="General remote host")
# ftp monitor
parser.add_argument("--ftp_host", type=str, help="FTP remote host")
parser.add_argument("--ftp_port", type=int, help="FTP remote port")
parser.add_argument(
"--ftp_username",
type=str,
default='',
help="FTP username. Not used if anonymous access.")
parser.add_argument(
"--ftp_password",
type=str,
default='',
help="FTP password. Not used if anonymous access")
# afs/hdfs monitor
parser.add_argument(
"--hadoop_bin", type=str, help="Path of Hadoop binary file")
parser.add_argument(
"--fs_name",
type=str,
default='',
help="AFS/HDFS fs_name. Not used if set in Hadoop-client.")
parser.add_argument(
"--fs_ugi",
type=str,
default='',
help="AFS/HDFS fs_ugi, Not used if set in Hadoop-client")
return parser.parse_args()
def get_monitor(mtype):
""" generator monitor instance.
Args:
mtype: type of monitor
Returns:
monitor instance.
"""
if mtype == 'ftp':
return FTPMonitor(
args.ftp_host,
args.ftp_port,
username=args.ftp_username,
password=args.ftp_password,
interval=args.interval)
elif mtype == 'general':
return GeneralMonitor(args.general_host, interval=args.interval)
elif mtype == 'afs' or mtype == 'hdfs':
return HadoopMonitor(
args.hadoop_bin, args.fs_name, args.fs_ugi, interval=args.interval)
else:
raise Exception('unsupport type.')
def start_monitor(monitor, args):
monitor.set_remote_path(args.remote_path)
monitor.set_remote_model_name(args.remote_model_name)
monitor.set_remote_donefile_name(args.remote_donefile_name)
monitor.set_local_path(args.local_path)
monitor.set_local_model_name(args.local_model_name)
monitor.set_local_timestamp_file(args.local_timestamp_file)
monitor.set_local_tmp_path(args.local_tmp_path)
monitor.set_unpacked_filename(args.unpacked_filename)
monitor.run()
if __name__ == "__main__":
args = parse_args()
if args.debug:
logging.basicConfig(
format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s',
datefmt='%Y-%m-%d %H:%M',
level=logging.DEBUG)
else:
logging.basicConfig(
format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s',
datefmt='%Y-%m-%d %H:%M',
level=logging.INFO)
monitor = get_monitor(args.type)
start_monitor(monitor, args)
...@@ -35,6 +35,7 @@ def start_gpu_card_model(index, gpuid, args): # pylint: disable=doc-string-miss ...@@ -35,6 +35,7 @@ def start_gpu_card_model(index, gpuid, args): # pylint: disable=doc-string-miss
thread_num = args.thread thread_num = args.thread
model = args.model model = args.model
mem_optim = args.mem_optim mem_optim = args.mem_optim
max_body_size = args.max_body_size
workdir = "{}_{}".format(args.workdir, gpuid) workdir = "{}_{}".format(args.workdir, gpuid)
if model == "": if model == "":
...@@ -56,6 +57,7 @@ def start_gpu_card_model(index, gpuid, args): # pylint: disable=doc-string-miss ...@@ -56,6 +57,7 @@ def start_gpu_card_model(index, gpuid, args): # pylint: disable=doc-string-miss
server.set_op_sequence(op_seq_maker.get_op_sequence()) server.set_op_sequence(op_seq_maker.get_op_sequence())
server.set_num_threads(thread_num) server.set_num_threads(thread_num)
server.set_memory_optimize(mem_optim) server.set_memory_optimize(mem_optim)
server.set_max_body_size(max_body_size)
server.load_model_config(model) server.load_model_config(model)
server.prepare_server(workdir=workdir, port=port, device=device) server.prepare_server(workdir=workdir, port=port, device=device)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册