提交 dfa96d6d 编写于 作者: W wuzewu

add a func to connect program

上级 0da9baf5
...@@ -25,3 +25,4 @@ from paddle_hub.module import create_module ...@@ -25,3 +25,4 @@ from paddle_hub.module import create_module
from paddle_hub.downloader import download_and_uncompress from paddle_hub.downloader import download_and_uncompress
from paddle_hub.signature import create_signature from paddle_hub.signature import create_signature
from paddle_hub.version import __version__ from paddle_hub.version import __version__
connect_program = ModuleUtils.connect_program
...@@ -23,14 +23,14 @@ import paddle.fluid as fluid ...@@ -23,14 +23,14 @@ import paddle.fluid as fluid
import numpy as np import numpy as np
import tempfile import tempfile
import os import os
import pickle import copy
from collections import defaultdict from collections import defaultdict
from paddle_hub.downloader import download_and_uncompress from paddle_hub.downloader import download_and_uncompress
from paddle_hub import module_desc_pb2 from paddle_hub import module_desc_pb2
from paddle_hub.logger import logger from paddle_hub.logger import logger
from paddle_hub.signature import Signature from paddle_hub.signature import Signature
from paddle_hub.utils import to_list from paddle_hub.utils import to_list, get_variable_info
from paddle_hub.version import __version__ from paddle_hub.version import __version__
__all__ = ["Module", "ModuleConfig", "ModuleUtils"] __all__ = ["Module", "ModuleConfig", "ModuleUtils"]
...@@ -235,25 +235,6 @@ class Module(object): ...@@ -235,25 +235,6 @@ class Module(object):
word_dict = self.config.get_assets_vocab() word_dict = self.config.get_assets_vocab()
return list(map(lambda x: word_dict[x], inputs)) return list(map(lambda x: word_dict[x], inputs))
def set_input(self, input_dict):
assert isinstance(input_dict, dict), "input_dict must be a dict"
if not input_dict:
logger.warning("the input_dict is empty")
for key, val in input_dict.items():
assert isinstance(
val, fluid.framework.Variable
), "the input_dict should be a dict with string-Variable pair"
program = val.block.program
assert key in program.global_block(
).vars, "can't found input %s in the module" % key
input_var = val
output_var = program.global_block().var(key)
program.global_block()._prepend_op(
type="assign",
inputs={'X': input_var},
outputs={'Out': output_var})
class ModuleConfig(object): class ModuleConfig(object):
def __init__(self, module_dir, module_name=None): def __init__(self, module_dir, module_name=None):
...@@ -476,6 +457,71 @@ class ModuleUtils(object): ...@@ -476,6 +457,71 @@ class ModuleUtils(object):
def __init__(self): def __init__(self):
pass pass
@staticmethod
def connect_program(pre_program, next_program, input_dict=None):
def _copy_vars_and_ops_in_blocks(from_block, to_block):
for var in from_block.vars:
var = from_block.var(var)
var_info = copy.deepcopy(get_variable_info(var))
if isinstance(var, fluid.framework.Parameter):
to_block.create_parameter(**var_info)
else:
to_block.create_var(**var_info)
for op in from_block.ops:
op_info = {
'type': op.type,
'inputs': {
input: [block.var(var) for var in op.input(input)]
for input in op.input_names
},
'outputs': {
output: [block.var(var) for var in op.output(output)]
for output in op.output_names
},
'attrs': copy.deepcopy(op.all_attrs())
}
to_block.append_op(**op_info)
assert isinstance(pre_program,
fluid.Program), "pre_program should be fluid.Program"
assert isinstance(next_program,
fluid.Program), "next_program should be fluid.Program"
new_program = pre_program.clone()
if input_dict:
assert isinstance(
input_dict, dict
), "the input_dict should be a dict with string-Variable pair"
for key, var in input_dict.items():
assert isinstance(
var, fluid.framework.Variable
), "the input_dict should be a dict with string-Variable pair"
var_info = copy.deepcopy(get_variable_info(var))
input_var = new_program.global_block().create_var(**var_info)
output_var = next_program.global_block().var(key)
var_info = copy.deepcopy(get_variable_info(output_var))
output_var = new_program.global_block().create_var(**var_info)
new_program.global_block()._prepend_op(
type="assign",
inputs={'X': input_var},
outputs={'Out': output_var})
block_map = {0: 0}
logger.info("start to connect program")
for index, block in enumerate(next_program.blocks):
if block.idx == 0:
_copy_vars_and_ops_in_blocks(block, new_program.global_block())
else:
block_map[index] = len(new_program.blocks)
logger.info(
"block_%d in next_program merge into block_%d in pre_program"
% (index, block_map[index]))
new_block = new_program._create_block(
parent_idx=block_map[block.parent_idx])
_copy_vars_and_ops_in_blocks(block, new_block)
logger.info("end of connect program")
return new_program
@staticmethod @staticmethod
def remove_feed_fetch_op(program): def remove_feed_fetch_op(program):
""" remove feed and fetch operator and variable for fine-tuning """ remove feed and fetch operator and variable for fine-tuning
......
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import paddle
import paddle.fluid as fluid
def to_list(input): def to_list(input):
...@@ -25,3 +27,29 @@ def to_list(input): ...@@ -25,3 +27,29 @@ def to_list(input):
input = [input] input = [input]
return input return input
def get_variable_info(var):
assert isinstance(
var,
fluid.framework.Variable), "var should be a fluid.framework.Variable"
var_info = {
'type': var.type,
'name': var.name,
'dtype': var.dtype,
'lod_level': var.lod_level,
'shape': var.shape,
'stop_gradient': var.stop_gradient,
'is_data': var.is_data,
'error_clip': var.error_clip
}
if isinstance(var, fluid.framework.Parameter):
var_info['trainable'] = var.trainable
var_info['optimize_attr'] = var.optimize_attr
var_info['regularizer'] = var.regularizer
var_info['gradient_clip_attr'] = var.gradient_clip_attr
var_info['do_model_average'] = var.do_model_average
else:
var_info['persistable'] = var.persistable
return var_info
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册