diff --git a/paddle_hub/__init__.py b/paddle_hub/__init__.py index 15f5f5adb8b04292e1f84bafa3e80f68f2532c12..b136876aabd669bc9aba6626ef06de10f785f7f6 100644 --- a/paddle_hub/__init__.py +++ b/paddle_hub/__init__.py @@ -25,3 +25,4 @@ from paddle_hub.module import create_module from paddle_hub.downloader import download_and_uncompress from paddle_hub.signature import create_signature from paddle_hub.version import __version__ +connect_program = ModuleUtils.connect_program diff --git a/paddle_hub/module.py b/paddle_hub/module.py index 12874a36c23dd2be55cfe64e0d9c9270013ea9d5..6bdc47f8759478ba11e8e586db2732a9615830c1 100644 --- a/paddle_hub/module.py +++ b/paddle_hub/module.py @@ -23,14 +23,14 @@ import paddle.fluid as fluid import numpy as np import tempfile import os -import pickle +import copy from collections import defaultdict from paddle_hub.downloader import download_and_uncompress from paddle_hub import module_desc_pb2 from paddle_hub.logger import logger from paddle_hub.signature import Signature -from paddle_hub.utils import to_list +from paddle_hub.utils import to_list, get_variable_info from paddle_hub.version import __version__ __all__ = ["Module", "ModuleConfig", "ModuleUtils"] @@ -235,25 +235,6 @@ class Module(object): word_dict = self.config.get_assets_vocab() return list(map(lambda x: word_dict[x], inputs)) - def set_input(self, input_dict): - assert isinstance(input_dict, dict), "input_dict must be a dict" - if not input_dict: - logger.warning("the input_dict is empty") - - for key, val in input_dict.items(): - assert isinstance( - val, fluid.framework.Variable - ), "the input_dict should be a dict with string-Variable pair" - program = val.block.program - assert key in program.global_block( - ).vars, "can't found input %s in the module" % key - input_var = val - output_var = program.global_block().var(key) - program.global_block()._prepend_op( - type="assign", - inputs={'X': input_var}, - outputs={'Out': output_var}) - class ModuleConfig(object): def __init__(self, module_dir, module_name=None): @@ -476,6 +457,71 @@ class ModuleUtils(object): def __init__(self): pass + @staticmethod + def connect_program(pre_program, next_program, input_dict=None): + def _copy_vars_and_ops_in_blocks(from_block, to_block): + for var in from_block.vars: + var = from_block.var(var) + var_info = copy.deepcopy(get_variable_info(var)) + if isinstance(var, fluid.framework.Parameter): + to_block.create_parameter(**var_info) + else: + to_block.create_var(**var_info) + + for op in from_block.ops: + op_info = { + 'type': op.type, + 'inputs': { + input: [block.var(var) for var in op.input(input)] + for input in op.input_names + }, + 'outputs': { + output: [block.var(var) for var in op.output(output)] + for output in op.output_names + }, + 'attrs': copy.deepcopy(op.all_attrs()) + } + to_block.append_op(**op_info) + + assert isinstance(pre_program, + fluid.Program), "pre_program should be fluid.Program" + assert isinstance(next_program, + fluid.Program), "next_program should be fluid.Program" + new_program = pre_program.clone() + if input_dict: + assert isinstance( + input_dict, dict + ), "the input_dict should be a dict with string-Variable pair" + for key, var in input_dict.items(): + assert isinstance( + var, fluid.framework.Variable + ), "the input_dict should be a dict with string-Variable pair" + var_info = copy.deepcopy(get_variable_info(var)) + input_var = new_program.global_block().create_var(**var_info) + output_var = next_program.global_block().var(key) + var_info = copy.deepcopy(get_variable_info(output_var)) + output_var = new_program.global_block().create_var(**var_info) + new_program.global_block()._prepend_op( + type="assign", + inputs={'X': input_var}, + outputs={'Out': output_var}) + + block_map = {0: 0} + logger.info("start to connect program") + for index, block in enumerate(next_program.blocks): + if block.idx == 0: + _copy_vars_and_ops_in_blocks(block, new_program.global_block()) + else: + block_map[index] = len(new_program.blocks) + logger.info( + "block_%d in next_program merge into block_%d in pre_program" + % (index, block_map[index])) + new_block = new_program._create_block( + parent_idx=block_map[block.parent_idx]) + _copy_vars_and_ops_in_blocks(block, new_block) + logger.info("end of connect program") + return new_program + @staticmethod def remove_feed_fetch_op(program): """ remove feed and fetch operator and variable for fine-tuning diff --git a/paddle_hub/utils.py b/paddle_hub/utils.py index 43b1aa8c08e00503785d88afc0ae80b4ab9e3841..ff6139d4113d1ede4e9eec0f2b1eb16c361e5cda 100644 --- a/paddle_hub/utils.py +++ b/paddle_hub/utils.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import paddle +import paddle.fluid as fluid def to_list(input): @@ -25,3 +27,29 @@ def to_list(input): input = [input] return input + + +def get_variable_info(var): + assert isinstance( + var, + fluid.framework.Variable), "var should be a fluid.framework.Variable" + var_info = { + 'type': var.type, + 'name': var.name, + 'dtype': var.dtype, + 'lod_level': var.lod_level, + 'shape': var.shape, + 'stop_gradient': var.stop_gradient, + 'is_data': var.is_data, + 'error_clip': var.error_clip + } + if isinstance(var, fluid.framework.Parameter): + var_info['trainable'] = var.trainable + var_info['optimize_attr'] = var.optimize_attr + var_info['regularizer'] = var.regularizer + var_info['gradient_clip_attr'] = var.gradient_clip_attr + var_info['do_model_average'] = var.do_model_average + else: + var_info['persistable'] = var.persistable + + return var_info