提交 4f92105e 编写于 作者: W wuzewu

modify code directory

上级 09dc8dcd
...@@ -12,17 +12,10 @@ ...@@ -12,17 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import from . import module
from __future__ import division from . import tools
from __future__ import print_function from . import data_process
from .module.module import Module, create_module
import paddle.fluid as fluid from .module.signature import Signature, create_signature
from .tools.logger import logger
from paddle_hub.module import Module from .tools.paddle_helper import connect_program
from paddle_hub.module import ModuleConfig
from paddle_hub.module import ModuleUtils
from paddle_hub.module import create_module
from paddle_hub.downloader import download_and_uncompress
from paddle_hub.signature import create_signature
from paddle_hub.version import __version__
connect_program = ModuleUtils.connect_program
#/bin/bash
protoc -I=./ --python_out=./ module_desc.proto
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.fluid as fluid
import numpy as np
import tempfile
import os
import copy
from collections import defaultdict
from paddle_hub.downloader import download_and_uncompress
from paddle_hub import module_desc_pb2
from paddle_hub.logger import logger
from paddle_hub.signature import Signature
from paddle_hub.utils import to_list, mkdir, from_pyobj_to_flexible_data, from_flexible_data_to_pyobj
from paddle_hub.paddle_helper import from_param_to_flexible_data, get_variable_info, from_flexible_data_to_param
from paddle_hub.version import __version__
__all__ = ["Module", "ModuleConfig", "ModuleUtils"]
# paddle hub module dir name
ASSETS_DIRNAME = "assets"
META_DIRNAME = "meta"
MODEL_DIRNAME = "model"
# paddle hub module serialze file name
DICT_FILENAME = "vocab.txt"
PARAM_FILENAME = "param.pkl"
MODULE_DESC_PBNAME = "module_desc.pb"
# paddle hub var prefix
HUB_VAR_PREFIX = "@HUB@"
class Module(object):
"""
Core object of PaddleHub
"""
def __init__(self, module_url=None, module_dir=None):
if module_url == None and module_dir == None:
raise Exception("Module:module_url and module_dir are None!")
self.module_dir = ""
self.module_name = ""
# donwload module
if module_url is not None and module_url.startswith("http"):
# if it's remote url link, then download and uncompress it
self.module_name, self.module_dir = download_and_uncompress(
module_url)
#TODO(ZeyuChen): check url link is valid url
elif module_dir is not None:
# otherwise it's local path, no need to deal with it
self.module_dir = module_dir
# use the path name as module name by default
self.module_name = module_dir.split("/")[-1]
#TODO(ZeyuChen) add more check about loading module from local path
def _process_parameter(self):
global_block = self.inference_program.global_block()
param_attrs = self.config.desc.extra_info.map.data['param_attrs']
for key, param_attr in param_attrs.map.data.items():
param = from_flexible_data_to_param(param_attr)
param['name'] = HUB_VAR_PREFIX + key
if (param['name'] not in global_block.vars):
continue
var = global_block.var(param['name'])
global_block.create_parameter(
**param,
shape=var.shape,
dtype=var.dtype,
type=var.type,
lod_level=var.lod_level,
error_clip=var.error_clip,
stop_gradient=var.stop_gradient,
is_data=var.is_data)
def _process_variable_info(self):
var_infos = self.config.desc.extra_info.map.data['var_infos']
for var_info in var_infos.map.data:
idx = from_flexible_data_to_pyobj(
var_infos.map.data[var_info].map.data['block_id'])
stop_gradient = from_flexible_data_to_pyobj(
var_infos.map.data[var_info].map.data['stop_gradient'])
block = self.inference_program.blocks[idx]
var_name = HUB_VAR_PREFIX + var_info
if var_name in block.vars:
var = block.vars[var_name]
var.stop_gradient = stop_gradient
def __call__(self, sign_name="default", trainable=False):
""" Call default signature and return results
"""
def _set_param_trainable(program, trainable=False):
for param in program.global_block().iter_parameters():
param.trainable = trainable
def _process_op_attr(program, is_test=False):
for op in program.global_block().ops:
if op.has_attr("is_test"):
op._set_attr("is_test", is_test)
def _process_input_output_key(module_desc, signature):
signature = module_desc.sign2var[signature]
feed_dict = {}
fetch_dict = {}
for index, feed in enumerate(signature.feed_desc):
if feed.alias != "":
feed_dict[feed.alias] = feed.var_name
feed_dict[index] = feed.var_name
for index, fetch in enumerate(signature.fetch_desc):
if fetch.alias != "":
fetch_dict[fetch.alias] = fetch.var_name
fetch_dict[index] = fetch.var_name
return feed_dict, fetch_dict
self.config = ModuleConfig(self.module_dir)
self.config.load()
# load paddle inference model
place = fluid.CPUPlace()
model_dir = os.path.join(self.module_dir, MODEL_DIRNAME)
self.exe = fluid.Executor(fluid.CPUPlace())
self.inference_program, self.feed_target_names, self.fetch_targets = fluid.io.load_inference_model(
model_dir, executor=self.exe)
feed_dict, fetch_dict = _process_input_output_key(
self.config.desc, sign_name)
# remove feed fetch operator and variable
ModuleUtils.remove_feed_fetch_op(self.inference_program)
logger.info("**feed_target_names**\n{}".format(self.feed_target_names))
logger.info("**fetch_targets**\n{}".format(self.fetch_targets))
self._process_parameter()
self._process_variable_info()
_process_op_attr(program=self.inference_program, is_test=False)
_set_param_trainable(
program=self.inference_program, trainable=trainable)
for key, value in feed_dict.items():
var = self.inference_program.global_block().var(HUB_VAR_PREFIX +
value)
feed_dict[key] = var
for key, value in fetch_dict.items():
var = self.inference_program.global_block().var(HUB_VAR_PREFIX +
value)
fetch_dict[key] = var
return feed_dict, fetch_dict, self.inference_program
def get_inference_program(self):
return self.inference_program
# for text sequence input, transform to lod tensor as paddle graph's input
def _preprocess_input(self, inputs):
# words id mapping and dealing with oov
# transform to lod tensor
seq = []
for s in inputs:
seq.append(self._word_id_mapping(s))
lod_tensor = self.seq2lod_tensor(seq)
return lod_tensor
def seq2lod_tensor(self, seq_inputs, place=fluid.CPUPlace()):
""" sequence to lod tensor, need to determine which space"""
lod = []
lod.append([])
for s in seq_inputs:
# generate lod
lod[0].append(len(s))
# print("seq", seq_inputs)
# print("lod", lod)
lod_tensor = fluid.create_lod_tensor(seq_inputs, lod, place)
return lod_tensor
def _word_id_mapping(self, inputs):
word_dict = self.config.get_assets_vocab()
return list(map(lambda x: word_dict[x], inputs))
class ModuleConfig(object):
def __init__(self, module_dir, module_name=None):
# generate model desc protobuf
self.module_dir = module_dir
self.desc = module_desc_pb2.ModuleDesc()
if module_name == None:
module_name = module_dir.split("/")[-1]
# initialize module config default value
self.desc.name = module_name
self.desc.contain_assets = True
self.desc.return_numpy = False
# init dict
self.dict = defaultdict(int)
self.dict.setdefault(0)
def get_assets_vocab(self):
""" Return dictionary in Module"""
return self.dict
def load(self):
"""
Load module config from module directory.
"""
#TODO(ZeyuChen): check module_desc.pb exsitance
with open(ModuleConfig.module_desc_path(self.module_dir), "rb") as fi:
self.desc.ParseFromString(fi.read())
if self.desc.contain_assets:
# load assets
word_id = 0
with open(ModuleConfig.assets_dict_path(self.module_dir)) as fi:
words = fi.readlines()
#TODO(ZeyuChen) check whether word id is duplicated and valid
for line in fi:
w, w_id = line.split()
self.dict[w] = int(w_id)
def return_numpy(self):
"""Return numpy or not according to the proto config.
"""
return self.desc.return_numpy
def save_dict(self, word_dict, dict_name=DICT_FILENAME):
""" Save dictionary for NLP module
"""
for w in word_dict:
self.dict[w] = word_dict[w]
@staticmethod
def module_desc_path(module_dir):
return os.path.join(module_dir, MODULE_DESC_PBNAME)
@staticmethod
def assets_dict_path(module_dir):
assets_path = os.path.join(module_dir, ASSETS_DIRNAME)
mkdir(assets_path)
return os.path.join(assets_path, DICT_FILENAME)
@staticmethod
def meta_param_path(module_dir):
meta_path = os.path.join(module_dir, META_DIRNAME)
mkdir(meta_path)
return os.path.join(meta_path, PARAM_FILENAME)
def create_module(sign_arr, module_dir=None, word_dict=None, exe=None):
""" Create a module from main program
"""
assert sign_arr, "signature array should not be None"
# check all variable
sign_arr = to_list(sign_arr)
program = sign_arr[0].get_inputs()[0].block.program
feeded_var_names = set()
target_vars = set()
for sign in sign_arr:
assert isinstance(sign,
Signature), "sign_arr should be list of Signature"
for input in sign.get_inputs():
feeded_var_names.add(input.name)
_tmp_program = input.block.program
assert program == _tmp_program, "all the variable should come from the same program"
for output in sign.get_outputs():
target_vars.add(output)
_tmp_program = output.block.program
assert program == _tmp_program, "all the variable should come from the same program"
# create module path for saving
if module_dir is None:
module_dir = os.path.join(".", "hub_module")
mkdir(module_dir)
# create module pb
module_desc = module_desc_pb2.ModuleDesc()
module_desc.auth_info.hub_version = __version__
module_desc.auth_info.paddle_version = paddle.__version__
logger.info("hub version is %s" % __version__)
logger.info("paddle version is %s" % paddle.__version__)
# save asset
if word_dict is None:
module_desc.contain_assets = False
else:
module_desc.contain_assets = True
with open(ModuleConfig.assets_dict_path(module_dir), "w") as fo:
for w in word_dict:
w_id = word_dict[w]
fo.write("{}\t{}\n".format(w, w_id))
# save fluid Parameter
extra_info = module_desc.extra_info
extra_info.type = module_desc_pb2.MAP
param_attrs = extra_info.map.data['param_attrs']
param_attrs.type = module_desc_pb2.MAP
for param in program.global_block().iter_parameters():
param_attr = param_attrs.map.data[param.name]
from_param_to_flexible_data(param, param_attr)
# save Variable Info
var_infos = extra_info.map.data['var_infos']
var_infos.type = module_desc_pb2.MAP
for block in program.blocks:
for var in block.vars.values():
var_info = var_infos.map.data[var.name]
var_info.type = module_desc_pb2.MAP
from_pyobj_to_flexible_data(var.stop_gradient,
var_info.map.data['stop_gradient'])
from_pyobj_to_flexible_data(block.idx,
var_info.map.data['block_id'])
# save signarture info
sign_map = module_desc.sign2var
for sign in sign_arr:
if sign.get_name() in sign_map:
raise "Error! sign_arr contains repeat signatrue %s" % sign
var = sign_map[sign.get_name()]
feed_desc = var.feed_desc
fetch_desc = var.fetch_desc
feed_names = sign.get_feed_names()
fetch_names = sign.get_fetch_names()
for index, input in enumerate(sign.get_inputs()):
feed_var = feed_desc.add()
feed_var.var_name = input.name
feed_var.alias = feed_names[index]
for index, output in enumerate(sign.get_outputs()):
fetch_var = fetch_desc.add()
fetch_var.var_name = output.name
fetch_var.alias = fetch_names[index]
# save inference program
program = program.clone()
if not exe:
place = fluid.CPUPlace()
exe = fluid.Executor(place=place)
save_model_dir = os.path.join(module_dir, "model")
mkdir(save_model_dir)
fluid.io.save_inference_model(
save_model_dir,
feeded_var_names=list(feeded_var_names),
target_vars=list(target_vars),
main_program=program,
executor=exe)
with open(os.path.join(save_model_dir, "__model__"), "rb") as file:
program_desc_str = file.read()
rename_program = fluid.framework.Program.parse_from_string(
program_desc_str)
varlist = {
var: block
for block in rename_program.blocks for var in block.vars
if HUB_VAR_PREFIX not in var
}
for var, block in varlist.items():
old_name = var
new_name = HUB_VAR_PREFIX + old_name
block._rename_var(old_name, new_name)
mkdir(save_model_dir)
with open(os.path.join(save_model_dir, "__model__"), "wb") as f:
f.write(rename_program.desc.serialize_to_string())
for file in os.listdir(save_model_dir):
if (file == "__model__" or HUB_VAR_PREFIX in file):
continue
os.rename(
os.path.join(save_model_dir, file),
os.path.join(save_model_dir, HUB_VAR_PREFIX + file))
# Serialize module_desc pb
module_pb = module_desc.SerializeToString()
with open(ModuleConfig.module_desc_path(module_dir), "wb") as f:
f.write(module_pb)
class ModuleUtils(object):
def __init__(self):
pass
@staticmethod
def connect_program(pre_program, next_program, input_dict=None):
def _copy_vars_and_ops_in_blocks(from_block, to_block):
for var in from_block.vars:
var = from_block.var(var)
var_info = copy.deepcopy(get_variable_info(var))
if isinstance(var, fluid.framework.Parameter):
to_block.create_parameter(**var_info)
else:
to_block.create_var(**var_info)
for op in from_block.ops:
op_info = {
'type': op.type,
'inputs': {
input: [block.var(var) for var in op.input(input)]
for input in op.input_names
},
'outputs': {
output: [block.var(var) for var in op.output(output)]
for output in op.output_names
},
'attrs': copy.deepcopy(op.all_attrs())
}
to_block.append_op(**op_info)
assert isinstance(pre_program,
fluid.Program), "pre_program should be fluid.Program"
assert isinstance(next_program,
fluid.Program), "next_program should be fluid.Program"
new_program = pre_program.clone()
if input_dict:
assert isinstance(
input_dict, dict
), "the input_dict should be a dict with string-Variable pair"
for key, var in input_dict.items():
assert isinstance(
var, fluid.framework.Variable
), "the input_dict should be a dict with string-Variable pair"
var_info = copy.deepcopy(get_variable_info(var))
input_var = new_program.global_block().create_var(**var_info)
output_var = next_program.global_block().var(key)
var_info = copy.deepcopy(get_variable_info(output_var))
output_var = new_program.global_block().create_var(**var_info)
new_program.global_block().append_op(
type="assign",
inputs={'X': input_var},
outputs={'Out': output_var})
block_map = {0: 0}
logger.info("start to connect program")
for index, block in enumerate(next_program.blocks):
if block.idx == 0:
_copy_vars_and_ops_in_blocks(block, new_program.global_block())
else:
block_map[index] = len(new_program.blocks)
logger.info(
"block_%d in next_program merge into block_%d in pre_program"
% (index, block_map[index]))
new_block = new_program._create_block(
parent_idx=block_map[block.parent_idx])
_copy_vars_and_ops_in_blocks(block, new_block)
logger.info("end of connect program")
return new_program
@staticmethod
def remove_feed_fetch_op(program):
""" remove feed and fetch operator and variable for fine-tuning
"""
logger.info("remove feed fetch op")
block = program.global_block()
need_to_remove_op_index = []
for i, op in enumerate(block.ops):
if op.type == "feed" or op.type == "fetch":
need_to_remove_op_index.append(i)
for index in need_to_remove_op_index[::-1]:
block._remove_op(index)
# TODO(wuzewu): get feed and fetch var by other way
block._remove_var(HUB_VAR_PREFIX + "feed")
block._remove_var(HUB_VAR_PREFIX + "fetch")
program.desc.flush()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import module
from . import signature
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle_hub.tools import utils
from paddle_hub.tools.logger import logger
from paddle_hub.tools import downloader
from paddle_hub.tools import paddle_helper
from paddle_hub.module import module_desc_pb2
from paddle_hub.module.signature import Signature, create_signature
from paddle_hub import version
import os
import paddle
import paddle.fluid as fluid
__all__ = ['Module', 'create_module']
def create_module(sign_arr, module_dir, exe=None):
sign_arr = utils.to_list(sign_arr)
module = Module(signatures=sign_arr)
module.serialize_to_path(path=module_dir, exe=exe)
# paddle hub module dir name
ASSETS_DIRNAME = "assets"
MODEL_DIRNAME = "model"
MODULE_DESC_PBNAME = "module_desc.pb"
# paddle hub var prefix
HUB_VAR_PREFIX = "@HUB@"
class ModuleWrapper:
def __init__(self, module, name):
self.module = module
self.name = name
def __call__(self, trainable=False):
return self.module(self.name, trainable)
class ModuleHelper:
def __init__(self, module_dir):
self.module_dir = module_dir
def module_desc_path(self):
return os.path.join(self.module_dir, MODULE_DESC_PBNAME)
def model_path(self):
return os.path.join(self.module_dir, MODEL_DIRNAME)
class Module:
def __init__(self, url=None, module_dir=None, signatures=None, name=None):
if not name:
name = "HubModule"
self.name = name
self.desc = module_desc_pb2.ModuleDesc()
self.program = None
self.assets = []
self.helper = None
self.signatures = {}
if url:
self._init_with_url(url=url)
elif module_dir:
self._init_with_module_file(module_dir=module_dir)
elif signatures:
self._init_with_signature(signatures=signatures)
else:
raise "Error! HubModule Can't init with nothing"
def _init_with_url(self, url):
utils.check_url_valid(url)
module_dir = downloader.download_and_uncompress(module_url)
self._init_with_module_file(module_dir)
def _init_with_module_file(self, module_dir):
self.helper = ModuleHelper(module_dir)
with open(self.helper.module_desc_path(), "rb") as fi:
self.desc.ParseFromString(fi.read())
exe = fluid.Executor(fluid.CPUPlace())
self.program, _, _ = fluid.io.load_inference_model(
self.helper.model_path(), executor=exe)
self._recovery_parameter(self.program)
self._recover_variable_info(self.program)
inputs = []
outputs = []
feed_names = []
fetch_names = []
for sign, module_var in self.desc.sign2var.items():
for var in module_var.feed_desc:
variable = self.program.global_block().vars[var.var_name]
inputs.append(variable)
feed_names.append(var.alias)
for var in module_var.fetch_desc:
variable = self.program.global_block().vars[var.var_name]
outputs.append(variable)
fetch_names.append(var.alias)
self.signatures[sign] = create_signature(
sign,
inputs=inputs,
outputs=outputs,
feed_names=feed_names,
fetch_names=fetch_names)
self._generate_sign_attr()
def _init_with_signature(self, signatures):
self._process_signatures(signatures)
self._check_signatures()
self._generate_desc()
self._generate_sign_attr()
def _init_with_program(self, program):
pass
def _process_signatures(self, signatures):
self.signatures = {}
self.program = signatures[0].inputs[0].block.program
for sign in signatures:
if sign.name in self.signatures:
raise "Error! signature array contains repeat signatrue %s" % sign
self.signatures[sign.name] = sign
def _recovery_parameter(self, program):
global_block = self.program.global_block()
param_attrs = self.desc.extra_info.map.data['param_attrs']
for key, param_attr in param_attrs.map.data.items():
param = paddle_helper.from_flexible_data_to_param(param_attr)
param['name'] = HUB_VAR_PREFIX + key
if (param['name'] not in global_block.vars):
continue
var = global_block.var(param['name'])
global_block.create_parameter(
**param,
shape=var.shape,
dtype=var.dtype,
type=var.type,
lod_level=var.lod_level,
error_clip=var.error_clip,
stop_gradient=var.stop_gradient,
is_data=var.is_data)
def _recover_variable_info(self, program):
var_infos = self.desc.extra_info.map.data['var_infos']
for var_info in var_infos.map.data:
idx = utils.from_flexible_data_to_pyobj(
var_infos.map.data[var_info].map.data['block_id'])
stop_gradient = utils.from_flexible_data_to_pyobj(
var_infos.map.data[var_info].map.data['stop_gradient'])
block = program.blocks[idx]
var_name = HUB_VAR_PREFIX + var_info
if var_name in block.vars:
var = block.vars[var_name]
var.stop_gradient = stop_gradient
def _generate_sign_attr(self):
self._check_signatures()
for sign in self.signatures:
self.__dict__[sign] = ModuleWrapper(self, sign)
def _generate_desc(self):
# save fluid Parameter
extra_info = self.desc.extra_info
extra_info.type = module_desc_pb2.MAP
param_attrs = extra_info.map.data['param_attrs']
param_attrs.type = module_desc_pb2.MAP
for param in self.program.global_block().iter_parameters():
param_attr = param_attrs.map.data[param.name]
paddle_helper.from_param_to_flexible_data(param, param_attr)
# save Variable Info
var_infos = extra_info.map.data['var_infos']
var_infos.type = module_desc_pb2.MAP
for block in self.program.blocks:
for var in block.vars.values():
var_info = var_infos.map.data[var.name]
var_info.type = module_desc_pb2.MAP
utils.from_pyobj_to_flexible_data(
var.stop_gradient, var_info.map.data['stop_gradient'])
utils.from_pyobj_to_flexible_data(block.idx,
var_info.map.data['block_id'])
# save signarture info
for key, sign in self.signatures.items():
var = self.desc.sign2var[sign.name]
feed_desc = var.feed_desc
fetch_desc = var.fetch_desc
feed_names = sign.feed_names
fetch_names = sign.fetch_names
for index, input in enumerate(sign.inputs):
feed_var = feed_desc.add()
feed_var.var_name = HUB_VAR_PREFIX + input.name
feed_var.alias = feed_names[index]
for index, output in enumerate(sign.outputs):
fetch_var = fetch_desc.add()
fetch_var.var_name = HUB_VAR_PREFIX + output.name
fetch_var.alias = fetch_names[index]
def __call__(self, sign_name, trainable=False):
assert sign_name in self.signatures, "module did not have a signature with name %s" % sign_name
signature = self.signatures[sign_name]
program = self.program.clone()
paddle_helper.remove_feed_fetch_op(program)
paddle_helper.set_parameter_trainable(program, trainable)
paddle_helper.set_op_attr(program, is_test=False)
self._recovery_parameter(program)
self._recover_variable_info(program)
feed_dict = {}
fetch_dict = {}
for index, var in enumerate(signature.inputs):
feed_dict[index] = program.global_block().var(var.name)
key = signature.feed_names[index]
if key:
feed_dict[key] = program.global_block().var(var.name)
for index, var in enumerate(signature.outputs):
fetch_dict[index] = program.global_block().var(var.name)
key = signature.fetch_names[index]
if key:
fetch_dict[key] = program.global_block().var(var.name)
return feed_dict, fetch_dict, program
def preprocess(self):
pass
def postprocess(self):
pass
def parameters(self):
pass
def parameter_attrs(self):
pass
def _check_signatures(self):
assert self.signatures, "signature array should not be None"
for key, sign in self.signatures.items():
assert isinstance(sign,
Signature), "sign_arr should be list of Signature"
for input in sign.inputs:
_tmp_program = input.block.program
assert self.program == _tmp_program, "all the variable should come from the same program"
for output in sign.outputs:
_tmp_program = output.block.program
assert self.program == _tmp_program, "all the variable should come from the same program"
def serialize_to_path(self, path=None, exe=None):
self._check_signatures()
self._generate_desc()
# create module path for saving
if path is None:
path = os.path.join(".", self.name)
self.helper = ModuleHelper(path)
utils.mkdir(self.helper.module_dir)
# create module pb
module_desc = module_desc_pb2.ModuleDesc()
logger.info("hub version is %s" % version.hub_version)
logger.info("proto version is %s" % version.proto_version)
logger.info("paddle version is %s" % paddle.__version__)
for asset in self.assets:
pass
feeded_var_names = [
input.name for key, sign in self.signatures.items()
for input in sign.inputs
]
target_vars = [
output for key, sign in self.signatures.items()
for output in sign.outputs
]
feeded_var_names = list(set(feeded_var_names))
target_vars = list(set(target_vars))
# save inference program
program = self.program.clone()
if not exe:
place = fluid.CPUPlace()
exe = fluid.Executor(place=place)
utils.mkdir(self.helper.model_path())
fluid.io.save_inference_model(
self.helper.model_path(),
feeded_var_names=list(feeded_var_names),
target_vars=list(target_vars),
main_program=program,
executor=exe)
with open(os.path.join(self.helper.model_path(), "__model__"),
"rb") as file:
program_desc_str = file.read()
rename_program = fluid.framework.Program.parse_from_string(
program_desc_str)
varlist = {
var: block
for block in rename_program.blocks for var in block.vars
if HUB_VAR_PREFIX not in var
}
for var, block in varlist.items():
old_name = var
new_name = HUB_VAR_PREFIX + old_name
block._rename_var(old_name, new_name)
utils.mkdir(self.helper.model_path())
with open(
os.path.join(self.helper.model_path(), "__model__"),
"wb") as f:
f.write(rename_program.desc.serialize_to_string())
for file in os.listdir(self.helper.model_path()):
if (file == "__model__" or HUB_VAR_PREFIX in file):
continue
os.rename(
os.path.join(self.helper.model_path(), file),
os.path.join(self.helper.model_path(),
HUB_VAR_PREFIX + file))
# Serialize module_desc pb
module_pb = self.desc.SerializeToString()
with open(self.helper.module_desc_path(), "wb") as f:
f.write(module_pb)
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
syntax = "proto3"; syntax = "proto3";
option optimize_for = LITE_RUNTIME; option optimize_for = LITE_RUNTIME;
package paddle_hub; package paddle_hub_module;
enum DataType { enum DataType {
NONE = 0; NONE = 0;
...@@ -72,7 +72,7 @@ message AuthInfo { ...@@ -72,7 +72,7 @@ message AuthInfo {
string hub_version = 2; string hub_version = 2;
} }
// A Hub Module is stored in a directory with a file 'paddlehub.pb' // A Hub Module is stored in a directory with a file 'module_desc.pb'
// containing a serialized protocol message of this type. The further contents // containing a serialized protocol message of this type. The further contents
// of the directory depend on the storage format described by the message. // of the directory depend on the storage format described by the message.
message ModuleDesc { message ModuleDesc {
......
...@@ -16,7 +16,7 @@ from __future__ import absolute_import ...@@ -16,7 +16,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from paddle.fluid.framework import Variable from paddle.fluid.framework import Variable
from paddle_hub.utils import to_list from paddle_hub.tools.utils import to_list
class Signature: class Signature:
...@@ -53,21 +53,6 @@ class Signature: ...@@ -53,21 +53,6 @@ class Signature:
self.feed_names = feed_names self.feed_names = feed_names
self.fetch_names = fetch_names self.fetch_names = fetch_names
def get_name(self):
return self.name
def get_inputs(self):
return self.inputs
def get_outputs(self):
return self.outputs
def get_feed_names(self):
return self.feed_names
def get_fetch_names(self):
return self.fetch_names
def create_signature(name="default", def create_signature(name="default",
inputs=[], inputs=[],
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import utils
...@@ -105,7 +105,7 @@ def download_and_uncompress(url, save_name=None): ...@@ -105,7 +105,7 @@ def download_and_uncompress(url, save_name=None):
for file_name in file_names: for file_name in file_names:
tar.extract(file_name, dirname) tar.extract(file_name, dirname)
return module_name, module_dir return module_dir
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -16,6 +16,7 @@ from __future__ import print_function ...@@ -16,6 +16,7 @@ from __future__ import print_function
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import logging import logging
import math
class Logger: class Logger:
...@@ -50,7 +51,7 @@ class Logger: ...@@ -50,7 +51,7 @@ class Logger:
return self.logLevel return self.logLevel
def __call__(self, type, msg): def __call__(self, type, msg):
def _get_log_arr(msg): def _get_log_arr(msg, len_limit=30):
ph = Logger.PLACEHOLDER ph = Logger.PLACEHOLDER
lrspace = 2 lrspace = 2
lc = rc = " " * lrspace lc = rc = " " * lrspace
...@@ -59,6 +60,22 @@ class Logger: ...@@ -59,6 +60,22 @@ class Logger:
if len(msgarr) == 1: if len(msgarr) == 1:
return msgarr return msgarr
temp_arr = msgarr
msgarr = []
for text in temp_arr:
if len(text) > len_limit:
for i in range(math.ceil(len(text) / len_limit)):
if i == 0:
msgarr.append(text[0:len_limit])
else:
fr = len_limit + (len_limit - 4) * (i - 1)
to = len_limit + (len_limit - 4) * i
if to > len(text):
to = len(text)
msgarr.append("===>" + text[fr:to])
else:
msgarr.append(text)
maxlen = -1 maxlen = -1
for text in msgarr: for text in msgarr:
if len(text) > maxlen: if len(text) > maxlen:
......
...@@ -15,11 +15,12 @@ ...@@ -15,11 +15,12 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from paddle_hub import module_desc_pb2 from paddle_hub.module import module_desc_pb2
from paddle_hub.utils import from_pyobj_to_flexible_data, from_flexible_data_to_pyobj from paddle_hub.tools.utils import from_pyobj_to_flexible_data, from_flexible_data_to_pyobj
from paddle_hub.logger import logger from paddle_hub.tools.logger import logger
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import copy
def get_variable_info(var): def get_variable_info(var):
...@@ -116,3 +117,110 @@ def from_flexible_data_to_param(flexible_data): ...@@ -116,3 +117,110 @@ def from_flexible_data_to_param(flexible_data):
(clip_type, clip_norm, group_name)) (clip_type, clip_norm, group_name))
return param return param
def connect_program(pre_program, next_program, input_dict=None):
def _copy_vars_and_ops_in_blocks(from_block, to_block):
for var in from_block.vars:
var = from_block.var(var)
var_info = copy.deepcopy(get_variable_info(var))
if isinstance(var, fluid.framework.Parameter):
to_block.create_parameter(**var_info)
else:
to_block.create_var(**var_info)
for op in from_block.ops:
op_info = {
'type': op.type,
'inputs': {
input: [block.var(var) for var in op.input(input)]
for input in op.input_names
},
'outputs': {
output: [block.var(var) for var in op.output(output)]
for output in op.output_names
},
'attrs': copy.deepcopy(op.all_attrs())
}
to_block.append_op(**op_info)
assert isinstance(pre_program,
fluid.Program), "pre_program should be fluid.Program"
assert isinstance(next_program,
fluid.Program), "next_program should be fluid.Program"
new_program = pre_program.clone()
if input_dict:
assert isinstance(
input_dict,
dict), "the input_dict should be a dict with string-Variable pair"
for key, var in input_dict.items():
assert isinstance(
var, fluid.framework.Variable
), "the input_dict should be a dict with string-Variable pair"
var_info = copy.deepcopy(get_variable_info(var))
input_var = new_program.global_block().create_var(**var_info)
output_var = next_program.global_block().var(key)
var_info = copy.deepcopy(get_variable_info(output_var))
output_var = new_program.global_block().create_var(**var_info)
new_program.global_block().append_op(
type="assign",
inputs={'X': input_var},
outputs={'Out': output_var})
block_map = {0: 0}
logger.info("start to connect program")
for index, block in enumerate(next_program.blocks):
if block.idx == 0:
_copy_vars_and_ops_in_blocks(block, new_program.global_block())
else:
block_map[index] = len(new_program.blocks)
logger.info(
"block_%d in next_program merge into block_%d in pre_program" %
(index, block_map[index]))
new_block = new_program._create_block(
parent_idx=block_map[block.parent_idx])
_copy_vars_and_ops_in_blocks(block, new_block)
logger.info("end of connect program")
return new_program
def remove_feed_fetch_op(program):
""" remove feed and fetch operator and variable for fine-tuning
"""
logger.info("remove feed fetch op")
block = program.global_block()
need_to_remove_op_index = []
for i, op in enumerate(block.ops):
if op.type == "feed" or op.type == "fetch":
need_to_remove_op_index.append(i)
for index in need_to_remove_op_index[::-1]:
block._remove_op(index)
need_to_remove_var = []
for var in block.vars:
if var.endswith("feed"):
need_to_remove_var.append(var)
if var.endswith("fetch"):
need_to_remove_var.append(var)
for var in need_to_remove_var:
block._remove_var(var)
program.desc.flush()
def set_parameter_trainable(program, trainable=True):
for param in program.global_block().iter_parameters():
param.trainable = trainable
def set_parameter_regularization(program, regularization):
pass
def set_op_attr(program, is_test=False):
for block in program.blocks:
for op in block.ops:
if op.has_attr("is_test"):
op._set_attr("is_test", is_test)
...@@ -17,8 +17,8 @@ ...@@ -17,8 +17,8 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from paddle_hub import module_desc_pb2 from paddle_hub.module import module_desc_pb2
from paddle_hub.logger import logger from paddle_hub.tools.logger import logger
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import os import os
......
...@@ -12,4 +12,5 @@ ...@@ -12,4 +12,5 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Store PaddleHub version string """ """ Store PaddleHub version string """
__version__ = "0.1.0-dev" hub_version = "0.1.0.dev"
proto_version = "0.1.0"
#/bin/bash
protoc -I=../paddle_hub/module --python_out=../paddle_hub/module ../paddle_hub/module/module_desc.proto
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册