提交 dfa96d6d 编写于 作者: W wuzewu

add a func to connect program

上级 0da9baf5
......@@ -25,3 +25,4 @@ from paddle_hub.module import create_module
from paddle_hub.downloader import download_and_uncompress
from paddle_hub.signature import create_signature
from paddle_hub.version import __version__
connect_program = ModuleUtils.connect_program
......@@ -23,14 +23,14 @@ import paddle.fluid as fluid
import numpy as np
import tempfile
import os
import pickle
import copy
from collections import defaultdict
from paddle_hub.downloader import download_and_uncompress
from paddle_hub import module_desc_pb2
from paddle_hub.logger import logger
from paddle_hub.signature import Signature
from paddle_hub.utils import to_list
from paddle_hub.utils import to_list, get_variable_info
from paddle_hub.version import __version__
__all__ = ["Module", "ModuleConfig", "ModuleUtils"]
......@@ -235,25 +235,6 @@ class Module(object):
word_dict = self.config.get_assets_vocab()
return list(map(lambda x: word_dict[x], inputs))
def set_input(self, input_dict):
assert isinstance(input_dict, dict), "input_dict must be a dict"
if not input_dict:
logger.warning("the input_dict is empty")
for key, val in input_dict.items():
assert isinstance(
val, fluid.framework.Variable
), "the input_dict should be a dict with string-Variable pair"
program = val.block.program
assert key in program.global_block(
).vars, "can't found input %s in the module" % key
input_var = val
output_var = program.global_block().var(key)
program.global_block()._prepend_op(
type="assign",
inputs={'X': input_var},
outputs={'Out': output_var})
class ModuleConfig(object):
def __init__(self, module_dir, module_name=None):
......@@ -476,6 +457,71 @@ class ModuleUtils(object):
def __init__(self):
pass
@staticmethod
def connect_program(pre_program, next_program, input_dict=None):
def _copy_vars_and_ops_in_blocks(from_block, to_block):
for var in from_block.vars:
var = from_block.var(var)
var_info = copy.deepcopy(get_variable_info(var))
if isinstance(var, fluid.framework.Parameter):
to_block.create_parameter(**var_info)
else:
to_block.create_var(**var_info)
for op in from_block.ops:
op_info = {
'type': op.type,
'inputs': {
input: [block.var(var) for var in op.input(input)]
for input in op.input_names
},
'outputs': {
output: [block.var(var) for var in op.output(output)]
for output in op.output_names
},
'attrs': copy.deepcopy(op.all_attrs())
}
to_block.append_op(**op_info)
assert isinstance(pre_program,
fluid.Program), "pre_program should be fluid.Program"
assert isinstance(next_program,
fluid.Program), "next_program should be fluid.Program"
new_program = pre_program.clone()
if input_dict:
assert isinstance(
input_dict, dict
), "the input_dict should be a dict with string-Variable pair"
for key, var in input_dict.items():
assert isinstance(
var, fluid.framework.Variable
), "the input_dict should be a dict with string-Variable pair"
var_info = copy.deepcopy(get_variable_info(var))
input_var = new_program.global_block().create_var(**var_info)
output_var = next_program.global_block().var(key)
var_info = copy.deepcopy(get_variable_info(output_var))
output_var = new_program.global_block().create_var(**var_info)
new_program.global_block()._prepend_op(
type="assign",
inputs={'X': input_var},
outputs={'Out': output_var})
block_map = {0: 0}
logger.info("start to connect program")
for index, block in enumerate(next_program.blocks):
if block.idx == 0:
_copy_vars_and_ops_in_blocks(block, new_program.global_block())
else:
block_map[index] = len(new_program.blocks)
logger.info(
"block_%d in next_program merge into block_%d in pre_program"
% (index, block_map[index]))
new_block = new_program._create_block(
parent_idx=block_map[block.parent_idx])
_copy_vars_and_ops_in_blocks(block, new_block)
logger.info("end of connect program")
return new_program
@staticmethod
def remove_feed_fetch_op(program):
""" remove feed and fetch operator and variable for fine-tuning
......
......@@ -17,6 +17,8 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.fluid as fluid
def to_list(input):
......@@ -25,3 +27,29 @@ def to_list(input):
input = [input]
return input
def get_variable_info(var):
assert isinstance(
var,
fluid.framework.Variable), "var should be a fluid.framework.Variable"
var_info = {
'type': var.type,
'name': var.name,
'dtype': var.dtype,
'lod_level': var.lod_level,
'shape': var.shape,
'stop_gradient': var.stop_gradient,
'is_data': var.is_data,
'error_clip': var.error_clip
}
if isinstance(var, fluid.framework.Parameter):
var_info['trainable'] = var.trainable
var_info['optimize_attr'] = var.optimize_attr
var_info['regularizer'] = var.regularizer
var_info['gradient_clip_attr'] = var.gradient_clip_attr
var_info['do_model_average'] = var.do_model_average
else:
var_info['persistable'] = var.persistable
return var_info
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册