提交 0f26cee1 编写于 作者: S seiriosPlus

add large scale optimizer fuse

上级 61284c0a
......@@ -624,6 +624,7 @@ def large_scale_sparse_pass(program, main_program, config, is_startup=False):
value_dims = []
grad = None
opt_idx = -1
fuse = False
for op in block.ops:
opt_idx += 1
......@@ -631,6 +632,9 @@ def large_scale_sparse_pass(program, main_program, config, is_startup=False):
if op.type not in opt_value_map.keys():
continue
if op.type in ["sgd", "adam"]:
fuse = True
grad = main_program.global_block().vars[op.input("Grad")[0]]
for value in opt_value_map[op.type]:
......@@ -644,7 +648,68 @@ def large_scale_sparse_pass(program, main_program, config, is_startup=False):
if value_names:
break
return grad, opt_idx, value_names, value_dims, acture_names
return grad, opt_idx, value_names, value_dims, acture_names, fuse
def add_fuse_large_scale_op(block, global_block, table_name, value_names,
acture_names, grad, is_entry, opt_idx):
op = block.ops[opt_idx]
if op.type == "sgd":
grad = main_program.global_block().vars[op.input("Grad")[0]]
# remove origin optimzier op
block._remove_op(opt_idx)
block._insert_op(
opt_idx,
type="lookup_sparse_table_fuse_sgd",
inputs={"Grad": grad},
attrs={
"is_entry": is_entry,
"tablename": table_name,
"value_names": value_names
})
elif op.type == "adam":
grad = main_program.global_block().vars[op.input("Grad")[0]]
beta1_pow = main_program.global_block().vars[op.input("Beta1Pow")[
0]]
beta2_pow = main_program.global_block().vars[op.input("Beta2Pow")[
0]]
beta1_pow_o = main_program.global_block().vars[op.input(
"Beta1PowOut")[0]]
beta2_pow_o = main_program.global_block().vars[op.input(
"Beta2PowOut")[0]]
beta1 = op.attr('shape')
beta2 = op.attr('beta2')
epsilon = op.attr('epsilon')
# remove origin optimzier op
block._remove_op(opt_idx)
block._insert_op(
opt_idx,
type="lookup_sparse_table_fuse_adam",
inputs={
"Grad": grad,
"Beta1Pow": beta1_pow,
"Beta2Pow": beta2_pow
},
outputs={
"Beta1PowOut": beta1_pow_o,
"Beta2PowOut": beta2_pow_o
},
attrs={
"beta1": beta1,
"beta2": beta2,
"epsilon": epsilon,
"is_entry": is_entry,
"tablename": table_name,
"value_names": value_names
})
else:
raise ValueError("only support sgd/adam optimizer now")
def add_large_scale_op(block, global_block, table_name, value_names,
acture_names, grad, is_entry, opt_idx):
......@@ -711,20 +776,27 @@ def large_scale_sparse_pass(program, main_program, config, is_startup=False):
for param, blockid in param_blockid_map.items():
opt_block = program.block(blockid)
grad, opt_idx, value_names, value_dims, acture_names = \
grad, opt_idx, value_names, value_dims, acture_names, fuse = \
get_optimizer_values(opt_block)
entry_attr = get_entry_attr(param)
is_entry = False if entry_attr == "none" else True
add_large_scale_op(opt_block,
program.global_block(), param, value_names,
acture_names, grad, is_entry, opt_idx)
if fuse:
add_fuse_large_scale_op(opt_block,
program.global_block(), param,
value_names, acture_names, grad,
is_entry, opt_idx)
else:
add_large_scale_op(opt_block,
program.global_block(), param, value_names,
acture_names, grad, is_entry, opt_idx)
else:
large_scale_kv_metas = []
for param, blockid in param_blockid_map.items():
opt_block = main_program.block(blockid)
grad, _, value_names, value_dims, acture_names = \
grad, _, value_names, value_dims, acture_names, fuse = \
get_optimizer_values(opt_block)
entry_attr = get_entry_attr(param)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册