diff --git a/.travis.yml b/.travis.yml index d9e6958a21adef675bb586a9561188004bcd5b48..b9308b0258039a3d6b6c69183b0e37d3cfec48ff 100644 --- a/.travis.yml +++ b/.travis.yml @@ -16,12 +16,6 @@ jobs: os: linux python: 3.6 script: /bin/bash ./scripts/check_code_style.sh - - name: "CI on Linux/Python3.5" - os: linux - python: 3.5 - - name: "CI on Linux/Python2.7" - os: linux - python: 2.7 env: - PYTHONPATH=${PWD} @@ -30,10 +24,6 @@ install: - pip install --upgrade paddlepaddle - pip install -r requirements.txt -script: - - if [[ $TRAVIS_OS_NAME != "windows" ]]; then /bin/bash ./scripts/test_cml.sh; fi - - if [[ $TRAVIS_OS_NAME != "windows" ]]; then /bin/bash ./scripts/test_all_module.sh; fi - notifications: email: on_success: change diff --git a/demo/serving/bert_service/README.md b/demo/serving/bert_service/README.md index 8c8d52d0372032c3a5c92b34f63e48c0193bba66..57c769e1bd159245dc55c023dffd968ec4f70335 100644 --- a/demo/serving/bert_service/README.md +++ b/demo/serving/bert_service/README.md @@ -68,7 +68,7 @@ $ pip install ujson |模型|网络| |:-|:-:| -|[ERNIE](https://paddlepaddle.org.cn/hubdetail?name=ERNIE&en_category=SemanticModel)|ERNIE| +|[ernie](https://paddlepaddle.org.cn/hubdetail?name=ERNIE&en_category=SemanticModel)|ERNIE| |[ernie_tiny](https://paddlepaddle.org.cn/hubdetail?name=ernie_tiny&en_category=SemanticModel)|ERNIE| |[ernie_v2_eng_large](https://paddlepaddle.org.cn/hubdetail?name=ernie_v2_eng_large&en_category=SemanticModel)|ERNIE| |[ernie_v2_eng_base](https://paddlepaddle.org.cn/hubdetail?name=ernie_v2_eng_base&en_category=SemanticModel)|ERNIE| @@ -179,18 +179,22 @@ Server[baidu::paddle_serving::predictor::bert_service::BertServiceImpl] is servi 首先导入客户端依赖。 ```python -from paddlehub.serving.bert_serving import bert_service +from paddlehub.serving.bert_serving import bs_client ``` -接着输入文本信息。 + +接着启动并初始化`bert service`客户端`BSClient`(这里的server为虚拟地址,需根据自己实际ip设置) +```python +bc = bs_client.BSClient(module_name="ernie_tiny", server="127.0.0.1:8866") +``` + +然后输入文本信息。 ```python input_text = [["西风吹老洞庭波"], ["一夜湘君白发多"], ["醉后不知天在水"], ["满船清梦压星河"], ] ``` -然后利用客户端接口发送文本到服务端,以获取embedding结果(server为虚拟地址,需根据自己实际ip设置)。 + +最后利用客户端接口`get_result`发送文本到服务端,以获取embedding结果。 ```python -result = bert_service.connect( - input_text=input_text, - model_name="ernie_tiny", - server="127.0.0.1:8866") +result = bc.get_result(input_text=input_text) ``` 最后即可得到embedding结果(此处只展示部分结果)。 ```python @@ -221,16 +225,16 @@ Paddle Inference Server exit successfully! > Q : 如何在一台服务器部署多个模型? > A : 可通过多次启动`Bert Service`,分配不同端口实现。如果使用GPU,需要指定不同的显卡。如同时部署`ernie`和`bert_chinese_L-12_H-768_A-12`,分别执行命令如下: > ```shell -> $ hub serving start bert_serving -m ernie -p 8866 -> $ hub serving start bert_serving -m bert_serving -m bert_chinese_L-12_H-768_A-12 -p 8867 +> $ hub serving start bert_service -m ernie -p 8866 +> $ hub serving start bert_service -m bert_chinese_L-12_H-768_A-12 -p 8867 > ``` > Q : 启动时显示"Check out http://yq01-gpu-255-129-12-00.epc.baidu.com:8887 in web browser.",这个页面有什么作用。 > A : 这是`BRPC`的内置服务,主要用于查看请求数、资源占用等信息,可对server端性能有大致了解,具体信息可查看[BRPC内置服务](https://github.com/apache/incubator-brpc/blob/master/docs/cn/builtin_service.md)。 -> Q : 为什么输入文本的格式为[["文本1"], ["文本2"], ],而不是["文本1", "文本2", ]? -> A : 因为Bert模型可以对一轮对话生成向量表示,例如[["问题1","回答1"],["问题2","回答2"]],为了防止使用时混乱,每个样本使用一个list表示,一个样本list内部可以是1条string或2条string,如下面的文本: +> Q : 为什么输入文本的格式为[["文本1"], ["文本2"], ],而不是["文本1", "文本2", ]? +> A : 因为Bert模型可以对一轮对话生成向量表示,例如[["问题1","回答1"],["问题2","回答2"]],为了防止使用时混乱,每个样本使用一个list表示,一个样本list内部可以是1条string或2条string,如下面的文本: > ```python > input_text = [ > ["你今天吃饭了吗","我已经吃过饭了"], diff --git a/demo/serving/bert_service/bert_service_client.py b/demo/serving/bert_service/bert_service_client.py index a8c2533641301bc0659699cd54a7e15fcce9bda3..a7a02183ea7279707070c380ec909b82de0ea0db 100644 --- a/demo/serving/bert_service/bert_service_client.py +++ b/demo/serving/bert_service/bert_service_client.py @@ -1,7 +1,10 @@ # coding: utf8 -from paddlehub.serving.bert_serving import bert_service +from paddlehub.serving.bert_serving import bs_client if __name__ == "__main__": + # 初始化bert_service客户端BSClient + bc = bs_client.BSClient(module_name="ernie_tiny", server="127.0.0.1:8866") + # 输入要做embedding的文本 # 文本格式为[["文本1"], ["文本2"], ] input_text = [ @@ -10,10 +13,10 @@ if __name__ == "__main__": ["醉后不知天在水"], ["满船清梦压星河"], ] - # 调用客户端接口bert_service.connect()获取结果 - result = bert_service.connect( - input_text=input_text, model_name="ernie_tiny", server="127.0.0.1:8866") - # 打印embedding结果 + # BSClient.get_result()获取结果 + result = bc.get_result(input_text=input_text) + + # 打印输入文本的embedding结果 for item in result: print(item) diff --git a/paddlehub/__init__.py b/paddlehub/__init__.py index b940e22925c18fb6a050a85e5270331a73d03510..e511d44dd1e399a9b0da56992c13064b1a127311 100644 --- a/paddlehub/__init__.py +++ b/paddlehub/__init__.py @@ -38,7 +38,7 @@ from .common.logger import logger from .common.paddle_helper import connect_program from .common.hub_server import default_hub_server -from .module.module import Module, create_module +from .module.module import Module from .module.base_processor import BaseProcessor from .module.signature import Signature, create_signature from .module.manager import default_module_manager diff --git a/paddlehub/autofinetune/autoft.py b/paddlehub/autofinetune/autoft.py index 11e10c23baa69c9a3182a16e7f693418eb287d37..e7aba48d5be0b768587a6e69a15e72c89976e0c3 100644 --- a/paddlehub/autofinetune/autoft.py +++ b/paddlehub/autofinetune/autoft.py @@ -26,6 +26,7 @@ from tb_paddle import SummaryWriter from paddlehub.common.logger import logger from paddlehub.common.utils import mkdir from paddlehub.autofinetune.evaluator import REWARD_SUM, TMP_HOME +from paddlehub.autofinetune.mpi_helper import MPIHelper if six.PY3: INF = math.inf @@ -75,6 +76,12 @@ class BaseTuningStrategy(object): logdir=self._output_dir + '/visualization/pop_{}'.format(i)) self.writer_pop_trails.append(writer_pop_trail) + # for parallel on mpi + self.mpi = MPIHelper() + if self.mpi.multi_machine: + print("Autofinetune multimachine mode: running on {}".format( + self.mpi.gather(self.mpi.name))) + @property def thread(self): return self._num_thread @@ -177,16 +184,22 @@ class BaseTuningStrategy(object): solutions_modeldirs = {} mkdir(output_dir) - for idx, solution in enumerate(solutions): + solutions = self.mpi.bcast(solutions) + + # split solutions to "solutions for me" + range_start, range_end = self.mpi.split_range(len(solutions)) + my_solutions = solutions[range_start:range_end] + + for idx, solution in enumerate(my_solutions): cuda = self.is_cuda_free["free"][0] modeldir = output_dir + "/model-" + str(idx) + "/" log_file = output_dir + "/log-" + str(idx) + ".info" params_cudas_dirs.append([solution, cuda, modeldir, log_file]) - solutions_modeldirs[tuple(solution)] = modeldir + solutions_modeldirs[tuple(solution)] = (modeldir, self.mpi.rank) self.is_cuda_free["free"].remove(cuda) self.is_cuda_free["busy"].append(cuda) if len(params_cudas_dirs - ) == self.thread or idx == len(solutions) - 1: + ) == self.thread or idx == len(my_solutions) - 1: tp = ThreadPool(len(params_cudas_dirs)) solution_results += tp.map(self.evaluator.run, params_cudas_dirs) @@ -198,13 +211,25 @@ class BaseTuningStrategy(object): self.is_cuda_free["busy"].remove(param_cuda[1]) params_cudas_dirs = [] - self.feedback(solutions, solution_results) + all_solution_results = self.mpi.gather(solution_results) + + if self.mpi.rank == 0: + # only rank 0 need to feedback + all_solution_results = [y for x in all_solution_results for y in x] + self.feedback(solutions, all_solution_results) + # remove the tmp.txt which records the eval results for trials tmp_file = os.path.join(TMP_HOME, "tmp.txt") if os.path.exists(tmp_file): os.remove(tmp_file) - return solutions_modeldirs + # collect all solutions_modeldirs + collected_solutions_modeldirs = self.mpi.allgather(solutions_modeldirs) + return_dict = {} + for i in collected_solutions_modeldirs: + return_dict.update(i) + + return return_dict class HAZero(BaseTuningStrategy): diff --git a/paddlehub/autofinetune/mpi_helper.py b/paddlehub/autofinetune/mpi_helper.py new file mode 100755 index 0000000000000000000000000000000000000000..9608363bfa888f208e20faaf7c9ac9d2278b33d3 --- /dev/null +++ b/paddlehub/autofinetune/mpi_helper.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + + +class MPIHelper(object): + def __init__(self): + try: + from mpi4py import MPI + except: + # local run + self._size = 1 + self._rank = 0 + self._multi_machine = False + + import socket + self._name = socket.gethostname() + else: + # in mpi environment + self._comm = MPI.COMM_WORLD + self._size = self._comm.Get_size() + self._rank = self._comm.Get_rank() + self._name = MPI.Get_processor_name() + if self._size > 1: + self._multi_machine = True + else: + self._multi_machine = False + + @property + def multi_machine(self): + return self._multi_machine + + @property + def rank(self): + return self._rank + + @property + def size(self): + return self._size + + @property + def name(self): + return self._name + + def bcast(self, data): + if self._multi_machine: + # call real bcast + return self._comm.bcast(data, root=0) + else: + # do nothing + return data + + def gather(self, data): + if self._multi_machine: + # call real gather + return self._comm.gather(data, root=0) + else: + # do nothing + return [data] + + def allgather(self, data): + if self._multi_machine: + # call real allgather + return self._comm.allgather(data) + else: + # do nothing + return [data] + + # calculate split range on mpi environment + def split_range(self, array_length): + if self._size == 1: + return 0, array_length + average_count = array_length / self._size + if array_length % self._size == 0: + return average_count * self._rank, average_count * (self._rank + 1) + else: + if self._rank < array_length % self._size: + return (average_count + 1) * self._rank, (average_count + 1) * ( + self._rank + 1) + else: + start = (average_count + 1) * (array_length % self._size) \ + + average_count * (self._rank - array_length % self._size) + return start, start + average_count + + +if __name__ == "__main__": + + mpi = MPIHelper() + print("Hello world from process {} of {} at {}.".format( + mpi.rank, mpi.size, mpi.name)) + + all_node_names = mpi.gather(mpi.name) + print("all node names using gather: {}".format(all_node_names)) + + all_node_names = mpi.allgather(mpi.name) + print("all node names using allgather: {}".format(all_node_names)) + + if mpi.rank == 0: + data = range(10) + else: + data = None + data = mpi.bcast(data) + print("after bcast, process {} have data {}".format(mpi.rank, data)) + + data = [i + mpi.rank for i in data] + print("after modify, process {} have data {}".format(mpi.rank, data)) + + new_data = mpi.gather(data) + print("after gather, process {} have data {}".format(mpi.rank, new_data)) + + # test for split + for i in range(12): + length = i + mpi.size # length should >= mpi.size + [start, end] = mpi.split_range(length) + split_result = mpi.gather([start, end]) + print("length {}, split_result {}".format(length, split_result)) diff --git a/paddlehub/commands/autofinetune.py b/paddlehub/commands/autofinetune.py index 7b79eb4487d6323cdaae5500dffdbf0ac3aa4aab..8efb56f53a47be7b6378ecfe174a24d3803a8025 100644 --- a/paddlehub/commands/autofinetune.py +++ b/paddlehub/commands/autofinetune.py @@ -188,37 +188,62 @@ class AutoFineTuneCommand(BaseCommand): run_round_cnt = run_round_cnt + 1 print("PaddleHub Autofinetune ends.") + best_hparams_origin = autoft.get_best_hparams() + best_hparams_origin = autoft.mpi.bcast(best_hparams_origin) + with open(autoft._output_dir + "/log_file.txt", "w") as f: - best_hparams = evaluator.convert_params(autoft.get_best_hparams()) + best_hparams = evaluator.convert_params(best_hparams_origin) print("The final best hyperparameters:") f.write("The final best hyperparameters:\n") for index, hparam_name in enumerate(autoft.hparams_name_list): print("%s=%s" % (hparam_name, best_hparams[index])) f.write(hparam_name + "\t:\t" + str(best_hparams[index]) + "\n") + best_hparams_dir, best_hparams_rank = solutions_modeldirs[tuple( + best_hparams_origin)] + print("The final best eval score is %s." % autoft.get_best_eval_value()) - print("The final best model parameters are saved as " + - autoft._output_dir + "/best_model .") + + if autoft.mpi.multi_machine: + print("The final best model parameters are saved as " + + autoft._output_dir + "/best_model on rank " + + str(best_hparams_rank) + " .") + else: + print("The final best model parameters are saved as " + + autoft._output_dir + "/best_model .") f.write("The final best eval score is %s.\n" % autoft.get_best_eval_value()) - f.write( - "The final best model parameters are saved as ./best_model .") best_model_dir = autoft._output_dir + "/best_model" - shutil.copytree( - solutions_modeldirs[tuple(autoft.get_best_hparams())], - best_model_dir) - f.write("\t".join(autoft.hparams_name_list) + - "\tsaved_params_dir\n") + if autoft.mpi.rank == best_hparams_rank: + shutil.copytree(best_hparams_dir, best_model_dir) + + if autoft.mpi.multi_machine: + f.write( + "The final best model parameters are saved as ./best_model on rank " \ + + str(best_hparams_rank) + " .") + f.write("\t".join(autoft.hparams_name_list) + + "\tsaved_params_dir\trank\n") + else: + f.write( + "The final best model parameters are saved as ./best_model ." + ) + f.write("\t".join(autoft.hparams_name_list) + + "\tsaved_params_dir\n") + print( - "The related infomation about hyperparamemters searched are saved as %s/log_file.txt ." + "The related infomation about hyperparamemters searched are saved as %s/log_file.txt ." % autoft._output_dir) for solution, modeldir in solutions_modeldirs.items(): param = evaluator.convert_params(solution) param = [str(p) for p in param] - f.write("\t".join(param) + "\t" + modeldir + "\n") + if autoft.mpi.multi_machine: + f.write("\t".join(param) + "\t" + modeldir[0] + "\t" + + str(modeldir[1]) + "\n") + else: + f.write("\t".join(param) + "\t" + modeldir[0] + "\n") return True diff --git a/paddlehub/commands/install.py b/paddlehub/commands/install.py index e9ba9ba46fca22106736699675ca9824a857916b..1e6407a49b79d0594c37f5e202d30675f328ed5f 100644 --- a/paddlehub/commands/install.py +++ b/paddlehub/commands/install.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function import argparse +import os from paddlehub.common import utils from paddlehub.module.manager import default_module_manager @@ -42,14 +43,23 @@ class InstallCommand(BaseCommand): print("ERROR: Please specify a module name.\n") self.help() return False - module_name = argv[0] - module_version = None if "==" not in module_name else module_name.split( - "==")[1] - module_name = module_name if "==" not in module_name else module_name.split( - "==")[0] extra = {"command": "install"} - result, tips, module_dir = default_module_manager.install_module( - module_name=module_name, module_version=module_version, extra=extra) + if argv[0].endswith("tar.gz") or argv[0].endswith("phm"): + result, tips, module_dir = default_module_manager.install_module( + module_package=argv[0], extra=extra) + elif os.path.exists(argv[0]) and os.path.isdir(argv[0]): + result, tips, module_dir = default_module_manager.install_module( + module_dir=argv[0], extra=extra) + else: + module_name = argv[0] + module_version = None if "==" not in module_name else module_name.split( + "==")[1] + module_name = module_name if "==" not in module_name else module_name.split( + "==")[0] + result, tips, module_dir = default_module_manager.install_module( + module_name=module_name, + module_version=module_version, + extra=extra) print(tips) return True diff --git a/paddlehub/commands/run.py b/paddlehub/commands/run.py index 30876bc7ca5dcb84b3712fbafd905539e9ff2c5e..9754d27f2d4ae862f0d3e4ad44bce63ab7bbadc5 100644 --- a/paddlehub/commands/run.py +++ b/paddlehub/commands/run.py @@ -71,7 +71,7 @@ class RunCommand(BaseCommand): if not result: return None - return hub.Module(module_dir=module_dir) + return hub.Module(directory=module_dir[0]) def add_module_config_arg(self): configs = self.module.processor.configs() @@ -105,7 +105,7 @@ class RunCommand(BaseCommand): def add_module_input_arg(self): module_type = self.module.type.lower() expect_data_format = self.module.processor.data_format( - self.module.default_signature.name) + self.module.default_signature) self.arg_input_group.add_argument( '--input_file', type=str, @@ -152,7 +152,7 @@ class RunCommand(BaseCommand): def get_data(self): module_type = self.module.type.lower() expect_data_format = self.module.processor.data_format( - self.module.default_signature.name) + self.module.default_signature) input_data = {} if len(expect_data_format) == 1: key = list(expect_data_format.keys())[0] @@ -177,7 +177,7 @@ class RunCommand(BaseCommand): def check_data(self, data): expect_data_format = self.module.processor.data_format( - self.module.default_signature.name) + self.module.default_signature) if len(data.keys()) != len(expect_data_format.keys()): print( @@ -236,35 +236,38 @@ class RunCommand(BaseCommand): return False # If the module is not executable, give an alarm and exit - if not self.module.default_signature: + if not self.module.is_runable: print("ERROR! Module %s is not executable." % module_name) return False - self.module.check_processor() - self.add_module_config_arg() - self.add_module_input_arg() + if self.module.code_version == "v2": + results = self.module(argv[1:]) + else: + self.module.check_processor() + self.add_module_config_arg() + self.add_module_input_arg() - if not argv[1:]: - self.help() - return False + if not argv[1:]: + self.help() + return False - self.args = self.parser.parse_args(argv[1:]) + self.args = self.parser.parse_args(argv[1:]) - config = self.get_config() - data = self.get_data() + config = self.get_config() + data = self.get_data() - try: - self.check_data(data) - except DataFormatError: - self.help() - return False - - results = self.module( - sign_name=self.module.default_signature.name, - data=data, - use_gpu=self.args.use_gpu, - batch_size=self.args.batch_size, - **config) + try: + self.check_data(data) + except DataFormatError: + self.help() + return False + + results = self.module( + sign_name=self.module.default_signature, + data=data, + use_gpu=self.args.use_gpu, + batch_size=self.args.batch_size, + **config) if six.PY2: try: diff --git a/paddlehub/commands/serving.py b/paddlehub/commands/serving.py index ab6725dd560fdd04fa5580798d2c419a59764d7e..fd39bf90d12a17d237ee9b22a00a05636683f4c0 100644 --- a/paddlehub/commands/serving.py +++ b/paddlehub/commands/serving.py @@ -159,7 +159,7 @@ class ServingCommand(BaseCommand): module = args.modules if module is not None: use_gpu = args.use_gpu - port = args.port[0] + port = args.port if ServingCommand.is_port_occupied("127.0.0.1", port) is True: print("Port %s is occupied, please change it." % (port)) return False @@ -206,8 +206,10 @@ class ServingCommand(BaseCommand): if args.sub_command == "start": if args.bert_service == "bert_service": ServingCommand.start_bert_serving(args) - else: + elif args.bert_service is None: ServingCommand.start_serving(args) + else: + ServingCommand.show_help() else: ServingCommand.show_help() diff --git a/paddlehub/commands/show.py b/paddlehub/commands/show.py index 1c3f45434fd0cee632e574874792ce5a04c0bc59..5caa8de7d99a4e9a03137288e40902fb77724708 100644 --- a/paddlehub/commands/show.py +++ b/paddlehub/commands/show.py @@ -125,8 +125,6 @@ class ShowCommand(BaseCommand): cwd = os.getcwd() module_dir = default_module_manager.search_module(module_name) - module_dir = (os.path.join(cwd, module_name), - None) if not module_dir else module_dir if not module_dir or not os.path.exists(module_dir[0]): print("%s is not existed!" % module_name) return True diff --git a/paddlehub/common/hub_server.py b/paddlehub/common/hub_server.py index 7039a2db090006de11a79f1523e9f72b3ce572b3..236b2e5be13969adcc3df3cdf42328bf490d4f82 100644 --- a/paddlehub/common/hub_server.py +++ b/paddlehub/common/hub_server.py @@ -298,17 +298,23 @@ class CacheUpdater(threading.Thread): api_url = srv_utils.uri_path(default_hub_server.get_server_url(), 'search') cache_path = os.path.join(CACHE_HOME, RESOURCE_LIST_FILE) - extra = { - "command": "update_cache", - "mtime": os.stat(cache_path).st_mtime - } + if os.path.exists(cache_path): + extra = { + "command": "update_cache", + "mtime": os.stat(cache_path).st_mtime + } + else: + extra = { + "command": "update_cache", + "mtime": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + } try: r = srv_utils.hub_request(api_url, payload, extra) + if r.get("update_cache", 0) == 1: + with open(cache_path, 'w+') as fp: + yaml.safe_dump({'resource_list': r['data']}, fp) except Exception as err: pass - if r.get("update_cache", 0) == 1: - with open(cache_path, 'w+') as fp: - yaml.safe_dump({'resource_list': r['data']}, fp) def run(self): self.update_resource_list_file(self.module, self.version) diff --git a/paddlehub/module/check_info.proto b/paddlehub/module/check_info.proto index 923de58cefe8bffd6499378733e29bbb2e7a508f..56c1b584de7afcd958eb3edaffc0fdef8b0d7363 100644 --- a/paddlehub/module/check_info.proto +++ b/paddlehub/module/check_info.proto @@ -50,6 +50,7 @@ message CheckInfo { string paddle_version = 1; string hub_version = 2; string module_proto_version = 3; - repeated FileInfo file_infos = 4; - repeated Requires requires = 5; + string module_code_version = 4; + repeated FileInfo file_infos = 5; + repeated Requires requires = 6; }; diff --git a/paddlehub/module/check_info_pb2.py b/paddlehub/module/check_info_pb2.py index 78f5546c49c417508d26fa0f809340459987fc66..8ed17a9ac532ad5bd7a7242d27793ca53a235b40 100644 --- a/paddlehub/module/check_info_pb2.py +++ b/paddlehub/module/check_info_pb2.py @@ -1,4 +1,3 @@ -#coding:utf-8 # Generated by the protocol buffer compiler. DO NOT EDIT! # source: check_info.proto @@ -19,7 +18,7 @@ DESCRIPTOR = _descriptor.FileDescriptor( package='paddlehub.module.checkinfo', syntax='proto3', serialized_pb=_b( - '\n\x10\x63heck_info.proto\x12\x1apaddlehub.module.checkinfo\"\x85\x01\n\x08\x46ileInfo\x12\x11\n\tfile_name\x18\x01 \x01(\t\x12\x33\n\x04type\x18\x02 \x01(\x0e\x32%.paddlehub.module.checkinfo.FILE_TYPE\x12\x0f\n\x07is_need\x18\x03 \x01(\x08\x12\x0b\n\x03md5\x18\x04 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x05 \x01(\t\"\x84\x01\n\x08Requires\x12>\n\x0crequire_type\x18\x01 \x01(\x0e\x32(.paddlehub.module.checkinfo.REQUIRE_TYPE\x12\x0f\n\x07version\x18\x02 \x01(\t\x12\x12\n\ngreat_than\x18\x03 \x01(\x08\x12\x13\n\x0b\x64\x65scription\x18\x04 \x01(\t\"\xc8\x01\n\tCheckInfo\x12\x16\n\x0epaddle_version\x18\x01 \x01(\t\x12\x13\n\x0bhub_version\x18\x02 \x01(\t\x12\x1c\n\x14module_proto_version\x18\x03 \x01(\t\x12\x38\n\nfile_infos\x18\x04 \x03(\x0b\x32$.paddlehub.module.checkinfo.FileInfo\x12\x36\n\x08requires\x18\x05 \x03(\x0b\x32$.paddlehub.module.checkinfo.Requires*\x1e\n\tFILE_TYPE\x12\x08\n\x04\x46ILE\x10\x00\x12\x07\n\x03\x44IR\x10\x01*[\n\x0cREQUIRE_TYPE\x12\x12\n\x0ePYTHON_PACKAGE\x10\x00\x12\x0e\n\nHUB_MODULE\x10\x01\x12\n\n\x06SYSTEM\x10\x02\x12\x0b\n\x07\x43OMMAND\x10\x03\x12\x0e\n\nPY_VERSION\x10\x04\x42\x02H\x03\x62\x06proto3' + '\n\x10\x63heck_info.proto\x12\x1apaddlehub.module.checkinfo\"\x85\x01\n\x08\x46ileInfo\x12\x11\n\tfile_name\x18\x01 \x01(\t\x12\x33\n\x04type\x18\x02 \x01(\x0e\x32%.paddlehub.module.checkinfo.FILE_TYPE\x12\x0f\n\x07is_need\x18\x03 \x01(\x08\x12\x0b\n\x03md5\x18\x04 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x05 \x01(\t\"\x84\x01\n\x08Requires\x12>\n\x0crequire_type\x18\x01 \x01(\x0e\x32(.paddlehub.module.checkinfo.REQUIRE_TYPE\x12\x0f\n\x07version\x18\x02 \x01(\t\x12\x12\n\ngreat_than\x18\x03 \x01(\x08\x12\x13\n\x0b\x64\x65scription\x18\x04 \x01(\t\"\xe5\x01\n\tCheckInfo\x12\x16\n\x0epaddle_version\x18\x01 \x01(\t\x12\x13\n\x0bhub_version\x18\x02 \x01(\t\x12\x1c\n\x14module_proto_version\x18\x03 \x01(\t\x12\x1b\n\x13module_code_version\x18\x04 \x01(\t\x12\x38\n\nfile_infos\x18\x05 \x03(\x0b\x32$.paddlehub.module.checkinfo.FileInfo\x12\x36\n\x08requires\x18\x06 \x03(\x0b\x32$.paddlehub.module.checkinfo.Requires*\x1e\n\tFILE_TYPE\x12\x08\n\x04\x46ILE\x10\x00\x12\x07\n\x03\x44IR\x10\x01*[\n\x0cREQUIRE_TYPE\x12\x12\n\x0ePYTHON_PACKAGE\x10\x00\x12\x0e\n\nHUB_MODULE\x10\x01\x12\n\n\x06SYSTEM\x10\x02\x12\x0b\n\x07\x43OMMAND\x10\x03\x12\x0e\n\nPY_VERSION\x10\x04\x42\x02H\x03\x62\x06proto3' )) _sym_db.RegisterFileDescriptor(DESCRIPTOR) @@ -36,8 +35,8 @@ _FILE_TYPE = _descriptor.EnumDescriptor( ], containing_type=None, options=None, - serialized_start=522, - serialized_end=552, + serialized_start=551, + serialized_end=581, ) _sym_db.RegisterEnumDescriptor(_FILE_TYPE) @@ -61,8 +60,8 @@ _REQUIRE_TYPE = _descriptor.EnumDescriptor( ], containing_type=None, options=None, - serialized_start=554, - serialized_end=645, + serialized_start=583, + serialized_end=674, ) _sym_db.RegisterEnumDescriptor(_REQUIRE_TYPE) @@ -316,10 +315,26 @@ _CHECKINFO = _descriptor.Descriptor( extension_scope=None, options=None), _descriptor.FieldDescriptor( - name='file_infos', - full_name='paddlehub.module.checkinfo.CheckInfo.file_infos', + name='module_code_version', + full_name='paddlehub.module.checkinfo.CheckInfo.module_code_version', index=3, number=4, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='file_infos', + full_name='paddlehub.module.checkinfo.CheckInfo.file_infos', + index=4, + number=5, type=11, cpp_type=10, label=3, @@ -334,8 +349,8 @@ _CHECKINFO = _descriptor.Descriptor( _descriptor.FieldDescriptor( name='requires', full_name='paddlehub.module.checkinfo.CheckInfo.requires', - index=4, - number=5, + index=5, + number=6, type=11, cpp_type=10, label=3, @@ -357,7 +372,7 @@ _CHECKINFO = _descriptor.Descriptor( extension_ranges=[], oneofs=[], serialized_start=320, - serialized_end=520, + serialized_end=549, ) _FILEINFO.fields_by_name['type'].enum_type = _FILE_TYPE diff --git a/paddlehub/module/checker.py b/paddlehub/module/checker.py index d76ca6bd74ab4d354ba5fa72e8f1e7215c0ed6f0..b1470af4d16774688af53a0a7293691d30bc3e6c 100644 --- a/paddlehub/module/checker.py +++ b/paddlehub/module/checker.py @@ -32,20 +32,22 @@ FILE_SEP = "/" class ModuleChecker(object): - def __init__(self, module_path): - self.module_path = module_path + def __init__(self, directory): + self._directory = directory + self._pb_path = os.path.join(self.directory, CHECK_INFO_PB_FILENAME) def generate_check_info(self): check_info = check_info_pb2.CheckInfo() check_info.paddle_version = paddle.__version__ check_info.hub_version = hub_version check_info.module_proto_version = module_proto_version + check_info.module_code_version = "v2" file_infos = check_info.file_infos - file_list = [file for file in os.listdir(self.module_path)] + file_list = [file for file in os.listdir(self.directory)] while file_list: file = file_list[0] file_list = file_list[1:] - abs_path = os.path.join(self.module_path, file) + abs_path = os.path.join(self.directory, file) if os.path.isdir(abs_path): for sub_file in os.listdir(abs_path): sub_file = os.path.join(file, sub_file) @@ -62,9 +64,12 @@ class ModuleChecker(object): file_info.type = check_info_pb2.FILE file_info.is_need = True - with open(os.path.join(self.module_path, CHECK_INFO_PB_FILENAME), - "wb") as fi: - fi.write(check_info.SerializeToString()) + with open(self.pb_path, "wb") as file: + file.write(check_info.SerializeToString()) + + @property + def module_code_version(self): + return self.check_info.module_code_version @property def module_proto_version(self): @@ -82,20 +87,25 @@ class ModuleChecker(object): def file_infos(self): return self.check_info.file_infos + @property + def directory(self): + return self._directory + + @property + def pb_path(self): + return self._pb_path + def check(self): result = True - self.check_info_pb_path = os.path.join(self.module_path, - CHECK_INFO_PB_FILENAME) - if not (os.path.exists(self.check_info_pb_path) - or os.path.isfile(self.check_info_pb_path)): + if not (os.path.exists(self.pb_path) or os.path.isfile(self.pb_path)): logger.warning( "This module lacks core file %s" % CHECK_INFO_PB_FILENAME) result = False self.check_info = check_info_pb2.CheckInfo() try: - with open(self.check_info_pb_path, "rb") as fi: + with open(self.pb_path, "rb") as fi: pb_string = fi.read() result = self.check_info.ParseFromString(pb_string) if len(pb_string) == 0 or (result is not None @@ -182,7 +192,7 @@ class ModuleChecker(object): for file_info in self.file_infos: file_type = file_info.type file_path = file_info.file_name.replace(FILE_SEP, os.sep) - file_path = os.path.join(self.module_path, file_path) + file_path = os.path.join(self.directory, file_path) if not os.path.exists(file_path): if file_info.is_need: logger.warning( diff --git a/paddlehub/module/manager.py b/paddlehub/module/manager.py index 433ccdaaab190f34efbbc7a391eb7058cb66001c..037c67750672ac7fce3c8f5ee1b18d06e6c3f00f 100644 --- a/paddlehub/module/manager.py +++ b/paddlehub/module/manager.py @@ -19,7 +19,9 @@ from __future__ import print_function import os import shutil + from functools import cmp_to_key +import tarfile from paddlehub.common import utils from paddlehub.common import srv_utils @@ -79,15 +81,76 @@ class LocalModuleManager(object): return self.modules_dict.get(module_name, None) def install_module(self, - module_name, + module_name=None, + module_dir=None, + module_package=None, module_version=None, upgrade=False, extra=None): - self.all_modules(update=True) - module_info = self.modules_dict.get(module_name, None) - if module_info: - if not module_version or module_version == self.modules_dict[ - module_name][1]: + md5_value = installed_module_version = None + from_user_dir = True if module_dir else False + if module_name: + self.all_modules(update=True) + module_info = self.modules_dict.get(module_name, None) + if module_info: + if not module_version or module_version == self.modules_dict[ + module_name][1]: + module_dir = self.modules_dict[module_name][0] + module_tag = module_name if not module_version else '%s-%s' % ( + module_name, module_version) + tips = "Module %s already installed in %s" % (module_tag, + module_dir) + return True, tips, self.modules_dict[module_name] + + search_result = hub.default_hub_server.get_module_url( + module_name, version=module_version, extra=extra) + name = search_result.get('name', None) + url = search_result.get('url', None) + md5_value = search_result.get('md5', None) + installed_module_version = search_result.get('version', None) + if not url or (module_version is not None + and installed_module_version != module_version) or ( + name != module_name): + if default_hub_server._server_check() is False: + tips = "Request Hub-Server unsuccessfully, please check your network." + else: + tips = "Can't find module %s" % module_name + if module_version: + tips += " with version %s" % module_version + module_tag = module_name if not module_version else '%s-%s' % ( + module_name, module_version) + return False, tips, None + + result, tips, module_zip_file = default_downloader.download_file( + url=url, + save_path=hub.CACHE_HOME, + save_name=module_name, + replace=True, + print_progress=True) + result, tips, module_dir = default_downloader.uncompress( + file=module_zip_file, + dirname=MODULE_HOME, + delete_file=True, + print_progress=True) + + if module_package: + with tarfile.open(module_package, "r:gz") as tar: + file_names = tar.getnames() + size = len(file_names) - 1 + module_dir = os.path.split(file_names[0])[0] + module_dir = os.path.join(hub.CACHE_HOME, module_dir) + # remove cache + if os.path.exists(module_dir): + shutil.rmtree(module_dir) + for index, file_name in enumerate(file_names): + tar.extract(file_name, hub.CACHE_HOME) + + if module_dir: + if not module_name: + module_name = hub.Module(directory=module_dir).name + self.all_modules(update=False) + module_info = self.modules_dict.get(module_name, None) + if module_info: module_dir = self.modules_dict[module_name][0] module_tag = module_name if not module_version else '%s-%s' % ( module_name, module_version) @@ -162,10 +225,19 @@ class LocalModuleManager(object): with open(os.path.join(MODULE_HOME, module_dir, "md5.txt"), "w") as fp: fp.write(md5_value) + if md5_value: + with open( + os.path.join(MODULE_HOME, module_dir, "md5.txt"), + "w") as fp: + fp.write(md5_value) + save_path = os.path.join(MODULE_HOME, module_name) if os.path.exists(save_path): - shutil.rmtree(save_path) - shutil.move(module_dir, save_path) + shutil.move(save_path) + if from_user_dir: + shutil.copytree(module_dir, save_path) + else: + shutil.move(module_dir, save_path) module_dir = save_path tips = "Successfully installed %s" % module_name if installed_module_version: diff --git a/paddlehub/module/module.py b/paddlehub/module/module.py index 18fcec7366c11042a305ca93ac6ea2d25b3a81a0..45607b0b374315a96e4ae1b927e21ba0bb0f6759 100644 --- a/paddlehub/module/module.py +++ b/paddlehub/module/module.py @@ -21,6 +21,10 @@ import os import time import sys import functools +import inspect +import importlib +import tarfile +from collections import defaultdict from shutil import copyfile import paddle @@ -28,22 +32,19 @@ import paddle.fluid as fluid from paddlehub.common import utils from paddlehub.common import paddle_helper -from paddlehub.common.logger import logger +from paddlehub.common.dir import CACHE_HOME from paddlehub.common.lock import lock -from paddlehub.common.downloader import default_downloader +from paddlehub.common.logger import logger +from paddlehub.common.hub_server import CacheUpdater from paddlehub.module import module_desc_pb2 -from paddlehub.common.dir import CONF_HOME from paddlehub.module import check_info_pb2 -from paddlehub.common.hub_server import CacheUpdater -from paddlehub.module.signature import Signature, create_signature -from paddlehub.module.checker import ModuleChecker from paddlehub.module.manager import default_module_manager +from paddlehub.module.checker import ModuleChecker +from paddlehub.module.signature import Signature, create_signature from paddlehub.module.base_processor import BaseProcessor from paddlehub.io.parser import yaml_parser from paddlehub import version -__all__ = ['Module', 'create_module'] - # PaddleHub module dir name ASSETS_DIRNAME = "assets" MODEL_DIRNAME = "model" @@ -52,67 +53,227 @@ PYTHON_DIR = "python" PROCESSOR_NAME = "processor" # PaddleHub var prefix HUB_VAR_PREFIX = "@HUB_%s@" +# PaddleHub Module package suffix +HUB_PACKAGE_SUFFIX = "phm" + + +def create_module(directory, name, author, email, module_type, summary, + version): + save_file_name = "{}-{}.{}".format(name, version, HUB_PACKAGE_SUFFIX) + + # record module info and serialize + desc = module_desc_pb2.ModuleDesc() + attr = desc.attr + attr.type = module_desc_pb2.MAP + module_info = attr.map.data['module_info'] + module_info.type = module_desc_pb2.MAP + utils.from_pyobj_to_module_attr(name, module_info.map.data['name']) + utils.from_pyobj_to_module_attr(author, module_info.map.data['author']) + utils.from_pyobj_to_module_attr(email, module_info.map.data['author_email']) + utils.from_pyobj_to_module_attr(module_type, module_info.map.data['type']) + utils.from_pyobj_to_module_attr(summary, module_info.map.data['summary']) + utils.from_pyobj_to_module_attr(version, module_info.map.data['version']) + + module_desc_path = os.path.join(directory, "module_desc.pb") + with open(module_desc_path, "wb") as f: + f.write(desc.SerializeToString()) + + # generate check info + checker = ModuleChecker(directory) + checker.generate_check_info() + + # add __init__ + module_init_1 = os.path.join(directory, "__init__.py") + with open(module_init_1, "a") as file: + file.write("") + + module_init_2 = os.path.join(directory, "python", "__init__.py") + with open(module_init_2, "a") as file: + file.write("") + + # package the module + with tarfile.open(save_file_name, "w:gz") as tar: + for dirname, _, files in os.walk(directory): + for file in files: + tar.add(os.path.join(dirname, file)) + + os.remove(module_desc_path) + os.remove(checker.pb_path) + os.remove(module_init_1) + os.remove(module_init_2) + +class Module(object): + def __new__(cls, name=None, directory=None, module_dir=None, version=None): + module = None + + if cls.__name__ == "Module": + if name: + module = cls.init_with_name(name=name, version=version) + elif directory: + module = cls.init_with_directory(directory=directory) + elif module_dir: + logger.warning( + "Parameter module_dir is deprecated, please use directory to specify the path" + ) + if isinstance(module_dir, list) or isinstance( + module_dir, tuple): + directory = module_dir[0] + version = module_dir[1] + else: + directory = module_dir + module = cls.init_with_directory(directory=directory) + + if not module: + module = object.__new__(cls) + else: + CacheUpdater(module.name, module.version).start() + return module + + def __init__(self, name=None, directory=None, module_dir=None, + version=None): + if not directory: + return + self._code_version = "v2" + self._directory = directory + self.module_desc_path = os.path.join(self.directory, MODULE_DESC_PBNAME) + self._desc = module_desc_pb2.ModuleDesc() + with open(self.module_desc_path, "rb") as file: + self._desc.ParseFromString(file.read()) -def create_module(sign_arr, - module_dir, - processor=None, - assets=None, - module_info=None, - exe=None, - extra_info=None): - sign_arr = utils.to_list(sign_arr) - module = Module( - signatures=sign_arr, - processor=processor, - assets=assets, - module_info=module_info, - extra_info=extra_info) - module.serialize_to_path(path=module_dir, exe=exe) + module_info = self.desc.attr.map.data['module_info'] + self._name = utils.from_module_attr_to_pyobj( + module_info.map.data['name']) + self._author = utils.from_module_attr_to_pyobj( + module_info.map.data['author']) + self._author_email = utils.from_module_attr_to_pyobj( + module_info.map.data['author_email']) + self._version = utils.from_module_attr_to_pyobj( + module_info.map.data['version']) + self._type = utils.from_module_attr_to_pyobj( + module_info.map.data['type']) + self._summary = utils.from_module_attr_to_pyobj( + module_info.map.data['summary']) + + self._initialize() + + @classmethod + def init_with_name(cls, name, version=None): + fp_lock = open(os.path.join(CACHE_HOME, name), "a") + lock.flock(fp_lock, lock.LOCK_EX) + log_msg = "Installing %s module" % name + if version: + log_msg += "-%s" % version + logger.info(log_msg) + extra = {"command": "install"} + result, tips, module_dir = default_module_manager.install_module( + module_name=name, module_version=version, extra=extra) + if not result: + logger.error(tips) + raise RuntimeError(tips) + + logger.info(tips) + lock.flock(fp_lock, lock.LOCK_UN) + return cls.init_with_directory(directory=module_dir[0]) + + @classmethod + def init_with_directory(cls, directory): + desc_file = os.path.join(directory, MODULE_DESC_PBNAME) + checker = ModuleChecker(directory) + checker.check() + + module_code_version = checker.module_code_version + if module_code_version == "v2": + basename = os.path.split(directory)[-1] + dirname = os.path.join(*list(os.path.split(directory)[:-1])) + sys.path.append(dirname) + pymodule = importlib.import_module( + "{}.python.module".format(basename)) + return pymodule.HubModule(directory=directory) + return ModuleV1(directory=directory) + + @property + def desc(self): + return self._desc + + @property + def directory(self): + return self._directory + + @property + def author(self): + return self._author + + @property + def author_email(self): + return self._author_email + + @property + def summary(self): + return self._summary + + @property + def type(self): + return self._type + + @property + def version(self): + return self._version + + @property + def name(self): + return self._name + + @property + def name_prefix(self): + return self._name_prefix + + @property + def code_version(self): + return self._code_version + + @property + def is_runable(self): + return False + + def _initialize(self): + pass class ModuleHelper(object): - def __init__(self, module_dir): - self.module_dir = module_dir + def __init__(self, directory): + self.directory = directory def module_desc_path(self): - return os.path.join(self.module_dir, MODULE_DESC_PBNAME) + return os.path.join(self.directory, MODULE_DESC_PBNAME) def model_path(self): - return os.path.join(self.module_dir, MODEL_DIRNAME) + return os.path.join(self.directory, MODEL_DIRNAME) def processor_path(self): - return os.path.join(self.module_dir, PYTHON_DIR) + return os.path.join(self.directory, PYTHON_DIR) def processor_name(self): return PROCESSOR_NAME def assets_path(self): - return os.path.join(self.module_dir, ASSETS_DIRNAME) + return os.path.join(self.directory, ASSETS_DIRNAME) -class Module(object): - def __init__(self, - name=None, - module_dir=None, - signatures=None, - module_info=None, - assets=None, - processor=None, - extra_info=None, +class ModuleV1(Module): + def __init__(self, name=None, directory=None, module_dir=None, version=None): - self.desc = module_desc_pb2.ModuleDesc() + if not directory: + return + super(ModuleV1, self).__init__(name, directory, module_dir, version) + self._code_version = "v1" self.program = None self.assets = [] self.helper = None self.signatures = {} self.default_signature = None - self.module_info = None self.processor = None - self.extra_info = {} if extra_info is None else extra_info - if not isinstance(self.extra_info, dict): - raise TypeError( - "The extra_info should be an instance of python dict") + self.extra_info = {} # cache data self.last_call_name = None @@ -120,62 +281,21 @@ class Module(object): self.cache_fetch_dict = None self.cache_program = None - fp_lock = open(os.path.join(CONF_HOME, 'config.json')) - lock.flock(fp_lock, lock.LOCK_EX) - if name: - self._init_with_name(name=name, version=version) - lock.flock(fp_lock, lock.LOCK_UN) - elif module_dir: - self._init_with_module_file(module_dir=module_dir[0]) - lock.flock(fp_lock, lock.LOCK_UN) - name = module_dir[0].split("/")[-1] - if len(module_dir) > 1: - version = module_dir[1] - else: - version = default_module_manager.search_module(name)[1] - elif signatures: - if processor: - if not issubclass(processor, BaseProcessor): - raise TypeError( - "Processor shoule be an instance of paddlehub.BaseProcessor" - ) - if assets: - self.assets = utils.to_list(assets) - # for asset in assets: - # utils.check_path(assets) - self.processor = processor - self._generate_module_info(module_info) - self._init_with_signature(signatures=signatures) - lock.flock(fp_lock, lock.LOCK_UN) - else: - lock.flock(fp_lock, lock.LOCK_UN) - raise ValueError("Module initialized parameter is empty") - CacheUpdater(name, version).start() - - def _init_with_name(self, name, version=None): - log_msg = "Installing %s module" % name - if version: - log_msg += "-%s" % version - logger.info(log_msg) - extra = {"command": "install"} - result, tips, module_dir = default_module_manager.install_module( - module_name=name, module_version=version, extra=extra) - if not result: - logger.error(tips) - raise RuntimeError(tips) - else: - logger.info(tips) - self._init_with_module_file(module_dir[0]) - - def _init_with_url(self, url): - utils.check_url(url) - result, tips, module_dir = default_downloader.download_file_and_uncompress( - url, save_path=".") - if not result: - logger.error(tips) - raise RuntimeError(tips) - else: - self._init_with_module_file(module_dir) + self.helper = ModuleHelper(directory) + exe = fluid.Executor(fluid.CPUPlace()) + self.program, _, _ = fluid.io.load_inference_model( + self.helper.model_path(), executor=exe) + for block in self.program.blocks: + for op in block.ops: + if "op_callstack" in op.all_attrs(): + op._set_attr("op_callstack", [""]) + self._load_processor() + self._load_assets() + self._recover_from_desc() + self._generate_sign_attr() + self._generate_extra_info() + self._restore_parameter(self.program) + self._recover_variable_info(self.program) def _dump_processor(self): import inspect @@ -216,52 +336,6 @@ class Module(object): filepath = os.path.join(self.helper.assets_path(), file) self.assets.append(filepath) - def _init_with_module_file(self, module_dir): - checker = ModuleChecker(module_dir) - checker.check() - - self.helper = ModuleHelper(module_dir) - with open(self.helper.module_desc_path(), "rb") as fi: - self.desc.ParseFromString(fi.read()) - - exe = fluid.Executor(fluid.CPUPlace()) - self.program, _, _ = fluid.io.load_inference_model( - self.helper.model_path(), executor=exe) - for block in self.program.blocks: - for op in block.ops: - if "op_callstack" in op.all_attrs(): - op._set_attr("op_callstack", [""]) - self._load_processor() - self._load_assets() - self._recover_from_desc() - self._generate_sign_attr() - self._generate_extra_info() - self._restore_parameter(self.program) - self._recover_variable_info(self.program) - - def _init_with_signature(self, signatures): - self.name_prefix = HUB_VAR_PREFIX % self.name - self._process_signatures(signatures) - self._check_signatures() - self._generate_desc() - self._generate_sign_attr() - self._generate_extra_info() - - def _init_with_program(self, program): - pass - - def _process_signatures(self, signatures): - self.signatures = {} - self.program = signatures[0].inputs[0].block.program - for sign in signatures: - if sign.name in self.signatures: - raise ValueError( - "Error! Signature array contains duplicated signatrues %s" % - sign) - if self.default_signature is None and sign.for_predict: - self.default_signature = sign - self.signatures[sign.name] = sign - def _restore_parameter(self, program): global_block = program.global_block() param_attrs = self.desc.attr.map.data['param_attrs'] @@ -302,21 +376,6 @@ class Module(object): self.__dict__["get_%s" % key] = functools.partial( self.get_extra_info, key=key) - def _generate_module_info(self, module_info=None): - if not module_info: - self.module_info = {} - else: - if not utils.is_yaml_file(module_info): - logger.critical("Module info file should be yaml format") - exit(1) - self.module_info = yaml_parser.parse(module_info) - self.author = self.module_info.get('author', 'UNKNOWN') - self.author_email = self.module_info.get('author_email', 'UNKNOWN') - self.summary = self.module_info.get('summary', 'UNKNOWN') - self.type = self.module_info.get('type', 'UNKNOWN') - self.version = self.module_info.get('version', 'UNKNOWN') - self.name = self.module_info.get('name', 'UNKNOWN') - def _generate_sign_attr(self): self._check_signatures() for sign in self.signatures: @@ -369,21 +428,21 @@ class Module(object): default_signature_name = utils.from_module_attr_to_pyobj( self.desc.attr.map.data['default_signature']) self.default_signature = self.signatures[ - default_signature_name] if default_signature_name else None + default_signature_name].name if default_signature_name else None # recover module info module_info = self.desc.attr.map.data['module_info'] - self.name = utils.from_module_attr_to_pyobj( + self._name = utils.from_module_attr_to_pyobj( module_info.map.data['name']) - self.author = utils.from_module_attr_to_pyobj( + self._author = utils.from_module_attr_to_pyobj( module_info.map.data['author']) - self.author_email = utils.from_module_attr_to_pyobj( + self._author_email = utils.from_module_attr_to_pyobj( module_info.map.data['author_email']) - self.version = utils.from_module_attr_to_pyobj( + self._version = utils.from_module_attr_to_pyobj( module_info.map.data['version']) - self.type = utils.from_module_attr_to_pyobj( + self._type = utils.from_module_attr_to_pyobj( module_info.map.data['type']) - self.summary = utils.from_module_attr_to_pyobj( + self._summary = utils.from_module_attr_to_pyobj( module_info.map.data['summary']) # recover extra info @@ -393,77 +452,9 @@ class Module(object): self.extra_info[key] = utils.from_module_attr_to_pyobj(value) # recover name prefix - self.name_prefix = utils.from_module_attr_to_pyobj( + self._name_prefix = utils.from_module_attr_to_pyobj( self.desc.attr.map.data["name_prefix"]) - def _generate_desc(self): - # save fluid Parameter - attr = self.desc.attr - attr.type = module_desc_pb2.MAP - param_attrs = attr.map.data['param_attrs'] - param_attrs.type = module_desc_pb2.MAP - for param in self.program.global_block().iter_parameters(): - param_attr = param_attrs.map.data[param.name] - paddle_helper.from_param_to_module_attr(param, param_attr) - - # save Variable Info - var_infos = attr.map.data['var_infos'] - var_infos.type = module_desc_pb2.MAP - for block in self.program.blocks: - for var in block.vars.values(): - var_info = var_infos.map.data[var.name] - var_info.type = module_desc_pb2.MAP - utils.from_pyobj_to_module_attr( - var.stop_gradient, var_info.map.data['stop_gradient']) - utils.from_pyobj_to_module_attr(block.idx, - var_info.map.data['block_id']) - - # save signarture info - for key, sign in self.signatures.items(): - var = self.desc.sign2var[sign.name] - feed_desc = var.feed_desc - fetch_desc = var.fetch_desc - feed_names = sign.feed_names - fetch_names = sign.fetch_names - for index, input in enumerate(sign.inputs): - feed_var = feed_desc.add() - feed_var.var_name = self.get_var_name_with_prefix(input.name) - feed_var.alias = feed_names[index] - - for index, output in enumerate(sign.outputs): - fetch_var = fetch_desc.add() - fetch_var.var_name = self.get_var_name_with_prefix(output.name) - fetch_var.alias = fetch_names[index] - - # save default signature - utils.from_pyobj_to_module_attr( - self.default_signature.name if self.default_signature else None, - attr.map.data['default_signature']) - - # save name prefix - utils.from_pyobj_to_module_attr(self.name_prefix, - self.desc.attr.map.data["name_prefix"]) - - # save module info - module_info = attr.map.data['module_info'] - module_info.type = module_desc_pb2.MAP - utils.from_pyobj_to_module_attr(self.name, module_info.map.data['name']) - utils.from_pyobj_to_module_attr(self.version, - module_info.map.data['version']) - utils.from_pyobj_to_module_attr(self.author, - module_info.map.data['author']) - utils.from_pyobj_to_module_attr(self.author_email, - module_info.map.data['author_email']) - utils.from_pyobj_to_module_attr(self.type, module_info.map.data['type']) - utils.from_pyobj_to_module_attr(self.summary, - module_info.map.data['summary']) - - # save extra info - extra_info = attr.map.data['extra_info'] - extra_info.type = module_desc_pb2.MAP - for key, value in self.extra_info.items(): - utils.from_pyobj_to_module_attr(value, extra_info.map.data[key]) - def __call__(self, sign_name, data, use_gpu=False, batch_size=1, **kwargs): self.check_processor() @@ -525,6 +516,10 @@ class Module(object): if not self.processor: raise ValueError("This Module is not callable!") + @property + def is_runable(self): + return self.default_signature != None + def context(self, sign_name=None, for_test=False, @@ -664,93 +659,3 @@ class Module(object): raise ValueError( "All input and outputs variables in signature should come from the same Program" ) - - def serialize_to_path(self, path=None, exe=None): - self._check_signatures() - self._generate_desc() - # create module path for saving - if path is None: - path = os.path.join(".", self.name) - self.helper = ModuleHelper(path) - utils.mkdir(self.helper.module_dir) - - # create module pb - module_desc = module_desc_pb2.ModuleDesc() - logger.info("PaddleHub version = %s" % version.hub_version) - logger.info("PaddleHub Module proto version = %s" % - version.module_proto_version) - logger.info("Paddle version = %s" % paddle.__version__) - - feeded_var_names = [ - input.name for key, sign in self.signatures.items() - for input in sign.inputs - ] - target_vars = [ - output for key, sign in self.signatures.items() - for output in sign.outputs - ] - feeded_var_names = list(set(feeded_var_names)) - target_vars = list(set(target_vars)) - - # save inference program - program = self.program.clone() - - for block in program.blocks: - for op in block.ops: - if "op_callstack" in op.all_attrs(): - op._set_attr("op_callstack", [""]) - - if not exe: - place = fluid.CPUPlace() - exe = fluid.Executor(place=place) - utils.mkdir(self.helper.model_path()) - fluid.io.save_inference_model( - self.helper.model_path(), - feeded_var_names=list(feeded_var_names), - target_vars=list(target_vars), - main_program=program, - executor=exe) - - with open(os.path.join(self.helper.model_path(), "__model__"), - "rb") as file: - program_desc_str = file.read() - rename_program = fluid.framework.Program.parse_from_string( - program_desc_str) - varlist = { - var: block - for block in rename_program.blocks for var in block.vars - if self.get_name_prefix() not in var - } - for var, block in varlist.items(): - old_name = var - new_name = self.get_var_name_with_prefix(old_name) - block._rename_var(old_name, new_name) - utils.mkdir(self.helper.model_path()) - with open( - os.path.join(self.helper.model_path(), "__model__"), - "wb") as f: - f.write(rename_program.desc.serialize_to_string()) - - for file in os.listdir(self.helper.model_path()): - if (file == "__model__" or self.get_name_prefix() in file): - continue - os.rename( - os.path.join(self.helper.model_path(), file), - os.path.join(self.helper.model_path(), - self.get_var_name_with_prefix(file))) - - # create processor file - if self.processor: - self._dump_processor() - - # create assets - self._dump_assets() - - # create check info - checker = ModuleChecker(self.helper.module_dir) - checker.generate_check_info() - - # Serialize module_desc pb - module_pb = self.desc.SerializeToString() - with open(self.helper.module_desc_path(), "wb") as f: - f.write(module_pb) diff --git a/paddlehub/serving/bert_serving/bert_service.py b/paddlehub/serving/bert_serving/bert_service.py index fa873d9d1b7010a847caddbc7f75da56311f772c..d78698fa965fabe9664e5089792f56093e33e792 100644 --- a/paddlehub/serving/bert_serving/bert_service.py +++ b/paddlehub/serving/bert_serving/bert_service.py @@ -14,7 +14,6 @@ # limitations under the License. import sys -import time import paddlehub as hub import ujson import random @@ -30,7 +29,7 @@ if is_py3: import http.client as httplib -class BertService(): +class BertService(object): def __init__(self, profile=False, max_seq_len=128, @@ -42,7 +41,7 @@ class BertService(): load_balance='round_robin'): self.process_id = process_id self.reader_flag = False - self.batch_size = 16 + self.batch_size = 0 self.max_seq_len = max_seq_len self.profile = profile self.model_name = model_name @@ -55,34 +54,29 @@ class BertService(): self.feed_var_names = '' self.retry = retry - def connect(self, server='127.0.0.1:8010'): + module = hub.Module(name=self.model_name) + inputs, outputs, program = module.context( + trainable=True, max_seq_len=self.max_seq_len) + input_ids = inputs["input_ids"] + position_ids = inputs["position_ids"] + segment_ids = inputs["segment_ids"] + input_mask = inputs["input_mask"] + self.feed_var_names = input_ids.name + ';' + position_ids.name + ';' + segment_ids.name + ';' + input_mask.name + self.reader = hub.reader.ClassifyReader( + vocab_path=module.get_vocab_path(), + dataset=None, + max_seq_len=self.max_seq_len, + do_lower_case=self.do_lower_case) + self.reader_flag = True + + def add_server(self, server='127.0.0.1:8010'): self.server_list.append(server) - def connect_all_server(self, server_list): + def add_server_list(self, server_list): for server_str in server_list: self.server_list.append(server_str) - def data_convert(self, text): - if self.reader_flag == False: - module = hub.Module(name=self.model_name) - inputs, outputs, program = module.context( - trainable=True, max_seq_len=self.max_seq_len) - input_ids = inputs["input_ids"] - position_ids = inputs["position_ids"] - segment_ids = inputs["segment_ids"] - input_mask = inputs["input_mask"] - self.feed_var_names = input_ids.name + ';' + position_ids.name + ';' + segment_ids.name + ';' + input_mask.name - self.reader = hub.reader.ClassifyReader( - vocab_path=module.get_vocab_path(), - dataset=None, - max_seq_len=self.max_seq_len, - do_lower_case=self.do_lower_case) - self.reader_flag = True - - return self.reader.data_generator( - batch_size=self.batch_size, phase='predict', data=text) - - def infer(self, request_msg): + def request_server(self, request_msg): if self.load_balance == 'round_robin': try: cur_con = httplib.HTTPConnection( @@ -157,17 +151,13 @@ class BertService(): self.server_list) return 'retry' - def encode(self, text): - if type(text) != list: - raise TypeError('Only support list') + def prepare_data(self, text): self.batch_size = len(text) - data_generator = self.data_convert(text) - start = time.time() - request_time = 0 - result = [] + data_generator = self.reader.data_generator( + batch_size=self.batch_size, phase='predict', data=text) + request_msg = "" for run_step, batch in enumerate(data_generator(), start=1): request = [] - copy_start = time.time() token_list = batch[0][0].reshape(-1).tolist() pos_list = batch[0][1].reshape(-1).tolist() sent_list = batch[0][2].reshape(-1).tolist() @@ -184,54 +174,34 @@ class BertService(): si + 1) * self.max_seq_len] request.append(instance_dict) - copy_time = time.time() - copy_start request = {"instances": request} request["max_seq_len"] = self.max_seq_len request["feed_var_names"] = self.feed_var_names request_msg = ujson.dumps(request) if self.show_ids: logger.info(request_msg) - request_start = time.time() - response_msg = self.infer(request_msg) - retry = 0 - while type(response_msg) == str and response_msg == 'retry': - if retry < self.retry: - retry += 1 - logger.info('Try to connect another servers') - response_msg = self.infer(request_msg) - else: - logger.error('Infer failed after {} times retry'.format( - self.retry)) - break - for msg in response_msg["instances"]: - for sample in msg["instances"]: - result.append(sample["values"]) - - request_time += time.time() - request_start - total_time = time.time() - start - if self.profile: - return [ - total_time, request_time, response_msg['op_time'], - response_msg['infer_time'], copy_time - ] - else: - return result - - -def connect(input_text, - model_name, - max_seq_len=128, - show_ids=False, - do_lower_case=True, - server="127.0.0.1:8866", - retry=3): - # format of input_text like [["As long as"],] - bc = BertService( - max_seq_len=max_seq_len, - model_name=model_name, - show_ids=show_ids, - do_lower_case=do_lower_case, - retry=retry) - bc.connect(server) - result = bc.encode(input_text) - return result + + return request_msg + + def encode(self, text): + if type(text) != list: + raise TypeError('Only support list') + request_msg = self.prepare_data(text) + + response_msg = self.request_server(request_msg) + retry = 0 + while type(response_msg) == str and response_msg == 'retry': + if retry < self.retry: + retry += 1 + logger.info('Try to connect another servers') + response_msg = self.request_server(request_msg) + else: + logger.error('Request failed after {} times retry'.format( + self.retry)) + break + result = [] + for msg in response_msg["instances"]: + for sample in msg["instances"]: + result.append(sample["values"]) + + return result diff --git a/paddlehub/serving/bert_serving/bs_client.py b/paddlehub/serving/bert_serving/bs_client.py new file mode 100644 index 0000000000000000000000000000000000000000..f2871a6109e15d11cc534431a698bdd471c74652 --- /dev/null +++ b/paddlehub/serving/bert_serving/bs_client.py @@ -0,0 +1,21 @@ +from paddlehub.serving.bert_serving import bert_service + + +class BSClient(object): + def __init__(self, + module_name, + server, + max_seq_len=20, + show_ids=False, + do_lower_case=True, + retry=3): + self.bs = bert_service.BertService( + model_name=module_name, + max_seq_len=max_seq_len, + show_ids=show_ids, + do_lower_case=do_lower_case, + retry=retry) + self.bs.add_server(server=server) + + def get_result(self, input_text): + return self.bs.encode(input_text)