diff --git a/python/examples/bert/batching.py b/python/examples/bert/batching.py new file mode 100644 index 0000000000000000000000000000000000000000..5ec5f320cf5ec7bd0ab4624d9b39ef936553c774 --- /dev/null +++ b/python/examples/bert/batching.py @@ -0,0 +1,126 @@ +#coding:utf-8 +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Mask, padding and batching.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + + +def prepare_batch_data(insts, + total_token_num, + max_seq_len=128, + pad_id=None, + cls_id=None, + sep_id=None, + mask_id=None, + return_input_mask=True, + return_max_len=True, + return_num_token=False): + """ + 1. generate Tensor of data + 2. generate Tensor of position + 3. generate self attention mask, [shape: batch_size * max_len * max_len] + """ + + batch_src_ids = [inst[0] for inst in insts] + batch_sent_ids = [inst[1] for inst in insts] + batch_pos_ids = [inst[2] for inst in insts] + labels_list = [] + # compatible with squad, whose example includes start/end positions, + # or unique id + + for i in range(3, len(insts[0]), 1): + labels = [inst[i] for inst in insts] + labels = np.array(labels).astype("int64").reshape([-1, 1]) + labels_list.append(labels) + + out = batch_src_ids + # Second step: padding + src_id, self_input_mask = pad_batch_data( + out, pad_idx=pad_id, max_seq_len=max_seq_len, return_input_mask=True) + pos_id = pad_batch_data( + batch_pos_ids, + pad_idx=pad_id, + max_seq_len=max_seq_len, + return_pos=False, + return_input_mask=False) + sent_id = pad_batch_data( + batch_sent_ids, + pad_idx=pad_id, + max_seq_len=max_seq_len, + return_pos=False, + return_input_mask=False) + + return_list = [src_id, pos_id, sent_id, self_input_mask] + labels_list + + return return_list if len(return_list) > 1 else return_list[0] + + +def pad_batch_data(insts, + pad_idx=0, + max_seq_len=128, + return_pos=False, + return_input_mask=False, + return_max_len=False, + return_num_token=False, + return_seq_lens=False): + """ + Pad the instances to the max sequence length in batch, and generate the + corresponding position data and input mask. + """ + return_list = [] + #max_len = max(len(inst) for inst in insts) + max_len = max_seq_len + # Any token included in dict can be used to pad, since the paddings' loss + # will be masked out by weights and make no effect on parameter gradients. + + inst_data = np.array([ + list(inst) + list([pad_idx] * (max_len - len(inst))) for inst in insts + ]) + return_list += [inst_data.astype("int64").reshape([-1, max_len, 1])] + + # position data + if return_pos: + inst_pos = np.array([ + list(range(0, len(inst))) + [pad_idx] * (max_len - len(inst)) + for inst in insts + ]) + + return_list += [inst_pos.astype("int64").reshape([-1, max_len, 1])] + + if return_input_mask: + # This is used to avoid attention on paddings. + input_mask_data = np.array( + [[1] * len(inst) + [0] * (max_len - len(inst)) for inst in insts]) + input_mask_data = np.expand_dims(input_mask_data, axis=-1) + return_list += [input_mask_data.astype("float32")] + + if return_max_len: + return_list += [max_len] + + if return_num_token: + num_token = 0 + for inst in insts: + num_token += len(inst) + return_list += [num_token] + + if return_seq_lens: + seq_lens = np.array([len(inst) for inst in insts]) + return_list += [seq_lens.astype("int64").reshape([-1, 1])] + + return return_list if len(return_list) > 1 else return_list[0] diff --git a/python/examples/bert/benchmark.py b/python/examples/bert/benchmark.py index 9e8af21cc622986e0cf8233a827341030900ff6b..dd11f1248a4a1bc80882039b1ea0dfaeb9a29079 100644 --- a/python/examples/bert/benchmark.py +++ b/python/examples/bert/benchmark.py @@ -1,3 +1,5 @@ +# -*- coding: utf-8 -*- +# # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,54 +14,53 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import unicode_literals, absolute_import +import os import sys +import time from paddle_serving_client import Client -from paddle_serving_client.metric import auc from paddle_serving_client.utils import MultiThreadRunner -import time -from bert_client import BertService +from paddle_serving_client.utils import benchmark_args +from batching import pad_batch_data +import tokenization +import requests +import json +from bert_reader import BertReader +args = benchmark_args() -def predict(thr_id, resource): - bc = BertService( - model_name="bert_chinese_L-12_H-768_A-12", - max_seq_len=20, - do_lower_case=True) - bc.load_client(resource["conf_file"], resource["server_endpoint"]) - thread_num = resource["thread_num"] - file_list = resource["filelist"] - line_id = 0 - result = [] - label_list = [] - dataset = [] - for fn in file_list: - fin = open(fn) +def single_func(idx, resource): + fin = open("data-c.txt") + if args.request == "rpc": + reader = BertReader(vocab_file="vocab.txt", max_seq_len=20) + config_file = './serving_client_conf/serving_client_conf.prototxt' + fetch = ["pooled_output"] + client = Client() + client.load_client_config(args.model) + client.connect([resource["endpoint"][idx % 4]]) + + start = time.time() for line in fin: - if line_id % thread_num == thr_id - 1: - dataset.append(line.strip()) - line_id += 1 - fin.close() - - start = time.time() - fetch = ["pooled_output"] - for inst in dataset: - fetch_map = bc.run_general([[inst]], fetch) - result.append(fetch_map["pooled_output"]) - end = time.time() - return [result, label_list, [end - start]] - + feed_dict = reader.process(line) + result = client.predict(feed=feed_dict, + fetch=fetch) + end = time.time() + elif args.request == "http": + start = time.time() + header = {"Content-Type":"application/json"} + for line in fin: + #dict_data = {"words": "this is for output ", "fetch": ["pooled_output"]} + dict_data = {"words": line, "fetch": ["pooled_output"]} + r = requests.post('http://{}/bert/prediction'.format(resource["endpoint"][0]), + data=json.dumps(dict_data), headers=header) + end = time.time() + return [[end - start]] if __name__ == '__main__': - conf_file = sys.argv[1] - data_file = sys.argv[2] - thread_num = sys.argv[3] - resource = {} - resource["conf_file"] = conf_file - resource["server_endpoint"] = ["127.0.0.1:9292"] - resource["filelist"] = [data_file] - resource["thread_num"] = int(thread_num) - - thread_runner = MultiThreadRunner() - result = thread_runner.run(predict, int(sys.argv[3]), resource) + multi_thread_runner = MultiThreadRunner() + endpoint_list = ["127.0.0.1:9494", "127.0.0.1:9495", "127.0.0.1:9496", "127.0.0.1:9497"] + #endpoint_list = endpoint_list + endpoint_list + endpoint_list + #result = multi_thread_runner.run(single_func, args.thread, {"endpoint":endpoint_list}) + result = single_func(0, {"endpoint":endpoint_list}) + print(result) - print("total time {} s".format(sum(result[-1]) / len(result[-1]))) diff --git a/python/examples/bert/bert_client.py b/python/examples/bert/bert_client.py index 7ac7f4fd8676e853547aa05b99222820c764b5ab..cdb75ff8444fce749c99c7598cac5b27b67fbe78 100644 --- a/python/examples/bert/bert_client.py +++ b/python/examples/bert/bert_client.py @@ -9,6 +9,9 @@ import time from paddlehub.common.logger import logger import socket from paddle_serving_client import Client +from paddle_serving_client.utils import MultiThreadRunner +from paddle_serving_client.utils import benchmark_args +args = benchmark_args() _ver = sys.version_info is_py2 = (_ver[0] == 2) @@ -122,36 +125,27 @@ class BertService(): return fetch_map_batch -def test(): +def single_func(idx, resource): bc = BertService( model_name='bert_chinese_L-12_H-768_A-12', max_seq_len=20, show_ids=False, do_lower_case=True) - server_addr = ["127.0.0.1:9292"] config_file = './serving_client_conf/serving_client_conf.prototxt' fetch = ["pooled_output"] + server_addr = [resource["endpoint"][idx]] bc.load_client(config_file, server_addr) batch_size = 1 - batch = [] - for line in sys.stdin: - if batch_size == 1: - result = bc.run_general([[line.strip()]], fetch) - print(result) - else: - if len(batch) < batch_size: - batch.append([line.strip()]) - else: - result = bc.run_batch_general(batch, fetch) - batch = [] - for r in result: - print(r) - if len(batch) > 0: - result = bc.run_batch_general(batch, fetch) - batch = [] - for r in result: - print(r) - + start = time.time() + fin = open("data-c.txt") + for line in fin: + result = bc.run_general([[line.strip()]], fetch) + end = time.time() + return [[end - start]] if __name__ == '__main__': - test() + multi_thread_runner = MultiThreadRunner() + result = multi_thread_runner.run(single_func, args.thread, {"endpoint":["127.0.0.1:9494", "127.0.0.1:9495", "127.0.0.1:9496", "127.0.0.1:9497"]}) + + + diff --git a/python/examples/bert/bert_reader.py b/python/examples/bert/bert_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..52bb6ebeb2d291087e9528c7a722eb2d2a77d27c --- /dev/null +++ b/python/examples/bert/bert_reader.py @@ -0,0 +1,55 @@ +from batching import pad_batch_data +import tokenization + +class BertReader(): + def __init__(self, vocab_file="", max_seq_len=128): + self.vocab_file = vocab_file + self.tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file) + self.max_seq_len = max_seq_len + self.vocab = self.tokenizer.vocab + self.pad_id = self.vocab["[PAD]"] + self.cls_id = self.vocab["[CLS]"] + self.sep_id = self.vocab["[SEP]"] + self.mask_id = self.vocab["[MASK]"] + + def pad_batch(self, token_ids, text_type_ids, position_ids): + batch_token_ids = [token_ids] + batch_text_type_ids = [text_type_ids] + batch_position_ids = [position_ids] + + padded_token_ids, input_mask = pad_batch_data( + batch_token_ids, + max_seq_len=self.max_seq_len, + pad_idx=self.pad_id, + return_input_mask=True) + padded_text_type_ids = pad_batch_data( + batch_text_type_ids, + max_seq_len=self.max_seq_len, + pad_idx=self.pad_id) + padded_position_ids = pad_batch_data( + batch_position_ids, + max_seq_len=self.max_seq_len, + pad_idx=self.pad_id) + return padded_token_ids, padded_position_ids, padded_text_type_ids, input_mask + + def process(self, sent): + text_a = tokenization.convert_to_unicode(sent) + tokens_a = self.tokenizer.tokenize(text_a) + if len(tokens_a) > self.max_seq_len - 2: + tokens_a = tokens_a[0:(self.max_seq_len - 2)] + tokens = [] + text_type_ids = [] + tokens.append("[CLS]") + text_type_ids.append(0) + for token in tokens_a: + tokens.append(token) + text_type_ids.append(0) + token_ids = self.tokenizer.convert_tokens_to_ids(tokens) + position_ids = list(range(len(token_ids))) + p_token_ids, p_pos_ids, p_text_type_ids, input_mask = \ + self.pad_batch(token_ids, text_type_ids, position_ids) + feed_result = {"input_ids": p_token_ids.reshape(-1).tolist(), + "position_ids": p_pos_ids.reshape(-1).tolist(), + "segment_ids": p_text_type_ids.reshape(-1).tolist(), + "input_mask": input_mask.reshape(-1).tolist()} + return feed_result diff --git a/python/examples/bert/bert_web_service.py b/python/examples/bert/bert_web_service.py new file mode 100644 index 0000000000000000000000000000000000000000..2b52d2fbe1e5cabe993b1ac40261eda8c51fd561 --- /dev/null +++ b/python/examples/bert/bert_web_service.py @@ -0,0 +1,37 @@ +# coding=utf-8 +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle_serving_server_gpu.web_service import WebService +from bert_reader import BertReader +import sys +import os + +class BertService(WebService): + def load(self): + self.reader = BertReader(vocab_file="vocab.txt", max_seq_len=20) + + def preprocess(self, feed={}, fetch=[]): + feed_res = self.reader.process(feed["words"].encode("utf-8")) + return feed_res, fetch + +bert_service = BertService(name="bert") +bert_service.load() +bert_service.load_model_config(sys.argv[1]) +gpu_ids = os.environ["CUDA_VISIBLE_DEVICES"] +gpus = [int(x) for x in gpu_ids.split(",")] +bert_service.set_gpus(gpus) +bert_service.prepare_server( + workdir="workdir", port=9494, device="gpu") +bert_service.run_server() diff --git a/python/examples/bert/prepare_model.py b/python/examples/bert/prepare_model.py index b5f80a78feb07a0617bec3833bbe1cf3884d7dea..5a7c5893d9dfc9473c5fabb8f111f6798a6a20d3 100644 --- a/python/examples/bert/prepare_model.py +++ b/python/examples/bert/prepare_model.py @@ -14,11 +14,12 @@ import paddlehub as hub import paddle.fluid as fluid +import sys import paddle_serving_client.io as serving_io model_name = "bert_chinese_L-12_H-768_A-12" module = hub.Module(model_name) -inputs, outputs, program = module.context(trainable=True, max_seq_len=20) +inputs, outputs, program = module.context(trainable=True, max_seq_len=int(sys.argv[1])) place = fluid.core_avx.CPUPlace() exe = fluid.Executor(place) input_ids = inputs["input_ids"] @@ -34,7 +35,7 @@ feed_var_names = [ target_vars = [pooled_output, sequence_output] -serving_io.save_model("serving_server_model", "serving_client_conf", { +serving_io.save_model("bert_seq{}_model".format(sys.argv[1]), "bert_seq{}_client".format(sys.argv[1]), { "input_ids": input_ids, "position_ids": position_ids, "segment_ids": segment_ids, diff --git a/python/examples/imdb/benchmark.py b/python/examples/imdb/benchmark.py index 1254ed21fd8ff30acdb9e8192b26b7918da315bc..e2e45f446f27547bff3a5edae6edbc69305fc021 100644 --- a/python/examples/imdb/benchmark.py +++ b/python/examples/imdb/benchmark.py @@ -48,7 +48,7 @@ def single_func(idx, resource): for line in fin: word_ids, label = imdb_dataset.get_words_and_label(line) r = requests.post("http://{}/imdb/prediction".format(args.endpoint), - data={"words": word_ids}) + data={"words": word_ids, "fetch": ["prediction"]}) end = time.time() return [[end - start]] diff --git a/python/paddle_serving_server/web_service.py b/python/paddle_serving_server/web_service.py index 00c060f1b207a2de3f45a46483819ef4ff9aa5c9..d54dc7765656f726397a6553498b4d6da72e59dd 100755 --- a/python/paddle_serving_server/web_service.py +++ b/python/paddle_serving_server/web_service.py @@ -60,6 +60,8 @@ class WebService(object): if "fetch" not in request.json: abort(400) feed, fetch = self.preprocess(request.json, request.json["fetch"]) + if "fetch" in feed: + del feed["fetch"] fetch_map = client_service.predict(feed=feed, fetch=fetch) fetch_map = self.postprocess(feed=request.json, fetch=fetch, fetch_map=fetch_map) return fetch_map diff --git a/python/paddle_serving_server_gpu/__init__.py b/python/paddle_serving_server_gpu/__init__.py index 4372884a290bb98faf4416c2cb0e872bc4c76003..d3bed0eed7d4912d8dd5b1cabee1c387fc055816 100644 --- a/python/paddle_serving_server_gpu/__init__.py +++ b/python/paddle_serving_server_gpu/__init__.py @@ -22,6 +22,26 @@ import paddle_serving_server_gpu as paddle_serving_server from version import serving_server_version from contextlib import closing +def serve_args(): + parser = argparse.ArgumentParser("serve") + parser.add_argument( + "--thread", type=int, default=10, help="Concurrency of server") + parser.add_argument( + "--model", type=str, default="", help="Model for serving") + parser.add_argument( + "--port", type=int, default=9292, help="Port of the starting gpu") + parser.add_argument( + "--workdir", + type=str, + default="workdir", + help="Working dir of current service") + parser.add_argument( + "--device", type=str, default="gpu", help="Type of device") + parser.add_argument( + "--gpu_ids", type=str, default="", help="gpu ids") + parser.add_argument( + "--name", type=str, default="default", help="Default service name") + return parser.parse_args() class OpMaker(object): def __init__(self): @@ -126,7 +146,8 @@ class Server(object): self.model_config_path = model_config_path self.engine.name = "general_model" - self.engine.reloadable_meta = model_config_path + "/fluid_time_file" + #self.engine.reloadable_meta = model_config_path + "/fluid_time_file" + self.engine.reloadable_meta = self.workdir + "/fluid_time_file" os.system("touch {}".format(self.engine.reloadable_meta)) self.engine.reloadable_type = "timestamp_ne" self.engine.runtime_thread_num = 0 @@ -154,6 +175,7 @@ class Server(object): self.infer_service_conf.services.extend([infer_service]) def _prepare_resource(self, workdir): + self.workdir = workdir if self.resource_conf == None: with open("{}/{}".format(workdir, self.general_model_config_fn), "w") as fout: @@ -217,6 +239,7 @@ class Server(object): if not self.check_port(port): raise SystemExit("Prot {} is already used".format(port)) + self.set_port(port) self._prepare_resource(workdir) self._prepare_engine(self.model_config_path, device) self._prepare_infer_service(port) diff --git a/python/paddle_serving_server_gpu/serve.py b/python/paddle_serving_server_gpu/serve.py index 1a88797b285c0b168e52e54755da3b7ea5bad434..00da8784cccae3e8c937a4168cd31e03d46b739f 100644 --- a/python/paddle_serving_server_gpu/serve.py +++ b/python/paddle_serving_server_gpu/serve.py @@ -18,35 +18,21 @@ Usage: python -m paddle_serving_server.serve --model ./serving_server_model --port 9292 """ import argparse +from multiprocessing import Pool, Process +from paddle_serving_server_gpu import serve_args -def parse_args(): - parser = argparse.ArgumentParser("serve") - parser.add_argument( - "--thread", type=int, default=10, help="Concurrency of server") - parser.add_argument( - "--model", type=str, default="", help="Model for serving") - parser.add_argument( - "--port", type=int, default=9292, help="Port the server") - parser.add_argument( - "--workdir", - type=str, - default="workdir", - help="Working dir of current service") - parser.add_argument( - "--device", type=str, default="gpu", help="Type of device") - parser.add_argument("--gpuid", type=int, default=0, help="Index of GPU") - return parser.parse_args() - - -def start_standard_model(): - args = parse_args() +def start_gpu_card_model(gpuid, args): + gpuid = int(gpuid) + device = "gpu" + port = args.port + if gpuid == -1: + device = "cpu" + elif gpuid >= 0: + port = args.port + gpuid thread_num = args.thread model = args.model - port = args.port - workdir = args.workdir - device = args.device - gpuid = args.gpuid + workdir = "{}_{}".format(args.workdir, gpuid) if model == "": print("You must specify your serving model") @@ -57,7 +43,7 @@ def start_standard_model(): read_op = op_maker.create('general_reader') general_infer_op = op_maker.create('general_infer') general_response_op = op_maker.create('general_response') - + op_seq_maker = serving.OpSeqMaker() op_seq_maker.add_op(read_op) op_seq_maker.add_op(general_infer_op) @@ -69,9 +55,28 @@ def start_standard_model(): server.load_model_config(model) server.prepare_server(workdir=workdir, port=port, device=device) - server.set_gpuid(gpuid) + if gpuid >= 0: + server.set_gpuid(gpuid) server.run_server() - +def start_multi_card(args): + gpus = "" + if args.gpu_ids == "": + gpus = os.environ["CUDA_VISIBLE_DEVICES"] + else: + gpus = args.gpu_ids.split(",") + if len(gpus) <= 0: + start_gpu_card_model(-1) + else: + gpu_processes = [] + for i, gpu_id in enumerate(gpus): + p = Process(target=start_gpu_card_model, args=(i, args, )) + gpu_processes.append(p) + for p in gpu_processes: + p.start() + for p in gpu_processes: + p.join() + if __name__ == "__main__": - start_standard_model() + args = serve_args() + start_multi_card(args) diff --git a/python/paddle_serving_server_gpu/web_serve.py b/python/paddle_serving_server_gpu/web_serve.py index e7b44034797a8de75ca8dc5d97f7dc93c9671954..c270997e228a07bc24a9214b93e8fe3494bcc1c7 100644 --- a/python/paddle_serving_server_gpu/web_serve.py +++ b/python/paddle_serving_server_gpu/web_serve.py @@ -17,35 +17,20 @@ Usage: Example: python -m paddle_serving_server.web_serve --model ./serving_server_model --port 9292 """ -import argparse +import os from multiprocessing import Pool, Process from .web_service import WebService - - -def parse_args(): - parser = argparse.ArgumentParser("web_serve") - parser.add_argument( - "--thread", type=int, default=10, help="Concurrency of server") - parser.add_argument( - "--model", type=str, default="", help="Model for serving") - parser.add_argument( - "--port", type=int, default=9292, help="Port the server") - parser.add_argument( - "--workdir", - type=str, - default="workdir", - help="Working dir of current service") - parser.add_argument( - "--device", type=str, default="cpu", help="Type of device") - parser.add_argument( - "--name", type=str, default="default", help="Default service name") - return parser.parse_args() - +import paddle_serving_server_gpu as serving +from paddle_serving_server_gpu import serve_args if __name__ == "__main__": - args = parse_args() - service = WebService(name=args.name) - service.load_model_config(args.model) - service.prepare_server( + args = serve_args() + web_service = WebService(name=args.name) + web_service.load_model_config(args.model) + if args.gpu_ids == "": + gpu_ids = os.environ["CUDA_VISIBLE_DEVICES"] + gpus = [int(x) for x in gpu_ids.split(",")] + web_service.set_gpus(gpus) + web_service.prepare_server( workdir=args.workdir, port=args.port, device=args.device) service.run_server() diff --git a/python/paddle_serving_server_gpu/web_service.py b/python/paddle_serving_server_gpu/web_service.py index 3f129a45853b02711f96953b0b902015d2f2d3e8..b77d0a698a829eaa3c1e3cd7304219f4285c1a64 100755 --- a/python/paddle_serving_server_gpu/web_service.py +++ b/python/paddle_serving_server_gpu/web_service.py @@ -15,46 +15,82 @@ from flask import Flask, request, abort from multiprocessing import Pool, Process from paddle_serving_server_gpu import OpMaker, OpSeqMaker, Server +import paddle_serving_server_gpu as serving from paddle_serving_client import Client +from .serve import start_multi_card +import time +import random class WebService(object): def __init__(self, name="default_service"): self.name = name + self.gpus = [] + self.rpc_service_list = [] def load_model_config(self, model_config): self.model_config = model_config - def _launch_rpc_service(self): - op_maker = OpMaker() + def set_gpus(self, gpus): + self.gpus = gpus + + def default_rpc_service(self, workdir="conf", port=9292, + gpuid=0, thread_num=10): + device = "gpu" + if gpuid == -1: + device = "cpu" + op_maker = serving.OpMaker() read_op = op_maker.create('general_reader') general_infer_op = op_maker.create('general_infer') general_response_op = op_maker.create('general_response') - op_seq_maker = OpSeqMaker() + + op_seq_maker = serving.OpSeqMaker() op_seq_maker.add_op(read_op) op_seq_maker.add_op(general_infer_op) op_seq_maker.add_op(general_response_op) - server = Server() + + server = serving.Server() server.set_op_sequence(op_seq_maker.get_op_sequence()) - server.set_num_threads(16) - server.set_gpuid = self.gpuid + server.set_num_threads(thread_num) + server.load_model_config(self.model_config) - server.prepare_server( - workdir=self.workdir, port=self.port + 1, device=self.device) - server.run_server() + if gpuid >= 0: + server.set_gpuid(gpuid) + server.prepare_server(workdir=workdir, port=port, device=device) + return server + + def _launch_rpc_service(self, service_idx): + self.rpc_service_list[service_idx].run_server() def prepare_server(self, workdir="", port=9393, device="gpu", gpuid=0): self.workdir = workdir self.port = port self.device = device self.gpuid = gpuid + if len(self.gpus) == 0: + # init cpu service + self.rpc_service_list.append( + self.default_rpc_service(self.workdir, self.port+1, + -1, thread_num=10)) + else: + for i, gpuid in enumerate(self.gpus): + self.rpc_service_list.append( + self.default_rpc_service("{}_{}".format(self.workdir, i), + self.port+1+i, + gpuid, thread_num=10)) - def _launch_web_service(self): + def _launch_web_service(self, gpu_num): app_instance = Flask(__name__) - client_service = Client() - client_service.load_client_config( - "{}/serving_server_conf.prototxt".format(self.model_config)) - client_service.connect(["127.0.0.1:{}".format(self.port + 1)]) + client_list = [] + if gpu_num > 1: + gpu_num = 0 + for i in range(gpu_num): + client_service = Client() + client_service.load_client_config( + "{}/serving_server_conf.prototxt".format(self.model_config)) + client_service.connect(["127.0.0.1:{}".format(self.port + i + 1)]) + client_list.append(client_service) + time.sleep(1) service_name = "/" + self.name + "/prediction" @app_instance.route(service_name, methods=['POST']) @@ -64,7 +100,8 @@ class WebService(object): if "fetch" not in request.json: abort(400) feed, fetch = self.preprocess(request.json, request.json["fetch"]) - fetch_map = client_service.predict(feed=feed, fetch=fetch) + fetch_map = client_list[0].predict( + feed=feed, fetch=fetch) fetch_map = self.postprocess( feed=request.json, fetch=fetch, fetch_map=fetch_map) return fetch_map @@ -80,12 +117,20 @@ class WebService(object): print("web service address:") print("http://{}:{}/{}/prediction".format(localIP, self.port, self.name)) - p_rpc = Process(target=self._launch_rpc_service) - p_web = Process(target=self._launch_web_service) - p_rpc.start() + + rpc_processes = [] + for idx in range(len(self.rpc_service_list)): + p_rpc = Process(target=self._launch_rpc_service, args=(idx,)) + rpc_processes.append(p_rpc) + + for p in rpc_processes: + p.start() + + p_web = Process(target=self._launch_web_service, args=(len(self.gpus),)) p_web.start() + for p in rpc_processes: + p.join() p_web.join() - p_rpc.join() def preprocess(self, feed={}, fetch=[]): return feed, fetch