提交 67b78422 编写于 作者: B barrierye

Merge branch 'ensemble-support' of https://github.com/barrierye/Serving into ensemble-support

...@@ -59,8 +59,11 @@ int GeneralInferOp::inference() { ...@@ -59,8 +59,11 @@ int GeneralInferOp::inference() {
int64_t start = timeline.TimeStampUS(); int64_t start = timeline.TimeStampUS();
timeline.Start(); timeline.Start();
if (InferManager::instance().infer(GENERAL_MODEL_NAME, in, out, batch_size)) { BLOG("engine name: %s", engine_name().c_str());
LOG(ERROR) << "Failed do infer in fluid model: " << GENERAL_MODEL_NAME; if (InferManager::instance().infer(
GeneralInferOp::engine_name().c_str(), in, out, batch_size)) {
LOG(ERROR) << "Failed do infer in fluid model: "
<< GeneralInferOp::engine_name();
return -1; return -1;
} }
......
...@@ -21,6 +21,9 @@ ...@@ -21,6 +21,9 @@
#include "core/predictor/framework/infer.h" #include "core/predictor/framework/infer.h"
#include "core/predictor/framework/memory.h" #include "core/predictor/framework/memory.h"
#include "core/util/include/timer.h" #include "core/util/include/timer.h"
#define BLOG(fmt, ...) \
printf( \
"[%s:%s]:%d " fmt "\n", __FILE__, __FUNCTION__, __LINE__, ##__VA_ARGS__)
namespace baidu { namespace baidu {
namespace paddle_serving { namespace paddle_serving {
...@@ -98,6 +101,7 @@ int GeneralReaderOp::inference() { ...@@ -98,6 +101,7 @@ int GeneralReaderOp::inference() {
baidu::paddle_serving::predictor::Resource::instance(); baidu::paddle_serving::predictor::Resource::instance();
VLOG(2) << "get resource pointer done."; VLOG(2) << "get resource pointer done.";
BLOG("engine name: %s", engine_name().c_str());
std::shared_ptr<PaddleGeneralModelConfig> model_config = std::shared_ptr<PaddleGeneralModelConfig> model_config =
resource.get_general_model_config(); resource.get_general_model_config();
......
...@@ -18,6 +18,9 @@ ...@@ -18,6 +18,9 @@
#include "core/predictor/common/inner_common.h" #include "core/predictor/common/inner_common.h"
#include "core/predictor/framework/predictor_metric.h" // PredictorMetric #include "core/predictor/framework/predictor_metric.h" // PredictorMetric
#include "core/predictor/op/op.h" #include "core/predictor/op/op.h"
#define BLOG(fmt, ...) \
printf( \
"[%s:%s]:%d " fmt "\n", __FILE__, __FUNCTION__, __LINE__, ##__VA_ARGS__)
namespace baidu { namespace baidu {
namespace paddle_serving { namespace paddle_serving {
...@@ -199,25 +202,110 @@ const DagStage* Dag::stage_by_index(uint32_t index) { return _stages[index]; } ...@@ -199,25 +202,110 @@ const DagStage* Dag::stage_by_index(uint32_t index) { return _stages[index]; }
int Dag::topo_sort() { int Dag::topo_sort() {
std::stringstream ss; std::stringstream ss;
for (uint32_t nid = 0; nid < _index_nodes.size(); nid++) { uint32_t nodes_size = _index_nodes.size();
DagStage* stage = new (std::nothrow) DagStage(); std::vector<uint32_t> in_degree(nodes_size, 0);
if (stage == NULL) { std::vector<std::vector<uint32_t>> in_egde(nodes_size);
LOG(ERROR) << "Invalid stage!"; for (uint32_t nid = 0; nid < nodes_size; nid++) {
return ERR_MEM_ALLOC_FAILURE; in_degree[nid] += _index_nodes[nid]->depends.size();
for (auto it = _index_nodes[nid]->depends.begin();
it != _index_nodes[nid]->depends.end();
++it) {
uint32_t pnid = Dag::node_by_name(it->first)->id -
1; // 0 is reserved for begginer-op
in_egde[pnid].push_back(nid);
BLOG("inegde[%d]: %d", pnid, nid);
}
}
for (int i = 0; i < in_degree.size(); ++i) {
BLOG("(%s) in_degree[%d]: %d",
_index_nodes[i]->name.c_str(),
i,
in_degree[i]);
}
int sorted_num = 0;
DagStage* stage = new (std::nothrow) DagStage();
if (stage == NULL) {
LOG(ERROR) << "Invalid stage!";
return ERR_MEM_ALLOC_FAILURE;
}
ss.str("");
ss << _stages.size();
stage->name = ss.str();
stage->full_name = full_name() + NAME_DELIMITER + stage->name;
BLOG("stage->full_name: %s", stage->full_name.c_str());
for (uint32_t nid = 0; nid < nodes_size; ++nid) {
if (in_degree[nid] == 0) {
BLOG("nid: %d", nid);
++sorted_num;
stage->nodes.push_back(_index_nodes[nid]);
// assign stage number after stage created
_index_nodes[nid]->stage = _stages.size();
// assign dag node full name after stage created
_index_nodes[nid]->full_name =
stage->full_name + NAME_DELIMITER + _index_nodes[nid]->name;
} }
stage->nodes.push_back(_index_nodes[nid]); }
if (stage->nodes.size() == 0) {
LOG(ERROR) << "Invalid Dag!";
return ERR_INTERNAL_FAILURE;
}
_stages.push_back(stage);
while (sorted_num < nodes_size) {
auto pre_nodes = _stages.back()->nodes;
DagStage* stage = new (std::nothrow) DagStage();
ss.str(""); ss.str("");
ss << _stages.size(); ss << _stages.size();
stage->name = ss.str(); stage->name = ss.str();
stage->full_name = full_name() + NAME_DELIMITER + stage->name; stage->full_name = full_name() + NAME_DELIMITER + stage->name;
BLOG("stage->full_name: %s", stage->full_name.c_str());
for (uint32_t pi = 0; pi < pre_nodes.size(); ++pi) {
uint32_t pnid = pre_nodes[pi]->id - 1;
BLOG("pnid: %d", pnid);
for (uint32_t ei = 0; ei < in_egde[pnid].size(); ++ei) {
uint32_t nid = in_egde[pnid][ei];
--in_degree[nid];
BLOG("nid: %d, indeg: %d", nid, in_degree[nid]);
if (in_degree[nid] == 0) {
BLOG("nid: %d", nid);
++sorted_num;
stage->nodes.push_back(_index_nodes[nid]);
// assign stage number after stage created
_index_nodes[nid]->stage = _stages.size();
// assign dag node full name after stage created
_index_nodes[nid]->full_name =
stage->full_name + NAME_DELIMITER + _index_nodes[nid]->name;
}
}
}
if (stage->nodes.size() == 0) {
LOG(ERROR) << "Invalid Dag!";
return ERR_INTERNAL_FAILURE;
}
_stages.push_back(stage); _stages.push_back(stage);
// assign stage number after stage created
_index_nodes[nid]->stage = nid;
// assign dag node full name after stage created
_index_nodes[nid]->full_name =
stage->full_name + NAME_DELIMITER + _index_nodes[nid]->name;
} }
/*std::stringstream ss;*/
// for (uint32_t nid = 0; nid < _index_nodes.size(); nid++) {
// DagStage* stage = new (std::nothrow) DagStage();
// if (stage == NULL) {
// LOG(ERROR) << "Invalid stage!";
// return ERR_MEM_ALLOC_FAILURE;
//}
// stage->nodes.push_back(_index_nodes[nid]);
// ss.str("");
// ss << _stages.size();
// stage->name = ss.str();
// stage->full_name = full_name() + NAME_DELIMITER + stage->name;
// BLOG("stage->full_name: %s", stage->full_name.c_str());
//_stages.push_back(stage);
//// assign stage number after stage created
//_index_nodes[nid]->stage = nid;
//// assign dag node full name after stage created
//_index_nodes[nid]->full_name =
// stage->full_name + NAME_DELIMITER + _index_nodes[nid]->name;
/*}*/
return ERR_OK; return ERR_OK;
} }
......
...@@ -21,6 +21,12 @@ ...@@ -21,6 +21,12 @@
#include <string> #include <string>
#include "core/predictor/common/inner_common.h" #include "core/predictor/common/inner_common.h"
#include "core/predictor/framework/op_repository.h" #include "core/predictor/framework/op_repository.h"
#define BLOG(fmt, ...) \
printf("[%s:%s]:%d " fmt "\n", \
__FILE__, \
__FUNCTION__, \
__LINE__, \
##__VA_ARGS__);
namespace baidu { namespace baidu {
namespace paddle_serving { namespace paddle_serving {
...@@ -76,6 +82,11 @@ int DagView::init(Dag* dag, const std::string& service_name) { ...@@ -76,6 +82,11 @@ int DagView::init(Dag* dag, const std::string& service_name) {
} }
op->set_full_name(service_name + NAME_DELIMITER + node->full_name); op->set_full_name(service_name + NAME_DELIMITER + node->full_name);
// Set the name of the Op as the key of the matching engine.
BLOG("op->set_engine_name(%s)", node->name.c_str());
op->set_engine_name(node->name);
vnode->conf = node; vnode->conf = node;
vnode->op = op; vnode->op = op;
vstage->nodes.push_back(vnode); vstage->nodes.push_back(vnode);
...@@ -121,6 +132,7 @@ int DagView::deinit() { ...@@ -121,6 +132,7 @@ int DagView::deinit() {
int DagView::execute(butil::IOBufBuilder* debug_os) { int DagView::execute(butil::IOBufBuilder* debug_os) {
uint32_t stage_size = _view.size(); uint32_t stage_size = _view.size();
for (uint32_t si = 0; si < stage_size; si++) { for (uint32_t si = 0; si < stage_size; si++) {
BLOG("start to execute stage[%u] %s", si, _view[si]->full_name.c_str());
TRACEPRINTF("start to execute stage[%u]", si); TRACEPRINTF("start to execute stage[%u]", si);
int errcode = execute_one_stage(_view[si], debug_os); int errcode = execute_one_stage(_view[si], debug_os);
TRACEPRINTF("finish to execute stage[%u]", si); TRACEPRINTF("finish to execute stage[%u]", si);
...@@ -139,12 +151,16 @@ int DagView::execute_one_stage(ViewStage* vstage, ...@@ -139,12 +151,16 @@ int DagView::execute_one_stage(ViewStage* vstage,
butil::IOBufBuilder* debug_os) { butil::IOBufBuilder* debug_os) {
butil::Timer stage_time(butil::Timer::STARTED); butil::Timer stage_time(butil::Timer::STARTED);
uint32_t node_size = vstage->nodes.size(); uint32_t node_size = vstage->nodes.size();
BLOG("vstage->nodes.size(): %d", node_size);
for (uint32_t ni = 0; ni < node_size; ni++) { for (uint32_t ni = 0; ni < node_size; ni++) {
ViewNode* vnode = vstage->nodes[ni]; ViewNode* vnode = vstage->nodes[ni];
DagNode* conf = vnode->conf; DagNode* conf = vnode->conf;
Op* op = vnode->op; Op* op = vnode->op;
BLOG("start to execute op[%s]", op->name());
BLOG("Op engine name: %s", op->engine_name().c_str());
TRACEPRINTF("start to execute op[%s]", op->name()); TRACEPRINTF("start to execute op[%s]", op->name());
int errcode = op->process(debug_os != NULL); int errcode = op->process(debug_os != NULL);
BLOG("finish to execute op[%s]", op->name());
TRACEPRINTF("finish to execute op[%s]", op->name()); TRACEPRINTF("finish to execute op[%s]", op->name());
if (errcode < 0) { if (errcode < 0) {
LOG(ERROR) << "Execute failed, Op:" << op->debug_string(); LOG(ERROR) << "Execute failed, Op:" << op->debug_string();
......
...@@ -23,6 +23,9 @@ ...@@ -23,6 +23,9 @@
#include "core/predictor/framework/bsf.h" #include "core/predictor/framework/bsf.h"
#include "core/predictor/framework/factory.h" #include "core/predictor/framework/factory.h"
#include "core/predictor/framework/infer_data.h" #include "core/predictor/framework/infer_data.h"
#define BLOG(fmt, ...) \
printf( \
"[%s:%s]:%d " fmt "\n", __FILE__, __FUNCTION__, __LINE__, ##__VA_ARGS__)
namespace baidu { namespace baidu {
namespace paddle_serving { namespace paddle_serving {
...@@ -765,6 +768,9 @@ class InferManager { ...@@ -765,6 +768,9 @@ class InferManager {
} }
size_t engine_num = model_toolkit_conf.engines_size(); size_t engine_num = model_toolkit_conf.engines_size();
for (size_t ei = 0; ei < engine_num; ++ei) { for (size_t ei = 0; ei < engine_num; ++ei) {
BLOG("model_toolkit_conf.engines(%d).name: %s",
ei,
model_toolkit_conf.engines(ei).name().c_str());
std::string engine_name = model_toolkit_conf.engines(ei).name(); std::string engine_name = model_toolkit_conf.engines(ei).name();
VersionedInferEngine* engine = new (std::nothrow) VersionedInferEngine(); VersionedInferEngine* engine = new (std::nothrow) VersionedInferEngine();
if (!engine) { if (!engine) {
...@@ -845,8 +851,10 @@ class InferManager { ...@@ -845,8 +851,10 @@ class InferManager {
void* out, void* out,
uint32_t batch_size = -1) { uint32_t batch_size = -1) {
auto it = _map.find(model_name); auto it = _map.find(model_name);
BLOG("find model_name: %s", model_name);
if (it == _map.end()) { if (it == _map.end()) {
LOG(WARNING) << "Cannot find engine in map, model name:" << model_name; LOG(WARNING) << "Cannot find engine in map, model name:" << model_name;
BLOG("Cannot find engine in map, model name: %s", model_name);
return -1; return -1;
} }
return it->second->infer(in, out, batch_size); return it->second->infer(in, out, batch_size);
......
...@@ -30,6 +30,9 @@ ...@@ -30,6 +30,9 @@
#include "core/predictor/framework/predictor_metric.h" // PredictorMetric #include "core/predictor/framework/predictor_metric.h" // PredictorMetric
#include "core/predictor/framework/resource.h" #include "core/predictor/framework/resource.h"
#include "core/predictor/framework/server.h" #include "core/predictor/framework/server.h"
#define BLOG(fmt, ...) \
printf( \
"[%s:%s]:%d " fmt "\n", __FILE__, __FUNCTION__, __LINE__, ##__VA_ARGS__)
namespace baidu { namespace baidu {
namespace paddle_serving { namespace paddle_serving {
...@@ -161,6 +164,7 @@ int InferService::inference(const google::protobuf::Message* request, ...@@ -161,6 +164,7 @@ int InferService::inference(const google::protobuf::Message* request,
return ERR_INTERNAL_FAILURE; return ERR_INTERNAL_FAILURE;
} }
TRACEPRINTF("start to execute workflow[%s]", workflow->name().c_str()); TRACEPRINTF("start to execute workflow[%s]", workflow->name().c_str());
BLOG("start to execute workflow[%s]", workflow->name().c_str());
int errcode = _execute_workflow(workflow, request, response, debug_os); int errcode = _execute_workflow(workflow, request, response, debug_os);
TRACEPRINTF("finish to execute workflow[%s]", workflow->name().c_str()); TRACEPRINTF("finish to execute workflow[%s]", workflow->name().c_str());
if (errcode < 0) { if (errcode < 0) {
...@@ -220,6 +224,7 @@ int InferService::_execute_workflow(Workflow* workflow, ...@@ -220,6 +224,7 @@ int InferService::_execute_workflow(Workflow* workflow,
// call actual inference interface // call actual inference interface
int errcode = dv->execute(debug_os); int errcode = dv->execute(debug_os);
BLOG("execute_workflow");
if (errcode < 0) { if (errcode < 0) {
LOG(ERROR) << "Failed execute dag for workflow:" << workflow->name(); LOG(ERROR) << "Failed execute dag for workflow:" << workflow->name();
return errcode; return errcode;
......
...@@ -25,6 +25,12 @@ ...@@ -25,6 +25,12 @@
#include "core/predictor/common/utils.h" #include "core/predictor/common/utils.h"
#include "core/predictor/framework/channel.h" #include "core/predictor/framework/channel.h"
#include "core/predictor/framework/dag.h" #include "core/predictor/framework/dag.h"
#define BLOG(fmt, ...) \
printf("[%s:%s]:%d " fmt "\n", \
__FILE__, \
__FUNCTION__, \
__LINE__, \
##__VA_ARGS__);
namespace baidu { namespace baidu {
namespace paddle_serving { namespace paddle_serving {
...@@ -133,6 +139,7 @@ int Op::process(bool debug) { ...@@ -133,6 +139,7 @@ int Op::process(bool debug) {
} }
// 2. current inference // 2. current inference
BLOG("Op: %s->inference()", _name.c_str());
if (inference() != 0) { if (inference() != 0) {
return ERR_OP_INFER_FAILURE; return ERR_OP_INFER_FAILURE;
} }
......
...@@ -144,6 +144,16 @@ class Op { ...@@ -144,6 +144,16 @@ class Op {
uint32_t id() const; uint32_t id() const;
// Set the name of the Op as the key of the matching engine.
// Notes that this key is only used by infer_op (only the
// infer_op needs to find the corresponding engine).
// At present, there is only general_infer_op.
void set_engine_name(const std::string engine_name) {
_engine_name = engine_name;
}
const std::string& engine_name() const { return _engine_name; }
// --------------- Default implements ---------------- // --------------- Default implements ----------------
virtual int custom_init() { return 0; } virtual int custom_init() { return 0; }
...@@ -196,6 +206,7 @@ class Op { ...@@ -196,6 +206,7 @@ class Op {
bool _has_calc; bool _has_calc;
bool _has_init; bool _has_init;
TimerFlow* _timer; TimerFlow* _timer;
std::string _engine_name; // only for infer_op
}; };
template <typename T> template <typename T>
......
...@@ -29,3 +29,4 @@ test_reader = paddle.batch( ...@@ -29,3 +29,4 @@ test_reader = paddle.batch(
for data in test_reader(): for data in test_reader():
fetch_map = client.predict(feed={"x": data[0][0]}, fetch=["price"]) fetch_map = client.predict(feed={"x": data[0][0]}, fetch=["price"])
print("{} {}".format(fetch_map["price"][0], data[0][1][0])) print("{} {}".format(fetch_map["price"][0], data[0][1][0]))
exit(0)
...@@ -21,16 +21,20 @@ from paddle_serving_server import Server ...@@ -21,16 +21,20 @@ from paddle_serving_server import Server
op_maker = OpMaker() op_maker = OpMaker()
read_op = op_maker.create('general_reader') read_op = op_maker.create('general_reader')
general_infer_op = op_maker.create('general_infer') g1_infer_op = op_maker.create('general_infer', node_name='g1')
g2_infer_op = op_maker.create('general_infer', node_name='g2')
response_op = op_maker.create('general_response') response_op = op_maker.create('general_response')
op_seq_maker = OpSeqMaker() op_seq_maker = OpSeqMaker()
op_seq_maker.add_op(read_op) op_seq_maker.add_op(read_op)
op_seq_maker.add_op(general_infer_op) op_seq_maker.add_op(g1_infer_op, dependent_nodes=[read_op])
op_seq_maker.add_op(response_op) op_seq_maker.add_op(g2_infer_op, dependent_nodes=[read_op])
op_seq_maker.add_op(response_op, dependent_nodes=[g1_infer_op, g2_infer_op])
server = Server() server = Server()
server.set_op_sequence(op_seq_maker.get_op_sequence()) server.set_op_sequence(op_seq_maker.get_op_sequence())
server.load_model_config(sys.argv[1]) # server.load_model_config(sys.argv[1])
model_configs = {'g1': 'uci_housing_model', 'g2': 'uci_housing_model'}
server.load_model_config(model_configs)
server.prepare_server(workdir="work_dir1", port=9393, device="cpu") server.prepare_server(workdir="work_dir1", port=9393, device="cpu")
server.run_server() server.run_server()
...@@ -39,13 +39,14 @@ class OpMaker(object): ...@@ -39,13 +39,14 @@ class OpMaker(object):
# currently, inputs and outputs are not used # currently, inputs and outputs are not used
# when we have OpGraphMaker, inputs and outputs are necessary # when we have OpGraphMaker, inputs and outputs are necessary
def create(self, name, inputs=[], outputs=[]): def create(self, node_type, node_name=None, inputs=[], outputs=[]):
if name not in self.op_dict: if node_type not in self.op_dict:
raise Exception("Op name {} is not supported right now".format( raise Exception("Op type {} is not supported right now".format(
name)) node_type))
node = server_sdk.DAGNode() node = server_sdk.DAGNode()
node.name = "{}_op".format(name) node.name = node_name if node_name is not None else "{}_op".format(
node.type = self.op_dict[name] node_type)
node.type = self.op_dict[node_type]
return node return node
...@@ -55,12 +56,19 @@ class OpSeqMaker(object): ...@@ -55,12 +56,19 @@ class OpSeqMaker(object):
self.workflow.name = "workflow1" self.workflow.name = "workflow1"
self.workflow.workflow_type = "Sequence" self.workflow.workflow_type = "Sequence"
def add_op(self, node): def add_op(self, node, dependent_nodes=None):
if len(self.workflow.nodes) >= 1: if dependent_nodes is None:
dep = server_sdk.DAGNodeDependency() if len(self.workflow.nodes) >= 1:
dep.name = self.workflow.nodes[-1].name dep = server_sdk.DAGNodeDependency()
dep.mode = "RO" dep.name = self.workflow.nodes[-1].name
node.dependencies.extend([dep]) dep.mode = "RO"
node.dependencies.extend([dep])
else:
for dep_node in dependent_nodes:
dep = server_sdk.DAGNodeDependency()
dep.name = dep_node.name
dep.mode = "RO"
node.dependencies.extend([dep])
self.workflow.nodes.extend([node]) self.workflow.nodes.extend([node])
def get_op_sequence(self): def get_op_sequence(self):
...@@ -75,7 +83,6 @@ class Server(object): ...@@ -75,7 +83,6 @@ class Server(object):
self.infer_service_conf = None self.infer_service_conf = None
self.model_toolkit_conf = None self.model_toolkit_conf = None
self.resource_conf = None self.resource_conf = None
self.engine = None
self.memory_optimization = False self.memory_optimization = False
self.model_conf = None self.model_conf = None
self.workflow_fn = "workflow.prototxt" self.workflow_fn = "workflow.prototxt"
...@@ -93,6 +100,7 @@ class Server(object): ...@@ -93,6 +100,7 @@ class Server(object):
self.cur_path = os.getcwd() self.cur_path = os.getcwd()
self.use_local_bin = False self.use_local_bin = False
self.mkl_flag = False self.mkl_flag = False
self.model_config_paths = None
def set_max_concurrency(self, concurrency): def set_max_concurrency(self, concurrency):
self.max_concurrency = concurrency self.max_concurrency = concurrency
...@@ -117,32 +125,36 @@ class Server(object): ...@@ -117,32 +125,36 @@ class Server(object):
self.use_local_bin = True self.use_local_bin = True
self.bin_path = os.environ["SERVING_BIN"] self.bin_path = os.environ["SERVING_BIN"]
def _prepare_engine(self, model_config_path, device): def _prepare_engine(self, model_config_paths, device):
if self.model_toolkit_conf == None: if self.model_toolkit_conf == None:
self.model_toolkit_conf = server_sdk.ModelToolkitConf() self.model_toolkit_conf = server_sdk.ModelToolkitConf()
if self.engine == None: if isinstance(model_config_paths, str):
self.engine = server_sdk.EngineDesc() model_config_paths = {"general_infer_op": model_config_paths}
elif not isinstance(model_config_paths, dict):
self.model_config_path = model_config_path raise Exception("model_config_paths can not be {}".format(
self.engine.name = "general_model" type(model_config_paths)))
self.engine.reloadable_meta = model_config_path + "/fluid_time_file"
os.system("touch {}".format(self.engine.reloadable_meta)) for engine_name, model_config_path in model_config_paths.items():
self.engine.reloadable_type = "timestamp_ne" engine = server_sdk.EngineDesc()
self.engine.runtime_thread_num = 0 engine.name = engine_name
self.engine.batch_infer_size = 0 engine.reloadable_meta = model_config_path + "/fluid_time_file"
self.engine.enable_batch_align = 0 os.system("touch {}".format(engine.reloadable_meta))
self.engine.model_data_path = model_config_path engine.reloadable_type = "timestamp_ne"
self.engine.enable_memory_optimization = self.memory_optimization engine.runtime_thread_num = 0
self.engine.static_optimization = False engine.batch_infer_size = 0
self.engine.force_update_static_cache = False engine.enable_batch_align = 0
engine.model_data_path = model_config_path
if device == "cpu": engine.enable_memory_optimization = self.memory_optimization
self.engine.type = "FLUID_CPU_ANALYSIS_DIR" engine.static_optimization = False
elif device == "gpu": engine.force_update_static_cache = False
self.engine.type = "FLUID_GPU_ANALYSIS_DIR"
if device == "cpu":
self.model_toolkit_conf.engines.extend([self.engine]) engine.type = "FLUID_CPU_ANALYSIS_DIR"
elif device == "gpu":
engine.type = "FLUID_GPU_ANALYSIS_DIR"
self.model_toolkit_conf.engines.extend([engine])
def _prepare_infer_service(self, port): def _prepare_infer_service(self, port):
if self.infer_service_conf == None: if self.infer_service_conf == None:
...@@ -175,7 +187,9 @@ class Server(object): ...@@ -175,7 +187,9 @@ class Server(object):
with open(filepath, "w") as fout: with open(filepath, "w") as fout:
fout.write(str(pb_obj)) fout.write(str(pb_obj))
def load_model_config(self, path): def load_model_config(self, model_config_paths):
self.model_config_paths = model_config_paths
path = model_config_paths.items()[0][1]
self.model_config_path = path self.model_config_path = path
self.model_conf = m_config.GeneralModelConfig() self.model_conf = m_config.GeneralModelConfig()
f = open("{}/serving_server_conf.prototxt".format(path), 'r') f = open("{}/serving_server_conf.prototxt".format(path), 'r')
...@@ -249,7 +263,7 @@ class Server(object): ...@@ -249,7 +263,7 @@ class Server(object):
if not self.port_is_available(port): if not self.port_is_available(port):
raise SystemExit("Prot {} is already used".format(port)) raise SystemExit("Prot {} is already used".format(port))
self._prepare_resource(workdir) self._prepare_resource(workdir)
self._prepare_engine(self.model_config_path, device) self._prepare_engine(self.model_config_paths, device)
self._prepare_infer_service(port) self._prepare_infer_service(port)
self.workdir = workdir self.workdir = workdir
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册