提交 4f92105e 编写于 作者: W wuzewu

modify code directory

上级 09dc8dcd
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
......@@ -12,17 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
from paddle_hub.module import Module
from paddle_hub.module import ModuleConfig
from paddle_hub.module import ModuleUtils
from paddle_hub.module import create_module
from paddle_hub.downloader import download_and_uncompress
from paddle_hub.signature import create_signature
from paddle_hub.version import __version__
connect_program = ModuleUtils.connect_program
from . import module
from . import tools
from . import data_process
from .module.module import Module, create_module
from .module.signature import Signature, create_signature
from .tools.logger import logger
from .tools.paddle_helper import connect_program
#/bin/bash
protoc -I=./ --python_out=./ module_desc.proto
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.fluid as fluid
import numpy as np
import tempfile
import os
import copy
from collections import defaultdict
from paddle_hub.downloader import download_and_uncompress
from paddle_hub import module_desc_pb2
from paddle_hub.logger import logger
from paddle_hub.signature import Signature
from paddle_hub.utils import to_list, mkdir, from_pyobj_to_flexible_data, from_flexible_data_to_pyobj
from paddle_hub.paddle_helper import from_param_to_flexible_data, get_variable_info, from_flexible_data_to_param
from paddle_hub.version import __version__
__all__ = ["Module", "ModuleConfig", "ModuleUtils"]
# paddle hub module dir name
ASSETS_DIRNAME = "assets"
META_DIRNAME = "meta"
MODEL_DIRNAME = "model"
# paddle hub module serialze file name
DICT_FILENAME = "vocab.txt"
PARAM_FILENAME = "param.pkl"
MODULE_DESC_PBNAME = "module_desc.pb"
# paddle hub var prefix
HUB_VAR_PREFIX = "@HUB@"
class Module(object):
"""
Core object of PaddleHub
"""
def __init__(self, module_url=None, module_dir=None):
if module_url == None and module_dir == None:
raise Exception("Module:module_url and module_dir are None!")
self.module_dir = ""
self.module_name = ""
# donwload module
if module_url is not None and module_url.startswith("http"):
# if it's remote url link, then download and uncompress it
self.module_name, self.module_dir = download_and_uncompress(
module_url)
#TODO(ZeyuChen): check url link is valid url
elif module_dir is not None:
# otherwise it's local path, no need to deal with it
self.module_dir = module_dir
# use the path name as module name by default
self.module_name = module_dir.split("/")[-1]
#TODO(ZeyuChen) add more check about loading module from local path
def _process_parameter(self):
global_block = self.inference_program.global_block()
param_attrs = self.config.desc.extra_info.map.data['param_attrs']
for key, param_attr in param_attrs.map.data.items():
param = from_flexible_data_to_param(param_attr)
param['name'] = HUB_VAR_PREFIX + key
if (param['name'] not in global_block.vars):
continue
var = global_block.var(param['name'])
global_block.create_parameter(
**param,
shape=var.shape,
dtype=var.dtype,
type=var.type,
lod_level=var.lod_level,
error_clip=var.error_clip,
stop_gradient=var.stop_gradient,
is_data=var.is_data)
def _process_variable_info(self):
var_infos = self.config.desc.extra_info.map.data['var_infos']
for var_info in var_infos.map.data:
idx = from_flexible_data_to_pyobj(
var_infos.map.data[var_info].map.data['block_id'])
stop_gradient = from_flexible_data_to_pyobj(
var_infos.map.data[var_info].map.data['stop_gradient'])
block = self.inference_program.blocks[idx]
var_name = HUB_VAR_PREFIX + var_info
if var_name in block.vars:
var = block.vars[var_name]
var.stop_gradient = stop_gradient
def __call__(self, sign_name="default", trainable=False):
""" Call default signature and return results
"""
def _set_param_trainable(program, trainable=False):
for param in program.global_block().iter_parameters():
param.trainable = trainable
def _process_op_attr(program, is_test=False):
for op in program.global_block().ops:
if op.has_attr("is_test"):
op._set_attr("is_test", is_test)
def _process_input_output_key(module_desc, signature):
signature = module_desc.sign2var[signature]
feed_dict = {}
fetch_dict = {}
for index, feed in enumerate(signature.feed_desc):
if feed.alias != "":
feed_dict[feed.alias] = feed.var_name
feed_dict[index] = feed.var_name
for index, fetch in enumerate(signature.fetch_desc):
if fetch.alias != "":
fetch_dict[fetch.alias] = fetch.var_name
fetch_dict[index] = fetch.var_name
return feed_dict, fetch_dict
self.config = ModuleConfig(self.module_dir)
self.config.load()
# load paddle inference model
place = fluid.CPUPlace()
model_dir = os.path.join(self.module_dir, MODEL_DIRNAME)
self.exe = fluid.Executor(fluid.CPUPlace())
self.inference_program, self.feed_target_names, self.fetch_targets = fluid.io.load_inference_model(
model_dir, executor=self.exe)
feed_dict, fetch_dict = _process_input_output_key(
self.config.desc, sign_name)
# remove feed fetch operator and variable
ModuleUtils.remove_feed_fetch_op(self.inference_program)
logger.info("**feed_target_names**\n{}".format(self.feed_target_names))
logger.info("**fetch_targets**\n{}".format(self.fetch_targets))
self._process_parameter()
self._process_variable_info()
_process_op_attr(program=self.inference_program, is_test=False)
_set_param_trainable(
program=self.inference_program, trainable=trainable)
for key, value in feed_dict.items():
var = self.inference_program.global_block().var(HUB_VAR_PREFIX +
value)
feed_dict[key] = var
for key, value in fetch_dict.items():
var = self.inference_program.global_block().var(HUB_VAR_PREFIX +
value)
fetch_dict[key] = var
return feed_dict, fetch_dict, self.inference_program
def get_inference_program(self):
return self.inference_program
# for text sequence input, transform to lod tensor as paddle graph's input
def _preprocess_input(self, inputs):
# words id mapping and dealing with oov
# transform to lod tensor
seq = []
for s in inputs:
seq.append(self._word_id_mapping(s))
lod_tensor = self.seq2lod_tensor(seq)
return lod_tensor
def seq2lod_tensor(self, seq_inputs, place=fluid.CPUPlace()):
""" sequence to lod tensor, need to determine which space"""
lod = []
lod.append([])
for s in seq_inputs:
# generate lod
lod[0].append(len(s))
# print("seq", seq_inputs)
# print("lod", lod)
lod_tensor = fluid.create_lod_tensor(seq_inputs, lod, place)
return lod_tensor
def _word_id_mapping(self, inputs):
word_dict = self.config.get_assets_vocab()
return list(map(lambda x: word_dict[x], inputs))
class ModuleConfig(object):
def __init__(self, module_dir, module_name=None):
# generate model desc protobuf
self.module_dir = module_dir
self.desc = module_desc_pb2.ModuleDesc()
if module_name == None:
module_name = module_dir.split("/")[-1]
# initialize module config default value
self.desc.name = module_name
self.desc.contain_assets = True
self.desc.return_numpy = False
# init dict
self.dict = defaultdict(int)
self.dict.setdefault(0)
def get_assets_vocab(self):
""" Return dictionary in Module"""
return self.dict
def load(self):
"""
Load module config from module directory.
"""
#TODO(ZeyuChen): check module_desc.pb exsitance
with open(ModuleConfig.module_desc_path(self.module_dir), "rb") as fi:
self.desc.ParseFromString(fi.read())
if self.desc.contain_assets:
# load assets
word_id = 0
with open(ModuleConfig.assets_dict_path(self.module_dir)) as fi:
words = fi.readlines()
#TODO(ZeyuChen) check whether word id is duplicated and valid
for line in fi:
w, w_id = line.split()
self.dict[w] = int(w_id)
def return_numpy(self):
"""Return numpy or not according to the proto config.
"""
return self.desc.return_numpy
def save_dict(self, word_dict, dict_name=DICT_FILENAME):
""" Save dictionary for NLP module
"""
for w in word_dict:
self.dict[w] = word_dict[w]
@staticmethod
def module_desc_path(module_dir):
return os.path.join(module_dir, MODULE_DESC_PBNAME)
@staticmethod
def assets_dict_path(module_dir):
assets_path = os.path.join(module_dir, ASSETS_DIRNAME)
mkdir(assets_path)
return os.path.join(assets_path, DICT_FILENAME)
@staticmethod
def meta_param_path(module_dir):
meta_path = os.path.join(module_dir, META_DIRNAME)
mkdir(meta_path)
return os.path.join(meta_path, PARAM_FILENAME)
def create_module(sign_arr, module_dir=None, word_dict=None, exe=None):
""" Create a module from main program
"""
assert sign_arr, "signature array should not be None"
# check all variable
sign_arr = to_list(sign_arr)
program = sign_arr[0].get_inputs()[0].block.program
feeded_var_names = set()
target_vars = set()
for sign in sign_arr:
assert isinstance(sign,
Signature), "sign_arr should be list of Signature"
for input in sign.get_inputs():
feeded_var_names.add(input.name)
_tmp_program = input.block.program
assert program == _tmp_program, "all the variable should come from the same program"
for output in sign.get_outputs():
target_vars.add(output)
_tmp_program = output.block.program
assert program == _tmp_program, "all the variable should come from the same program"
# create module path for saving
if module_dir is None:
module_dir = os.path.join(".", "hub_module")
mkdir(module_dir)
# create module pb
module_desc = module_desc_pb2.ModuleDesc()
module_desc.auth_info.hub_version = __version__
module_desc.auth_info.paddle_version = paddle.__version__
logger.info("hub version is %s" % __version__)
logger.info("paddle version is %s" % paddle.__version__)
# save asset
if word_dict is None:
module_desc.contain_assets = False
else:
module_desc.contain_assets = True
with open(ModuleConfig.assets_dict_path(module_dir), "w") as fo:
for w in word_dict:
w_id = word_dict[w]
fo.write("{}\t{}\n".format(w, w_id))
# save fluid Parameter
extra_info = module_desc.extra_info
extra_info.type = module_desc_pb2.MAP
param_attrs = extra_info.map.data['param_attrs']
param_attrs.type = module_desc_pb2.MAP
for param in program.global_block().iter_parameters():
param_attr = param_attrs.map.data[param.name]
from_param_to_flexible_data(param, param_attr)
# save Variable Info
var_infos = extra_info.map.data['var_infos']
var_infos.type = module_desc_pb2.MAP
for block in program.blocks:
for var in block.vars.values():
var_info = var_infos.map.data[var.name]
var_info.type = module_desc_pb2.MAP
from_pyobj_to_flexible_data(var.stop_gradient,
var_info.map.data['stop_gradient'])
from_pyobj_to_flexible_data(block.idx,
var_info.map.data['block_id'])
# save signarture info
sign_map = module_desc.sign2var
for sign in sign_arr:
if sign.get_name() in sign_map:
raise "Error! sign_arr contains repeat signatrue %s" % sign
var = sign_map[sign.get_name()]
feed_desc = var.feed_desc
fetch_desc = var.fetch_desc
feed_names = sign.get_feed_names()
fetch_names = sign.get_fetch_names()
for index, input in enumerate(sign.get_inputs()):
feed_var = feed_desc.add()
feed_var.var_name = input.name
feed_var.alias = feed_names[index]
for index, output in enumerate(sign.get_outputs()):
fetch_var = fetch_desc.add()
fetch_var.var_name = output.name
fetch_var.alias = fetch_names[index]
# save inference program
program = program.clone()
if not exe:
place = fluid.CPUPlace()
exe = fluid.Executor(place=place)
save_model_dir = os.path.join(module_dir, "model")
mkdir(save_model_dir)
fluid.io.save_inference_model(
save_model_dir,
feeded_var_names=list(feeded_var_names),
target_vars=list(target_vars),
main_program=program,
executor=exe)
with open(os.path.join(save_model_dir, "__model__"), "rb") as file:
program_desc_str = file.read()
rename_program = fluid.framework.Program.parse_from_string(
program_desc_str)
varlist = {
var: block
for block in rename_program.blocks for var in block.vars
if HUB_VAR_PREFIX not in var
}
for var, block in varlist.items():
old_name = var
new_name = HUB_VAR_PREFIX + old_name
block._rename_var(old_name, new_name)
mkdir(save_model_dir)
with open(os.path.join(save_model_dir, "__model__"), "wb") as f:
f.write(rename_program.desc.serialize_to_string())
for file in os.listdir(save_model_dir):
if (file == "__model__" or HUB_VAR_PREFIX in file):
continue
os.rename(
os.path.join(save_model_dir, file),
os.path.join(save_model_dir, HUB_VAR_PREFIX + file))
# Serialize module_desc pb
module_pb = module_desc.SerializeToString()
with open(ModuleConfig.module_desc_path(module_dir), "wb") as f:
f.write(module_pb)
class ModuleUtils(object):
def __init__(self):
pass
@staticmethod
def connect_program(pre_program, next_program, input_dict=None):
def _copy_vars_and_ops_in_blocks(from_block, to_block):
for var in from_block.vars:
var = from_block.var(var)
var_info = copy.deepcopy(get_variable_info(var))
if isinstance(var, fluid.framework.Parameter):
to_block.create_parameter(**var_info)
else:
to_block.create_var(**var_info)
for op in from_block.ops:
op_info = {
'type': op.type,
'inputs': {
input: [block.var(var) for var in op.input(input)]
for input in op.input_names
},
'outputs': {
output: [block.var(var) for var in op.output(output)]
for output in op.output_names
},
'attrs': copy.deepcopy(op.all_attrs())
}
to_block.append_op(**op_info)
assert isinstance(pre_program,
fluid.Program), "pre_program should be fluid.Program"
assert isinstance(next_program,
fluid.Program), "next_program should be fluid.Program"
new_program = pre_program.clone()
if input_dict:
assert isinstance(
input_dict, dict
), "the input_dict should be a dict with string-Variable pair"
for key, var in input_dict.items():
assert isinstance(
var, fluid.framework.Variable
), "the input_dict should be a dict with string-Variable pair"
var_info = copy.deepcopy(get_variable_info(var))
input_var = new_program.global_block().create_var(**var_info)
output_var = next_program.global_block().var(key)
var_info = copy.deepcopy(get_variable_info(output_var))
output_var = new_program.global_block().create_var(**var_info)
new_program.global_block().append_op(
type="assign",
inputs={'X': input_var},
outputs={'Out': output_var})
block_map = {0: 0}
logger.info("start to connect program")
for index, block in enumerate(next_program.blocks):
if block.idx == 0:
_copy_vars_and_ops_in_blocks(block, new_program.global_block())
else:
block_map[index] = len(new_program.blocks)
logger.info(
"block_%d in next_program merge into block_%d in pre_program"
% (index, block_map[index]))
new_block = new_program._create_block(
parent_idx=block_map[block.parent_idx])
_copy_vars_and_ops_in_blocks(block, new_block)
logger.info("end of connect program")
return new_program
@staticmethod
def remove_feed_fetch_op(program):
""" remove feed and fetch operator and variable for fine-tuning
"""
logger.info("remove feed fetch op")
block = program.global_block()
need_to_remove_op_index = []
for i, op in enumerate(block.ops):
if op.type == "feed" or op.type == "fetch":
need_to_remove_op_index.append(i)
for index in need_to_remove_op_index[::-1]:
block._remove_op(index)
# TODO(wuzewu): get feed and fetch var by other way
block._remove_var(HUB_VAR_PREFIX + "feed")
block._remove_var(HUB_VAR_PREFIX + "fetch")
program.desc.flush()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import module
from . import signature
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle_hub.tools import utils
from paddle_hub.tools.logger import logger
from paddle_hub.tools import downloader
from paddle_hub.tools import paddle_helper
from paddle_hub.module import module_desc_pb2
from paddle_hub.module.signature import Signature, create_signature
from paddle_hub import version
import os
import paddle
import paddle.fluid as fluid
__all__ = ['Module', 'create_module']
def create_module(sign_arr, module_dir, exe=None):
sign_arr = utils.to_list(sign_arr)
module = Module(signatures=sign_arr)
module.serialize_to_path(path=module_dir, exe=exe)
# paddle hub module dir name
ASSETS_DIRNAME = "assets"
MODEL_DIRNAME = "model"
MODULE_DESC_PBNAME = "module_desc.pb"
# paddle hub var prefix
HUB_VAR_PREFIX = "@HUB@"
class ModuleWrapper:
def __init__(self, module, name):
self.module = module
self.name = name
def __call__(self, trainable=False):
return self.module(self.name, trainable)
class ModuleHelper:
def __init__(self, module_dir):
self.module_dir = module_dir
def module_desc_path(self):
return os.path.join(self.module_dir, MODULE_DESC_PBNAME)
def model_path(self):
return os.path.join(self.module_dir, MODEL_DIRNAME)
class Module:
def __init__(self, url=None, module_dir=None, signatures=None, name=None):
if not name:
name = "HubModule"
self.name = name
self.desc = module_desc_pb2.ModuleDesc()
self.program = None
self.assets = []
self.helper = None
self.signatures = {}
if url:
self._init_with_url(url=url)
elif module_dir:
self._init_with_module_file(module_dir=module_dir)
elif signatures:
self._init_with_signature(signatures=signatures)
else:
raise "Error! HubModule Can't init with nothing"
def _init_with_url(self, url):
utils.check_url_valid(url)
module_dir = downloader.download_and_uncompress(module_url)
self._init_with_module_file(module_dir)
def _init_with_module_file(self, module_dir):
self.helper = ModuleHelper(module_dir)
with open(self.helper.module_desc_path(), "rb") as fi:
self.desc.ParseFromString(fi.read())
exe = fluid.Executor(fluid.CPUPlace())
self.program, _, _ = fluid.io.load_inference_model(
self.helper.model_path(), executor=exe)
self._recovery_parameter(self.program)
self._recover_variable_info(self.program)
inputs = []
outputs = []
feed_names = []
fetch_names = []
for sign, module_var in self.desc.sign2var.items():
for var in module_var.feed_desc:
variable = self.program.global_block().vars[var.var_name]
inputs.append(variable)
feed_names.append(var.alias)
for var in module_var.fetch_desc:
variable = self.program.global_block().vars[var.var_name]
outputs.append(variable)
fetch_names.append(var.alias)
self.signatures[sign] = create_signature(
sign,
inputs=inputs,
outputs=outputs,
feed_names=feed_names,
fetch_names=fetch_names)
self._generate_sign_attr()
def _init_with_signature(self, signatures):
self._process_signatures(signatures)
self._check_signatures()
self._generate_desc()
self._generate_sign_attr()
def _init_with_program(self, program):
pass
def _process_signatures(self, signatures):
self.signatures = {}
self.program = signatures[0].inputs[0].block.program
for sign in signatures:
if sign.name in self.signatures:
raise "Error! signature array contains repeat signatrue %s" % sign
self.signatures[sign.name] = sign
def _recovery_parameter(self, program):
global_block = self.program.global_block()
param_attrs = self.desc.extra_info.map.data['param_attrs']
for key, param_attr in param_attrs.map.data.items():
param = paddle_helper.from_flexible_data_to_param(param_attr)
param['name'] = HUB_VAR_PREFIX + key
if (param['name'] not in global_block.vars):
continue
var = global_block.var(param['name'])
global_block.create_parameter(
**param,
shape=var.shape,
dtype=var.dtype,
type=var.type,
lod_level=var.lod_level,
error_clip=var.error_clip,
stop_gradient=var.stop_gradient,
is_data=var.is_data)
def _recover_variable_info(self, program):
var_infos = self.desc.extra_info.map.data['var_infos']
for var_info in var_infos.map.data:
idx = utils.from_flexible_data_to_pyobj(
var_infos.map.data[var_info].map.data['block_id'])
stop_gradient = utils.from_flexible_data_to_pyobj(
var_infos.map.data[var_info].map.data['stop_gradient'])
block = program.blocks[idx]
var_name = HUB_VAR_PREFIX + var_info
if var_name in block.vars:
var = block.vars[var_name]
var.stop_gradient = stop_gradient
def _generate_sign_attr(self):
self._check_signatures()
for sign in self.signatures:
self.__dict__[sign] = ModuleWrapper(self, sign)
def _generate_desc(self):
# save fluid Parameter
extra_info = self.desc.extra_info
extra_info.type = module_desc_pb2.MAP
param_attrs = extra_info.map.data['param_attrs']
param_attrs.type = module_desc_pb2.MAP
for param in self.program.global_block().iter_parameters():
param_attr = param_attrs.map.data[param.name]
paddle_helper.from_param_to_flexible_data(param, param_attr)
# save Variable Info
var_infos = extra_info.map.data['var_infos']
var_infos.type = module_desc_pb2.MAP
for block in self.program.blocks:
for var in block.vars.values():
var_info = var_infos.map.data[var.name]
var_info.type = module_desc_pb2.MAP
utils.from_pyobj_to_flexible_data(
var.stop_gradient, var_info.map.data['stop_gradient'])
utils.from_pyobj_to_flexible_data(block.idx,
var_info.map.data['block_id'])
# save signarture info
for key, sign in self.signatures.items():
var = self.desc.sign2var[sign.name]
feed_desc = var.feed_desc
fetch_desc = var.fetch_desc
feed_names = sign.feed_names
fetch_names = sign.fetch_names
for index, input in enumerate(sign.inputs):
feed_var = feed_desc.add()
feed_var.var_name = HUB_VAR_PREFIX + input.name
feed_var.alias = feed_names[index]
for index, output in enumerate(sign.outputs):
fetch_var = fetch_desc.add()
fetch_var.var_name = HUB_VAR_PREFIX + output.name
fetch_var.alias = fetch_names[index]
def __call__(self, sign_name, trainable=False):
assert sign_name in self.signatures, "module did not have a signature with name %s" % sign_name
signature = self.signatures[sign_name]
program = self.program.clone()
paddle_helper.remove_feed_fetch_op(program)
paddle_helper.set_parameter_trainable(program, trainable)
paddle_helper.set_op_attr(program, is_test=False)
self._recovery_parameter(program)
self._recover_variable_info(program)
feed_dict = {}
fetch_dict = {}
for index, var in enumerate(signature.inputs):
feed_dict[index] = program.global_block().var(var.name)
key = signature.feed_names[index]
if key:
feed_dict[key] = program.global_block().var(var.name)
for index, var in enumerate(signature.outputs):
fetch_dict[index] = program.global_block().var(var.name)
key = signature.fetch_names[index]
if key:
fetch_dict[key] = program.global_block().var(var.name)
return feed_dict, fetch_dict, program
def preprocess(self):
pass
def postprocess(self):
pass
def parameters(self):
pass
def parameter_attrs(self):
pass
def _check_signatures(self):
assert self.signatures, "signature array should not be None"
for key, sign in self.signatures.items():
assert isinstance(sign,
Signature), "sign_arr should be list of Signature"
for input in sign.inputs:
_tmp_program = input.block.program
assert self.program == _tmp_program, "all the variable should come from the same program"
for output in sign.outputs:
_tmp_program = output.block.program
assert self.program == _tmp_program, "all the variable should come from the same program"
def serialize_to_path(self, path=None, exe=None):
self._check_signatures()
self._generate_desc()
# create module path for saving
if path is None:
path = os.path.join(".", self.name)
self.helper = ModuleHelper(path)
utils.mkdir(self.helper.module_dir)
# create module pb
module_desc = module_desc_pb2.ModuleDesc()
logger.info("hub version is %s" % version.hub_version)
logger.info("proto version is %s" % version.proto_version)
logger.info("paddle version is %s" % paddle.__version__)
for asset in self.assets:
pass
feeded_var_names = [
input.name for key, sign in self.signatures.items()
for input in sign.inputs
]
target_vars = [
output for key, sign in self.signatures.items()
for output in sign.outputs
]
feeded_var_names = list(set(feeded_var_names))
target_vars = list(set(target_vars))
# save inference program
program = self.program.clone()
if not exe:
place = fluid.CPUPlace()
exe = fluid.Executor(place=place)
utils.mkdir(self.helper.model_path())
fluid.io.save_inference_model(
self.helper.model_path(),
feeded_var_names=list(feeded_var_names),
target_vars=list(target_vars),
main_program=program,
executor=exe)
with open(os.path.join(self.helper.model_path(), "__model__"),
"rb") as file:
program_desc_str = file.read()
rename_program = fluid.framework.Program.parse_from_string(
program_desc_str)
varlist = {
var: block
for block in rename_program.blocks for var in block.vars
if HUB_VAR_PREFIX not in var
}
for var, block in varlist.items():
old_name = var
new_name = HUB_VAR_PREFIX + old_name
block._rename_var(old_name, new_name)
utils.mkdir(self.helper.model_path())
with open(
os.path.join(self.helper.model_path(), "__model__"),
"wb") as f:
f.write(rename_program.desc.serialize_to_string())
for file in os.listdir(self.helper.model_path()):
if (file == "__model__" or HUB_VAR_PREFIX in file):
continue
os.rename(
os.path.join(self.helper.model_path(), file),
os.path.join(self.helper.model_path(),
HUB_VAR_PREFIX + file))
# Serialize module_desc pb
module_pb = self.desc.SerializeToString()
with open(self.helper.module_desc_path(), "wb") as f:
f.write(module_pb)
......@@ -16,7 +16,7 @@
syntax = "proto3";
option optimize_for = LITE_RUNTIME;
package paddle_hub;
package paddle_hub_module;
enum DataType {
NONE = 0;
......@@ -72,7 +72,7 @@ message AuthInfo {
string hub_version = 2;
}
// A Hub Module is stored in a directory with a file 'paddlehub.pb'
// A Hub Module is stored in a directory with a file 'module_desc.pb'
// containing a serialized protocol message of this type. The further contents
// of the directory depend on the storage format described by the message.
message ModuleDesc {
......
......@@ -15,16 +15,16 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='module_desc.proto',
package='paddle_hub',
package='paddle_hub_module',
syntax='proto3',
serialized_pb=_b(
'\n\x11module_desc.proto\x12\npaddle_hub\"\xf3\x01\n\x06KVData\x12\x30\n\x07keyType\x18\x01 \x03(\x0b\x32\x1f.paddle_hub.KVData.KeyTypeEntry\x12*\n\x04\x64\x61ta\x18\x02 \x03(\x0b\x32\x1c.paddle_hub.KVData.DataEntry\x1a\x44\n\x0cKeyTypeEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0e\x32\x14.paddle_hub.DataType:\x02\x38\x01\x1a\x45\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\'\n\x05value\x18\x02 \x01(\x0b\x32\x18.paddle_hub.FlexibleData:\x02\x38\x01\"\x82\x02\n\x0c\x46lexibleData\x12\"\n\x04type\x18\x01 \x01(\x0e\x32\x14.paddle_hub.DataType\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\t\n\x01i\x18\x03 \x01(\x03\x12\t\n\x01\x66\x18\x04 \x01(\x01\x12\t\n\x01\x62\x18\x05 \x01(\x08\x12\t\n\x01s\x18\x06 \x01(\t\x12\x1f\n\x03map\x18\x07 \x01(\x0b\x32\x12.paddle_hub.KVData\x12 \n\x04list\x18\x08 \x01(\x0b\x32\x12.paddle_hub.KVData\x12\x1f\n\x03set\x18\t \x01(\x0b\x32\x12.paddle_hub.KVData\x12\"\n\x06object\x18\n \x01(\x0b\x32\x12.paddle_hub.KVData\x12\x0c\n\x04info\x18\x0b \x01(\t\"+\n\x08\x46\x65\x65\x64\x44\x65sc\x12\x10\n\x08var_name\x18\x01 \x01(\t\x12\r\n\x05\x61lias\x18\x02 \x01(\t\",\n\tFetchDesc\x12\x10\n\x08var_name\x18\x01 \x01(\t\x12\r\n\x05\x61lias\x18\x02 \x01(\t\"_\n\tModuleVar\x12)\n\nfetch_desc\x18\x01 \x03(\x0b\x32\x15.paddle_hub.FetchDesc\x12\'\n\tfeed_desc\x18\x02 \x03(\x0b\x32\x14.paddle_hub.FeedDesc\"7\n\x08\x41uthInfo\x12\x16\n\x0epaddle_version\x18\x01 \x01(\t\x12\x13\n\x0bhub_version\x18\x02 \x01(\t\"\x9f\x02\n\nModuleDesc\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x36\n\x08sign2var\x18\x02 \x03(\x0b\x32$.paddle_hub.ModuleDesc.Sign2varEntry\x12\x14\n\x0creturn_numpy\x18\x03 \x01(\x08\x12\x16\n\x0e\x63ontain_assets\x18\x04 \x01(\x08\x12\'\n\tauth_info\x18\x05 \x01(\x0b\x32\x14.paddle_hub.AuthInfo\x12,\n\nextra_info\x18\x06 \x01(\x0b\x32\x18.paddle_hub.FlexibleData\x1a\x46\n\rSign2varEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.paddle_hub.ModuleVar:\x02\x38\x01*i\n\x08\x44\x61taType\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03INT\x10\x01\x12\t\n\x05\x46LOAT\x10\x02\x12\n\n\x06STRING\x10\x03\x12\x0b\n\x07\x42OOLEAN\x10\x04\x12\x08\n\x04LIST\x10\x05\x12\x07\n\x03MAP\x10\x06\x12\x07\n\x03SET\x10\x07\x12\n\n\x06OBJECT\x10\x08\x42\x02H\x03\x62\x06proto3'
'\n\x11module_desc.proto\x12\x11paddle_hub_module\"\x8f\x02\n\x06KVData\x12\x37\n\x07keyType\x18\x01 \x03(\x0b\x32&.paddle_hub_module.KVData.KeyTypeEntry\x12\x31\n\x04\x64\x61ta\x18\x02 \x03(\x0b\x32#.paddle_hub_module.KVData.DataEntry\x1aK\n\x0cKeyTypeEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12*\n\x05value\x18\x02 \x01(\x0e\x32\x1b.paddle_hub_module.DataType:\x02\x38\x01\x1aL\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12.\n\x05value\x18\x02 \x01(\x0b\x32\x1f.paddle_hub_module.FlexibleData:\x02\x38\x01\"\xa5\x02\n\x0c\x46lexibleData\x12)\n\x04type\x18\x01 \x01(\x0e\x32\x1b.paddle_hub_module.DataType\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\t\n\x01i\x18\x03 \x01(\x03\x12\t\n\x01\x66\x18\x04 \x01(\x01\x12\t\n\x01\x62\x18\x05 \x01(\x08\x12\t\n\x01s\x18\x06 \x01(\t\x12&\n\x03map\x18\x07 \x01(\x0b\x32\x19.paddle_hub_module.KVData\x12\'\n\x04list\x18\x08 \x01(\x0b\x32\x19.paddle_hub_module.KVData\x12&\n\x03set\x18\t \x01(\x0b\x32\x19.paddle_hub_module.KVData\x12)\n\x06object\x18\n \x01(\x0b\x32\x19.paddle_hub_module.KVData\x12\x0c\n\x04info\x18\x0b \x01(\t\"+\n\x08\x46\x65\x65\x64\x44\x65sc\x12\x10\n\x08var_name\x18\x01 \x01(\t\x12\r\n\x05\x61lias\x18\x02 \x01(\t\",\n\tFetchDesc\x12\x10\n\x08var_name\x18\x01 \x01(\t\x12\r\n\x05\x61lias\x18\x02 \x01(\t\"m\n\tModuleVar\x12\x30\n\nfetch_desc\x18\x01 \x03(\x0b\x32\x1c.paddle_hub_module.FetchDesc\x12.\n\tfeed_desc\x18\x02 \x03(\x0b\x32\x1b.paddle_hub_module.FeedDesc\"7\n\x08\x41uthInfo\x12\x16\n\x0epaddle_version\x18\x01 \x01(\t\x12\x13\n\x0bhub_version\x18\x02 \x01(\t\"\xbb\x02\n\nModuleDesc\x12\x0c\n\x04name\x18\x01 \x01(\t\x12=\n\x08sign2var\x18\x02 \x03(\x0b\x32+.paddle_hub_module.ModuleDesc.Sign2varEntry\x12\x14\n\x0creturn_numpy\x18\x03 \x01(\x08\x12\x16\n\x0e\x63ontain_assets\x18\x04 \x01(\x08\x12.\n\tauth_info\x18\x05 \x01(\x0b\x32\x1b.paddle_hub_module.AuthInfo\x12\x33\n\nextra_info\x18\x06 \x01(\x0b\x32\x1f.paddle_hub_module.FlexibleData\x1aM\n\rSign2varEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12+\n\x05value\x18\x02 \x01(\x0b\x32\x1c.paddle_hub_module.ModuleVar:\x02\x38\x01*i\n\x08\x44\x61taType\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03INT\x10\x01\x12\t\n\x05\x46LOAT\x10\x02\x12\n\n\x06STRING\x10\x03\x12\x0b\n\x07\x42OOLEAN\x10\x04\x12\x08\n\x04LIST\x10\x05\x12\x07\n\x03MAP\x10\x06\x12\x07\n\x03SET\x10\x07\x12\n\n\x06OBJECT\x10\x08\x42\x02H\x03\x62\x06proto3'
))
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
_DATATYPE = _descriptor.EnumDescriptor(
name='DataType',
full_name='paddle_hub.DataType',
full_name='paddle_hub_module.DataType',
filename=None,
file=DESCRIPTOR,
values=[
......@@ -49,8 +49,8 @@ _DATATYPE = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
serialized_start=1075,
serialized_end=1180,
serialized_start=1187,
serialized_end=1292,
)
_sym_db.RegisterEnumDescriptor(_DATATYPE)
......@@ -67,14 +67,14 @@ OBJECT = 8
_KVDATA_KEYTYPEENTRY = _descriptor.Descriptor(
name='KeyTypeEntry',
full_name='paddle_hub.KVData.KeyTypeEntry',
full_name='paddle_hub_module.KVData.KeyTypeEntry',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='key',
full_name='paddle_hub.KVData.KeyTypeEntry.key',
full_name='paddle_hub_module.KVData.KeyTypeEntry.key',
index=0,
number=1,
type=9,
......@@ -90,7 +90,7 @@ _KVDATA_KEYTYPEENTRY = _descriptor.Descriptor(
options=None),
_descriptor.FieldDescriptor(
name='value',
full_name='paddle_hub.KVData.KeyTypeEntry.value',
full_name='paddle_hub_module.KVData.KeyTypeEntry.value',
index=1,
number=2,
type=14,
......@@ -114,20 +114,20 @@ _KVDATA_KEYTYPEENTRY = _descriptor.Descriptor(
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=138,
serialized_end=206,
serialized_start=159,
serialized_end=234,
)
_KVDATA_DATAENTRY = _descriptor.Descriptor(
name='DataEntry',
full_name='paddle_hub.KVData.DataEntry',
full_name='paddle_hub_module.KVData.DataEntry',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='key',
full_name='paddle_hub.KVData.DataEntry.key',
full_name='paddle_hub_module.KVData.DataEntry.key',
index=0,
number=1,
type=9,
......@@ -143,7 +143,7 @@ _KVDATA_DATAENTRY = _descriptor.Descriptor(
options=None),
_descriptor.FieldDescriptor(
name='value',
full_name='paddle_hub.KVData.DataEntry.value',
full_name='paddle_hub_module.KVData.DataEntry.value',
index=1,
number=2,
type=11,
......@@ -167,20 +167,20 @@ _KVDATA_DATAENTRY = _descriptor.Descriptor(
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=208,
serialized_end=277,
serialized_start=236,
serialized_end=312,
)
_KVDATA = _descriptor.Descriptor(
name='KVData',
full_name='paddle_hub.KVData',
full_name='paddle_hub_module.KVData',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='keyType',
full_name='paddle_hub.KVData.keyType',
full_name='paddle_hub_module.KVData.keyType',
index=0,
number=1,
type=11,
......@@ -196,7 +196,7 @@ _KVDATA = _descriptor.Descriptor(
options=None),
_descriptor.FieldDescriptor(
name='data',
full_name='paddle_hub.KVData.data',
full_name='paddle_hub_module.KVData.data',
index=1,
number=2,
type=11,
......@@ -222,20 +222,20 @@ _KVDATA = _descriptor.Descriptor(
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=34,
serialized_end=277,
serialized_start=41,
serialized_end=312,
)
_FLEXIBLEDATA = _descriptor.Descriptor(
name='FlexibleData',
full_name='paddle_hub.FlexibleData',
full_name='paddle_hub_module.FlexibleData',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='type',
full_name='paddle_hub.FlexibleData.type',
full_name='paddle_hub_module.FlexibleData.type',
index=0,
number=1,
type=14,
......@@ -251,7 +251,7 @@ _FLEXIBLEDATA = _descriptor.Descriptor(
options=None),
_descriptor.FieldDescriptor(
name='name',
full_name='paddle_hub.FlexibleData.name',
full_name='paddle_hub_module.FlexibleData.name',
index=1,
number=2,
type=9,
......@@ -267,7 +267,7 @@ _FLEXIBLEDATA = _descriptor.Descriptor(
options=None),
_descriptor.FieldDescriptor(
name='i',
full_name='paddle_hub.FlexibleData.i',
full_name='paddle_hub_module.FlexibleData.i',
index=2,
number=3,
type=3,
......@@ -283,7 +283,7 @@ _FLEXIBLEDATA = _descriptor.Descriptor(
options=None),
_descriptor.FieldDescriptor(
name='f',
full_name='paddle_hub.FlexibleData.f',
full_name='paddle_hub_module.FlexibleData.f',
index=3,
number=4,
type=1,
......@@ -299,7 +299,7 @@ _FLEXIBLEDATA = _descriptor.Descriptor(
options=None),
_descriptor.FieldDescriptor(
name='b',
full_name='paddle_hub.FlexibleData.b',
full_name='paddle_hub_module.FlexibleData.b',
index=4,
number=5,
type=8,
......@@ -315,7 +315,7 @@ _FLEXIBLEDATA = _descriptor.Descriptor(
options=None),
_descriptor.FieldDescriptor(
name='s',
full_name='paddle_hub.FlexibleData.s',
full_name='paddle_hub_module.FlexibleData.s',
index=5,
number=6,
type=9,
......@@ -331,7 +331,7 @@ _FLEXIBLEDATA = _descriptor.Descriptor(
options=None),
_descriptor.FieldDescriptor(
name='map',
full_name='paddle_hub.FlexibleData.map',
full_name='paddle_hub_module.FlexibleData.map',
index=6,
number=7,
type=11,
......@@ -347,7 +347,7 @@ _FLEXIBLEDATA = _descriptor.Descriptor(
options=None),
_descriptor.FieldDescriptor(
name='list',
full_name='paddle_hub.FlexibleData.list',
full_name='paddle_hub_module.FlexibleData.list',
index=7,
number=8,
type=11,
......@@ -363,7 +363,7 @@ _FLEXIBLEDATA = _descriptor.Descriptor(
options=None),
_descriptor.FieldDescriptor(
name='set',
full_name='paddle_hub.FlexibleData.set',
full_name='paddle_hub_module.FlexibleData.set',
index=8,
number=9,
type=11,
......@@ -379,7 +379,7 @@ _FLEXIBLEDATA = _descriptor.Descriptor(
options=None),
_descriptor.FieldDescriptor(
name='object',
full_name='paddle_hub.FlexibleData.object',
full_name='paddle_hub_module.FlexibleData.object',
index=9,
number=10,
type=11,
......@@ -395,7 +395,7 @@ _FLEXIBLEDATA = _descriptor.Descriptor(
options=None),
_descriptor.FieldDescriptor(
name='info',
full_name='paddle_hub.FlexibleData.info',
full_name='paddle_hub_module.FlexibleData.info',
index=10,
number=11,
type=9,
......@@ -418,20 +418,20 @@ _FLEXIBLEDATA = _descriptor.Descriptor(
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=280,
serialized_end=538,
serialized_start=315,
serialized_end=608,
)
_FEEDDESC = _descriptor.Descriptor(
name='FeedDesc',
full_name='paddle_hub.FeedDesc',
full_name='paddle_hub_module.FeedDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='var_name',
full_name='paddle_hub.FeedDesc.var_name',
full_name='paddle_hub_module.FeedDesc.var_name',
index=0,
number=1,
type=9,
......@@ -447,7 +447,7 @@ _FEEDDESC = _descriptor.Descriptor(
options=None),
_descriptor.FieldDescriptor(
name='alias',
full_name='paddle_hub.FeedDesc.alias',
full_name='paddle_hub_module.FeedDesc.alias',
index=1,
number=2,
type=9,
......@@ -470,20 +470,20 @@ _FEEDDESC = _descriptor.Descriptor(
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=540,
serialized_end=583,
serialized_start=610,
serialized_end=653,
)
_FETCHDESC = _descriptor.Descriptor(
name='FetchDesc',
full_name='paddle_hub.FetchDesc',
full_name='paddle_hub_module.FetchDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='var_name',
full_name='paddle_hub.FetchDesc.var_name',
full_name='paddle_hub_module.FetchDesc.var_name',
index=0,
number=1,
type=9,
......@@ -499,7 +499,7 @@ _FETCHDESC = _descriptor.Descriptor(
options=None),
_descriptor.FieldDescriptor(
name='alias',
full_name='paddle_hub.FetchDesc.alias',
full_name='paddle_hub_module.FetchDesc.alias',
index=1,
number=2,
type=9,
......@@ -522,20 +522,20 @@ _FETCHDESC = _descriptor.Descriptor(
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=585,
serialized_end=629,
serialized_start=655,
serialized_end=699,
)
_MODULEVAR = _descriptor.Descriptor(
name='ModuleVar',
full_name='paddle_hub.ModuleVar',
full_name='paddle_hub_module.ModuleVar',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='fetch_desc',
full_name='paddle_hub.ModuleVar.fetch_desc',
full_name='paddle_hub_module.ModuleVar.fetch_desc',
index=0,
number=1,
type=11,
......@@ -551,7 +551,7 @@ _MODULEVAR = _descriptor.Descriptor(
options=None),
_descriptor.FieldDescriptor(
name='feed_desc',
full_name='paddle_hub.ModuleVar.feed_desc',
full_name='paddle_hub_module.ModuleVar.feed_desc',
index=1,
number=2,
type=11,
......@@ -574,20 +574,20 @@ _MODULEVAR = _descriptor.Descriptor(
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=631,
serialized_end=726,
serialized_start=701,
serialized_end=810,
)
_AUTHINFO = _descriptor.Descriptor(
name='AuthInfo',
full_name='paddle_hub.AuthInfo',
full_name='paddle_hub_module.AuthInfo',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='paddle_version',
full_name='paddle_hub.AuthInfo.paddle_version',
full_name='paddle_hub_module.AuthInfo.paddle_version',
index=0,
number=1,
type=9,
......@@ -603,7 +603,7 @@ _AUTHINFO = _descriptor.Descriptor(
options=None),
_descriptor.FieldDescriptor(
name='hub_version',
full_name='paddle_hub.AuthInfo.hub_version',
full_name='paddle_hub_module.AuthInfo.hub_version',
index=1,
number=2,
type=9,
......@@ -626,20 +626,20 @@ _AUTHINFO = _descriptor.Descriptor(
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=728,
serialized_end=783,
serialized_start=812,
serialized_end=867,
)
_MODULEDESC_SIGN2VARENTRY = _descriptor.Descriptor(
name='Sign2varEntry',
full_name='paddle_hub.ModuleDesc.Sign2varEntry',
full_name='paddle_hub_module.ModuleDesc.Sign2varEntry',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='key',
full_name='paddle_hub.ModuleDesc.Sign2varEntry.key',
full_name='paddle_hub_module.ModuleDesc.Sign2varEntry.key',
index=0,
number=1,
type=9,
......@@ -655,7 +655,7 @@ _MODULEDESC_SIGN2VARENTRY = _descriptor.Descriptor(
options=None),
_descriptor.FieldDescriptor(
name='value',
full_name='paddle_hub.ModuleDesc.Sign2varEntry.value',
full_name='paddle_hub_module.ModuleDesc.Sign2varEntry.value',
index=1,
number=2,
type=11,
......@@ -679,20 +679,20 @@ _MODULEDESC_SIGN2VARENTRY = _descriptor.Descriptor(
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=1003,
serialized_end=1073,
serialized_start=1108,
serialized_end=1185,
)
_MODULEDESC = _descriptor.Descriptor(
name='ModuleDesc',
full_name='paddle_hub.ModuleDesc',
full_name='paddle_hub_module.ModuleDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='name',
full_name='paddle_hub.ModuleDesc.name',
full_name='paddle_hub_module.ModuleDesc.name',
index=0,
number=1,
type=9,
......@@ -708,7 +708,7 @@ _MODULEDESC = _descriptor.Descriptor(
options=None),
_descriptor.FieldDescriptor(
name='sign2var',
full_name='paddle_hub.ModuleDesc.sign2var',
full_name='paddle_hub_module.ModuleDesc.sign2var',
index=1,
number=2,
type=11,
......@@ -724,7 +724,7 @@ _MODULEDESC = _descriptor.Descriptor(
options=None),
_descriptor.FieldDescriptor(
name='return_numpy',
full_name='paddle_hub.ModuleDesc.return_numpy',
full_name='paddle_hub_module.ModuleDesc.return_numpy',
index=2,
number=3,
type=8,
......@@ -740,7 +740,7 @@ _MODULEDESC = _descriptor.Descriptor(
options=None),
_descriptor.FieldDescriptor(
name='contain_assets',
full_name='paddle_hub.ModuleDesc.contain_assets',
full_name='paddle_hub_module.ModuleDesc.contain_assets',
index=3,
number=4,
type=8,
......@@ -756,7 +756,7 @@ _MODULEDESC = _descriptor.Descriptor(
options=None),
_descriptor.FieldDescriptor(
name='auth_info',
full_name='paddle_hub.ModuleDesc.auth_info',
full_name='paddle_hub_module.ModuleDesc.auth_info',
index=4,
number=5,
type=11,
......@@ -772,7 +772,7 @@ _MODULEDESC = _descriptor.Descriptor(
options=None),
_descriptor.FieldDescriptor(
name='extra_info',
full_name='paddle_hub.ModuleDesc.extra_info',
full_name='paddle_hub_module.ModuleDesc.extra_info',
index=5,
number=6,
type=11,
......@@ -797,8 +797,8 @@ _MODULEDESC = _descriptor.Descriptor(
syntax='proto3',
extension_ranges=[],
oneofs=[],
serialized_start=786,
serialized_end=1073,
serialized_start=870,
serialized_end=1185,
)
_KVDATA_KEYTYPEENTRY.fields_by_name['value'].enum_type = _DATATYPE
......@@ -838,7 +838,7 @@ KVData = _reflection.GeneratedProtocolMessageType(
dict(
DESCRIPTOR=_KVDATA_KEYTYPEENTRY,
__module__='module_desc_pb2'
# @@protoc_insertion_point(class_scope:paddle_hub.KVData.KeyTypeEntry)
# @@protoc_insertion_point(class_scope:paddle_hub_module.KVData.KeyTypeEntry)
)),
DataEntry=_reflection.GeneratedProtocolMessageType(
'DataEntry',
......@@ -846,11 +846,11 @@ KVData = _reflection.GeneratedProtocolMessageType(
dict(
DESCRIPTOR=_KVDATA_DATAENTRY,
__module__='module_desc_pb2'
# @@protoc_insertion_point(class_scope:paddle_hub.KVData.DataEntry)
# @@protoc_insertion_point(class_scope:paddle_hub_module.KVData.DataEntry)
)),
DESCRIPTOR=_KVDATA,
__module__='module_desc_pb2'
# @@protoc_insertion_point(class_scope:paddle_hub.KVData)
# @@protoc_insertion_point(class_scope:paddle_hub_module.KVData)
))
_sym_db.RegisterMessage(KVData)
_sym_db.RegisterMessage(KVData.KeyTypeEntry)
......@@ -862,7 +862,7 @@ FlexibleData = _reflection.GeneratedProtocolMessageType(
dict(
DESCRIPTOR=_FLEXIBLEDATA,
__module__='module_desc_pb2'
# @@protoc_insertion_point(class_scope:paddle_hub.FlexibleData)
# @@protoc_insertion_point(class_scope:paddle_hub_module.FlexibleData)
))
_sym_db.RegisterMessage(FlexibleData)
......@@ -872,7 +872,7 @@ FeedDesc = _reflection.GeneratedProtocolMessageType(
dict(
DESCRIPTOR=_FEEDDESC,
__module__='module_desc_pb2'
# @@protoc_insertion_point(class_scope:paddle_hub.FeedDesc)
# @@protoc_insertion_point(class_scope:paddle_hub_module.FeedDesc)
))
_sym_db.RegisterMessage(FeedDesc)
......@@ -882,7 +882,7 @@ FetchDesc = _reflection.GeneratedProtocolMessageType(
dict(
DESCRIPTOR=_FETCHDESC,
__module__='module_desc_pb2'
# @@protoc_insertion_point(class_scope:paddle_hub.FetchDesc)
# @@protoc_insertion_point(class_scope:paddle_hub_module.FetchDesc)
))
_sym_db.RegisterMessage(FetchDesc)
......@@ -892,7 +892,7 @@ ModuleVar = _reflection.GeneratedProtocolMessageType(
dict(
DESCRIPTOR=_MODULEVAR,
__module__='module_desc_pb2'
# @@protoc_insertion_point(class_scope:paddle_hub.ModuleVar)
# @@protoc_insertion_point(class_scope:paddle_hub_module.ModuleVar)
))
_sym_db.RegisterMessage(ModuleVar)
......@@ -902,7 +902,7 @@ AuthInfo = _reflection.GeneratedProtocolMessageType(
dict(
DESCRIPTOR=_AUTHINFO,
__module__='module_desc_pb2'
# @@protoc_insertion_point(class_scope:paddle_hub.AuthInfo)
# @@protoc_insertion_point(class_scope:paddle_hub_module.AuthInfo)
))
_sym_db.RegisterMessage(AuthInfo)
......@@ -916,11 +916,11 @@ ModuleDesc = _reflection.GeneratedProtocolMessageType(
dict(
DESCRIPTOR=_MODULEDESC_SIGN2VARENTRY,
__module__='module_desc_pb2'
# @@protoc_insertion_point(class_scope:paddle_hub.ModuleDesc.Sign2varEntry)
# @@protoc_insertion_point(class_scope:paddle_hub_module.ModuleDesc.Sign2varEntry)
)),
DESCRIPTOR=_MODULEDESC,
__module__='module_desc_pb2'
# @@protoc_insertion_point(class_scope:paddle_hub.ModuleDesc)
# @@protoc_insertion_point(class_scope:paddle_hub_module.ModuleDesc)
))
_sym_db.RegisterMessage(ModuleDesc)
_sym_db.RegisterMessage(ModuleDesc.Sign2varEntry)
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
......@@ -16,7 +16,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle.fluid.framework import Variable
from paddle_hub.utils import to_list
from paddle_hub.tools.utils import to_list
class Signature:
......@@ -53,21 +53,6 @@ class Signature:
self.feed_names = feed_names
self.fetch_names = fetch_names
def get_name(self):
return self.name
def get_inputs(self):
return self.inputs
def get_outputs(self):
return self.outputs
def get_feed_names(self):
return self.feed_names
def get_fetch_names(self):
return self.fetch_names
def create_signature(name="default",
inputs=[],
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import utils
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
......@@ -105,7 +105,7 @@ def download_and_uncompress(url, save_name=None):
for file_name in file_names:
tar.extract(file_name, dirname)
return module_name, module_dir
return module_dir
if __name__ == "__main__":
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
......@@ -16,6 +16,7 @@ from __future__ import print_function
from __future__ import division
from __future__ import print_function
import logging
import math
class Logger:
......@@ -50,7 +51,7 @@ class Logger:
return self.logLevel
def __call__(self, type, msg):
def _get_log_arr(msg):
def _get_log_arr(msg, len_limit=30):
ph = Logger.PLACEHOLDER
lrspace = 2
lc = rc = " " * lrspace
......@@ -59,6 +60,22 @@ class Logger:
if len(msgarr) == 1:
return msgarr
temp_arr = msgarr
msgarr = []
for text in temp_arr:
if len(text) > len_limit:
for i in range(math.ceil(len(text) / len_limit)):
if i == 0:
msgarr.append(text[0:len_limit])
else:
fr = len_limit + (len_limit - 4) * (i - 1)
to = len_limit + (len_limit - 4) * i
if to > len(text):
to = len(text)
msgarr.append("===>" + text[fr:to])
else:
msgarr.append(text)
maxlen = -1
for text in msgarr:
if len(text) > maxlen:
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
......@@ -15,11 +15,12 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle_hub import module_desc_pb2
from paddle_hub.utils import from_pyobj_to_flexible_data, from_flexible_data_to_pyobj
from paddle_hub.logger import logger
from paddle_hub.module import module_desc_pb2
from paddle_hub.tools.utils import from_pyobj_to_flexible_data, from_flexible_data_to_pyobj
from paddle_hub.tools.logger import logger
import paddle
import paddle.fluid as fluid
import copy
def get_variable_info(var):
......@@ -116,3 +117,110 @@ def from_flexible_data_to_param(flexible_data):
(clip_type, clip_norm, group_name))
return param
def connect_program(pre_program, next_program, input_dict=None):
def _copy_vars_and_ops_in_blocks(from_block, to_block):
for var in from_block.vars:
var = from_block.var(var)
var_info = copy.deepcopy(get_variable_info(var))
if isinstance(var, fluid.framework.Parameter):
to_block.create_parameter(**var_info)
else:
to_block.create_var(**var_info)
for op in from_block.ops:
op_info = {
'type': op.type,
'inputs': {
input: [block.var(var) for var in op.input(input)]
for input in op.input_names
},
'outputs': {
output: [block.var(var) for var in op.output(output)]
for output in op.output_names
},
'attrs': copy.deepcopy(op.all_attrs())
}
to_block.append_op(**op_info)
assert isinstance(pre_program,
fluid.Program), "pre_program should be fluid.Program"
assert isinstance(next_program,
fluid.Program), "next_program should be fluid.Program"
new_program = pre_program.clone()
if input_dict:
assert isinstance(
input_dict,
dict), "the input_dict should be a dict with string-Variable pair"
for key, var in input_dict.items():
assert isinstance(
var, fluid.framework.Variable
), "the input_dict should be a dict with string-Variable pair"
var_info = copy.deepcopy(get_variable_info(var))
input_var = new_program.global_block().create_var(**var_info)
output_var = next_program.global_block().var(key)
var_info = copy.deepcopy(get_variable_info(output_var))
output_var = new_program.global_block().create_var(**var_info)
new_program.global_block().append_op(
type="assign",
inputs={'X': input_var},
outputs={'Out': output_var})
block_map = {0: 0}
logger.info("start to connect program")
for index, block in enumerate(next_program.blocks):
if block.idx == 0:
_copy_vars_and_ops_in_blocks(block, new_program.global_block())
else:
block_map[index] = len(new_program.blocks)
logger.info(
"block_%d in next_program merge into block_%d in pre_program" %
(index, block_map[index]))
new_block = new_program._create_block(
parent_idx=block_map[block.parent_idx])
_copy_vars_and_ops_in_blocks(block, new_block)
logger.info("end of connect program")
return new_program
def remove_feed_fetch_op(program):
""" remove feed and fetch operator and variable for fine-tuning
"""
logger.info("remove feed fetch op")
block = program.global_block()
need_to_remove_op_index = []
for i, op in enumerate(block.ops):
if op.type == "feed" or op.type == "fetch":
need_to_remove_op_index.append(i)
for index in need_to_remove_op_index[::-1]:
block._remove_op(index)
need_to_remove_var = []
for var in block.vars:
if var.endswith("feed"):
need_to_remove_var.append(var)
if var.endswith("fetch"):
need_to_remove_var.append(var)
for var in need_to_remove_var:
block._remove_var(var)
program.desc.flush()
def set_parameter_trainable(program, trainable=True):
for param in program.global_block().iter_parameters():
param.trainable = trainable
def set_parameter_regularization(program, regularization):
pass
def set_op_attr(program, is_test=False):
for block in program.blocks:
for op in block.ops:
if op.has_attr("is_test"):
op._set_attr("is_test", is_test)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
......@@ -17,8 +17,8 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle_hub import module_desc_pb2
from paddle_hub.logger import logger
from paddle_hub.module import module_desc_pb2
from paddle_hub.tools.logger import logger
import paddle
import paddle.fluid as fluid
import os
......
......@@ -12,4 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" Store PaddleHub version string """
__version__ = "0.1.0-dev"
hub_version = "0.1.0.dev"
proto_version = "0.1.0"
#/bin/bash
protoc -I=../paddle_hub/module --python_out=../paddle_hub/module ../paddle_hub/module/module_desc.proto
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册