提交 666a91fa 编写于 作者: W wuzewu

Add helper method

上级 a12300be
......@@ -279,3 +279,32 @@ def clone_program(origin_program, for_test=False):
).vars[name].stop_gradient = var.stop_gradient
return dest_program
def rename_var(block, old_name, new_name):
for op in block.ops:
for input_name in op.input_arg_names:
if input_name == old_name:
op._rename_input(old_name, new_name)
for output_name in op.output_arg_names:
if output_name == old_name:
op._rename_output(old_name, new_name)
block._rename_var(old_name, new_name)
def add_vars_prefix(program, prefix, vars=None):
block = program.global_block()
vars = list(vars) if vars else list(block.vars.keys())
for var in vars:
rename_var(block, var, prefix + var)
def remove_vars_prefix(program, prefix, vars=None):
block = program.global_block()
vars = [var for var in vars if var.startswith(prefix)] if vars else [
var for var in block.vars.keys() if var.startswith(prefix)
]
for var in vars:
rename_var(block, var, var.replace(prefix, "", 1))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册