提交 77ac2c2b 编写于 作者: S seiriosPlus

fix fuse

上级 3cf50b74
......@@ -658,8 +658,8 @@ def large_scale_sparse_pass(program, main_program, config, is_startup=False):
if op.type == "sgd":
grad = main_program.global_block().vars[op.input("Grad")[0]]
lr = main_program.global_block().vars[op.input("LearningRate")[0]]
# remove origin optimzier op
block._remove_op(opt_idx)
## remove origin optimzier op
#block._remove_op(opt_idx)
block._insert_op(
opt_idx,
......@@ -688,8 +688,8 @@ def large_scale_sparse_pass(program, main_program, config, is_startup=False):
beta2 = op.attr('beta2')
epsilon = op.attr('epsilon')
# remove origin optimzier op
block._remove_op(opt_idx)
## remove origin optimzier op
#block._remove_op(opt_idx)
block._insert_op(
opt_idx,
......@@ -800,11 +800,15 @@ def large_scale_sparse_pass(program, main_program, config, is_startup=False):
for param, blockid in param_blockid_map.items():
opt_block = main_program.block(blockid)
grad, _, value_names, value_dims, acture_names, fuse = \
grad, opt_idx, value_names, value_dims, acture_names, fuse = \
get_optimizer_values(opt_block)
entry_attr = get_entry_attr(param)
if fuse:
# remove origin optimzier op
opt_block._remove_op(opt_idx)
# training/infer
mode = "0"
names_str = ",".join(value_names)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册