提交 7309ad63 编写于 作者: B barrierye

change narray to list

上级 6acca6f7
......@@ -392,6 +392,14 @@ class GClient(object):
self.stub_ = gserver_general_model_service_pb2_grpc.GServerGeneralModelServiceStub(
self.channel_)
def _flatten_list(self, nested_list):
for item in nested_list:
if isinstance(item, (list, tuple)):
for sub_item in self._flatten_list(item):
yield sub_item
else:
yield item
def _parse_model_config(self, model_config_path):
model_conf = m_config.GeneralModelConfig()
f = open(model_config_path, 'r')
......@@ -407,24 +415,21 @@ class GClient(object):
for i, var in enumerate(model_conf.feed_var):
self.feed_types_[var.alias_name] = var.feed_type
self.feed_shapes_[var.alias_name] = var.shape
if self.feed_types_[var.alias_name] == 'float':
self.feed_types_[var.alias_name] = 'float32'
if var.is_lod_tensor:
self.lod_tensor_set.add(var.alias_name)
self.lod_tensor_set_.add(var.alias_name)
else:
counter = 1
for dim in self.feed_shapes_[var.alias_name]:
counter *= dim
for i, var in enumerate(model_conf.fetch_var):
self.fetch_types_[var.alias_name] = var.fetch_type
if self.fetch_types_[var.alias_name] == 'float':
self.fetch_types_[var.alias_name] = 'float32'
if var.is_lod_tensor:
self.lod_tensor_set_.add(var.alias_name)
def _pack_feed_data(self, feed, fetch):
req = gserver_general_model_service_pb2.Request()
req.fetch_var_names.extend(fetch)
req.feed_var_names.extend(feed.keys())
feed_batch = None
if isinstance(feed, dict):
feed_batch = [feed]
......@@ -432,39 +437,55 @@ class GClient(object):
feed_batch = feed
else:
raise Exception("{} not support".format(type(feed)))
init_feed_names = False
for feed_data in feed_batch:
inst = gserver_general_model_service_pb2.Inst()
for name, var in feed_data.items():
inst.names.append(name)
itype = self.type_map_[self.feed_types_[name]]
data = np.array(var, dtype=itype)
inst.data.append(data.tobytes())
inst = gserver_general_model_service_pb2.FeedInst()
for name in req.feed_var_names:
tensor = gserver_general_model_service_pb2.Tensor()
var = feed_data[name]
v_type = self.feed_types_[name]
if v_type == 0: # int64
if isinstance(var, np.ndarray):
tensor.int64_data.extend(
self._flatten_list(var.tolist()))
else:
tensor.int64_data.extend(self._flatten_list(var))
elif v_type == 1: # float32
if isinstance(var, np.ndarray):
tensor.float_data.extend(
self._flatten_list(var.tolist()))
else:
tensor.float_data.extend(self._flatten_list(var))
else:
raise Exception("error type.")
if isinstance(var, np.ndarray):
inst.shape.append(
np.array(
list(var.shape), dtype="int32").tobytes())
tensor.shape.extend(list(var.shape))
else:
inst.shape.append(
np.array(
self.feed_shapes_[name], dtype="int32").tobytes())
req.feed_insts.append(inst)
tensor.shape.extend(self.feed_shapes_[name])
inst.tensor_array.append(tensor)
req.insts.append(inst)
return req
def _unpack_resp(self, resp):
def _unpack_resp(self, resp, fetch):
result_map = {}
inst = resp.fetch_insts[0]
for i, name in enumerate(inst.names):
if name not in self.fetch_names_:
continue
itype = self.type_map_[self.fetch_types_[name]]
result_map[name] = np.frombuffer(inst.data[i], dtype=itype)
result_map[name].shape = np.frombuffer(inst.shape[i], dtype="int32")
inst = resp.outputs[0].insts[0]
tag = resp.tag
for i, name in enumerate(fetch):
var = inst.tensor_array[i]
v_type = self.fetch_types_[name]
if v_type == 0: # int64
result_map[name] = np.array(list(var.int64_data))
elif v_type == 1: # flot32
result_map[name] = np.array(list(var.float_data))
else:
raise Exception("error type.")
result_map[name].shape = list(var.shape)
if name in self.lod_tensor_set_:
result_map["{}.lod".format(name)] = np.frombuffer(
inst.lod[i], dtype="int32")
return result_map
result_map["{}.lod".format(name)] = np.array(list(var.lod))
return result_map, tag
def predict(self, feed, fetch):
def predict(self, feed, fetch, need_variant_tag=False):
req = self._pack_feed_data(feed, fetch)
resp = self.stub_.inference(req)
return self._unpack_resp(resp)
result_map, tag = self._unpack_resp(resp, fetch)
return result_map if not need_variant_tag else [result_map, tag]
......@@ -31,6 +31,7 @@ import gserver_general_model_service_pb2
import gserver_general_model_service_pb2_grpc
from multiprocessing import Pool, Process
from concurrent import futures
import itertools
class OpMaker(object):
......@@ -463,52 +464,77 @@ class GServerService(
for i, var in enumerate(model_conf.feed_var):
self.feed_types_[var.alias_name] = var.feed_type
self.feed_shapes_[var.alias_name] = var.shape
if self.feed_types_[var.alias_name] == 'float':
self.feed_types_[var.alias_name] = 'float32'
if var.is_lod_tensor:
self.lod_tensor_set_.add(var.alias_name)
for i, var in enumerate(model_conf.fetch_var):
self.fetch_types_[var.alias_name] = var.fetch_type
if self.fetch_types_[var.alias_name] == 'float':
self.fetch_types_[var.alias_name] = 'float32'
if var.is_lod_tensor:
self.lod_tensor_set_.add(var.alias_name)
def _flatten_list(self, nested_list):
for item in nested_list:
if isinstance(item, (list, tuple)):
for sub_item in self._flatten_list(item):
yield sub_item
else:
yield item
def _unpack_request(self, request):
feed_names = list(request.feed_var_names)
fetch_names = list(request.fetch_var_names)
feed_batch = []
for feed_inst in request.feed_insts:
for feed_inst in request.insts:
feed_dict = {}
for idx, name in enumerate(feed_inst.names):
data = feed_inst.data[idx]
shape = feed_inst.shape[idx]
itype = self.type_map_[self.feed_types_[name]]
feed_dict[name] = np.frombuffer(data, dtype=itype)
feed_dict[name].shape = np.frombuffer(shape, dtype="int32")
for idx, name in enumerate(feed_names):
v_type = self.feed_types_[name]
data = None
if v_type == 0: # int64
data = np.array(
list(feed_inst.tensor_array[idx].int64_data),
dtype="int64")
elif v_type == 1: # float32
data = np.array(
list(feed_inst.tensor_array[idx].float_data),
dtype="float")
else:
raise Exception("error type.")
shape = list(feed_inst.tensor_array[idx].shape)
data.shape = shape
feed_dict[name] = data
feed_batch.append(feed_dict)
return feed_batch, fetch_names
def _pack_resp_package(self, result, fetch_names):
def _pack_resp_package(self, result, fetch_names, tag):
resp = gserver_general_model_service_pb2.Response()
inst = gserver_general_model_service_pb2.Inst()
for name in fetch_names:
inst.names.append(name)
inst.data.append(result[name].tobytes())
inst.shape.append(
np.array(
result[name].shape, dtype="int32").tobytes())
if name in self.lod_tensor_set_:
inst.lod.append(result["{}.lod".format(name)].tobytes())
# Only one model is supported temporarily
model_output = gserver_general_model_service_pb2.ModelOutput()
inst = gserver_general_model_service_pb2.FetchInst()
for idx, name in enumerate(fetch_names):
# model_output.fetch_var_names.append(name)
tensor = gserver_general_model_service_pb2.Tensor()
v_type = self.fetch_types_[name]
if v_type == 0: # int64
tensor.int64_data.extend(
self._flatten_list(result[name].tolist()))
elif v_type == 1: # float32
tensor.float_data.extend(
self._flatten_list(result[name].tolist()))
else:
# TODO
inst.lod.append(bytes(0))
resp.fetch_insts.append(inst)
raise Exception("error type.")
tensor.shape.extend(list(result[name].shape))
if name in self.lod_tensor_set_:
tensor.lod.extend(result["{}.lod".format(name)].tolist())
inst.tensor_array.append(tensor)
model_output.insts.append(inst)
resp.outputs.append(model_output)
resp.tag = tag
return resp
def inference(self, request, context):
feed_dict, fetch_names = self._unpack_request(request)
data = self.bclient_.predict(feed=feed_dict, fetch=fetch_names)
return self._pack_resp_package(data, fetch_names)
data, tag = self.bclient_.predict(
feed=feed_dict, fetch=fetch_names, need_variant_tag=True)
return self._pack_resp_package(data, fetch_names, tag)
class GServer(object):
......
......@@ -14,19 +14,35 @@
syntax = "proto2";
message Inst {
message Tensor {
repeated bytes data = 1;
repeated string names = 2;
repeated bytes lod = 3;
repeated bytes shape = 4;
}
repeated int32 int_data = 2;
repeated int64 int64_data = 3;
repeated float float_data = 4;
optional int32 elem_type = 5;
repeated int32 shape = 6;
repeated int32 lod = 7; // only for fetch tensor currently
};
message FeedInst { repeated Tensor tensor_array = 1; };
message FetchInst { repeated Tensor tensor_array = 1; };
message Request {
repeated Inst feed_insts = 1;
repeated string fetch_var_names = 2;
repeated FeedInst insts = 1;
repeated string feed_var_names = 2;
repeated string fetch_var_names = 3;
};
message Response { repeated Inst fetch_insts = 1; }
message Response {
repeated ModelOutput outputs = 1;
optional string tag = 2;
};
message ModelOutput {
repeated FetchInst insts = 1;
optional string engine_name = 2;
}
service GServerGeneralModelService {
rpc inference(Request) returns (Response) {}
......
......@@ -37,7 +37,7 @@ def python_version():
max_version, mid_version, min_version = python_version()
REQUIRED_PACKAGES = [
'six >= 1.10.0', 'protobuf >= 3.1.0',
'six >= 1.10.0', 'protobuf >= 3.1.0', 'grpcio >= 1.28.1',
'paddle_serving_client', 'flask >= 1.1.1', 'paddle_serving_app'
]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册