diff --git a/paddle_hub/module.py b/paddle_hub/module.py index 95bb75e83990f2f22573b9b4cb75e17b4b667482..b6d48adc15dd9503fb4b4fe74eaeadb1a03718ea 100644 --- a/paddle_hub/module.py +++ b/paddle_hub/module.py @@ -235,6 +235,25 @@ class Module(object): word_dict = self.config.get_assets_vocab() return list(map(lambda x: word_dict[x], inputs)) + def set_input(self, input_dict): + assert isinstance(input_dict, dict), "input_dict must be a dict" + if not input_dict: + logger.warning("the input_dict is empty") + + for key, val in input_dict.items(): + assert isinstance( + val, fluid.framework.Variable + ), "the input_dict should be a dict with string-Variable pair" + program = val.block.program + assert key in program.global_block( + ).vars, "can't found input %s in the module" % key + input_var = val + output_var = program.global_block().var(key) + program.global_block()._prepend_op( + type="assign", + inputs={'X': input_var}, + outputs={'Out': output_var}) + class ModuleConfig(object): def __init__(self, module_dir, module_name=None):