未验证 提交 658387b0 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] fix insert concat op (#47710)

* fix insert concat op

* fix fp16 assert
上级 d926c270
...@@ -134,7 +134,7 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase): ...@@ -134,7 +134,7 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
raise StopIteration raise StopIteration
def _infer_steps(self): def _infer_steps(self):
if self.steps_per_epoch is not None: if isinstance(self.steps_per_epoch, int) and self.steps_per_epoch > 1:
return self.steps_per_epoch return self.steps_per_epoch
try: try:
if isinstance(self.dataset, IterableDataset): if isinstance(self.dataset, IterableDataset):
......
...@@ -487,9 +487,8 @@ class FP16State: ...@@ -487,9 +487,8 @@ class FP16State:
# create cast grad # create cast grad
grad_slot_name = slot_name + "@GRAD" grad_slot_name = slot_name + "@GRAD"
assert ( if grad_slot_name not in op.output_names:
grad_slot_name in op.output_names continue
), "[{}], Current Op: {}".format(grad_slot_name, str(op))
# some forward input maybe stop_gradient=True, e.g. input_mask # some forward input maybe stop_gradient=True, e.g. input_mask
if len(op.output(grad_slot_name)) == 0: if len(op.output(grad_slot_name)) == 0:
...@@ -785,33 +784,67 @@ class FP16Pass(AMPPass): ...@@ -785,33 +784,67 @@ class FP16Pass(AMPPass):
with main_program._optimized_guard([]): with main_program._optimized_guard([]):
block = main_program.global_block() block = main_program.global_block()
all_infs = paddle.fluid.layers.concat(found_infs) # all_infs = paddle.fluid.layers.concat(found_infs)
all_infs = block.create_var(
name=paddle.fluid.unique_name.generate_with_ignorable_key(
".".join(['concat', 'tmp'])
),
dtype=found_infs[0].dtype,
shape=None,
lod_level=found_infs[0].lod_level,
type=found_infs[0].type,
persistable=False,
stop_gradient=False,
)
concat_op = block.append_op(
type='concat',
inputs={'X': found_infs},
outputs={'Out': [all_infs]},
attrs={'axis': 0},
)
set_var_dist_attr( set_var_dist_attr(
self.dist_context, self.dist_context,
all_infs, all_infs,
[-1], [-1],
world_process_group.ranks, world_process_group.ranks,
) )
new_op = block.ops[-1]
assert new_op.type == "concat"
_set_op_dist_attr_with_ranks( _set_op_dist_attr_with_ranks(
new_op, concat_op,
world_process_group.ranks, world_process_group.ranks,
block, block,
self.dist_context, self.dist_context,
) )
found_inf = paddle.fluid.layers.reduce_any(all_infs) # found_inf = paddle.fluid.layers.reduce_any(all_infs)
found_inf = block.create_var(
name=paddle.fluid.unique_name.generate_with_ignorable_key(
".".join(['reduce_any', 'tmp'])
),
dtype=all_infs.dtype,
shape=None,
lod_level=all_infs.lod_level,
type=all_infs.type,
persistable=False,
stop_gradient=False,
)
reduce_any_op = block.append_op(
type='reduce_any',
inputs={'X': all_infs},
outputs={'Out': found_inf},
attrs={
'dim': [0],
'keep_dim': False,
'reduce_all': True,
},
)
set_var_dist_attr( set_var_dist_attr(
self.dist_context, self.dist_context,
found_inf, found_inf,
[-1], [-1],
world_process_group.ranks, world_process_group.ranks,
) )
new_op = block.ops[-1]
assert new_op.type == "reduce_any"
_set_op_dist_attr_with_ranks( _set_op_dist_attr_with_ranks(
new_op, reduce_any_op,
world_process_group.ranks, world_process_group.ranks,
block, block,
self.dist_context, self.dist_context,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册