未验证 提交 fa6f8d16 编写于 作者: B Bin Long 提交者: GitHub

Merge branch 'develop' into add_search_tip

...@@ -16,12 +16,6 @@ jobs: ...@@ -16,12 +16,6 @@ jobs:
os: linux os: linux
python: 3.6 python: 3.6
script: /bin/bash ./scripts/check_code_style.sh script: /bin/bash ./scripts/check_code_style.sh
- name: "CI on Linux/Python3.5"
os: linux
python: 3.5
- name: "CI on Linux/Python2.7"
os: linux
python: 2.7
env: env:
- PYTHONPATH=${PWD} - PYTHONPATH=${PWD}
...@@ -30,10 +24,6 @@ install: ...@@ -30,10 +24,6 @@ install:
- pip install --upgrade paddlepaddle - pip install --upgrade paddlepaddle
- pip install -r requirements.txt - pip install -r requirements.txt
script:
- if [[ $TRAVIS_OS_NAME != "windows" ]]; then /bin/bash ./scripts/test_cml.sh; fi
- if [[ $TRAVIS_OS_NAME != "windows" ]]; then /bin/bash ./scripts/test_all_module.sh; fi
notifications: notifications:
email: email:
on_success: change on_success: change
......
...@@ -68,7 +68,7 @@ $ pip install ujson ...@@ -68,7 +68,7 @@ $ pip install ujson
|模型|网络| |模型|网络|
|:-|:-:| |:-|:-:|
|[ERNIE](https://paddlepaddle.org.cn/hubdetail?name=ERNIE&en_category=SemanticModel)|ERNIE| |[ernie](https://paddlepaddle.org.cn/hubdetail?name=ERNIE&en_category=SemanticModel)|ERNIE|
|[ernie_tiny](https://paddlepaddle.org.cn/hubdetail?name=ernie_tiny&en_category=SemanticModel)|ERNIE| |[ernie_tiny](https://paddlepaddle.org.cn/hubdetail?name=ernie_tiny&en_category=SemanticModel)|ERNIE|
|[ernie_v2_eng_large](https://paddlepaddle.org.cn/hubdetail?name=ernie_v2_eng_large&en_category=SemanticModel)|ERNIE| |[ernie_v2_eng_large](https://paddlepaddle.org.cn/hubdetail?name=ernie_v2_eng_large&en_category=SemanticModel)|ERNIE|
|[ernie_v2_eng_base](https://paddlepaddle.org.cn/hubdetail?name=ernie_v2_eng_base&en_category=SemanticModel)|ERNIE| |[ernie_v2_eng_base](https://paddlepaddle.org.cn/hubdetail?name=ernie_v2_eng_base&en_category=SemanticModel)|ERNIE|
...@@ -179,18 +179,22 @@ Server[baidu::paddle_serving::predictor::bert_service::BertServiceImpl] is servi ...@@ -179,18 +179,22 @@ Server[baidu::paddle_serving::predictor::bert_service::BertServiceImpl] is servi
首先导入客户端依赖。 首先导入客户端依赖。
```python ```python
from paddlehub.serving.bert_serving import bert_service from paddlehub.serving.bert_serving import bs_client
``` ```
接着输入文本信息。
接着启动并初始化`bert service`客户端`BSClient`(这里的server为虚拟地址,需根据自己实际ip设置)
```python
bc = bs_client.BSClient(module_name="ernie_tiny", server="127.0.0.1:8866")
```
然后输入文本信息。
```python ```python
input_text = [["西风吹老洞庭波"], ["一夜湘君白发多"], ["醉后不知天在水"], ["满船清梦压星河"], ] input_text = [["西风吹老洞庭波"], ["一夜湘君白发多"], ["醉后不知天在水"], ["满船清梦压星河"], ]
``` ```
然后利用客户端接口发送文本到服务端,以获取embedding结果(server为虚拟地址,需根据自己实际ip设置)。
最后利用客户端接口`get_result`发送文本到服务端,以获取embedding结果。
```python ```python
result = bert_service.connect( result = bc.get_result(input_text=input_text)
input_text=input_text,
model_name="ernie_tiny",
server="127.0.0.1:8866")
``` ```
最后即可得到embedding结果(此处只展示部分结果)。 最后即可得到embedding结果(此处只展示部分结果)。
```python ```python
...@@ -221,8 +225,8 @@ Paddle Inference Server exit successfully! ...@@ -221,8 +225,8 @@ Paddle Inference Server exit successfully!
> Q : 如何在一台服务器部署多个模型? > Q : 如何在一台服务器部署多个模型?
> A : 可通过多次启动`Bert Service`,分配不同端口实现。如果使用GPU,需要指定不同的显卡。如同时部署`ernie`和`bert_chinese_L-12_H-768_A-12`,分别执行命令如下: > A : 可通过多次启动`Bert Service`,分配不同端口实现。如果使用GPU,需要指定不同的显卡。如同时部署`ernie`和`bert_chinese_L-12_H-768_A-12`,分别执行命令如下:
> ```shell > ```shell
> $ hub serving start bert_serving -m ernie -p 8866 > $ hub serving start bert_service -m ernie -p 8866
> $ hub serving start bert_serving -m bert_serving -m bert_chinese_L-12_H-768_A-12 -p 8867 > $ hub serving start bert_service -m bert_chinese_L-12_H-768_A-12 -p 8867
> ``` > ```
> Q : 启动时显示"Check out http://yq01-gpu-255-129-12-00.epc.baidu.com:8887 in web > Q : 启动时显示"Check out http://yq01-gpu-255-129-12-00.epc.baidu.com:8887 in web
......
# coding: utf8 # coding: utf8
from paddlehub.serving.bert_serving import bert_service from paddlehub.serving.bert_serving import bs_client
if __name__ == "__main__": if __name__ == "__main__":
# 初始化bert_service客户端BSClient
bc = bs_client.BSClient(module_name="ernie_tiny", server="127.0.0.1:8866")
# 输入要做embedding的文本 # 输入要做embedding的文本
# 文本格式为[["文本1"], ["文本2"], ] # 文本格式为[["文本1"], ["文本2"], ]
input_text = [ input_text = [
...@@ -10,10 +13,10 @@ if __name__ == "__main__": ...@@ -10,10 +13,10 @@ if __name__ == "__main__":
["醉后不知天在水"], ["醉后不知天在水"],
["满船清梦压星河"], ["满船清梦压星河"],
] ]
# 调用客户端接口bert_service.connect()获取结果
result = bert_service.connect(
input_text=input_text, model_name="ernie_tiny", server="127.0.0.1:8866")
# 打印embedding结果 # BSClient.get_result()获取结果
result = bc.get_result(input_text=input_text)
# 打印输入文本的embedding结果
for item in result: for item in result:
print(item) print(item)
...@@ -38,7 +38,7 @@ from .common.logger import logger ...@@ -38,7 +38,7 @@ from .common.logger import logger
from .common.paddle_helper import connect_program from .common.paddle_helper import connect_program
from .common.hub_server import default_hub_server from .common.hub_server import default_hub_server
from .module.module import Module, create_module from .module.module import Module
from .module.base_processor import BaseProcessor from .module.base_processor import BaseProcessor
from .module.signature import Signature, create_signature from .module.signature import Signature, create_signature
from .module.manager import default_module_manager from .module.manager import default_module_manager
......
...@@ -26,6 +26,7 @@ from tb_paddle import SummaryWriter ...@@ -26,6 +26,7 @@ from tb_paddle import SummaryWriter
from paddlehub.common.logger import logger from paddlehub.common.logger import logger
from paddlehub.common.utils import mkdir from paddlehub.common.utils import mkdir
from paddlehub.autofinetune.evaluator import REWARD_SUM, TMP_HOME from paddlehub.autofinetune.evaluator import REWARD_SUM, TMP_HOME
from paddlehub.autofinetune.mpi_helper import MPIHelper
if six.PY3: if six.PY3:
INF = math.inf INF = math.inf
...@@ -75,6 +76,12 @@ class BaseTuningStrategy(object): ...@@ -75,6 +76,12 @@ class BaseTuningStrategy(object):
logdir=self._output_dir + '/visualization/pop_{}'.format(i)) logdir=self._output_dir + '/visualization/pop_{}'.format(i))
self.writer_pop_trails.append(writer_pop_trail) self.writer_pop_trails.append(writer_pop_trail)
# for parallel on mpi
self.mpi = MPIHelper()
if self.mpi.multi_machine:
print("Autofinetune multimachine mode: running on {}".format(
self.mpi.gather(self.mpi.name)))
@property @property
def thread(self): def thread(self):
return self._num_thread return self._num_thread
...@@ -177,16 +184,22 @@ class BaseTuningStrategy(object): ...@@ -177,16 +184,22 @@ class BaseTuningStrategy(object):
solutions_modeldirs = {} solutions_modeldirs = {}
mkdir(output_dir) mkdir(output_dir)
for idx, solution in enumerate(solutions): solutions = self.mpi.bcast(solutions)
# split solutions to "solutions for me"
range_start, range_end = self.mpi.split_range(len(solutions))
my_solutions = solutions[range_start:range_end]
for idx, solution in enumerate(my_solutions):
cuda = self.is_cuda_free["free"][0] cuda = self.is_cuda_free["free"][0]
modeldir = output_dir + "/model-" + str(idx) + "/" modeldir = output_dir + "/model-" + str(idx) + "/"
log_file = output_dir + "/log-" + str(idx) + ".info" log_file = output_dir + "/log-" + str(idx) + ".info"
params_cudas_dirs.append([solution, cuda, modeldir, log_file]) params_cudas_dirs.append([solution, cuda, modeldir, log_file])
solutions_modeldirs[tuple(solution)] = modeldir solutions_modeldirs[tuple(solution)] = (modeldir, self.mpi.rank)
self.is_cuda_free["free"].remove(cuda) self.is_cuda_free["free"].remove(cuda)
self.is_cuda_free["busy"].append(cuda) self.is_cuda_free["busy"].append(cuda)
if len(params_cudas_dirs if len(params_cudas_dirs
) == self.thread or idx == len(solutions) - 1: ) == self.thread or idx == len(my_solutions) - 1:
tp = ThreadPool(len(params_cudas_dirs)) tp = ThreadPool(len(params_cudas_dirs))
solution_results += tp.map(self.evaluator.run, solution_results += tp.map(self.evaluator.run,
params_cudas_dirs) params_cudas_dirs)
...@@ -198,13 +211,25 @@ class BaseTuningStrategy(object): ...@@ -198,13 +211,25 @@ class BaseTuningStrategy(object):
self.is_cuda_free["busy"].remove(param_cuda[1]) self.is_cuda_free["busy"].remove(param_cuda[1])
params_cudas_dirs = [] params_cudas_dirs = []
self.feedback(solutions, solution_results) all_solution_results = self.mpi.gather(solution_results)
if self.mpi.rank == 0:
# only rank 0 need to feedback
all_solution_results = [y for x in all_solution_results for y in x]
self.feedback(solutions, all_solution_results)
# remove the tmp.txt which records the eval results for trials # remove the tmp.txt which records the eval results for trials
tmp_file = os.path.join(TMP_HOME, "tmp.txt") tmp_file = os.path.join(TMP_HOME, "tmp.txt")
if os.path.exists(tmp_file): if os.path.exists(tmp_file):
os.remove(tmp_file) os.remove(tmp_file)
return solutions_modeldirs # collect all solutions_modeldirs
collected_solutions_modeldirs = self.mpi.allgather(solutions_modeldirs)
return_dict = {}
for i in collected_solutions_modeldirs:
return_dict.update(i)
return return_dict
class HAZero(BaseTuningStrategy): class HAZero(BaseTuningStrategy):
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
class MPIHelper(object):
def __init__(self):
try:
from mpi4py import MPI
except:
# local run
self._size = 1
self._rank = 0
self._multi_machine = False
import socket
self._name = socket.gethostname()
else:
# in mpi environment
self._comm = MPI.COMM_WORLD
self._size = self._comm.Get_size()
self._rank = self._comm.Get_rank()
self._name = MPI.Get_processor_name()
if self._size > 1:
self._multi_machine = True
else:
self._multi_machine = False
@property
def multi_machine(self):
return self._multi_machine
@property
def rank(self):
return self._rank
@property
def size(self):
return self._size
@property
def name(self):
return self._name
def bcast(self, data):
if self._multi_machine:
# call real bcast
return self._comm.bcast(data, root=0)
else:
# do nothing
return data
def gather(self, data):
if self._multi_machine:
# call real gather
return self._comm.gather(data, root=0)
else:
# do nothing
return [data]
def allgather(self, data):
if self._multi_machine:
# call real allgather
return self._comm.allgather(data)
else:
# do nothing
return [data]
# calculate split range on mpi environment
def split_range(self, array_length):
if self._size == 1:
return 0, array_length
average_count = array_length / self._size
if array_length % self._size == 0:
return average_count * self._rank, average_count * (self._rank + 1)
else:
if self._rank < array_length % self._size:
return (average_count + 1) * self._rank, (average_count + 1) * (
self._rank + 1)
else:
start = (average_count + 1) * (array_length % self._size) \
+ average_count * (self._rank - array_length % self._size)
return start, start + average_count
if __name__ == "__main__":
mpi = MPIHelper()
print("Hello world from process {} of {} at {}.".format(
mpi.rank, mpi.size, mpi.name))
all_node_names = mpi.gather(mpi.name)
print("all node names using gather: {}".format(all_node_names))
all_node_names = mpi.allgather(mpi.name)
print("all node names using allgather: {}".format(all_node_names))
if mpi.rank == 0:
data = range(10)
else:
data = None
data = mpi.bcast(data)
print("after bcast, process {} have data {}".format(mpi.rank, data))
data = [i + mpi.rank for i in data]
print("after modify, process {} have data {}".format(mpi.rank, data))
new_data = mpi.gather(data)
print("after gather, process {} have data {}".format(mpi.rank, new_data))
# test for split
for i in range(12):
length = i + mpi.size # length should >= mpi.size
[start, end] = mpi.split_range(length)
split_result = mpi.gather([start, end])
print("length {}, split_result {}".format(length, split_result))
...@@ -188,37 +188,62 @@ class AutoFineTuneCommand(BaseCommand): ...@@ -188,37 +188,62 @@ class AutoFineTuneCommand(BaseCommand):
run_round_cnt = run_round_cnt + 1 run_round_cnt = run_round_cnt + 1
print("PaddleHub Autofinetune ends.") print("PaddleHub Autofinetune ends.")
best_hparams_origin = autoft.get_best_hparams()
best_hparams_origin = autoft.mpi.bcast(best_hparams_origin)
with open(autoft._output_dir + "/log_file.txt", "w") as f: with open(autoft._output_dir + "/log_file.txt", "w") as f:
best_hparams = evaluator.convert_params(autoft.get_best_hparams()) best_hparams = evaluator.convert_params(best_hparams_origin)
print("The final best hyperparameters:") print("The final best hyperparameters:")
f.write("The final best hyperparameters:\n") f.write("The final best hyperparameters:\n")
for index, hparam_name in enumerate(autoft.hparams_name_list): for index, hparam_name in enumerate(autoft.hparams_name_list):
print("%s=%s" % (hparam_name, best_hparams[index])) print("%s=%s" % (hparam_name, best_hparams[index]))
f.write(hparam_name + "\t:\t" + str(best_hparams[index]) + "\n") f.write(hparam_name + "\t:\t" + str(best_hparams[index]) + "\n")
best_hparams_dir, best_hparams_rank = solutions_modeldirs[tuple(
best_hparams_origin)]
print("The final best eval score is %s." % print("The final best eval score is %s." %
autoft.get_best_eval_value()) autoft.get_best_eval_value())
if autoft.mpi.multi_machine:
print("The final best model parameters are saved as " +
autoft._output_dir + "/best_model on rank " +
str(best_hparams_rank) + " .")
else:
print("The final best model parameters are saved as " + print("The final best model parameters are saved as " +
autoft._output_dir + "/best_model .") autoft._output_dir + "/best_model .")
f.write("The final best eval score is %s.\n" % f.write("The final best eval score is %s.\n" %
autoft.get_best_eval_value()) autoft.get_best_eval_value())
f.write(
"The final best model parameters are saved as ./best_model .")
best_model_dir = autoft._output_dir + "/best_model" best_model_dir = autoft._output_dir + "/best_model"
shutil.copytree(
solutions_modeldirs[tuple(autoft.get_best_hparams())],
best_model_dir)
if autoft.mpi.rank == best_hparams_rank:
shutil.copytree(best_hparams_dir, best_model_dir)
if autoft.mpi.multi_machine:
f.write(
"The final best model parameters are saved as ./best_model on rank " \
+ str(best_hparams_rank) + " .")
f.write("\t".join(autoft.hparams_name_list) +
"\tsaved_params_dir\trank\n")
else:
f.write(
"The final best model parameters are saved as ./best_model ."
)
f.write("\t".join(autoft.hparams_name_list) + f.write("\t".join(autoft.hparams_name_list) +
"\tsaved_params_dir\n") "\tsaved_params_dir\n")
print( print(
"The related infomation about hyperparamemters searched are saved as %s/log_file.txt ." "The related infomation about hyperparamemters searched are saved as %s/log_file.txt ."
% autoft._output_dir) % autoft._output_dir)
for solution, modeldir in solutions_modeldirs.items(): for solution, modeldir in solutions_modeldirs.items():
param = evaluator.convert_params(solution) param = evaluator.convert_params(solution)
param = [str(p) for p in param] param = [str(p) for p in param]
f.write("\t".join(param) + "\t" + modeldir + "\n") if autoft.mpi.multi_machine:
f.write("\t".join(param) + "\t" + modeldir[0] + "\t" +
str(modeldir[1]) + "\n")
else:
f.write("\t".join(param) + "\t" + modeldir[0] + "\n")
return True return True
......
...@@ -18,6 +18,7 @@ from __future__ import division ...@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse import argparse
import os
from paddlehub.common import utils from paddlehub.common import utils
from paddlehub.module.manager import default_module_manager from paddlehub.module.manager import default_module_manager
...@@ -42,14 +43,23 @@ class InstallCommand(BaseCommand): ...@@ -42,14 +43,23 @@ class InstallCommand(BaseCommand):
print("ERROR: Please specify a module name.\n") print("ERROR: Please specify a module name.\n")
self.help() self.help()
return False return False
extra = {"command": "install"}
if argv[0].endswith("tar.gz") or argv[0].endswith("phm"):
result, tips, module_dir = default_module_manager.install_module(
module_package=argv[0], extra=extra)
elif os.path.exists(argv[0]) and os.path.isdir(argv[0]):
result, tips, module_dir = default_module_manager.install_module(
module_dir=argv[0], extra=extra)
else:
module_name = argv[0] module_name = argv[0]
module_version = None if "==" not in module_name else module_name.split( module_version = None if "==" not in module_name else module_name.split(
"==")[1] "==")[1]
module_name = module_name if "==" not in module_name else module_name.split( module_name = module_name if "==" not in module_name else module_name.split(
"==")[0] "==")[0]
extra = {"command": "install"}
result, tips, module_dir = default_module_manager.install_module( result, tips, module_dir = default_module_manager.install_module(
module_name=module_name, module_version=module_version, extra=extra) module_name=module_name,
module_version=module_version,
extra=extra)
print(tips) print(tips)
return True return True
......
...@@ -71,7 +71,7 @@ class RunCommand(BaseCommand): ...@@ -71,7 +71,7 @@ class RunCommand(BaseCommand):
if not result: if not result:
return None return None
return hub.Module(module_dir=module_dir) return hub.Module(directory=module_dir[0])
def add_module_config_arg(self): def add_module_config_arg(self):
configs = self.module.processor.configs() configs = self.module.processor.configs()
...@@ -105,7 +105,7 @@ class RunCommand(BaseCommand): ...@@ -105,7 +105,7 @@ class RunCommand(BaseCommand):
def add_module_input_arg(self): def add_module_input_arg(self):
module_type = self.module.type.lower() module_type = self.module.type.lower()
expect_data_format = self.module.processor.data_format( expect_data_format = self.module.processor.data_format(
self.module.default_signature.name) self.module.default_signature)
self.arg_input_group.add_argument( self.arg_input_group.add_argument(
'--input_file', '--input_file',
type=str, type=str,
...@@ -152,7 +152,7 @@ class RunCommand(BaseCommand): ...@@ -152,7 +152,7 @@ class RunCommand(BaseCommand):
def get_data(self): def get_data(self):
module_type = self.module.type.lower() module_type = self.module.type.lower()
expect_data_format = self.module.processor.data_format( expect_data_format = self.module.processor.data_format(
self.module.default_signature.name) self.module.default_signature)
input_data = {} input_data = {}
if len(expect_data_format) == 1: if len(expect_data_format) == 1:
key = list(expect_data_format.keys())[0] key = list(expect_data_format.keys())[0]
...@@ -177,7 +177,7 @@ class RunCommand(BaseCommand): ...@@ -177,7 +177,7 @@ class RunCommand(BaseCommand):
def check_data(self, data): def check_data(self, data):
expect_data_format = self.module.processor.data_format( expect_data_format = self.module.processor.data_format(
self.module.default_signature.name) self.module.default_signature)
if len(data.keys()) != len(expect_data_format.keys()): if len(data.keys()) != len(expect_data_format.keys()):
print( print(
...@@ -236,10 +236,13 @@ class RunCommand(BaseCommand): ...@@ -236,10 +236,13 @@ class RunCommand(BaseCommand):
return False return False
# If the module is not executable, give an alarm and exit # If the module is not executable, give an alarm and exit
if not self.module.default_signature: if not self.module.is_runable:
print("ERROR! Module %s is not executable." % module_name) print("ERROR! Module %s is not executable." % module_name)
return False return False
if self.module.code_version == "v2":
results = self.module(argv[1:])
else:
self.module.check_processor() self.module.check_processor()
self.add_module_config_arg() self.add_module_config_arg()
self.add_module_input_arg() self.add_module_input_arg()
...@@ -260,7 +263,7 @@ class RunCommand(BaseCommand): ...@@ -260,7 +263,7 @@ class RunCommand(BaseCommand):
return False return False
results = self.module( results = self.module(
sign_name=self.module.default_signature.name, sign_name=self.module.default_signature,
data=data, data=data,
use_gpu=self.args.use_gpu, use_gpu=self.args.use_gpu,
batch_size=self.args.batch_size, batch_size=self.args.batch_size,
......
...@@ -159,7 +159,7 @@ class ServingCommand(BaseCommand): ...@@ -159,7 +159,7 @@ class ServingCommand(BaseCommand):
module = args.modules module = args.modules
if module is not None: if module is not None:
use_gpu = args.use_gpu use_gpu = args.use_gpu
port = args.port[0] port = args.port
if ServingCommand.is_port_occupied("127.0.0.1", port) is True: if ServingCommand.is_port_occupied("127.0.0.1", port) is True:
print("Port %s is occupied, please change it." % (port)) print("Port %s is occupied, please change it." % (port))
return False return False
...@@ -206,10 +206,12 @@ class ServingCommand(BaseCommand): ...@@ -206,10 +206,12 @@ class ServingCommand(BaseCommand):
if args.sub_command == "start": if args.sub_command == "start":
if args.bert_service == "bert_service": if args.bert_service == "bert_service":
ServingCommand.start_bert_serving(args) ServingCommand.start_bert_serving(args)
else: elif args.bert_service is None:
ServingCommand.start_serving(args) ServingCommand.start_serving(args)
else: else:
ServingCommand.show_help() ServingCommand.show_help()
else:
ServingCommand.show_help()
command = ServingCommand.instance() command = ServingCommand.instance()
...@@ -125,8 +125,6 @@ class ShowCommand(BaseCommand): ...@@ -125,8 +125,6 @@ class ShowCommand(BaseCommand):
cwd = os.getcwd() cwd = os.getcwd()
module_dir = default_module_manager.search_module(module_name) module_dir = default_module_manager.search_module(module_name)
module_dir = (os.path.join(cwd, module_name),
None) if not module_dir else module_dir
if not module_dir or not os.path.exists(module_dir[0]): if not module_dir or not os.path.exists(module_dir[0]):
print("%s is not existed!" % module_name) print("%s is not existed!" % module_name)
return True return True
......
...@@ -298,17 +298,23 @@ class CacheUpdater(threading.Thread): ...@@ -298,17 +298,23 @@ class CacheUpdater(threading.Thread):
api_url = srv_utils.uri_path(default_hub_server.get_server_url(), api_url = srv_utils.uri_path(default_hub_server.get_server_url(),
'search') 'search')
cache_path = os.path.join(CACHE_HOME, RESOURCE_LIST_FILE) cache_path = os.path.join(CACHE_HOME, RESOURCE_LIST_FILE)
if os.path.exists(cache_path):
extra = { extra = {
"command": "update_cache", "command": "update_cache",
"mtime": os.stat(cache_path).st_mtime "mtime": os.stat(cache_path).st_mtime
} }
else:
extra = {
"command": "update_cache",
"mtime": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
}
try: try:
r = srv_utils.hub_request(api_url, payload, extra) r = srv_utils.hub_request(api_url, payload, extra)
except Exception as err:
pass
if r.get("update_cache", 0) == 1: if r.get("update_cache", 0) == 1:
with open(cache_path, 'w+') as fp: with open(cache_path, 'w+') as fp:
yaml.safe_dump({'resource_list': r['data']}, fp) yaml.safe_dump({'resource_list': r['data']}, fp)
except Exception as err:
pass
def run(self): def run(self):
self.update_resource_list_file(self.module, self.version) self.update_resource_list_file(self.module, self.version)
......
...@@ -50,6 +50,7 @@ message CheckInfo { ...@@ -50,6 +50,7 @@ message CheckInfo {
string paddle_version = 1; string paddle_version = 1;
string hub_version = 2; string hub_version = 2;
string module_proto_version = 3; string module_proto_version = 3;
repeated FileInfo file_infos = 4; string module_code_version = 4;
repeated Requires requires = 5; repeated FileInfo file_infos = 5;
repeated Requires requires = 6;
}; };
#coding:utf-8
# Generated by the protocol buffer compiler. DO NOT EDIT! # Generated by the protocol buffer compiler. DO NOT EDIT!
# source: check_info.proto # source: check_info.proto
...@@ -19,7 +18,7 @@ DESCRIPTOR = _descriptor.FileDescriptor( ...@@ -19,7 +18,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
package='paddlehub.module.checkinfo', package='paddlehub.module.checkinfo',
syntax='proto3', syntax='proto3',
serialized_pb=_b( serialized_pb=_b(
'\n\x10\x63heck_info.proto\x12\x1apaddlehub.module.checkinfo\"\x85\x01\n\x08\x46ileInfo\x12\x11\n\tfile_name\x18\x01 \x01(\t\x12\x33\n\x04type\x18\x02 \x01(\x0e\x32%.paddlehub.module.checkinfo.FILE_TYPE\x12\x0f\n\x07is_need\x18\x03 \x01(\x08\x12\x0b\n\x03md5\x18\x04 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x05 \x01(\t\"\x84\x01\n\x08Requires\x12>\n\x0crequire_type\x18\x01 \x01(\x0e\x32(.paddlehub.module.checkinfo.REQUIRE_TYPE\x12\x0f\n\x07version\x18\x02 \x01(\t\x12\x12\n\ngreat_than\x18\x03 \x01(\x08\x12\x13\n\x0b\x64\x65scription\x18\x04 \x01(\t\"\xc8\x01\n\tCheckInfo\x12\x16\n\x0epaddle_version\x18\x01 \x01(\t\x12\x13\n\x0bhub_version\x18\x02 \x01(\t\x12\x1c\n\x14module_proto_version\x18\x03 \x01(\t\x12\x38\n\nfile_infos\x18\x04 \x03(\x0b\x32$.paddlehub.module.checkinfo.FileInfo\x12\x36\n\x08requires\x18\x05 \x03(\x0b\x32$.paddlehub.module.checkinfo.Requires*\x1e\n\tFILE_TYPE\x12\x08\n\x04\x46ILE\x10\x00\x12\x07\n\x03\x44IR\x10\x01*[\n\x0cREQUIRE_TYPE\x12\x12\n\x0ePYTHON_PACKAGE\x10\x00\x12\x0e\n\nHUB_MODULE\x10\x01\x12\n\n\x06SYSTEM\x10\x02\x12\x0b\n\x07\x43OMMAND\x10\x03\x12\x0e\n\nPY_VERSION\x10\x04\x42\x02H\x03\x62\x06proto3' '\n\x10\x63heck_info.proto\x12\x1apaddlehub.module.checkinfo\"\x85\x01\n\x08\x46ileInfo\x12\x11\n\tfile_name\x18\x01 \x01(\t\x12\x33\n\x04type\x18\x02 \x01(\x0e\x32%.paddlehub.module.checkinfo.FILE_TYPE\x12\x0f\n\x07is_need\x18\x03 \x01(\x08\x12\x0b\n\x03md5\x18\x04 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x05 \x01(\t\"\x84\x01\n\x08Requires\x12>\n\x0crequire_type\x18\x01 \x01(\x0e\x32(.paddlehub.module.checkinfo.REQUIRE_TYPE\x12\x0f\n\x07version\x18\x02 \x01(\t\x12\x12\n\ngreat_than\x18\x03 \x01(\x08\x12\x13\n\x0b\x64\x65scription\x18\x04 \x01(\t\"\xe5\x01\n\tCheckInfo\x12\x16\n\x0epaddle_version\x18\x01 \x01(\t\x12\x13\n\x0bhub_version\x18\x02 \x01(\t\x12\x1c\n\x14module_proto_version\x18\x03 \x01(\t\x12\x1b\n\x13module_code_version\x18\x04 \x01(\t\x12\x38\n\nfile_infos\x18\x05 \x03(\x0b\x32$.paddlehub.module.checkinfo.FileInfo\x12\x36\n\x08requires\x18\x06 \x03(\x0b\x32$.paddlehub.module.checkinfo.Requires*\x1e\n\tFILE_TYPE\x12\x08\n\x04\x46ILE\x10\x00\x12\x07\n\x03\x44IR\x10\x01*[\n\x0cREQUIRE_TYPE\x12\x12\n\x0ePYTHON_PACKAGE\x10\x00\x12\x0e\n\nHUB_MODULE\x10\x01\x12\n\n\x06SYSTEM\x10\x02\x12\x0b\n\x07\x43OMMAND\x10\x03\x12\x0e\n\nPY_VERSION\x10\x04\x42\x02H\x03\x62\x06proto3'
)) ))
_sym_db.RegisterFileDescriptor(DESCRIPTOR) _sym_db.RegisterFileDescriptor(DESCRIPTOR)
...@@ -36,8 +35,8 @@ _FILE_TYPE = _descriptor.EnumDescriptor( ...@@ -36,8 +35,8 @@ _FILE_TYPE = _descriptor.EnumDescriptor(
], ],
containing_type=None, containing_type=None,
options=None, options=None,
serialized_start=522, serialized_start=551,
serialized_end=552, serialized_end=581,
) )
_sym_db.RegisterEnumDescriptor(_FILE_TYPE) _sym_db.RegisterEnumDescriptor(_FILE_TYPE)
...@@ -61,8 +60,8 @@ _REQUIRE_TYPE = _descriptor.EnumDescriptor( ...@@ -61,8 +60,8 @@ _REQUIRE_TYPE = _descriptor.EnumDescriptor(
], ],
containing_type=None, containing_type=None,
options=None, options=None,
serialized_start=554, serialized_start=583,
serialized_end=645, serialized_end=674,
) )
_sym_db.RegisterEnumDescriptor(_REQUIRE_TYPE) _sym_db.RegisterEnumDescriptor(_REQUIRE_TYPE)
...@@ -316,10 +315,26 @@ _CHECKINFO = _descriptor.Descriptor( ...@@ -316,10 +315,26 @@ _CHECKINFO = _descriptor.Descriptor(
extension_scope=None, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='file_infos', name='module_code_version',
full_name='paddlehub.module.checkinfo.CheckInfo.file_infos', full_name='paddlehub.module.checkinfo.CheckInfo.module_code_version',
index=3, index=3,
number=4, number=4,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='file_infos',
full_name='paddlehub.module.checkinfo.CheckInfo.file_infos',
index=4,
number=5,
type=11, type=11,
cpp_type=10, cpp_type=10,
label=3, label=3,
...@@ -334,8 +349,8 @@ _CHECKINFO = _descriptor.Descriptor( ...@@ -334,8 +349,8 @@ _CHECKINFO = _descriptor.Descriptor(
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='requires', name='requires',
full_name='paddlehub.module.checkinfo.CheckInfo.requires', full_name='paddlehub.module.checkinfo.CheckInfo.requires',
index=4, index=5,
number=5, number=6,
type=11, type=11,
cpp_type=10, cpp_type=10,
label=3, label=3,
...@@ -357,7 +372,7 @@ _CHECKINFO = _descriptor.Descriptor( ...@@ -357,7 +372,7 @@ _CHECKINFO = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=320, serialized_start=320,
serialized_end=520, serialized_end=549,
) )
_FILEINFO.fields_by_name['type'].enum_type = _FILE_TYPE _FILEINFO.fields_by_name['type'].enum_type = _FILE_TYPE
......
...@@ -32,20 +32,22 @@ FILE_SEP = "/" ...@@ -32,20 +32,22 @@ FILE_SEP = "/"
class ModuleChecker(object): class ModuleChecker(object):
def __init__(self, module_path): def __init__(self, directory):
self.module_path = module_path self._directory = directory
self._pb_path = os.path.join(self.directory, CHECK_INFO_PB_FILENAME)
def generate_check_info(self): def generate_check_info(self):
check_info = check_info_pb2.CheckInfo() check_info = check_info_pb2.CheckInfo()
check_info.paddle_version = paddle.__version__ check_info.paddle_version = paddle.__version__
check_info.hub_version = hub_version check_info.hub_version = hub_version
check_info.module_proto_version = module_proto_version check_info.module_proto_version = module_proto_version
check_info.module_code_version = "v2"
file_infos = check_info.file_infos file_infos = check_info.file_infos
file_list = [file for file in os.listdir(self.module_path)] file_list = [file for file in os.listdir(self.directory)]
while file_list: while file_list:
file = file_list[0] file = file_list[0]
file_list = file_list[1:] file_list = file_list[1:]
abs_path = os.path.join(self.module_path, file) abs_path = os.path.join(self.directory, file)
if os.path.isdir(abs_path): if os.path.isdir(abs_path):
for sub_file in os.listdir(abs_path): for sub_file in os.listdir(abs_path):
sub_file = os.path.join(file, sub_file) sub_file = os.path.join(file, sub_file)
...@@ -62,9 +64,12 @@ class ModuleChecker(object): ...@@ -62,9 +64,12 @@ class ModuleChecker(object):
file_info.type = check_info_pb2.FILE file_info.type = check_info_pb2.FILE
file_info.is_need = True file_info.is_need = True
with open(os.path.join(self.module_path, CHECK_INFO_PB_FILENAME), with open(self.pb_path, "wb") as file:
"wb") as fi: file.write(check_info.SerializeToString())
fi.write(check_info.SerializeToString())
@property
def module_code_version(self):
return self.check_info.module_code_version
@property @property
def module_proto_version(self): def module_proto_version(self):
...@@ -82,20 +87,25 @@ class ModuleChecker(object): ...@@ -82,20 +87,25 @@ class ModuleChecker(object):
def file_infos(self): def file_infos(self):
return self.check_info.file_infos return self.check_info.file_infos
@property
def directory(self):
return self._directory
@property
def pb_path(self):
return self._pb_path
def check(self): def check(self):
result = True result = True
self.check_info_pb_path = os.path.join(self.module_path,
CHECK_INFO_PB_FILENAME)
if not (os.path.exists(self.check_info_pb_path) if not (os.path.exists(self.pb_path) or os.path.isfile(self.pb_path)):
or os.path.isfile(self.check_info_pb_path)):
logger.warning( logger.warning(
"This module lacks core file %s" % CHECK_INFO_PB_FILENAME) "This module lacks core file %s" % CHECK_INFO_PB_FILENAME)
result = False result = False
self.check_info = check_info_pb2.CheckInfo() self.check_info = check_info_pb2.CheckInfo()
try: try:
with open(self.check_info_pb_path, "rb") as fi: with open(self.pb_path, "rb") as fi:
pb_string = fi.read() pb_string = fi.read()
result = self.check_info.ParseFromString(pb_string) result = self.check_info.ParseFromString(pb_string)
if len(pb_string) == 0 or (result is not None if len(pb_string) == 0 or (result is not None
...@@ -182,7 +192,7 @@ class ModuleChecker(object): ...@@ -182,7 +192,7 @@ class ModuleChecker(object):
for file_info in self.file_infos: for file_info in self.file_infos:
file_type = file_info.type file_type = file_info.type
file_path = file_info.file_name.replace(FILE_SEP, os.sep) file_path = file_info.file_name.replace(FILE_SEP, os.sep)
file_path = os.path.join(self.module_path, file_path) file_path = os.path.join(self.directory, file_path)
if not os.path.exists(file_path): if not os.path.exists(file_path):
if file_info.is_need: if file_info.is_need:
logger.warning( logger.warning(
......
...@@ -19,7 +19,9 @@ from __future__ import print_function ...@@ -19,7 +19,9 @@ from __future__ import print_function
import os import os
import shutil import shutil
from functools import cmp_to_key from functools import cmp_to_key
import tarfile
from paddlehub.common import utils from paddlehub.common import utils
from paddlehub.common import srv_utils from paddlehub.common import srv_utils
...@@ -79,10 +81,15 @@ class LocalModuleManager(object): ...@@ -79,10 +81,15 @@ class LocalModuleManager(object):
return self.modules_dict.get(module_name, None) return self.modules_dict.get(module_name, None)
def install_module(self, def install_module(self,
module_name, module_name=None,
module_dir=None,
module_package=None,
module_version=None, module_version=None,
upgrade=False, upgrade=False,
extra=None): extra=None):
md5_value = installed_module_version = None
from_user_dir = True if module_dir else False
if module_name:
self.all_modules(update=True) self.all_modules(update=True)
module_info = self.modules_dict.get(module_name, None) module_info = self.modules_dict.get(module_name, None)
if module_info: if module_info:
...@@ -95,6 +102,62 @@ class LocalModuleManager(object): ...@@ -95,6 +102,62 @@ class LocalModuleManager(object):
module_dir) module_dir)
return True, tips, self.modules_dict[module_name] return True, tips, self.modules_dict[module_name]
search_result = hub.default_hub_server.get_module_url(
module_name, version=module_version, extra=extra)
name = search_result.get('name', None)
url = search_result.get('url', None)
md5_value = search_result.get('md5', None)
installed_module_version = search_result.get('version', None)
if not url or (module_version is not None
and installed_module_version != module_version) or (
name != module_name):
if default_hub_server._server_check() is False:
tips = "Request Hub-Server unsuccessfully, please check your network."
else:
tips = "Can't find module %s" % module_name
if module_version:
tips += " with version %s" % module_version
module_tag = module_name if not module_version else '%s-%s' % (
module_name, module_version)
return False, tips, None
result, tips, module_zip_file = default_downloader.download_file(
url=url,
save_path=hub.CACHE_HOME,
save_name=module_name,
replace=True,
print_progress=True)
result, tips, module_dir = default_downloader.uncompress(
file=module_zip_file,
dirname=MODULE_HOME,
delete_file=True,
print_progress=True)
if module_package:
with tarfile.open(module_package, "r:gz") as tar:
file_names = tar.getnames()
size = len(file_names) - 1
module_dir = os.path.split(file_names[0])[0]
module_dir = os.path.join(hub.CACHE_HOME, module_dir)
# remove cache
if os.path.exists(module_dir):
shutil.rmtree(module_dir)
for index, file_name in enumerate(file_names):
tar.extract(file_name, hub.CACHE_HOME)
if module_dir:
if not module_name:
module_name = hub.Module(directory=module_dir).name
self.all_modules(update=False)
module_info = self.modules_dict.get(module_name, None)
if module_info:
module_dir = self.modules_dict[module_name][0]
module_tag = module_name if not module_version else '%s-%s' % (
module_name, module_version)
tips = "Module %s already installed in %s" % (module_tag,
module_dir)
return True, tips, self.modules_dict[module_name]
search_result = hub.default_hub_server.get_module_url( search_result = hub.default_hub_server.get_module_url(
module_name, version=module_version, extra=extra) module_name, version=module_version, extra=extra)
name = search_result.get('name', None) name = search_result.get('name', None)
...@@ -162,9 +225,18 @@ class LocalModuleManager(object): ...@@ -162,9 +225,18 @@ class LocalModuleManager(object):
with open(os.path.join(MODULE_HOME, module_dir, "md5.txt"), with open(os.path.join(MODULE_HOME, module_dir, "md5.txt"),
"w") as fp: "w") as fp:
fp.write(md5_value) fp.write(md5_value)
if md5_value:
with open(
os.path.join(MODULE_HOME, module_dir, "md5.txt"),
"w") as fp:
fp.write(md5_value)
save_path = os.path.join(MODULE_HOME, module_name) save_path = os.path.join(MODULE_HOME, module_name)
if os.path.exists(save_path): if os.path.exists(save_path):
shutil.rmtree(save_path) shutil.move(save_path)
if from_user_dir:
shutil.copytree(module_dir, save_path)
else:
shutil.move(module_dir, save_path) shutil.move(module_dir, save_path)
module_dir = save_path module_dir = save_path
tips = "Successfully installed %s" % module_name tips = "Successfully installed %s" % module_name
......
...@@ -21,6 +21,10 @@ import os ...@@ -21,6 +21,10 @@ import os
import time import time
import sys import sys
import functools import functools
import inspect
import importlib
import tarfile
from collections import defaultdict
from shutil import copyfile from shutil import copyfile
import paddle import paddle
...@@ -28,22 +32,19 @@ import paddle.fluid as fluid ...@@ -28,22 +32,19 @@ import paddle.fluid as fluid
from paddlehub.common import utils from paddlehub.common import utils
from paddlehub.common import paddle_helper from paddlehub.common import paddle_helper
from paddlehub.common.logger import logger from paddlehub.common.dir import CACHE_HOME
from paddlehub.common.lock import lock from paddlehub.common.lock import lock
from paddlehub.common.downloader import default_downloader from paddlehub.common.logger import logger
from paddlehub.common.hub_server import CacheUpdater
from paddlehub.module import module_desc_pb2 from paddlehub.module import module_desc_pb2
from paddlehub.common.dir import CONF_HOME
from paddlehub.module import check_info_pb2 from paddlehub.module import check_info_pb2
from paddlehub.common.hub_server import CacheUpdater
from paddlehub.module.signature import Signature, create_signature
from paddlehub.module.checker import ModuleChecker
from paddlehub.module.manager import default_module_manager from paddlehub.module.manager import default_module_manager
from paddlehub.module.checker import ModuleChecker
from paddlehub.module.signature import Signature, create_signature
from paddlehub.module.base_processor import BaseProcessor from paddlehub.module.base_processor import BaseProcessor
from paddlehub.io.parser import yaml_parser from paddlehub.io.parser import yaml_parser
from paddlehub import version from paddlehub import version
__all__ = ['Module', 'create_module']
# PaddleHub module dir name # PaddleHub module dir name
ASSETS_DIRNAME = "assets" ASSETS_DIRNAME = "assets"
MODEL_DIRNAME = "model" MODEL_DIRNAME = "model"
...@@ -52,67 +53,227 @@ PYTHON_DIR = "python" ...@@ -52,67 +53,227 @@ PYTHON_DIR = "python"
PROCESSOR_NAME = "processor" PROCESSOR_NAME = "processor"
# PaddleHub var prefix # PaddleHub var prefix
HUB_VAR_PREFIX = "@HUB_%s@" HUB_VAR_PREFIX = "@HUB_%s@"
# PaddleHub Module package suffix
HUB_PACKAGE_SUFFIX = "phm"
def create_module(directory, name, author, email, module_type, summary,
version):
save_file_name = "{}-{}.{}".format(name, version, HUB_PACKAGE_SUFFIX)
# record module info and serialize
desc = module_desc_pb2.ModuleDesc()
attr = desc.attr
attr.type = module_desc_pb2.MAP
module_info = attr.map.data['module_info']
module_info.type = module_desc_pb2.MAP
utils.from_pyobj_to_module_attr(name, module_info.map.data['name'])
utils.from_pyobj_to_module_attr(author, module_info.map.data['author'])
utils.from_pyobj_to_module_attr(email, module_info.map.data['author_email'])
utils.from_pyobj_to_module_attr(module_type, module_info.map.data['type'])
utils.from_pyobj_to_module_attr(summary, module_info.map.data['summary'])
utils.from_pyobj_to_module_attr(version, module_info.map.data['version'])
module_desc_path = os.path.join(directory, "module_desc.pb")
with open(module_desc_path, "wb") as f:
f.write(desc.SerializeToString())
# generate check info
checker = ModuleChecker(directory)
checker.generate_check_info()
# add __init__
module_init_1 = os.path.join(directory, "__init__.py")
with open(module_init_1, "a") as file:
file.write("")
module_init_2 = os.path.join(directory, "python", "__init__.py")
with open(module_init_2, "a") as file:
file.write("")
# package the module
with tarfile.open(save_file_name, "w:gz") as tar:
for dirname, _, files in os.walk(directory):
for file in files:
tar.add(os.path.join(dirname, file))
os.remove(module_desc_path)
os.remove(checker.pb_path)
os.remove(module_init_1)
os.remove(module_init_2)
class Module(object):
def __new__(cls, name=None, directory=None, module_dir=None, version=None):
module = None
if cls.__name__ == "Module":
if name:
module = cls.init_with_name(name=name, version=version)
elif directory:
module = cls.init_with_directory(directory=directory)
elif module_dir:
logger.warning(
"Parameter module_dir is deprecated, please use directory to specify the path"
)
if isinstance(module_dir, list) or isinstance(
module_dir, tuple):
directory = module_dir[0]
version = module_dir[1]
else:
directory = module_dir
module = cls.init_with_directory(directory=directory)
if not module:
module = object.__new__(cls)
else:
CacheUpdater(module.name, module.version).start()
return module
def __init__(self, name=None, directory=None, module_dir=None,
version=None):
if not directory:
return
self._code_version = "v2"
self._directory = directory
self.module_desc_path = os.path.join(self.directory, MODULE_DESC_PBNAME)
self._desc = module_desc_pb2.ModuleDesc()
with open(self.module_desc_path, "rb") as file:
self._desc.ParseFromString(file.read())
module_info = self.desc.attr.map.data['module_info']
self._name = utils.from_module_attr_to_pyobj(
module_info.map.data['name'])
self._author = utils.from_module_attr_to_pyobj(
module_info.map.data['author'])
self._author_email = utils.from_module_attr_to_pyobj(
module_info.map.data['author_email'])
self._version = utils.from_module_attr_to_pyobj(
module_info.map.data['version'])
self._type = utils.from_module_attr_to_pyobj(
module_info.map.data['type'])
self._summary = utils.from_module_attr_to_pyobj(
module_info.map.data['summary'])
self._initialize()
@classmethod
def init_with_name(cls, name, version=None):
fp_lock = open(os.path.join(CACHE_HOME, name), "a")
lock.flock(fp_lock, lock.LOCK_EX)
log_msg = "Installing %s module" % name
if version:
log_msg += "-%s" % version
logger.info(log_msg)
extra = {"command": "install"}
result, tips, module_dir = default_module_manager.install_module(
module_name=name, module_version=version, extra=extra)
if not result:
logger.error(tips)
raise RuntimeError(tips)
logger.info(tips)
lock.flock(fp_lock, lock.LOCK_UN)
return cls.init_with_directory(directory=module_dir[0])
@classmethod
def init_with_directory(cls, directory):
desc_file = os.path.join(directory, MODULE_DESC_PBNAME)
checker = ModuleChecker(directory)
checker.check()
def create_module(sign_arr, module_code_version = checker.module_code_version
module_dir, if module_code_version == "v2":
processor=None, basename = os.path.split(directory)[-1]
assets=None, dirname = os.path.join(*list(os.path.split(directory)[:-1]))
module_info=None, sys.path.append(dirname)
exe=None, pymodule = importlib.import_module(
extra_info=None): "{}.python.module".format(basename))
sign_arr = utils.to_list(sign_arr) return pymodule.HubModule(directory=directory)
module = Module( return ModuleV1(directory=directory)
signatures=sign_arr,
processor=processor, @property
assets=assets, def desc(self):
module_info=module_info, return self._desc
extra_info=extra_info)
module.serialize_to_path(path=module_dir, exe=exe) @property
def directory(self):
return self._directory
@property
def author(self):
return self._author
@property
def author_email(self):
return self._author_email
@property
def summary(self):
return self._summary
@property
def type(self):
return self._type
@property
def version(self):
return self._version
@property
def name(self):
return self._name
@property
def name_prefix(self):
return self._name_prefix
@property
def code_version(self):
return self._code_version
@property
def is_runable(self):
return False
def _initialize(self):
pass
class ModuleHelper(object): class ModuleHelper(object):
def __init__(self, module_dir): def __init__(self, directory):
self.module_dir = module_dir self.directory = directory
def module_desc_path(self): def module_desc_path(self):
return os.path.join(self.module_dir, MODULE_DESC_PBNAME) return os.path.join(self.directory, MODULE_DESC_PBNAME)
def model_path(self): def model_path(self):
return os.path.join(self.module_dir, MODEL_DIRNAME) return os.path.join(self.directory, MODEL_DIRNAME)
def processor_path(self): def processor_path(self):
return os.path.join(self.module_dir, PYTHON_DIR) return os.path.join(self.directory, PYTHON_DIR)
def processor_name(self): def processor_name(self):
return PROCESSOR_NAME return PROCESSOR_NAME
def assets_path(self): def assets_path(self):
return os.path.join(self.module_dir, ASSETS_DIRNAME) return os.path.join(self.directory, ASSETS_DIRNAME)
class Module(object): class ModuleV1(Module):
def __init__(self, def __init__(self, name=None, directory=None, module_dir=None,
name=None,
module_dir=None,
signatures=None,
module_info=None,
assets=None,
processor=None,
extra_info=None,
version=None): version=None):
self.desc = module_desc_pb2.ModuleDesc() if not directory:
return
super(ModuleV1, self).__init__(name, directory, module_dir, version)
self._code_version = "v1"
self.program = None self.program = None
self.assets = [] self.assets = []
self.helper = None self.helper = None
self.signatures = {} self.signatures = {}
self.default_signature = None self.default_signature = None
self.module_info = None
self.processor = None self.processor = None
self.extra_info = {} if extra_info is None else extra_info self.extra_info = {}
if not isinstance(self.extra_info, dict):
raise TypeError(
"The extra_info should be an instance of python dict")
# cache data # cache data
self.last_call_name = None self.last_call_name = None
...@@ -120,62 +281,21 @@ class Module(object): ...@@ -120,62 +281,21 @@ class Module(object):
self.cache_fetch_dict = None self.cache_fetch_dict = None
self.cache_program = None self.cache_program = None
fp_lock = open(os.path.join(CONF_HOME, 'config.json')) self.helper = ModuleHelper(directory)
lock.flock(fp_lock, lock.LOCK_EX) exe = fluid.Executor(fluid.CPUPlace())
if name: self.program, _, _ = fluid.io.load_inference_model(
self._init_with_name(name=name, version=version) self.helper.model_path(), executor=exe)
lock.flock(fp_lock, lock.LOCK_UN) for block in self.program.blocks:
elif module_dir: for op in block.ops:
self._init_with_module_file(module_dir=module_dir[0]) if "op_callstack" in op.all_attrs():
lock.flock(fp_lock, lock.LOCK_UN) op._set_attr("op_callstack", [""])
name = module_dir[0].split("/")[-1] self._load_processor()
if len(module_dir) > 1: self._load_assets()
version = module_dir[1] self._recover_from_desc()
else: self._generate_sign_attr()
version = default_module_manager.search_module(name)[1] self._generate_extra_info()
elif signatures: self._restore_parameter(self.program)
if processor: self._recover_variable_info(self.program)
if not issubclass(processor, BaseProcessor):
raise TypeError(
"Processor shoule be an instance of paddlehub.BaseProcessor"
)
if assets:
self.assets = utils.to_list(assets)
# for asset in assets:
# utils.check_path(assets)
self.processor = processor
self._generate_module_info(module_info)
self._init_with_signature(signatures=signatures)
lock.flock(fp_lock, lock.LOCK_UN)
else:
lock.flock(fp_lock, lock.LOCK_UN)
raise ValueError("Module initialized parameter is empty")
CacheUpdater(name, version).start()
def _init_with_name(self, name, version=None):
log_msg = "Installing %s module" % name
if version:
log_msg += "-%s" % version
logger.info(log_msg)
extra = {"command": "install"}
result, tips, module_dir = default_module_manager.install_module(
module_name=name, module_version=version, extra=extra)
if not result:
logger.error(tips)
raise RuntimeError(tips)
else:
logger.info(tips)
self._init_with_module_file(module_dir[0])
def _init_with_url(self, url):
utils.check_url(url)
result, tips, module_dir = default_downloader.download_file_and_uncompress(
url, save_path=".")
if not result:
logger.error(tips)
raise RuntimeError(tips)
else:
self._init_with_module_file(module_dir)
def _dump_processor(self): def _dump_processor(self):
import inspect import inspect
...@@ -216,52 +336,6 @@ class Module(object): ...@@ -216,52 +336,6 @@ class Module(object):
filepath = os.path.join(self.helper.assets_path(), file) filepath = os.path.join(self.helper.assets_path(), file)
self.assets.append(filepath) self.assets.append(filepath)
def _init_with_module_file(self, module_dir):
checker = ModuleChecker(module_dir)
checker.check()
self.helper = ModuleHelper(module_dir)
with open(self.helper.module_desc_path(), "rb") as fi:
self.desc.ParseFromString(fi.read())
exe = fluid.Executor(fluid.CPUPlace())
self.program, _, _ = fluid.io.load_inference_model(
self.helper.model_path(), executor=exe)
for block in self.program.blocks:
for op in block.ops:
if "op_callstack" in op.all_attrs():
op._set_attr("op_callstack", [""])
self._load_processor()
self._load_assets()
self._recover_from_desc()
self._generate_sign_attr()
self._generate_extra_info()
self._restore_parameter(self.program)
self._recover_variable_info(self.program)
def _init_with_signature(self, signatures):
self.name_prefix = HUB_VAR_PREFIX % self.name
self._process_signatures(signatures)
self._check_signatures()
self._generate_desc()
self._generate_sign_attr()
self._generate_extra_info()
def _init_with_program(self, program):
pass
def _process_signatures(self, signatures):
self.signatures = {}
self.program = signatures[0].inputs[0].block.program
for sign in signatures:
if sign.name in self.signatures:
raise ValueError(
"Error! Signature array contains duplicated signatrues %s" %
sign)
if self.default_signature is None and sign.for_predict:
self.default_signature = sign
self.signatures[sign.name] = sign
def _restore_parameter(self, program): def _restore_parameter(self, program):
global_block = program.global_block() global_block = program.global_block()
param_attrs = self.desc.attr.map.data['param_attrs'] param_attrs = self.desc.attr.map.data['param_attrs']
...@@ -302,21 +376,6 @@ class Module(object): ...@@ -302,21 +376,6 @@ class Module(object):
self.__dict__["get_%s" % key] = functools.partial( self.__dict__["get_%s" % key] = functools.partial(
self.get_extra_info, key=key) self.get_extra_info, key=key)
def _generate_module_info(self, module_info=None):
if not module_info:
self.module_info = {}
else:
if not utils.is_yaml_file(module_info):
logger.critical("Module info file should be yaml format")
exit(1)
self.module_info = yaml_parser.parse(module_info)
self.author = self.module_info.get('author', 'UNKNOWN')
self.author_email = self.module_info.get('author_email', 'UNKNOWN')
self.summary = self.module_info.get('summary', 'UNKNOWN')
self.type = self.module_info.get('type', 'UNKNOWN')
self.version = self.module_info.get('version', 'UNKNOWN')
self.name = self.module_info.get('name', 'UNKNOWN')
def _generate_sign_attr(self): def _generate_sign_attr(self):
self._check_signatures() self._check_signatures()
for sign in self.signatures: for sign in self.signatures:
...@@ -369,21 +428,21 @@ class Module(object): ...@@ -369,21 +428,21 @@ class Module(object):
default_signature_name = utils.from_module_attr_to_pyobj( default_signature_name = utils.from_module_attr_to_pyobj(
self.desc.attr.map.data['default_signature']) self.desc.attr.map.data['default_signature'])
self.default_signature = self.signatures[ self.default_signature = self.signatures[
default_signature_name] if default_signature_name else None default_signature_name].name if default_signature_name else None
# recover module info # recover module info
module_info = self.desc.attr.map.data['module_info'] module_info = self.desc.attr.map.data['module_info']
self.name = utils.from_module_attr_to_pyobj( self._name = utils.from_module_attr_to_pyobj(
module_info.map.data['name']) module_info.map.data['name'])
self.author = utils.from_module_attr_to_pyobj( self._author = utils.from_module_attr_to_pyobj(
module_info.map.data['author']) module_info.map.data['author'])
self.author_email = utils.from_module_attr_to_pyobj( self._author_email = utils.from_module_attr_to_pyobj(
module_info.map.data['author_email']) module_info.map.data['author_email'])
self.version = utils.from_module_attr_to_pyobj( self._version = utils.from_module_attr_to_pyobj(
module_info.map.data['version']) module_info.map.data['version'])
self.type = utils.from_module_attr_to_pyobj( self._type = utils.from_module_attr_to_pyobj(
module_info.map.data['type']) module_info.map.data['type'])
self.summary = utils.from_module_attr_to_pyobj( self._summary = utils.from_module_attr_to_pyobj(
module_info.map.data['summary']) module_info.map.data['summary'])
# recover extra info # recover extra info
...@@ -393,77 +452,9 @@ class Module(object): ...@@ -393,77 +452,9 @@ class Module(object):
self.extra_info[key] = utils.from_module_attr_to_pyobj(value) self.extra_info[key] = utils.from_module_attr_to_pyobj(value)
# recover name prefix # recover name prefix
self.name_prefix = utils.from_module_attr_to_pyobj( self._name_prefix = utils.from_module_attr_to_pyobj(
self.desc.attr.map.data["name_prefix"]) self.desc.attr.map.data["name_prefix"])
def _generate_desc(self):
# save fluid Parameter
attr = self.desc.attr
attr.type = module_desc_pb2.MAP
param_attrs = attr.map.data['param_attrs']
param_attrs.type = module_desc_pb2.MAP
for param in self.program.global_block().iter_parameters():
param_attr = param_attrs.map.data[param.name]
paddle_helper.from_param_to_module_attr(param, param_attr)
# save Variable Info
var_infos = attr.map.data['var_infos']
var_infos.type = module_desc_pb2.MAP
for block in self.program.blocks:
for var in block.vars.values():
var_info = var_infos.map.data[var.name]
var_info.type = module_desc_pb2.MAP
utils.from_pyobj_to_module_attr(
var.stop_gradient, var_info.map.data['stop_gradient'])
utils.from_pyobj_to_module_attr(block.idx,
var_info.map.data['block_id'])
# save signarture info
for key, sign in self.signatures.items():
var = self.desc.sign2var[sign.name]
feed_desc = var.feed_desc
fetch_desc = var.fetch_desc
feed_names = sign.feed_names
fetch_names = sign.fetch_names
for index, input in enumerate(sign.inputs):
feed_var = feed_desc.add()
feed_var.var_name = self.get_var_name_with_prefix(input.name)
feed_var.alias = feed_names[index]
for index, output in enumerate(sign.outputs):
fetch_var = fetch_desc.add()
fetch_var.var_name = self.get_var_name_with_prefix(output.name)
fetch_var.alias = fetch_names[index]
# save default signature
utils.from_pyobj_to_module_attr(
self.default_signature.name if self.default_signature else None,
attr.map.data['default_signature'])
# save name prefix
utils.from_pyobj_to_module_attr(self.name_prefix,
self.desc.attr.map.data["name_prefix"])
# save module info
module_info = attr.map.data['module_info']
module_info.type = module_desc_pb2.MAP
utils.from_pyobj_to_module_attr(self.name, module_info.map.data['name'])
utils.from_pyobj_to_module_attr(self.version,
module_info.map.data['version'])
utils.from_pyobj_to_module_attr(self.author,
module_info.map.data['author'])
utils.from_pyobj_to_module_attr(self.author_email,
module_info.map.data['author_email'])
utils.from_pyobj_to_module_attr(self.type, module_info.map.data['type'])
utils.from_pyobj_to_module_attr(self.summary,
module_info.map.data['summary'])
# save extra info
extra_info = attr.map.data['extra_info']
extra_info.type = module_desc_pb2.MAP
for key, value in self.extra_info.items():
utils.from_pyobj_to_module_attr(value, extra_info.map.data[key])
def __call__(self, sign_name, data, use_gpu=False, batch_size=1, **kwargs): def __call__(self, sign_name, data, use_gpu=False, batch_size=1, **kwargs):
self.check_processor() self.check_processor()
...@@ -525,6 +516,10 @@ class Module(object): ...@@ -525,6 +516,10 @@ class Module(object):
if not self.processor: if not self.processor:
raise ValueError("This Module is not callable!") raise ValueError("This Module is not callable!")
@property
def is_runable(self):
return self.default_signature != None
def context(self, def context(self,
sign_name=None, sign_name=None,
for_test=False, for_test=False,
...@@ -664,93 +659,3 @@ class Module(object): ...@@ -664,93 +659,3 @@ class Module(object):
raise ValueError( raise ValueError(
"All input and outputs variables in signature should come from the same Program" "All input and outputs variables in signature should come from the same Program"
) )
def serialize_to_path(self, path=None, exe=None):
self._check_signatures()
self._generate_desc()
# create module path for saving
if path is None:
path = os.path.join(".", self.name)
self.helper = ModuleHelper(path)
utils.mkdir(self.helper.module_dir)
# create module pb
module_desc = module_desc_pb2.ModuleDesc()
logger.info("PaddleHub version = %s" % version.hub_version)
logger.info("PaddleHub Module proto version = %s" %
version.module_proto_version)
logger.info("Paddle version = %s" % paddle.__version__)
feeded_var_names = [
input.name for key, sign in self.signatures.items()
for input in sign.inputs
]
target_vars = [
output for key, sign in self.signatures.items()
for output in sign.outputs
]
feeded_var_names = list(set(feeded_var_names))
target_vars = list(set(target_vars))
# save inference program
program = self.program.clone()
for block in program.blocks:
for op in block.ops:
if "op_callstack" in op.all_attrs():
op._set_attr("op_callstack", [""])
if not exe:
place = fluid.CPUPlace()
exe = fluid.Executor(place=place)
utils.mkdir(self.helper.model_path())
fluid.io.save_inference_model(
self.helper.model_path(),
feeded_var_names=list(feeded_var_names),
target_vars=list(target_vars),
main_program=program,
executor=exe)
with open(os.path.join(self.helper.model_path(), "__model__"),
"rb") as file:
program_desc_str = file.read()
rename_program = fluid.framework.Program.parse_from_string(
program_desc_str)
varlist = {
var: block
for block in rename_program.blocks for var in block.vars
if self.get_name_prefix() not in var
}
for var, block in varlist.items():
old_name = var
new_name = self.get_var_name_with_prefix(old_name)
block._rename_var(old_name, new_name)
utils.mkdir(self.helper.model_path())
with open(
os.path.join(self.helper.model_path(), "__model__"),
"wb") as f:
f.write(rename_program.desc.serialize_to_string())
for file in os.listdir(self.helper.model_path()):
if (file == "__model__" or self.get_name_prefix() in file):
continue
os.rename(
os.path.join(self.helper.model_path(), file),
os.path.join(self.helper.model_path(),
self.get_var_name_with_prefix(file)))
# create processor file
if self.processor:
self._dump_processor()
# create assets
self._dump_assets()
# create check info
checker = ModuleChecker(self.helper.module_dir)
checker.generate_check_info()
# Serialize module_desc pb
module_pb = self.desc.SerializeToString()
with open(self.helper.module_desc_path(), "wb") as f:
f.write(module_pb)
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
import sys import sys
import time
import paddlehub as hub import paddlehub as hub
import ujson import ujson
import random import random
...@@ -30,7 +29,7 @@ if is_py3: ...@@ -30,7 +29,7 @@ if is_py3:
import http.client as httplib import http.client as httplib
class BertService(): class BertService(object):
def __init__(self, def __init__(self,
profile=False, profile=False,
max_seq_len=128, max_seq_len=128,
...@@ -42,7 +41,7 @@ class BertService(): ...@@ -42,7 +41,7 @@ class BertService():
load_balance='round_robin'): load_balance='round_robin'):
self.process_id = process_id self.process_id = process_id
self.reader_flag = False self.reader_flag = False
self.batch_size = 16 self.batch_size = 0
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
self.profile = profile self.profile = profile
self.model_name = model_name self.model_name = model_name
...@@ -55,15 +54,6 @@ class BertService(): ...@@ -55,15 +54,6 @@ class BertService():
self.feed_var_names = '' self.feed_var_names = ''
self.retry = retry self.retry = retry
def connect(self, server='127.0.0.1:8010'):
self.server_list.append(server)
def connect_all_server(self, server_list):
for server_str in server_list:
self.server_list.append(server_str)
def data_convert(self, text):
if self.reader_flag == False:
module = hub.Module(name=self.model_name) module = hub.Module(name=self.model_name)
inputs, outputs, program = module.context( inputs, outputs, program = module.context(
trainable=True, max_seq_len=self.max_seq_len) trainable=True, max_seq_len=self.max_seq_len)
...@@ -79,10 +69,14 @@ class BertService(): ...@@ -79,10 +69,14 @@ class BertService():
do_lower_case=self.do_lower_case) do_lower_case=self.do_lower_case)
self.reader_flag = True self.reader_flag = True
return self.reader.data_generator( def add_server(self, server='127.0.0.1:8010'):
batch_size=self.batch_size, phase='predict', data=text) self.server_list.append(server)
def infer(self, request_msg): def add_server_list(self, server_list):
for server_str in server_list:
self.server_list.append(server_str)
def request_server(self, request_msg):
if self.load_balance == 'round_robin': if self.load_balance == 'round_robin':
try: try:
cur_con = httplib.HTTPConnection( cur_con = httplib.HTTPConnection(
...@@ -157,17 +151,13 @@ class BertService(): ...@@ -157,17 +151,13 @@ class BertService():
self.server_list) self.server_list)
return 'retry' return 'retry'
def encode(self, text): def prepare_data(self, text):
if type(text) != list:
raise TypeError('Only support list')
self.batch_size = len(text) self.batch_size = len(text)
data_generator = self.data_convert(text) data_generator = self.reader.data_generator(
start = time.time() batch_size=self.batch_size, phase='predict', data=text)
request_time = 0 request_msg = ""
result = []
for run_step, batch in enumerate(data_generator(), start=1): for run_step, batch in enumerate(data_generator(), start=1):
request = [] request = []
copy_start = time.time()
token_list = batch[0][0].reshape(-1).tolist() token_list = batch[0][0].reshape(-1).tolist()
pos_list = batch[0][1].reshape(-1).tolist() pos_list = batch[0][1].reshape(-1).tolist()
sent_list = batch[0][2].reshape(-1).tolist() sent_list = batch[0][2].reshape(-1).tolist()
...@@ -184,54 +174,34 @@ class BertService(): ...@@ -184,54 +174,34 @@ class BertService():
si + 1) * self.max_seq_len] si + 1) * self.max_seq_len]
request.append(instance_dict) request.append(instance_dict)
copy_time = time.time() - copy_start
request = {"instances": request} request = {"instances": request}
request["max_seq_len"] = self.max_seq_len request["max_seq_len"] = self.max_seq_len
request["feed_var_names"] = self.feed_var_names request["feed_var_names"] = self.feed_var_names
request_msg = ujson.dumps(request) request_msg = ujson.dumps(request)
if self.show_ids: if self.show_ids:
logger.info(request_msg) logger.info(request_msg)
request_start = time.time()
response_msg = self.infer(request_msg) return request_msg
def encode(self, text):
if type(text) != list:
raise TypeError('Only support list')
request_msg = self.prepare_data(text)
response_msg = self.request_server(request_msg)
retry = 0 retry = 0
while type(response_msg) == str and response_msg == 'retry': while type(response_msg) == str and response_msg == 'retry':
if retry < self.retry: if retry < self.retry:
retry += 1 retry += 1
logger.info('Try to connect another servers') logger.info('Try to connect another servers')
response_msg = self.infer(request_msg) response_msg = self.request_server(request_msg)
else: else:
logger.error('Infer failed after {} times retry'.format( logger.error('Request failed after {} times retry'.format(
self.retry)) self.retry))
break break
result = []
for msg in response_msg["instances"]: for msg in response_msg["instances"]:
for sample in msg["instances"]: for sample in msg["instances"]:
result.append(sample["values"]) result.append(sample["values"])
request_time += time.time() - request_start
total_time = time.time() - start
if self.profile:
return [
total_time, request_time, response_msg['op_time'],
response_msg['infer_time'], copy_time
]
else:
return result
def connect(input_text,
model_name,
max_seq_len=128,
show_ids=False,
do_lower_case=True,
server="127.0.0.1:8866",
retry=3):
# format of input_text like [["As long as"],]
bc = BertService(
max_seq_len=max_seq_len,
model_name=model_name,
show_ids=show_ids,
do_lower_case=do_lower_case,
retry=retry)
bc.connect(server)
result = bc.encode(input_text)
return result return result
from paddlehub.serving.bert_serving import bert_service
class BSClient(object):
def __init__(self,
module_name,
server,
max_seq_len=20,
show_ids=False,
do_lower_case=True,
retry=3):
self.bs = bert_service.BertService(
model_name=module_name,
max_seq_len=max_seq_len,
show_ids=show_ids,
do_lower_case=do_lower_case,
retry=retry)
self.bs.add_server(server=server)
def get_result(self, input_text):
return self.bs.encode(input_text)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册