From 2c045b9c300ac32d36834ead0c96e48bd6e49ceb Mon Sep 17 00:00:00 2001 From: wuzewu Date: Mon, 11 Mar 2019 17:32:13 +0800 Subject: [PATCH] add config process --- paddle_hub/commands/run.py | 2 +- paddle_hub/module/module.py | 37 +++++++++++++++++++++++++++---------- 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/paddle_hub/commands/run.py b/paddle_hub/commands/run.py index ec4866b7..2d210d81 100644 --- a/paddle_hub/commands/run.py +++ b/paddle_hub/commands/run.py @@ -96,7 +96,7 @@ class RunCommand(BaseCommand): module( sign_name=self.args.signature, data=input_data, - config=yaml_config['config']) + **yaml_config['config']) command = RunCommand.instance() diff --git a/paddle_hub/module/module.py b/paddle_hub/module/module.py index 6cc5331e..e697d320 100644 --- a/paddle_hub/module/module.py +++ b/paddle_hub/module/module.py @@ -330,11 +330,11 @@ class Module: utils.from_pyobj_to_flexible_data(self.summary, module_info.map.data['summary']) - def __call__(self, sign_name, data, config=None): - feed_dict, fetch_dict, program = self.context(sign_name) + def __call__(self, sign_name, data, **kwargs): + feed_dict, fetch_dict, program = self.context(sign_name, for_test=True) #TODO(wuzewu): more option - program = program.clone(for_test=True) - reader = self.processor.reader(sign_name=sign_name, data_dict=data) + reader = self.processor.reader( + sign_name=sign_name, data_dict=data, **kwargs) feed_name_list = list( set([value.name for key, value in feed_dict.items()])) fetch_list = list(set([value for key, value in fetch_dict.items()])) @@ -347,20 +347,37 @@ class Module: feed=feeder.feed(batch), fetch_list=fetch_list, return_numpy=False) - self.processor.postprocess(sign_name, data_out, config) + self.processor.postprocess(sign_name, data_out, **kwargs) - def context(self, sign_name, trainable=False): + def context(self, + sign_name, + for_test=False, + trainable=False, + regularizer=None, + learning_rate=1e-3): assert sign_name in self.signatures, "module did not have a signature with name %s" % sign_name signature = self.signatures[sign_name] - program = self.program.clone() + program = self.program.clone(for_test=for_test) paddle_helper.remove_feed_fetch_op(program) - paddle_helper.set_parameter_trainable(program, trainable) - paddle_helper.set_op_attr(program, is_test=False) - self._recovery_parameter(program) + + if not for_test: + if trainable != "Default": + paddle_helper.set_parameter_trainable(program, trainable) + + if learning_rate != "Default": + paddle_helper.set_parameter_learning_rate( + program, learning_rate) + + if regularizer != "Default": + paddle_helper.set_parameter_regularizer(program, regularizer) + + self._recovery_parameter(program) + self._recover_variable_info(program) + paddle_helper.set_op_attr(program, is_test=for_test) #TODO(wuzewu): return feed_list and fetch_list directly feed_dict = {} fetch_dict = {} -- GitLab