提交 17b54bfe 编写于 作者: W wuzewu

add set_input method to replace feed var in module

上级 cff7c12e
...@@ -235,6 +235,25 @@ class Module(object): ...@@ -235,6 +235,25 @@ class Module(object):
word_dict = self.config.get_assets_vocab() word_dict = self.config.get_assets_vocab()
return list(map(lambda x: word_dict[x], inputs)) return list(map(lambda x: word_dict[x], inputs))
def set_input(self, input_dict):
assert isinstance(input_dict, dict), "input_dict must be a dict"
if not input_dict:
logger.warning("the input_dict is empty")
for key, val in input_dict.items():
assert isinstance(
val, fluid.framework.Variable
), "the input_dict should be a dict with string-Variable pair"
program = val.block.program
assert key in program.global_block(
).vars, "can't found input %s in the module" % key
input_var = val
output_var = program.global_block().var(key)
program.global_block()._prepend_op(
type="assign",
inputs={'X': input_var},
outputs={'Out': output_var})
class ModuleConfig(object): class ModuleConfig(object):
def __init__(self, module_dir, module_name=None): def __init__(self, module_dir, module_name=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册