未验证 提交 fa6f8d16 编写于 作者: B Bin Long 提交者: GitHub

Merge branch 'develop' into add_search_tip

......@@ -16,12 +16,6 @@ jobs:
os: linux
python: 3.6
script: /bin/bash ./scripts/check_code_style.sh
- name: "CI on Linux/Python3.5"
os: linux
python: 3.5
- name: "CI on Linux/Python2.7"
os: linux
python: 2.7
env:
- PYTHONPATH=${PWD}
......@@ -30,10 +24,6 @@ install:
- pip install --upgrade paddlepaddle
- pip install -r requirements.txt
script:
- if [[ $TRAVIS_OS_NAME != "windows" ]]; then /bin/bash ./scripts/test_cml.sh; fi
- if [[ $TRAVIS_OS_NAME != "windows" ]]; then /bin/bash ./scripts/test_all_module.sh; fi
notifications:
email:
on_success: change
......
......@@ -68,7 +68,7 @@ $ pip install ujson
|模型|网络|
|:-|:-:|
|[ERNIE](https://paddlepaddle.org.cn/hubdetail?name=ERNIE&en_category=SemanticModel)|ERNIE|
|[ernie](https://paddlepaddle.org.cn/hubdetail?name=ERNIE&en_category=SemanticModel)|ERNIE|
|[ernie_tiny](https://paddlepaddle.org.cn/hubdetail?name=ernie_tiny&en_category=SemanticModel)|ERNIE|
|[ernie_v2_eng_large](https://paddlepaddle.org.cn/hubdetail?name=ernie_v2_eng_large&en_category=SemanticModel)|ERNIE|
|[ernie_v2_eng_base](https://paddlepaddle.org.cn/hubdetail?name=ernie_v2_eng_base&en_category=SemanticModel)|ERNIE|
......@@ -179,18 +179,22 @@ Server[baidu::paddle_serving::predictor::bert_service::BertServiceImpl] is servi
首先导入客户端依赖。
```python
from paddlehub.serving.bert_serving import bert_service
from paddlehub.serving.bert_serving import bs_client
```
接着输入文本信息。
接着启动并初始化`bert service`客户端`BSClient`(这里的server为虚拟地址,需根据自己实际ip设置)
```python
bc = bs_client.BSClient(module_name="ernie_tiny", server="127.0.0.1:8866")
```
然后输入文本信息。
```python
input_text = [["西风吹老洞庭波"], ["一夜湘君白发多"], ["醉后不知天在水"], ["满船清梦压星河"], ]
```
然后利用客户端接口发送文本到服务端,以获取embedding结果(server为虚拟地址,需根据自己实际ip设置)。
最后利用客户端接口`get_result`发送文本到服务端,以获取embedding结果。
```python
result = bert_service.connect(
input_text=input_text,
model_name="ernie_tiny",
server="127.0.0.1:8866")
result = bc.get_result(input_text=input_text)
```
最后即可得到embedding结果(此处只展示部分结果)。
```python
......@@ -221,8 +225,8 @@ Paddle Inference Server exit successfully!
> Q : 如何在一台服务器部署多个模型?
> A : 可通过多次启动`Bert Service`,分配不同端口实现。如果使用GPU,需要指定不同的显卡。如同时部署`ernie`和`bert_chinese_L-12_H-768_A-12`,分别执行命令如下:
> ```shell
> $ hub serving start bert_serving -m ernie -p 8866
> $ hub serving start bert_serving -m bert_serving -m bert_chinese_L-12_H-768_A-12 -p 8867
> $ hub serving start bert_service -m ernie -p 8866
> $ hub serving start bert_service -m bert_chinese_L-12_H-768_A-12 -p 8867
> ```
> Q : 启动时显示"Check out http://yq01-gpu-255-129-12-00.epc.baidu.com:8887 in web
......
# coding: utf8
from paddlehub.serving.bert_serving import bert_service
from paddlehub.serving.bert_serving import bs_client
if __name__ == "__main__":
# 初始化bert_service客户端BSClient
bc = bs_client.BSClient(module_name="ernie_tiny", server="127.0.0.1:8866")
# 输入要做embedding的文本
# 文本格式为[["文本1"], ["文本2"], ]
input_text = [
......@@ -10,10 +13,10 @@ if __name__ == "__main__":
["醉后不知天在水"],
["满船清梦压星河"],
]
# 调用客户端接口bert_service.connect()获取结果
result = bert_service.connect(
input_text=input_text, model_name="ernie_tiny", server="127.0.0.1:8866")
# 打印embedding结果
# BSClient.get_result()获取结果
result = bc.get_result(input_text=input_text)
# 打印输入文本的embedding结果
for item in result:
print(item)
......@@ -38,7 +38,7 @@ from .common.logger import logger
from .common.paddle_helper import connect_program
from .common.hub_server import default_hub_server
from .module.module import Module, create_module
from .module.module import Module
from .module.base_processor import BaseProcessor
from .module.signature import Signature, create_signature
from .module.manager import default_module_manager
......
......@@ -26,6 +26,7 @@ from tb_paddle import SummaryWriter
from paddlehub.common.logger import logger
from paddlehub.common.utils import mkdir
from paddlehub.autofinetune.evaluator import REWARD_SUM, TMP_HOME
from paddlehub.autofinetune.mpi_helper import MPIHelper
if six.PY3:
INF = math.inf
......@@ -75,6 +76,12 @@ class BaseTuningStrategy(object):
logdir=self._output_dir + '/visualization/pop_{}'.format(i))
self.writer_pop_trails.append(writer_pop_trail)
# for parallel on mpi
self.mpi = MPIHelper()
if self.mpi.multi_machine:
print("Autofinetune multimachine mode: running on {}".format(
self.mpi.gather(self.mpi.name)))
@property
def thread(self):
return self._num_thread
......@@ -177,16 +184,22 @@ class BaseTuningStrategy(object):
solutions_modeldirs = {}
mkdir(output_dir)
for idx, solution in enumerate(solutions):
solutions = self.mpi.bcast(solutions)
# split solutions to "solutions for me"
range_start, range_end = self.mpi.split_range(len(solutions))
my_solutions = solutions[range_start:range_end]
for idx, solution in enumerate(my_solutions):
cuda = self.is_cuda_free["free"][0]
modeldir = output_dir + "/model-" + str(idx) + "/"
log_file = output_dir + "/log-" + str(idx) + ".info"
params_cudas_dirs.append([solution, cuda, modeldir, log_file])
solutions_modeldirs[tuple(solution)] = modeldir
solutions_modeldirs[tuple(solution)] = (modeldir, self.mpi.rank)
self.is_cuda_free["free"].remove(cuda)
self.is_cuda_free["busy"].append(cuda)
if len(params_cudas_dirs
) == self.thread or idx == len(solutions) - 1:
) == self.thread or idx == len(my_solutions) - 1:
tp = ThreadPool(len(params_cudas_dirs))
solution_results += tp.map(self.evaluator.run,
params_cudas_dirs)
......@@ -198,13 +211,25 @@ class BaseTuningStrategy(object):
self.is_cuda_free["busy"].remove(param_cuda[1])
params_cudas_dirs = []
self.feedback(solutions, solution_results)
all_solution_results = self.mpi.gather(solution_results)
if self.mpi.rank == 0:
# only rank 0 need to feedback
all_solution_results = [y for x in all_solution_results for y in x]
self.feedback(solutions, all_solution_results)
# remove the tmp.txt which records the eval results for trials
tmp_file = os.path.join(TMP_HOME, "tmp.txt")
if os.path.exists(tmp_file):
os.remove(tmp_file)
return solutions_modeldirs
# collect all solutions_modeldirs
collected_solutions_modeldirs = self.mpi.allgather(solutions_modeldirs)
return_dict = {}
for i in collected_solutions_modeldirs:
return_dict.update(i)
return return_dict
class HAZero(BaseTuningStrategy):
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
class MPIHelper(object):
def __init__(self):
try:
from mpi4py import MPI
except:
# local run
self._size = 1
self._rank = 0
self._multi_machine = False
import socket
self._name = socket.gethostname()
else:
# in mpi environment
self._comm = MPI.COMM_WORLD
self._size = self._comm.Get_size()
self._rank = self._comm.Get_rank()
self._name = MPI.Get_processor_name()
if self._size > 1:
self._multi_machine = True
else:
self._multi_machine = False
@property
def multi_machine(self):
return self._multi_machine
@property
def rank(self):
return self._rank
@property
def size(self):
return self._size
@property
def name(self):
return self._name
def bcast(self, data):
if self._multi_machine:
# call real bcast
return self._comm.bcast(data, root=0)
else:
# do nothing
return data
def gather(self, data):
if self._multi_machine:
# call real gather
return self._comm.gather(data, root=0)
else:
# do nothing
return [data]
def allgather(self, data):
if self._multi_machine:
# call real allgather
return self._comm.allgather(data)
else:
# do nothing
return [data]
# calculate split range on mpi environment
def split_range(self, array_length):
if self._size == 1:
return 0, array_length
average_count = array_length / self._size
if array_length % self._size == 0:
return average_count * self._rank, average_count * (self._rank + 1)
else:
if self._rank < array_length % self._size:
return (average_count + 1) * self._rank, (average_count + 1) * (
self._rank + 1)
else:
start = (average_count + 1) * (array_length % self._size) \
+ average_count * (self._rank - array_length % self._size)
return start, start + average_count
if __name__ == "__main__":
mpi = MPIHelper()
print("Hello world from process {} of {} at {}.".format(
mpi.rank, mpi.size, mpi.name))
all_node_names = mpi.gather(mpi.name)
print("all node names using gather: {}".format(all_node_names))
all_node_names = mpi.allgather(mpi.name)
print("all node names using allgather: {}".format(all_node_names))
if mpi.rank == 0:
data = range(10)
else:
data = None
data = mpi.bcast(data)
print("after bcast, process {} have data {}".format(mpi.rank, data))
data = [i + mpi.rank for i in data]
print("after modify, process {} have data {}".format(mpi.rank, data))
new_data = mpi.gather(data)
print("after gather, process {} have data {}".format(mpi.rank, new_data))
# test for split
for i in range(12):
length = i + mpi.size # length should >= mpi.size
[start, end] = mpi.split_range(length)
split_result = mpi.gather([start, end])
print("length {}, split_result {}".format(length, split_result))
......@@ -188,37 +188,62 @@ class AutoFineTuneCommand(BaseCommand):
run_round_cnt = run_round_cnt + 1
print("PaddleHub Autofinetune ends.")
best_hparams_origin = autoft.get_best_hparams()
best_hparams_origin = autoft.mpi.bcast(best_hparams_origin)
with open(autoft._output_dir + "/log_file.txt", "w") as f:
best_hparams = evaluator.convert_params(autoft.get_best_hparams())
best_hparams = evaluator.convert_params(best_hparams_origin)
print("The final best hyperparameters:")
f.write("The final best hyperparameters:\n")
for index, hparam_name in enumerate(autoft.hparams_name_list):
print("%s=%s" % (hparam_name, best_hparams[index]))
f.write(hparam_name + "\t:\t" + str(best_hparams[index]) + "\n")
best_hparams_dir, best_hparams_rank = solutions_modeldirs[tuple(
best_hparams_origin)]
print("The final best eval score is %s." %
autoft.get_best_eval_value())
if autoft.mpi.multi_machine:
print("The final best model parameters are saved as " +
autoft._output_dir + "/best_model on rank " +
str(best_hparams_rank) + " .")
else:
print("The final best model parameters are saved as " +
autoft._output_dir + "/best_model .")
f.write("The final best eval score is %s.\n" %
autoft.get_best_eval_value())
f.write(
"The final best model parameters are saved as ./best_model .")
best_model_dir = autoft._output_dir + "/best_model"
shutil.copytree(
solutions_modeldirs[tuple(autoft.get_best_hparams())],
best_model_dir)
if autoft.mpi.rank == best_hparams_rank:
shutil.copytree(best_hparams_dir, best_model_dir)
if autoft.mpi.multi_machine:
f.write(
"The final best model parameters are saved as ./best_model on rank " \
+ str(best_hparams_rank) + " .")
f.write("\t".join(autoft.hparams_name_list) +
"\tsaved_params_dir\trank\n")
else:
f.write(
"The final best model parameters are saved as ./best_model ."
)
f.write("\t".join(autoft.hparams_name_list) +
"\tsaved_params_dir\n")
print(
"The related infomation about hyperparamemters searched are saved as %s/log_file.txt ."
% autoft._output_dir)
for solution, modeldir in solutions_modeldirs.items():
param = evaluator.convert_params(solution)
param = [str(p) for p in param]
f.write("\t".join(param) + "\t" + modeldir + "\n")
if autoft.mpi.multi_machine:
f.write("\t".join(param) + "\t" + modeldir[0] + "\t" +
str(modeldir[1]) + "\n")
else:
f.write("\t".join(param) + "\t" + modeldir[0] + "\n")
return True
......
......@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
import argparse
import os
from paddlehub.common import utils
from paddlehub.module.manager import default_module_manager
......@@ -42,14 +43,23 @@ class InstallCommand(BaseCommand):
print("ERROR: Please specify a module name.\n")
self.help()
return False
extra = {"command": "install"}
if argv[0].endswith("tar.gz") or argv[0].endswith("phm"):
result, tips, module_dir = default_module_manager.install_module(
module_package=argv[0], extra=extra)
elif os.path.exists(argv[0]) and os.path.isdir(argv[0]):
result, tips, module_dir = default_module_manager.install_module(
module_dir=argv[0], extra=extra)
else:
module_name = argv[0]
module_version = None if "==" not in module_name else module_name.split(
"==")[1]
module_name = module_name if "==" not in module_name else module_name.split(
"==")[0]
extra = {"command": "install"}
result, tips, module_dir = default_module_manager.install_module(
module_name=module_name, module_version=module_version, extra=extra)
module_name=module_name,
module_version=module_version,
extra=extra)
print(tips)
return True
......
......@@ -71,7 +71,7 @@ class RunCommand(BaseCommand):
if not result:
return None
return hub.Module(module_dir=module_dir)
return hub.Module(directory=module_dir[0])
def add_module_config_arg(self):
configs = self.module.processor.configs()
......@@ -105,7 +105,7 @@ class RunCommand(BaseCommand):
def add_module_input_arg(self):
module_type = self.module.type.lower()
expect_data_format = self.module.processor.data_format(
self.module.default_signature.name)
self.module.default_signature)
self.arg_input_group.add_argument(
'--input_file',
type=str,
......@@ -152,7 +152,7 @@ class RunCommand(BaseCommand):
def get_data(self):
module_type = self.module.type.lower()
expect_data_format = self.module.processor.data_format(
self.module.default_signature.name)
self.module.default_signature)
input_data = {}
if len(expect_data_format) == 1:
key = list(expect_data_format.keys())[0]
......@@ -177,7 +177,7 @@ class RunCommand(BaseCommand):
def check_data(self, data):
expect_data_format = self.module.processor.data_format(
self.module.default_signature.name)
self.module.default_signature)
if len(data.keys()) != len(expect_data_format.keys()):
print(
......@@ -236,10 +236,13 @@ class RunCommand(BaseCommand):
return False
# If the module is not executable, give an alarm and exit
if not self.module.default_signature:
if not self.module.is_runable:
print("ERROR! Module %s is not executable." % module_name)
return False
if self.module.code_version == "v2":
results = self.module(argv[1:])
else:
self.module.check_processor()
self.add_module_config_arg()
self.add_module_input_arg()
......@@ -260,7 +263,7 @@ class RunCommand(BaseCommand):
return False
results = self.module(
sign_name=self.module.default_signature.name,
sign_name=self.module.default_signature,
data=data,
use_gpu=self.args.use_gpu,
batch_size=self.args.batch_size,
......
......@@ -159,7 +159,7 @@ class ServingCommand(BaseCommand):
module = args.modules
if module is not None:
use_gpu = args.use_gpu
port = args.port[0]
port = args.port
if ServingCommand.is_port_occupied("127.0.0.1", port) is True:
print("Port %s is occupied, please change it." % (port))
return False
......@@ -206,10 +206,12 @@ class ServingCommand(BaseCommand):
if args.sub_command == "start":
if args.bert_service == "bert_service":
ServingCommand.start_bert_serving(args)
else:
elif args.bert_service is None:
ServingCommand.start_serving(args)
else:
ServingCommand.show_help()
else:
ServingCommand.show_help()
command = ServingCommand.instance()
......@@ -125,8 +125,6 @@ class ShowCommand(BaseCommand):
cwd = os.getcwd()
module_dir = default_module_manager.search_module(module_name)
module_dir = (os.path.join(cwd, module_name),
None) if not module_dir else module_dir
if not module_dir or not os.path.exists(module_dir[0]):
print("%s is not existed!" % module_name)
return True
......
......@@ -298,17 +298,23 @@ class CacheUpdater(threading.Thread):
api_url = srv_utils.uri_path(default_hub_server.get_server_url(),
'search')
cache_path = os.path.join(CACHE_HOME, RESOURCE_LIST_FILE)
if os.path.exists(cache_path):
extra = {
"command": "update_cache",
"mtime": os.stat(cache_path).st_mtime
}
else:
extra = {
"command": "update_cache",
"mtime": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
}
try:
r = srv_utils.hub_request(api_url, payload, extra)
except Exception as err:
pass
if r.get("update_cache", 0) == 1:
with open(cache_path, 'w+') as fp:
yaml.safe_dump({'resource_list': r['data']}, fp)
except Exception as err:
pass
def run(self):
self.update_resource_list_file(self.module, self.version)
......
......@@ -50,6 +50,7 @@ message CheckInfo {
string paddle_version = 1;
string hub_version = 2;
string module_proto_version = 3;
repeated FileInfo file_infos = 4;
repeated Requires requires = 5;
string module_code_version = 4;
repeated FileInfo file_infos = 5;
repeated Requires requires = 6;
};
#coding:utf-8
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: check_info.proto
......@@ -19,7 +18,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
package='paddlehub.module.checkinfo',
syntax='proto3',
serialized_pb=_b(
'\n\x10\x63heck_info.proto\x12\x1apaddlehub.module.checkinfo\"\x85\x01\n\x08\x46ileInfo\x12\x11\n\tfile_name\x18\x01 \x01(\t\x12\x33\n\x04type\x18\x02 \x01(\x0e\x32%.paddlehub.module.checkinfo.FILE_TYPE\x12\x0f\n\x07is_need\x18\x03 \x01(\x08\x12\x0b\n\x03md5\x18\x04 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x05 \x01(\t\"\x84\x01\n\x08Requires\x12>\n\x0crequire_type\x18\x01 \x01(\x0e\x32(.paddlehub.module.checkinfo.REQUIRE_TYPE\x12\x0f\n\x07version\x18\x02 \x01(\t\x12\x12\n\ngreat_than\x18\x03 \x01(\x08\x12\x13\n\x0b\x64\x65scription\x18\x04 \x01(\t\"\xc8\x01\n\tCheckInfo\x12\x16\n\x0epaddle_version\x18\x01 \x01(\t\x12\x13\n\x0bhub_version\x18\x02 \x01(\t\x12\x1c\n\x14module_proto_version\x18\x03 \x01(\t\x12\x38\n\nfile_infos\x18\x04 \x03(\x0b\x32$.paddlehub.module.checkinfo.FileInfo\x12\x36\n\x08requires\x18\x05 \x03(\x0b\x32$.paddlehub.module.checkinfo.Requires*\x1e\n\tFILE_TYPE\x12\x08\n\x04\x46ILE\x10\x00\x12\x07\n\x03\x44IR\x10\x01*[\n\x0cREQUIRE_TYPE\x12\x12\n\x0ePYTHON_PACKAGE\x10\x00\x12\x0e\n\nHUB_MODULE\x10\x01\x12\n\n\x06SYSTEM\x10\x02\x12\x0b\n\x07\x43OMMAND\x10\x03\x12\x0e\n\nPY_VERSION\x10\x04\x42\x02H\x03\x62\x06proto3'
'\n\x10\x63heck_info.proto\x12\x1apaddlehub.module.checkinfo\"\x85\x01\n\x08\x46ileInfo\x12\x11\n\tfile_name\x18\x01 \x01(\t\x12\x33\n\x04type\x18\x02 \x01(\x0e\x32%.paddlehub.module.checkinfo.FILE_TYPE\x12\x0f\n\x07is_need\x18\x03 \x01(\x08\x12\x0b\n\x03md5\x18\x04 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x05 \x01(\t\"\x84\x01\n\x08Requires\x12>\n\x0crequire_type\x18\x01 \x01(\x0e\x32(.paddlehub.module.checkinfo.REQUIRE_TYPE\x12\x0f\n\x07version\x18\x02 \x01(\t\x12\x12\n\ngreat_than\x18\x03 \x01(\x08\x12\x13\n\x0b\x64\x65scription\x18\x04 \x01(\t\"\xe5\x01\n\tCheckInfo\x12\x16\n\x0epaddle_version\x18\x01 \x01(\t\x12\x13\n\x0bhub_version\x18\x02 \x01(\t\x12\x1c\n\x14module_proto_version\x18\x03 \x01(\t\x12\x1b\n\x13module_code_version\x18\x04 \x01(\t\x12\x38\n\nfile_infos\x18\x05 \x03(\x0b\x32$.paddlehub.module.checkinfo.FileInfo\x12\x36\n\x08requires\x18\x06 \x03(\x0b\x32$.paddlehub.module.checkinfo.Requires*\x1e\n\tFILE_TYPE\x12\x08\n\x04\x46ILE\x10\x00\x12\x07\n\x03\x44IR\x10\x01*[\n\x0cREQUIRE_TYPE\x12\x12\n\x0ePYTHON_PACKAGE\x10\x00\x12\x0e\n\nHUB_MODULE\x10\x01\x12\n\n\x06SYSTEM\x10\x02\x12\x0b\n\x07\x43OMMAND\x10\x03\x12\x0e\n\nPY_VERSION\x10\x04\x42\x02H\x03\x62\x06proto3'
))
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
......@@ -36,8 +35,8 @@ _FILE_TYPE = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
serialized_start=522,
serialized_end=552,
serialized_start=551,
serialized_end=581,
)
_sym_db.RegisterEnumDescriptor(_FILE_TYPE)
......@@ -61,8 +60,8 @@ _REQUIRE_TYPE = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
serialized_start=554,
serialized_end=645,
serialized_start=583,
serialized_end=674,
)
_sym_db.RegisterEnumDescriptor(_REQUIRE_TYPE)
......@@ -316,10 +315,26 @@ _CHECKINFO = _descriptor.Descriptor(
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='file_infos',
full_name='paddlehub.module.checkinfo.CheckInfo.file_infos',
name='module_code_version',
full_name='paddlehub.module.checkinfo.CheckInfo.module_code_version',
index=3,
number=4,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='file_infos',
full_name='paddlehub.module.checkinfo.CheckInfo.file_infos',
index=4,
number=5,
type=11,
cpp_type=10,
label=3,
......@@ -334,8 +349,8 @@ _CHECKINFO = _descriptor.Descriptor(
_descriptor.FieldDescriptor(
name='requires',
full_name='paddlehub.module.checkinfo.CheckInfo.requires',
index=4,
number=5,
index=5,
number=6,
type=11,
cpp_type=10,
label=3,
......@@ -357,7 +372,7 @@ _CHECKINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[],
serialized_start=320,
serialized_end=520,
serialized_end=549,
)
_FILEINFO.fields_by_name['type'].enum_type = _FILE_TYPE
......
......@@ -32,20 +32,22 @@ FILE_SEP = "/"
class ModuleChecker(object):
def __init__(self, module_path):
self.module_path = module_path
def __init__(self, directory):
self._directory = directory
self._pb_path = os.path.join(self.directory, CHECK_INFO_PB_FILENAME)
def generate_check_info(self):
check_info = check_info_pb2.CheckInfo()
check_info.paddle_version = paddle.__version__
check_info.hub_version = hub_version
check_info.module_proto_version = module_proto_version
check_info.module_code_version = "v2"
file_infos = check_info.file_infos
file_list = [file for file in os.listdir(self.module_path)]
file_list = [file for file in os.listdir(self.directory)]
while file_list:
file = file_list[0]
file_list = file_list[1:]
abs_path = os.path.join(self.module_path, file)
abs_path = os.path.join(self.directory, file)
if os.path.isdir(abs_path):
for sub_file in os.listdir(abs_path):
sub_file = os.path.join(file, sub_file)
......@@ -62,9 +64,12 @@ class ModuleChecker(object):
file_info.type = check_info_pb2.FILE
file_info.is_need = True
with open(os.path.join(self.module_path, CHECK_INFO_PB_FILENAME),
"wb") as fi:
fi.write(check_info.SerializeToString())
with open(self.pb_path, "wb") as file:
file.write(check_info.SerializeToString())
@property
def module_code_version(self):
return self.check_info.module_code_version
@property
def module_proto_version(self):
......@@ -82,20 +87,25 @@ class ModuleChecker(object):
def file_infos(self):
return self.check_info.file_infos
@property
def directory(self):
return self._directory
@property
def pb_path(self):
return self._pb_path
def check(self):
result = True
self.check_info_pb_path = os.path.join(self.module_path,
CHECK_INFO_PB_FILENAME)
if not (os.path.exists(self.check_info_pb_path)
or os.path.isfile(self.check_info_pb_path)):
if not (os.path.exists(self.pb_path) or os.path.isfile(self.pb_path)):
logger.warning(
"This module lacks core file %s" % CHECK_INFO_PB_FILENAME)
result = False
self.check_info = check_info_pb2.CheckInfo()
try:
with open(self.check_info_pb_path, "rb") as fi:
with open(self.pb_path, "rb") as fi:
pb_string = fi.read()
result = self.check_info.ParseFromString(pb_string)
if len(pb_string) == 0 or (result is not None
......@@ -182,7 +192,7 @@ class ModuleChecker(object):
for file_info in self.file_infos:
file_type = file_info.type
file_path = file_info.file_name.replace(FILE_SEP, os.sep)
file_path = os.path.join(self.module_path, file_path)
file_path = os.path.join(self.directory, file_path)
if not os.path.exists(file_path):
if file_info.is_need:
logger.warning(
......
......@@ -19,7 +19,9 @@ from __future__ import print_function
import os
import shutil
from functools import cmp_to_key
import tarfile
from paddlehub.common import utils
from paddlehub.common import srv_utils
......@@ -79,10 +81,15 @@ class LocalModuleManager(object):
return self.modules_dict.get(module_name, None)
def install_module(self,
module_name,
module_name=None,
module_dir=None,
module_package=None,
module_version=None,
upgrade=False,
extra=None):
md5_value = installed_module_version = None
from_user_dir = True if module_dir else False
if module_name:
self.all_modules(update=True)
module_info = self.modules_dict.get(module_name, None)
if module_info:
......@@ -95,6 +102,62 @@ class LocalModuleManager(object):
module_dir)
return True, tips, self.modules_dict[module_name]
search_result = hub.default_hub_server.get_module_url(
module_name, version=module_version, extra=extra)
name = search_result.get('name', None)
url = search_result.get('url', None)
md5_value = search_result.get('md5', None)
installed_module_version = search_result.get('version', None)
if not url or (module_version is not None
and installed_module_version != module_version) or (
name != module_name):
if default_hub_server._server_check() is False:
tips = "Request Hub-Server unsuccessfully, please check your network."
else:
tips = "Can't find module %s" % module_name
if module_version:
tips += " with version %s" % module_version
module_tag = module_name if not module_version else '%s-%s' % (
module_name, module_version)
return False, tips, None
result, tips, module_zip_file = default_downloader.download_file(
url=url,
save_path=hub.CACHE_HOME,
save_name=module_name,
replace=True,
print_progress=True)
result, tips, module_dir = default_downloader.uncompress(
file=module_zip_file,
dirname=MODULE_HOME,
delete_file=True,
print_progress=True)
if module_package:
with tarfile.open(module_package, "r:gz") as tar:
file_names = tar.getnames()
size = len(file_names) - 1
module_dir = os.path.split(file_names[0])[0]
module_dir = os.path.join(hub.CACHE_HOME, module_dir)
# remove cache
if os.path.exists(module_dir):
shutil.rmtree(module_dir)
for index, file_name in enumerate(file_names):
tar.extract(file_name, hub.CACHE_HOME)
if module_dir:
if not module_name:
module_name = hub.Module(directory=module_dir).name
self.all_modules(update=False)
module_info = self.modules_dict.get(module_name, None)
if module_info:
module_dir = self.modules_dict[module_name][0]
module_tag = module_name if not module_version else '%s-%s' % (
module_name, module_version)
tips = "Module %s already installed in %s" % (module_tag,
module_dir)
return True, tips, self.modules_dict[module_name]
search_result = hub.default_hub_server.get_module_url(
module_name, version=module_version, extra=extra)
name = search_result.get('name', None)
......@@ -162,9 +225,18 @@ class LocalModuleManager(object):
with open(os.path.join(MODULE_HOME, module_dir, "md5.txt"),
"w") as fp:
fp.write(md5_value)
if md5_value:
with open(
os.path.join(MODULE_HOME, module_dir, "md5.txt"),
"w") as fp:
fp.write(md5_value)
save_path = os.path.join(MODULE_HOME, module_name)
if os.path.exists(save_path):
shutil.rmtree(save_path)
shutil.move(save_path)
if from_user_dir:
shutil.copytree(module_dir, save_path)
else:
shutil.move(module_dir, save_path)
module_dir = save_path
tips = "Successfully installed %s" % module_name
......
......@@ -21,6 +21,10 @@ import os
import time
import sys
import functools
import inspect
import importlib
import tarfile
from collections import defaultdict
from shutil import copyfile
import paddle
......@@ -28,22 +32,19 @@ import paddle.fluid as fluid
from paddlehub.common import utils
from paddlehub.common import paddle_helper
from paddlehub.common.logger import logger
from paddlehub.common.dir import CACHE_HOME
from paddlehub.common.lock import lock
from paddlehub.common.downloader import default_downloader
from paddlehub.common.logger import logger
from paddlehub.common.hub_server import CacheUpdater
from paddlehub.module import module_desc_pb2
from paddlehub.common.dir import CONF_HOME
from paddlehub.module import check_info_pb2
from paddlehub.common.hub_server import CacheUpdater
from paddlehub.module.signature import Signature, create_signature
from paddlehub.module.checker import ModuleChecker
from paddlehub.module.manager import default_module_manager
from paddlehub.module.checker import ModuleChecker
from paddlehub.module.signature import Signature, create_signature
from paddlehub.module.base_processor import BaseProcessor
from paddlehub.io.parser import yaml_parser
from paddlehub import version
__all__ = ['Module', 'create_module']
# PaddleHub module dir name
ASSETS_DIRNAME = "assets"
MODEL_DIRNAME = "model"
......@@ -52,67 +53,227 @@ PYTHON_DIR = "python"
PROCESSOR_NAME = "processor"
# PaddleHub var prefix
HUB_VAR_PREFIX = "@HUB_%s@"
# PaddleHub Module package suffix
HUB_PACKAGE_SUFFIX = "phm"
def create_module(directory, name, author, email, module_type, summary,
version):
save_file_name = "{}-{}.{}".format(name, version, HUB_PACKAGE_SUFFIX)
# record module info and serialize
desc = module_desc_pb2.ModuleDesc()
attr = desc.attr
attr.type = module_desc_pb2.MAP
module_info = attr.map.data['module_info']
module_info.type = module_desc_pb2.MAP
utils.from_pyobj_to_module_attr(name, module_info.map.data['name'])
utils.from_pyobj_to_module_attr(author, module_info.map.data['author'])
utils.from_pyobj_to_module_attr(email, module_info.map.data['author_email'])
utils.from_pyobj_to_module_attr(module_type, module_info.map.data['type'])
utils.from_pyobj_to_module_attr(summary, module_info.map.data['summary'])
utils.from_pyobj_to_module_attr(version, module_info.map.data['version'])
module_desc_path = os.path.join(directory, "module_desc.pb")
with open(module_desc_path, "wb") as f:
f.write(desc.SerializeToString())
# generate check info
checker = ModuleChecker(directory)
checker.generate_check_info()
# add __init__
module_init_1 = os.path.join(directory, "__init__.py")
with open(module_init_1, "a") as file:
file.write("")
module_init_2 = os.path.join(directory, "python", "__init__.py")
with open(module_init_2, "a") as file:
file.write("")
# package the module
with tarfile.open(save_file_name, "w:gz") as tar:
for dirname, _, files in os.walk(directory):
for file in files:
tar.add(os.path.join(dirname, file))
os.remove(module_desc_path)
os.remove(checker.pb_path)
os.remove(module_init_1)
os.remove(module_init_2)
class Module(object):
def __new__(cls, name=None, directory=None, module_dir=None, version=None):
module = None
if cls.__name__ == "Module":
if name:
module = cls.init_with_name(name=name, version=version)
elif directory:
module = cls.init_with_directory(directory=directory)
elif module_dir:
logger.warning(
"Parameter module_dir is deprecated, please use directory to specify the path"
)
if isinstance(module_dir, list) or isinstance(
module_dir, tuple):
directory = module_dir[0]
version = module_dir[1]
else:
directory = module_dir
module = cls.init_with_directory(directory=directory)
if not module:
module = object.__new__(cls)
else:
CacheUpdater(module.name, module.version).start()
return module
def __init__(self, name=None, directory=None, module_dir=None,
version=None):
if not directory:
return
self._code_version = "v2"
self._directory = directory
self.module_desc_path = os.path.join(self.directory, MODULE_DESC_PBNAME)
self._desc = module_desc_pb2.ModuleDesc()
with open(self.module_desc_path, "rb") as file:
self._desc.ParseFromString(file.read())
module_info = self.desc.attr.map.data['module_info']
self._name = utils.from_module_attr_to_pyobj(
module_info.map.data['name'])
self._author = utils.from_module_attr_to_pyobj(
module_info.map.data['author'])
self._author_email = utils.from_module_attr_to_pyobj(
module_info.map.data['author_email'])
self._version = utils.from_module_attr_to_pyobj(
module_info.map.data['version'])
self._type = utils.from_module_attr_to_pyobj(
module_info.map.data['type'])
self._summary = utils.from_module_attr_to_pyobj(
module_info.map.data['summary'])
self._initialize()
@classmethod
def init_with_name(cls, name, version=None):
fp_lock = open(os.path.join(CACHE_HOME, name), "a")
lock.flock(fp_lock, lock.LOCK_EX)
log_msg = "Installing %s module" % name
if version:
log_msg += "-%s" % version
logger.info(log_msg)
extra = {"command": "install"}
result, tips, module_dir = default_module_manager.install_module(
module_name=name, module_version=version, extra=extra)
if not result:
logger.error(tips)
raise RuntimeError(tips)
logger.info(tips)
lock.flock(fp_lock, lock.LOCK_UN)
return cls.init_with_directory(directory=module_dir[0])
@classmethod
def init_with_directory(cls, directory):
desc_file = os.path.join(directory, MODULE_DESC_PBNAME)
checker = ModuleChecker(directory)
checker.check()
def create_module(sign_arr,
module_dir,
processor=None,
assets=None,
module_info=None,
exe=None,
extra_info=None):
sign_arr = utils.to_list(sign_arr)
module = Module(
signatures=sign_arr,
processor=processor,
assets=assets,
module_info=module_info,
extra_info=extra_info)
module.serialize_to_path(path=module_dir, exe=exe)
module_code_version = checker.module_code_version
if module_code_version == "v2":
basename = os.path.split(directory)[-1]
dirname = os.path.join(*list(os.path.split(directory)[:-1]))
sys.path.append(dirname)
pymodule = importlib.import_module(
"{}.python.module".format(basename))
return pymodule.HubModule(directory=directory)
return ModuleV1(directory=directory)
@property
def desc(self):
return self._desc
@property
def directory(self):
return self._directory
@property
def author(self):
return self._author
@property
def author_email(self):
return self._author_email
@property
def summary(self):
return self._summary
@property
def type(self):
return self._type
@property
def version(self):
return self._version
@property
def name(self):
return self._name
@property
def name_prefix(self):
return self._name_prefix
@property
def code_version(self):
return self._code_version
@property
def is_runable(self):
return False
def _initialize(self):
pass
class ModuleHelper(object):
def __init__(self, module_dir):
self.module_dir = module_dir
def __init__(self, directory):
self.directory = directory
def module_desc_path(self):
return os.path.join(self.module_dir, MODULE_DESC_PBNAME)
return os.path.join(self.directory, MODULE_DESC_PBNAME)
def model_path(self):
return os.path.join(self.module_dir, MODEL_DIRNAME)
return os.path.join(self.directory, MODEL_DIRNAME)
def processor_path(self):
return os.path.join(self.module_dir, PYTHON_DIR)
return os.path.join(self.directory, PYTHON_DIR)
def processor_name(self):
return PROCESSOR_NAME
def assets_path(self):
return os.path.join(self.module_dir, ASSETS_DIRNAME)
return os.path.join(self.directory, ASSETS_DIRNAME)
class Module(object):
def __init__(self,
name=None,
module_dir=None,
signatures=None,
module_info=None,
assets=None,
processor=None,
extra_info=None,
class ModuleV1(Module):
def __init__(self, name=None, directory=None, module_dir=None,
version=None):
self.desc = module_desc_pb2.ModuleDesc()
if not directory:
return
super(ModuleV1, self).__init__(name, directory, module_dir, version)
self._code_version = "v1"
self.program = None
self.assets = []
self.helper = None
self.signatures = {}
self.default_signature = None
self.module_info = None
self.processor = None
self.extra_info = {} if extra_info is None else extra_info
if not isinstance(self.extra_info, dict):
raise TypeError(
"The extra_info should be an instance of python dict")
self.extra_info = {}
# cache data
self.last_call_name = None
......@@ -120,62 +281,21 @@ class Module(object):
self.cache_fetch_dict = None
self.cache_program = None
fp_lock = open(os.path.join(CONF_HOME, 'config.json'))
lock.flock(fp_lock, lock.LOCK_EX)
if name:
self._init_with_name(name=name, version=version)
lock.flock(fp_lock, lock.LOCK_UN)
elif module_dir:
self._init_with_module_file(module_dir=module_dir[0])
lock.flock(fp_lock, lock.LOCK_UN)
name = module_dir[0].split("/")[-1]
if len(module_dir) > 1:
version = module_dir[1]
else:
version = default_module_manager.search_module(name)[1]
elif signatures:
if processor:
if not issubclass(processor, BaseProcessor):
raise TypeError(
"Processor shoule be an instance of paddlehub.BaseProcessor"
)
if assets:
self.assets = utils.to_list(assets)
# for asset in assets:
# utils.check_path(assets)
self.processor = processor
self._generate_module_info(module_info)
self._init_with_signature(signatures=signatures)
lock.flock(fp_lock, lock.LOCK_UN)
else:
lock.flock(fp_lock, lock.LOCK_UN)
raise ValueError("Module initialized parameter is empty")
CacheUpdater(name, version).start()
def _init_with_name(self, name, version=None):
log_msg = "Installing %s module" % name
if version:
log_msg += "-%s" % version
logger.info(log_msg)
extra = {"command": "install"}
result, tips, module_dir = default_module_manager.install_module(
module_name=name, module_version=version, extra=extra)
if not result:
logger.error(tips)
raise RuntimeError(tips)
else:
logger.info(tips)
self._init_with_module_file(module_dir[0])
def _init_with_url(self, url):
utils.check_url(url)
result, tips, module_dir = default_downloader.download_file_and_uncompress(
url, save_path=".")
if not result:
logger.error(tips)
raise RuntimeError(tips)
else:
self._init_with_module_file(module_dir)
self.helper = ModuleHelper(directory)
exe = fluid.Executor(fluid.CPUPlace())
self.program, _, _ = fluid.io.load_inference_model(
self.helper.model_path(), executor=exe)
for block in self.program.blocks:
for op in block.ops:
if "op_callstack" in op.all_attrs():
op._set_attr("op_callstack", [""])
self._load_processor()
self._load_assets()
self._recover_from_desc()
self._generate_sign_attr()
self._generate_extra_info()
self._restore_parameter(self.program)
self._recover_variable_info(self.program)
def _dump_processor(self):
import inspect
......@@ -216,52 +336,6 @@ class Module(object):
filepath = os.path.join(self.helper.assets_path(), file)
self.assets.append(filepath)
def _init_with_module_file(self, module_dir):
checker = ModuleChecker(module_dir)
checker.check()
self.helper = ModuleHelper(module_dir)
with open(self.helper.module_desc_path(), "rb") as fi:
self.desc.ParseFromString(fi.read())
exe = fluid.Executor(fluid.CPUPlace())
self.program, _, _ = fluid.io.load_inference_model(
self.helper.model_path(), executor=exe)
for block in self.program.blocks:
for op in block.ops:
if "op_callstack" in op.all_attrs():
op._set_attr("op_callstack", [""])
self._load_processor()
self._load_assets()
self._recover_from_desc()
self._generate_sign_attr()
self._generate_extra_info()
self._restore_parameter(self.program)
self._recover_variable_info(self.program)
def _init_with_signature(self, signatures):
self.name_prefix = HUB_VAR_PREFIX % self.name
self._process_signatures(signatures)
self._check_signatures()
self._generate_desc()
self._generate_sign_attr()
self._generate_extra_info()
def _init_with_program(self, program):
pass
def _process_signatures(self, signatures):
self.signatures = {}
self.program = signatures[0].inputs[0].block.program
for sign in signatures:
if sign.name in self.signatures:
raise ValueError(
"Error! Signature array contains duplicated signatrues %s" %
sign)
if self.default_signature is None and sign.for_predict:
self.default_signature = sign
self.signatures[sign.name] = sign
def _restore_parameter(self, program):
global_block = program.global_block()
param_attrs = self.desc.attr.map.data['param_attrs']
......@@ -302,21 +376,6 @@ class Module(object):
self.__dict__["get_%s" % key] = functools.partial(
self.get_extra_info, key=key)
def _generate_module_info(self, module_info=None):
if not module_info:
self.module_info = {}
else:
if not utils.is_yaml_file(module_info):
logger.critical("Module info file should be yaml format")
exit(1)
self.module_info = yaml_parser.parse(module_info)
self.author = self.module_info.get('author', 'UNKNOWN')
self.author_email = self.module_info.get('author_email', 'UNKNOWN')
self.summary = self.module_info.get('summary', 'UNKNOWN')
self.type = self.module_info.get('type', 'UNKNOWN')
self.version = self.module_info.get('version', 'UNKNOWN')
self.name = self.module_info.get('name', 'UNKNOWN')
def _generate_sign_attr(self):
self._check_signatures()
for sign in self.signatures:
......@@ -369,21 +428,21 @@ class Module(object):
default_signature_name = utils.from_module_attr_to_pyobj(
self.desc.attr.map.data['default_signature'])
self.default_signature = self.signatures[
default_signature_name] if default_signature_name else None
default_signature_name].name if default_signature_name else None
# recover module info
module_info = self.desc.attr.map.data['module_info']
self.name = utils.from_module_attr_to_pyobj(
self._name = utils.from_module_attr_to_pyobj(
module_info.map.data['name'])
self.author = utils.from_module_attr_to_pyobj(
self._author = utils.from_module_attr_to_pyobj(
module_info.map.data['author'])
self.author_email = utils.from_module_attr_to_pyobj(
self._author_email = utils.from_module_attr_to_pyobj(
module_info.map.data['author_email'])
self.version = utils.from_module_attr_to_pyobj(
self._version = utils.from_module_attr_to_pyobj(
module_info.map.data['version'])
self.type = utils.from_module_attr_to_pyobj(
self._type = utils.from_module_attr_to_pyobj(
module_info.map.data['type'])
self.summary = utils.from_module_attr_to_pyobj(
self._summary = utils.from_module_attr_to_pyobj(
module_info.map.data['summary'])
# recover extra info
......@@ -393,77 +452,9 @@ class Module(object):
self.extra_info[key] = utils.from_module_attr_to_pyobj(value)
# recover name prefix
self.name_prefix = utils.from_module_attr_to_pyobj(
self._name_prefix = utils.from_module_attr_to_pyobj(
self.desc.attr.map.data["name_prefix"])
def _generate_desc(self):
# save fluid Parameter
attr = self.desc.attr
attr.type = module_desc_pb2.MAP
param_attrs = attr.map.data['param_attrs']
param_attrs.type = module_desc_pb2.MAP
for param in self.program.global_block().iter_parameters():
param_attr = param_attrs.map.data[param.name]
paddle_helper.from_param_to_module_attr(param, param_attr)
# save Variable Info
var_infos = attr.map.data['var_infos']
var_infos.type = module_desc_pb2.MAP
for block in self.program.blocks:
for var in block.vars.values():
var_info = var_infos.map.data[var.name]
var_info.type = module_desc_pb2.MAP
utils.from_pyobj_to_module_attr(
var.stop_gradient, var_info.map.data['stop_gradient'])
utils.from_pyobj_to_module_attr(block.idx,
var_info.map.data['block_id'])
# save signarture info
for key, sign in self.signatures.items():
var = self.desc.sign2var[sign.name]
feed_desc = var.feed_desc
fetch_desc = var.fetch_desc
feed_names = sign.feed_names
fetch_names = sign.fetch_names
for index, input in enumerate(sign.inputs):
feed_var = feed_desc.add()
feed_var.var_name = self.get_var_name_with_prefix(input.name)
feed_var.alias = feed_names[index]
for index, output in enumerate(sign.outputs):
fetch_var = fetch_desc.add()
fetch_var.var_name = self.get_var_name_with_prefix(output.name)
fetch_var.alias = fetch_names[index]
# save default signature
utils.from_pyobj_to_module_attr(
self.default_signature.name if self.default_signature else None,
attr.map.data['default_signature'])
# save name prefix
utils.from_pyobj_to_module_attr(self.name_prefix,
self.desc.attr.map.data["name_prefix"])
# save module info
module_info = attr.map.data['module_info']
module_info.type = module_desc_pb2.MAP
utils.from_pyobj_to_module_attr(self.name, module_info.map.data['name'])
utils.from_pyobj_to_module_attr(self.version,
module_info.map.data['version'])
utils.from_pyobj_to_module_attr(self.author,
module_info.map.data['author'])
utils.from_pyobj_to_module_attr(self.author_email,
module_info.map.data['author_email'])
utils.from_pyobj_to_module_attr(self.type, module_info.map.data['type'])
utils.from_pyobj_to_module_attr(self.summary,
module_info.map.data['summary'])
# save extra info
extra_info = attr.map.data['extra_info']
extra_info.type = module_desc_pb2.MAP
for key, value in self.extra_info.items():
utils.from_pyobj_to_module_attr(value, extra_info.map.data[key])
def __call__(self, sign_name, data, use_gpu=False, batch_size=1, **kwargs):
self.check_processor()
......@@ -525,6 +516,10 @@ class Module(object):
if not self.processor:
raise ValueError("This Module is not callable!")
@property
def is_runable(self):
return self.default_signature != None
def context(self,
sign_name=None,
for_test=False,
......@@ -664,93 +659,3 @@ class Module(object):
raise ValueError(
"All input and outputs variables in signature should come from the same Program"
)
def serialize_to_path(self, path=None, exe=None):
self._check_signatures()
self._generate_desc()
# create module path for saving
if path is None:
path = os.path.join(".", self.name)
self.helper = ModuleHelper(path)
utils.mkdir(self.helper.module_dir)
# create module pb
module_desc = module_desc_pb2.ModuleDesc()
logger.info("PaddleHub version = %s" % version.hub_version)
logger.info("PaddleHub Module proto version = %s" %
version.module_proto_version)
logger.info("Paddle version = %s" % paddle.__version__)
feeded_var_names = [
input.name for key, sign in self.signatures.items()
for input in sign.inputs
]
target_vars = [
output for key, sign in self.signatures.items()
for output in sign.outputs
]
feeded_var_names = list(set(feeded_var_names))
target_vars = list(set(target_vars))
# save inference program
program = self.program.clone()
for block in program.blocks:
for op in block.ops:
if "op_callstack" in op.all_attrs():
op._set_attr("op_callstack", [""])
if not exe:
place = fluid.CPUPlace()
exe = fluid.Executor(place=place)
utils.mkdir(self.helper.model_path())
fluid.io.save_inference_model(
self.helper.model_path(),
feeded_var_names=list(feeded_var_names),
target_vars=list(target_vars),
main_program=program,
executor=exe)
with open(os.path.join(self.helper.model_path(), "__model__"),
"rb") as file:
program_desc_str = file.read()
rename_program = fluid.framework.Program.parse_from_string(
program_desc_str)
varlist = {
var: block
for block in rename_program.blocks for var in block.vars
if self.get_name_prefix() not in var
}
for var, block in varlist.items():
old_name = var
new_name = self.get_var_name_with_prefix(old_name)
block._rename_var(old_name, new_name)
utils.mkdir(self.helper.model_path())
with open(
os.path.join(self.helper.model_path(), "__model__"),
"wb") as f:
f.write(rename_program.desc.serialize_to_string())
for file in os.listdir(self.helper.model_path()):
if (file == "__model__" or self.get_name_prefix() in file):
continue
os.rename(
os.path.join(self.helper.model_path(), file),
os.path.join(self.helper.model_path(),
self.get_var_name_with_prefix(file)))
# create processor file
if self.processor:
self._dump_processor()
# create assets
self._dump_assets()
# create check info
checker = ModuleChecker(self.helper.module_dir)
checker.generate_check_info()
# Serialize module_desc pb
module_pb = self.desc.SerializeToString()
with open(self.helper.module_desc_path(), "wb") as f:
f.write(module_pb)
......@@ -14,7 +14,6 @@
# limitations under the License.
import sys
import time
import paddlehub as hub
import ujson
import random
......@@ -30,7 +29,7 @@ if is_py3:
import http.client as httplib
class BertService():
class BertService(object):
def __init__(self,
profile=False,
max_seq_len=128,
......@@ -42,7 +41,7 @@ class BertService():
load_balance='round_robin'):
self.process_id = process_id
self.reader_flag = False
self.batch_size = 16
self.batch_size = 0
self.max_seq_len = max_seq_len
self.profile = profile
self.model_name = model_name
......@@ -55,15 +54,6 @@ class BertService():
self.feed_var_names = ''
self.retry = retry
def connect(self, server='127.0.0.1:8010'):
self.server_list.append(server)
def connect_all_server(self, server_list):
for server_str in server_list:
self.server_list.append(server_str)
def data_convert(self, text):
if self.reader_flag == False:
module = hub.Module(name=self.model_name)
inputs, outputs, program = module.context(
trainable=True, max_seq_len=self.max_seq_len)
......@@ -79,10 +69,14 @@ class BertService():
do_lower_case=self.do_lower_case)
self.reader_flag = True
return self.reader.data_generator(
batch_size=self.batch_size, phase='predict', data=text)
def add_server(self, server='127.0.0.1:8010'):
self.server_list.append(server)
def infer(self, request_msg):
def add_server_list(self, server_list):
for server_str in server_list:
self.server_list.append(server_str)
def request_server(self, request_msg):
if self.load_balance == 'round_robin':
try:
cur_con = httplib.HTTPConnection(
......@@ -157,17 +151,13 @@ class BertService():
self.server_list)
return 'retry'
def encode(self, text):
if type(text) != list:
raise TypeError('Only support list')
def prepare_data(self, text):
self.batch_size = len(text)
data_generator = self.data_convert(text)
start = time.time()
request_time = 0
result = []
data_generator = self.reader.data_generator(
batch_size=self.batch_size, phase='predict', data=text)
request_msg = ""
for run_step, batch in enumerate(data_generator(), start=1):
request = []
copy_start = time.time()
token_list = batch[0][0].reshape(-1).tolist()
pos_list = batch[0][1].reshape(-1).tolist()
sent_list = batch[0][2].reshape(-1).tolist()
......@@ -184,54 +174,34 @@ class BertService():
si + 1) * self.max_seq_len]
request.append(instance_dict)
copy_time = time.time() - copy_start
request = {"instances": request}
request["max_seq_len"] = self.max_seq_len
request["feed_var_names"] = self.feed_var_names
request_msg = ujson.dumps(request)
if self.show_ids:
logger.info(request_msg)
request_start = time.time()
response_msg = self.infer(request_msg)
return request_msg
def encode(self, text):
if type(text) != list:
raise TypeError('Only support list')
request_msg = self.prepare_data(text)
response_msg = self.request_server(request_msg)
retry = 0
while type(response_msg) == str and response_msg == 'retry':
if retry < self.retry:
retry += 1
logger.info('Try to connect another servers')
response_msg = self.infer(request_msg)
response_msg = self.request_server(request_msg)
else:
logger.error('Infer failed after {} times retry'.format(
logger.error('Request failed after {} times retry'.format(
self.retry))
break
result = []
for msg in response_msg["instances"]:
for sample in msg["instances"]:
result.append(sample["values"])
request_time += time.time() - request_start
total_time = time.time() - start
if self.profile:
return [
total_time, request_time, response_msg['op_time'],
response_msg['infer_time'], copy_time
]
else:
return result
def connect(input_text,
model_name,
max_seq_len=128,
show_ids=False,
do_lower_case=True,
server="127.0.0.1:8866",
retry=3):
# format of input_text like [["As long as"],]
bc = BertService(
max_seq_len=max_seq_len,
model_name=model_name,
show_ids=show_ids,
do_lower_case=do_lower_case,
retry=retry)
bc.connect(server)
result = bc.encode(input_text)
return result
from paddlehub.serving.bert_serving import bert_service
class BSClient(object):
def __init__(self,
module_name,
server,
max_seq_len=20,
show_ids=False,
do_lower_case=True,
retry=3):
self.bs = bert_service.BertService(
model_name=module_name,
max_seq_len=max_seq_len,
show_ids=show_ids,
do_lower_case=do_lower_case,
retry=retry)
self.bs.add_server(server=server)
def get_result(self, input_text):
return self.bs.encode(input_text)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册