提交 002e52c3 编写于 作者: S seiriosPlus

fix fuse

上级 f10138b2
...@@ -658,8 +658,6 @@ def large_scale_sparse_pass(program, main_program, config, is_startup=False): ...@@ -658,8 +658,6 @@ def large_scale_sparse_pass(program, main_program, config, is_startup=False):
if op.type == "sgd": if op.type == "sgd":
grad = main_program.global_block().vars[op.input("Grad")[0]] grad = main_program.global_block().vars[op.input("Grad")[0]]
lr = main_program.global_block().vars[op.input("LearningRate")[0]] lr = main_program.global_block().vars[op.input("LearningRate")[0]]
## remove origin optimzier op
#block._remove_op(opt_idx)
block._insert_op( block._insert_op(
opt_idx, opt_idx,
...@@ -679,18 +677,15 @@ def large_scale_sparse_pass(program, main_program, config, is_startup=False): ...@@ -679,18 +677,15 @@ def large_scale_sparse_pass(program, main_program, config, is_startup=False):
0]] 0]]
beta2_pow = main_program.global_block().vars[op.input("Beta2Pow")[ beta2_pow = main_program.global_block().vars[op.input("Beta2Pow")[
0]] 0]]
beta1_pow_o = main_program.global_block().vars[op.input( beta1_pow_o = main_program.global_block().vars[op.output(
"Beta1PowOut")[0]] "Beta1PowOut")[0]]
beta2_pow_o = main_program.global_block().vars[op.input( beta2_pow_o = main_program.global_block().vars[op.output(
"Beta2PowOut")[0]] "Beta2PowOut")[0]]
beta1 = op.attr('shape') beta1 = op.attr('shape')
beta2 = op.attr('beta2') beta2 = op.attr('beta2')
epsilon = op.attr('epsilon') epsilon = op.attr('epsilon')
## remove origin optimzier op
#block._remove_op(opt_idx)
block._insert_op( block._insert_op(
opt_idx, opt_idx,
type="lookup_sparse_table_fuse_adam", type="lookup_sparse_table_fuse_adam",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册