提交 b9f9ff25 编写于 作者: W wuzewu

process op attrs

上级 c6089088
......@@ -132,11 +132,17 @@ class Module(object):
for param in program.global_block().iter_parameters():
param.trainable = trainable
def _process_op_attr(program, is_test=False):
for op in program.global_block().ops:
if op.has_attr("is_test"):
op._set_attr("is_test", is_test)
if not run_config:
run_config = RunConfig()
program = self.get_inference_program().clone()
_process_op_attr(program=program, is_test=False)
if run_config.param_train_config == ParamTrainConfig.PARAM_TRAIN_ALL:
_set_param_trainable(program=program, trainable=True)
elif run_config.param_train_config == ParamTrainConfig.PARAM_TRAIN_ALL:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册