提交 743892b6 编写于 作者: B barrierye

add gpu part && test=serving

上级 d7057e40
...@@ -27,6 +27,13 @@ import argparse ...@@ -27,6 +27,13 @@ import argparse
import collections import collections
import fcntl import fcntl
import numpy as np
import grpc
from .proto import multi_lang_general_model_service_pb2
from .proto import multi_lang_general_model_service_pb2_grpc
from multiprocessing import Pool, Process
from concurrent import futures
def serve_args(): def serve_args():
parser = argparse.ArgumentParser("serve") parser = argparse.ArgumentParser("serve")
...@@ -469,3 +476,152 @@ class Server(object): ...@@ -469,3 +476,152 @@ class Server(object):
print(command) print(command)
os.system(command) os.system(command)
class MultiLangServerService(
multi_lang_general_model_service_pb2_grpc.MultiLangGeneralModelService):
def __init__(self, model_config_path, endpoints):
from paddle_serving_client import Client
self._parse_model_config(model_config_path)
self.bclient_ = Client()
self.bclient_.load_client_config(
"{}/serving_server_conf.prototxt".format(model_config_path))
self.bclient_.connect(endpoints)
def _parse_model_config(self, model_config_path):
model_conf = m_config.GeneralModelConfig()
f = open("{}/serving_server_conf.prototxt".format(model_config_path),
'r')
model_conf = google.protobuf.text_format.Merge(
str(f.read()), model_conf)
self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
self.feed_types_ = {}
self.feed_shapes_ = {}
self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
self.fetch_types_ = {}
self.type_map_ = {0: "int64", 1: "float32"}
self.lod_tensor_set_ = set()
for i, var in enumerate(model_conf.feed_var):
self.feed_types_[var.alias_name] = var.feed_type
self.feed_shapes_[var.alias_name] = var.shape
if var.is_lod_tensor:
self.lod_tensor_set_.add(var.alias_name)
for i, var in enumerate(model_conf.fetch_var):
self.fetch_types_[var.alias_name] = var.fetch_type
if var.is_lod_tensor:
self.lod_tensor_set_.add(var.alias_name)
def _flatten_list(self, nested_list):
for item in nested_list:
if isinstance(item, (list, tuple)):
for sub_item in self._flatten_list(item):
yield sub_item
else:
yield item
def _unpack_request(self, request):
feed_names = list(request.feed_var_names)
fetch_names = list(request.fetch_var_names)
feed_batch = []
for feed_inst in request.insts:
feed_dict = {}
for idx, name in enumerate(feed_names):
v_type = self.feed_types_[name]
data = None
if v_type == 0: # int64
data = np.array(
list(feed_inst.tensor_array[idx].int64_data),
dtype="int64")
elif v_type == 1: # float32
data = np.array(
list(feed_inst.tensor_array[idx].float_data),
dtype="float")
else:
raise Exception("error type.")
shape = list(feed_inst.tensor_array[idx].shape)
data.shape = shape
feed_dict[name] = data
feed_batch.append(feed_dict)
return feed_batch, fetch_names
def _pack_resp_package(self, result, fetch_names, tag):
resp = multi_lang_general_model_service_pb2.Response()
# Only one model is supported temporarily
model_output = multi_lang_general_model_service_pb2.ModelOutput()
inst = multi_lang_general_model_service_pb2.FetchInst()
for idx, name in enumerate(fetch_names):
# model_output.fetch_var_names.append(name)
tensor = multi_lang_general_model_service_pb2.Tensor()
v_type = self.fetch_types_[name]
if v_type == 0: # int64
tensor.int64_data.extend(result[name].reshape(-1).tolist())
elif v_type == 1: # float32
tensor.float_data.extend(result[name].reshape(-1).tolist())
else:
raise Exception("error type.")
tensor.shape.extend(list(result[name].shape))
if name in self.lod_tensor_set_:
tensor.lod.extend(result["{}.lod".format(name)].tolist())
inst.tensor_array.append(tensor)
model_output.insts.append(inst)
resp.outputs.append(model_output)
resp.tag = tag
return resp
def inference(self, request, context):
feed_dict, fetch_names = self._unpack_request(request)
data, tag = self.bclient_.predict(
feed=feed_dict, fetch=fetch_names, need_variant_tag=True)
return self._pack_resp_package(data, fetch_names, tag)
class MultiLangServer(object):
def __init__(self, worker_num=2):
self.bserver_ = Server()
self.worker_num_ = worker_num
def set_op_sequence(self, op_seq):
self.bserver_.set_op_sequence(op_seq)
def load_model_config(self, model_config_path):
if not isinstance(model_config_path, str):
raise Exception(
"MultiLangServer only supports multi-model temporarily")
self.bserver_.load_model_config(model_config_path)
self.model_config_path_ = model_config_path
def prepare_server(self, workdir=None, port=9292, device="cpu"):
default_port = 12000
self.port_list_ = []
for i in range(1000):
if default_port + i != port and self._port_is_available(default_port
+ i):
self.port_list_.append(default_port + i)
break
self.bserver_.prepare_server(
workdir=workdir, port=self.port_list_[0], device=device)
self.gport_ = port
def _launch_brpc_service(self, bserver):
bserver.run_server()
def _port_is_available(self, port):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
sock.settimeout(2)
result = sock.connect_ex(('0.0.0.0', port))
return result != 0
def run_server(self):
p_bserver = Process(
target=self._launch_brpc_service, args=(self.bserver_, ))
p_bserver.start()
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=self.worker_num_))
multi_lang_general_model_service_pb2_grpc.add_MultiLangGeneralModelServiceServicer_to_server(
MultiLangServerService(self.model_config_path_,
["0.0.0.0:{}".format(self.port_list_[0])]),
server)
server.add_insecure_port('[::]:{}'.format(self.gport_))
server.start()
p_bserver.join()
server.wait_for_termination()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册