diff --git a/paddle_hub/commands/run.py b/paddle_hub/commands/run.py index ec4866b72769ceae7e4dccf0a98bd7a38ef52b5b..2d210d81a59ecf1de8c414f6e4a754a8cc0330b7 100644 --- a/paddle_hub/commands/run.py +++ b/paddle_hub/commands/run.py @@ -96,7 +96,7 @@ class RunCommand(BaseCommand): module( sign_name=self.args.signature, data=input_data, - config=yaml_config['config']) + **yaml_config['config']) command = RunCommand.instance() diff --git a/paddle_hub/module/module.py b/paddle_hub/module/module.py index 6cc5331e558450e7ab2e69c3f2e38a15a2040544..e697d32029b6735539e2a60d9378af7cf9168ad5 100644 --- a/paddle_hub/module/module.py +++ b/paddle_hub/module/module.py @@ -330,11 +330,11 @@ class Module: utils.from_pyobj_to_flexible_data(self.summary, module_info.map.data['summary']) - def __call__(self, sign_name, data, config=None): - feed_dict, fetch_dict, program = self.context(sign_name) + def __call__(self, sign_name, data, **kwargs): + feed_dict, fetch_dict, program = self.context(sign_name, for_test=True) #TODO(wuzewu): more option - program = program.clone(for_test=True) - reader = self.processor.reader(sign_name=sign_name, data_dict=data) + reader = self.processor.reader( + sign_name=sign_name, data_dict=data, **kwargs) feed_name_list = list( set([value.name for key, value in feed_dict.items()])) fetch_list = list(set([value for key, value in fetch_dict.items()])) @@ -347,20 +347,37 @@ class Module: feed=feeder.feed(batch), fetch_list=fetch_list, return_numpy=False) - self.processor.postprocess(sign_name, data_out, config) + self.processor.postprocess(sign_name, data_out, **kwargs) - def context(self, sign_name, trainable=False): + def context(self, + sign_name, + for_test=False, + trainable=False, + regularizer=None, + learning_rate=1e-3): assert sign_name in self.signatures, "module did not have a signature with name %s" % sign_name signature = self.signatures[sign_name] - program = self.program.clone() + program = self.program.clone(for_test=for_test) paddle_helper.remove_feed_fetch_op(program) - paddle_helper.set_parameter_trainable(program, trainable) - paddle_helper.set_op_attr(program, is_test=False) - self._recovery_parameter(program) + + if not for_test: + if trainable != "Default": + paddle_helper.set_parameter_trainable(program, trainable) + + if learning_rate != "Default": + paddle_helper.set_parameter_learning_rate( + program, learning_rate) + + if regularizer != "Default": + paddle_helper.set_parameter_regularizer(program, regularizer) + + self._recovery_parameter(program) + self._recover_variable_info(program) + paddle_helper.set_op_attr(program, is_test=for_test) #TODO(wuzewu): return feed_list and fetch_list directly feed_dict = {} fetch_dict = {}