提交 1021499c 编写于 作者: D dongdaxiang

add pybind and python api function to make get_general_model_conf doable

上级 c585f678
......@@ -158,6 +158,8 @@ class PredictorClient {
int init(const std::string& client_conf);
int init_from_string(const std::string& conf_string);
void set_predictor_conf(const std::string& conf_path,
const std::string& conf_file);
......
......@@ -18,6 +18,7 @@
#include "core/sdk-cpp/include/common.h"
#include "core/sdk-cpp/include/predictor_sdk.h"
#include "core/util/include/timer.h"
#include "google/protobuf/text_format.h"
DEFINE_bool(profile_client, false, "");
DEFINE_bool(profile_server, false, "");
......@@ -104,6 +105,57 @@ int PredictorClient::init(const std::string &conf_file) {
return 0;
}
int PredictorClient::init_from_string(const std::string &conf_string) {
try {
GeneralModelConfig model_config;
bool success = google::protobuf::TextFormat::ParseFromString(conf_string,
&model_config);
if (!success) {
LOG(ERROR) << "Failed to parse config string into GeneralModelConfig";
return -1;
}
_feed_name_to_idx.clear();
_fetch_name_to_idx.clear();
_shape.clear();
int feed_var_num = model_config.feed_var_size();
int fetch_var_num = model_config.fetch_var_size();
VLOG(2) << "feed var num: " << feed_var_num
<< "fetch_var_num: " << fetch_var_num;
for (int i = 0; i < feed_var_num; ++i) {
_feed_name_to_idx[model_config.feed_var(i).alias_name()] = i;
VLOG(2) << "feed alias name: " << model_config.feed_var(i).alias_name()
<< " index: " << i;
std::vector<int> tmp_feed_shape;
VLOG(2) << "feed"
<< "[" << i << "] shape:";
for (int j = 0; j < model_config.feed_var(i).shape_size(); ++j) {
tmp_feed_shape.push_back(model_config.feed_var(i).shape(j));
VLOG(2) << "shape[" << j << "]: " << model_config.feed_var(i).shape(j);
}
_type.push_back(model_config.feed_var(i).feed_type());
VLOG(2) << "feed"
<< "[" << i
<< "] feed type: " << model_config.feed_var(i).feed_type();
_shape.push_back(tmp_feed_shape);
}
for (int i = 0; i < fetch_var_num; ++i) {
_fetch_name_to_idx[model_config.fetch_var(i).alias_name()] = i;
VLOG(2) << "fetch [" << i << "]"
<< " alias name: " << model_config.fetch_var(i).alias_name();
_fetch_name_to_var_name[model_config.fetch_var(i).alias_name()] =
model_config.fetch_var(i).name();
_fetch_name_to_type[model_config.fetch_var(i).alias_name()] =
model_config.fetch_var(i).fetch_type();
}
} catch (std::exception &e) {
LOG(ERROR) << "Failed load general model config" << e.what();
return -1;
}
return 0;
}
void PredictorClient::set_predictor_conf(const std::string &conf_path,
const std::string &conf_file) {
_predictor_path = conf_path;
......
......@@ -66,6 +66,10 @@ PYBIND11_MODULE(serving_client, m) {
[](PredictorClient &self, const std::string &conf) {
return self.init(conf);
})
.def("init_from_string",
[](PredictorClient &self, const std::string &conf_str) {
return self.init_from_string(conf_str);
})
.def("set_predictor_conf",
[](PredictorClient &self,
const std::string &conf_path,
......
......@@ -128,51 +128,8 @@ class Client(object):
os.system('patchelf --set-rpath {} {}'.format(lib_path, client_path))
def load_client_config(self, path):
from .serving_client import PredictorClient
from .serving_client import PredictorRes
model_conf = m_config.GeneralModelConfig()
f = open(path, 'r')
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
# load configuraion here
# get feed vars, fetch vars
# get feed shapes, feed types
# map feed names to index
self.result_handle_ = PredictorRes()
self.client_handle_ = PredictorClient()
self.client_handle_.init(path)
if "FLAGS_max_body_size" not in os.environ:
os.environ["FLAGS_max_body_size"] = str(512 * 1024 * 1024)
read_env_flags = ["profile_client", "profile_server", "max_body_size"]
self.client_handle_.init_gflags([sys.argv[
0]] + ["--tryfromenv=" + ",".join(read_env_flags)])
self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
self.feed_names_to_idx_ = {}
self.fetch_names_to_type_ = {}
self.fetch_names_to_idx_ = {}
self.lod_tensor_set = set()
self.feed_tensor_len = {}
for i, var in enumerate(model_conf.feed_var):
self.feed_names_to_idx_[var.alias_name] = i
self.feed_types_[var.alias_name] = var.feed_type
self.feed_shapes_[var.alias_name] = var.shape
if var.is_lod_tensor:
self.lod_tensor_set.add(var.alias_name)
else:
counter = 1
for dim in self.feed_shapes_[var.alias_name]:
counter *= dim
self.feed_tensor_len[var.alias_name] = counter
for i, var in enumerate(model_conf.fetch_var):
self.fetch_names_to_idx_[var.alias_name] = i
self.fetch_names_to_type_[var.alias_name] = var.fetch_type
if var.is_lod_tensor:
self.lod_tensor_set.add(var.alias_name)
return
# reserve this interface, but do nothing
pass
def add_variant(self, tag, cluster, variant_weight):
if self.predictor_sdk_ is None:
......
......@@ -99,6 +99,16 @@ class OpSeqMaker(object):
def get_op_sequence(self):
workflow_conf = server_sdk.WorkflowConf()
workflow_conf.workflows.extend([self.workflow])
# prepare GetConf workflow
get_conf_workflow = server_sdk.Workflow()
get_conf_workflow.name = "workflow2"
get_conf_workflow.workflow_type = "Sequence"
node = server_sdk.DAGNode()
node.name = "general_get_conf_0"
node.type = "GeneralGetConfOp"
get_conf_workflow.nodes.extend(node)
workflow_conf.workflows.extend([get_conf_workflow])
return workflow_conf
......@@ -117,6 +127,15 @@ class OpGraphMaker(object):
def get_op_graph(self):
workflow_conf = server_sdk.WorkflowConf()
workflow_conf.workflows.extend([self.workflow])
# prepare GetConf workflow
get_conf_workflow = server_sdk.Workflow()
get_conf_workflow.name = "workflow2"
get_conf_workflow.workflow_type = "Sequence"
node = server_sdk.DAGNode()
node.name = "general_get_conf_0"
node.type = "GeneralGetConfOp"
get_conf_workflow.nodes.extend(node)
workflow_conf.workflows.extend([get_conf_workflow])
return workflow_conf
......@@ -210,8 +229,22 @@ class Server(object):
self.infer_service_conf = server_sdk.InferServiceConf()
self.infer_service_conf.port = port
infer_service = server_sdk.InferService()
infer_service.enable_map_request_to_workflow = True
infer_service.request_field_key = "request_type"
kv_workflow1 = server_sdk.ValueMappedWorkflow()
kv_workflow1.request_field_value = "GetConf"
kv_workflow1.workflows = "workflow1"
infer_service.value_mapped_workflows.extend([kv_workflow1])
kv_workflow2 = server_sdk.ValueMappedWorkflow()
kv_workflow2.request_field_value = "Predict"
kv_workflow2.workflows = "workflow2"
infer_service.value_mapped_workflows.extend([kv_workflow2])
infer_service.name = "GeneralModelService"
infer_service.workflows.extend(["workflow1"])
self.infer_service_conf.services.extend([infer_service])
def _prepare_resource(self, workdir):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册