提交 2c045b9c 编写于 作者: W wuzewu

add config process

上级 1d4afb60
......@@ -96,7 +96,7 @@ class RunCommand(BaseCommand):
module(
sign_name=self.args.signature,
data=input_data,
config=yaml_config['config'])
**yaml_config['config'])
command = RunCommand.instance()
......@@ -330,11 +330,11 @@ class Module:
utils.from_pyobj_to_flexible_data(self.summary,
module_info.map.data['summary'])
def __call__(self, sign_name, data, config=None):
feed_dict, fetch_dict, program = self.context(sign_name)
def __call__(self, sign_name, data, **kwargs):
feed_dict, fetch_dict, program = self.context(sign_name, for_test=True)
#TODO(wuzewu): more option
program = program.clone(for_test=True)
reader = self.processor.reader(sign_name=sign_name, data_dict=data)
reader = self.processor.reader(
sign_name=sign_name, data_dict=data, **kwargs)
feed_name_list = list(
set([value.name for key, value in feed_dict.items()]))
fetch_list = list(set([value for key, value in fetch_dict.items()]))
......@@ -347,20 +347,37 @@ class Module:
feed=feeder.feed(batch),
fetch_list=fetch_list,
return_numpy=False)
self.processor.postprocess(sign_name, data_out, config)
self.processor.postprocess(sign_name, data_out, **kwargs)
def context(self, sign_name, trainable=False):
def context(self,
sign_name,
for_test=False,
trainable=False,
regularizer=None,
learning_rate=1e-3):
assert sign_name in self.signatures, "module did not have a signature with name %s" % sign_name
signature = self.signatures[sign_name]
program = self.program.clone()
program = self.program.clone(for_test=for_test)
paddle_helper.remove_feed_fetch_op(program)
paddle_helper.set_parameter_trainable(program, trainable)
paddle_helper.set_op_attr(program, is_test=False)
self._recovery_parameter(program)
if not for_test:
if trainable != "Default":
paddle_helper.set_parameter_trainable(program, trainable)
if learning_rate != "Default":
paddle_helper.set_parameter_learning_rate(
program, learning_rate)
if regularizer != "Default":
paddle_helper.set_parameter_regularizer(program, regularizer)
self._recovery_parameter(program)
self._recover_variable_info(program)
paddle_helper.set_op_attr(program, is_test=for_test)
#TODO(wuzewu): return feed_list and fetch_list directly
feed_dict = {}
fetch_dict = {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册