提交 b9f9ff25 编写于 作者: W wuzewu

process op attrs

上级 c6089088
...@@ -132,11 +132,17 @@ class Module(object): ...@@ -132,11 +132,17 @@ class Module(object):
for param in program.global_block().iter_parameters(): for param in program.global_block().iter_parameters():
param.trainable = trainable param.trainable = trainable
def _process_op_attr(program, is_test=False):
for op in program.global_block().ops:
if op.has_attr("is_test"):
op._set_attr("is_test", is_test)
if not run_config: if not run_config:
run_config = RunConfig() run_config = RunConfig()
program = self.get_inference_program().clone() program = self.get_inference_program().clone()
_process_op_attr(program=program, is_test=False)
if run_config.param_train_config == ParamTrainConfig.PARAM_TRAIN_ALL: if run_config.param_train_config == ParamTrainConfig.PARAM_TRAIN_ALL:
_set_param_trainable(program=program, trainable=True) _set_param_trainable(program=program, trainable=True)
elif run_config.param_train_config == ParamTrainConfig.PARAM_TRAIN_ALL: elif run_config.param_train_config == ParamTrainConfig.PARAM_TRAIN_ALL:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册