提交 2c045b9c 编写于 作者: W wuzewu

add config process

上级 1d4afb60
...@@ -96,7 +96,7 @@ class RunCommand(BaseCommand): ...@@ -96,7 +96,7 @@ class RunCommand(BaseCommand):
module( module(
sign_name=self.args.signature, sign_name=self.args.signature,
data=input_data, data=input_data,
config=yaml_config['config']) **yaml_config['config'])
command = RunCommand.instance() command = RunCommand.instance()
...@@ -330,11 +330,11 @@ class Module: ...@@ -330,11 +330,11 @@ class Module:
utils.from_pyobj_to_flexible_data(self.summary, utils.from_pyobj_to_flexible_data(self.summary,
module_info.map.data['summary']) module_info.map.data['summary'])
def __call__(self, sign_name, data, config=None): def __call__(self, sign_name, data, **kwargs):
feed_dict, fetch_dict, program = self.context(sign_name) feed_dict, fetch_dict, program = self.context(sign_name, for_test=True)
#TODO(wuzewu): more option #TODO(wuzewu): more option
program = program.clone(for_test=True) reader = self.processor.reader(
reader = self.processor.reader(sign_name=sign_name, data_dict=data) sign_name=sign_name, data_dict=data, **kwargs)
feed_name_list = list( feed_name_list = list(
set([value.name for key, value in feed_dict.items()])) set([value.name for key, value in feed_dict.items()]))
fetch_list = list(set([value for key, value in fetch_dict.items()])) fetch_list = list(set([value for key, value in fetch_dict.items()]))
...@@ -347,20 +347,37 @@ class Module: ...@@ -347,20 +347,37 @@ class Module:
feed=feeder.feed(batch), feed=feeder.feed(batch),
fetch_list=fetch_list, fetch_list=fetch_list,
return_numpy=False) return_numpy=False)
self.processor.postprocess(sign_name, data_out, config) self.processor.postprocess(sign_name, data_out, **kwargs)
def context(self, sign_name, trainable=False): def context(self,
sign_name,
for_test=False,
trainable=False,
regularizer=None,
learning_rate=1e-3):
assert sign_name in self.signatures, "module did not have a signature with name %s" % sign_name assert sign_name in self.signatures, "module did not have a signature with name %s" % sign_name
signature = self.signatures[sign_name] signature = self.signatures[sign_name]
program = self.program.clone() program = self.program.clone(for_test=for_test)
paddle_helper.remove_feed_fetch_op(program) paddle_helper.remove_feed_fetch_op(program)
if not for_test:
if trainable != "Default":
paddle_helper.set_parameter_trainable(program, trainable) paddle_helper.set_parameter_trainable(program, trainable)
paddle_helper.set_op_attr(program, is_test=False)
if learning_rate != "Default":
paddle_helper.set_parameter_learning_rate(
program, learning_rate)
if regularizer != "Default":
paddle_helper.set_parameter_regularizer(program, regularizer)
self._recovery_parameter(program) self._recovery_parameter(program)
self._recover_variable_info(program) self._recover_variable_info(program)
paddle_helper.set_op_attr(program, is_test=for_test)
#TODO(wuzewu): return feed_list and fetch_list directly #TODO(wuzewu): return feed_list and fetch_list directly
feed_dict = {} feed_dict = {}
fetch_dict = {} fetch_dict = {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册