未验证 提交 c234f1f2 编写于 作者: Z zhaoyingli 提交者: GitHub

remove allreduce before c_allgather (#55143)

* remove allreduce before c_allgather

* update reshard insert_fill_constant_op func

* insert_fill_constant_op add shape arg
上级 86694ce3
......@@ -568,7 +568,7 @@ class Inserter:
return outs
@staticmethod
def insert_fill_constant_op(block, idx, op_role):
def insert_fill_constant_op(block, idx, op_role, shape):
"""Insert fill constant op into block at the given index."""
# to avoid name conflict with framework
helper = LayerHelper('fill_constant@RESHARD', **locals())
......@@ -591,7 +591,7 @@ class Inserter:
attrs['dtype'] = out.dtype
attrs['op_role'] = op_role
paddle.utils.get_shape_tensor_inputs(
inputs=inputs, attrs=attrs, shape=[0], op_type='fill_constant'
inputs=inputs, attrs=attrs, shape=shape, op_type='fill_constant'
)
fillconstant_op = block._insert_op(
idx,
......@@ -611,38 +611,6 @@ class Inserter:
group = new_process_group(ranks)
idx_offset = 0
# instant process group before insert allgather op.
if not group.is_instantiate():
# insert fill_constant op
fill_constant_out = Inserter.insert_fill_constant_op(
block, idx, op_role
)
fill_constant_out.stop_gradient = True
# insert c_allreduce_sum op
allreduce_op = block._insert_op(
idx + 1,
type="c_allreduce_sum",
inputs={'X': [fill_constant_out]},
outputs={'Out': [fill_constant_out]},
attrs={
'ring_id': 0,
'use_calc_stream': True,
'op_role': op_role,
},
)
allreduce_op._set_attr('op_namescope', "/auto_parallel/reshard")
# insert c_sync_calc_stream op
sync_calc_op = block._insert_op(
idx + 2,
type="c_sync_calc_stream",
inputs={'X': [fill_constant_out]},
outputs={'Out': [fill_constant_out]},
attrs={'op_role': op_role},
)
sync_calc_op._set_attr('op_namescope', "/auto_parallel/reshard")
idx_offset = 3
# insert c_allgather op
op_type = 'c_allgather'
# to avoid name conflict with framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册