提交 77ac2c2b 编写于 作者: S seiriosPlus

fix fuse

上级 3cf50b74
...@@ -658,8 +658,8 @@ def large_scale_sparse_pass(program, main_program, config, is_startup=False): ...@@ -658,8 +658,8 @@ def large_scale_sparse_pass(program, main_program, config, is_startup=False):
if op.type == "sgd": if op.type == "sgd":
grad = main_program.global_block().vars[op.input("Grad")[0]] grad = main_program.global_block().vars[op.input("Grad")[0]]
lr = main_program.global_block().vars[op.input("LearningRate")[0]] lr = main_program.global_block().vars[op.input("LearningRate")[0]]
# remove origin optimzier op ## remove origin optimzier op
block._remove_op(opt_idx) #block._remove_op(opt_idx)
block._insert_op( block._insert_op(
opt_idx, opt_idx,
...@@ -688,8 +688,8 @@ def large_scale_sparse_pass(program, main_program, config, is_startup=False): ...@@ -688,8 +688,8 @@ def large_scale_sparse_pass(program, main_program, config, is_startup=False):
beta2 = op.attr('beta2') beta2 = op.attr('beta2')
epsilon = op.attr('epsilon') epsilon = op.attr('epsilon')
# remove origin optimzier op ## remove origin optimzier op
block._remove_op(opt_idx) #block._remove_op(opt_idx)
block._insert_op( block._insert_op(
opt_idx, opt_idx,
...@@ -800,11 +800,15 @@ def large_scale_sparse_pass(program, main_program, config, is_startup=False): ...@@ -800,11 +800,15 @@ def large_scale_sparse_pass(program, main_program, config, is_startup=False):
for param, blockid in param_blockid_map.items(): for param, blockid in param_blockid_map.items():
opt_block = main_program.block(blockid) opt_block = main_program.block(blockid)
grad, _, value_names, value_dims, acture_names, fuse = \ grad, opt_idx, value_names, value_dims, acture_names, fuse = \
get_optimizer_values(opt_block) get_optimizer_values(opt_block)
entry_attr = get_entry_attr(param) entry_attr = get_entry_attr(param)
if fuse:
# remove origin optimzier op
opt_block._remove_op(opt_idx)
# training/infer # training/infer
mode = "0" mode = "0"
names_str = ",".join(value_names) names_str = ",".join(value_names)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册