未验证 提交 423ea978 编写于 作者: Y Yuang Liu 提交者: GitHub

all reduce fusion for shardinug, test=develop (#34480)

上级 79e758c6
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle.fluid import core
from paddle.fluid import core, unique_name
from functools import reduce
from paddle.distributed.fleet.meta_optimizers.common import is_loss_grad_op
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
......@@ -333,13 +333,19 @@ def insert_allreduce_ops(block,
ring_id,
allreduce_vars,
op_role=OpRole.Backward,
use_calc_stream=False):
use_calc_stream=False,
user_defined_strategy=None):
"""
_add_allreduce_ops
"""
if len(allreduce_vars) == 0:
return
if user_defined_strategy and user_defined_strategy.fuse_all_reduce_ops:
insert_fused_allreduce_ops(block, insert_idx, ring_id, allreduce_vars,
op_role, use_calc_stream,
user_defined_strategy.fuse_grad_size_in_MB)
else:
for var in allreduce_vars:
block._insert_op_without_sync(
insert_idx,
......@@ -355,6 +361,70 @@ def insert_allreduce_ops(block,
return
def insert_fused_allreduce_ops(block,
insert_idx,
ring_id,
allreduce_vars,
op_role=OpRole.Backward,
use_calc_stream=False,
fuse_grad_size_in_MB=32):
segments = []
cur_size = 0.
last_dtype = None
for var in allreduce_vars:
real_var = block.var(var)
var_size = get_var_size(real_var)
if cur_size + var_size > fuse_grad_size_in_MB \
or len(segments) == 0 \
or real_var.dtype != last_dtype:
segments.append([real_var])
cur_size = var_size
last_dtype = real_var.dtype
else:
segments[-1].append(real_var)
cur_size += var_size
fused_vars = []
for segment in segments:
tmp_var = block.create_var(
name=unique_name.generate('FusedOutput_{}'.format(segment[0].name)),
dtype=segment[0].dtype,
persistable=False,
stop_gradient=True)
fused_vars.append(tmp_var)
block._insert_op_without_sync(
insert_idx,
type="coalesce_tensor",
inputs={"Input": segment},
outputs={"Output": segment,
"FusedOutput": tmp_var},
attrs={
"copy_data": True,
"use_align": True,
"dtype": segment[0].dtype,
OP_ROLE_KEY: op_role
})
for fused_var in fused_vars:
block._insert_op_without_sync(
insert_idx + len(fused_vars),
type='c_allreduce_sum',
inputs={'X': fused_var},
outputs={'Out': fused_var},
attrs={
'ring_id': ring_id,
'use_calc_stream': use_calc_stream,
OP_ROLE_KEY: op_role
})
if not use_calc_stream:
block._insert_op_without_sync(
insert_idx + len(fused_vars),
type='c_sync_calc_stream',
inputs={'X': fused_var},
outputs={'Out': fused_var},
attrs={OP_ROLE_KEY: op_role})
def insert_reduce_ops(block,
insert_idx,
ring_id,
......
......@@ -322,7 +322,8 @@ class ShardingOptimizer(MetaOptimizerBase):
self.dp_ring_id,
accumulated_grad_names,
core.op_proto_and_checker_maker.OpRole.Optimize,
use_calc_stream=True)
use_calc_stream=True,
user_defined_strategy=self.user_defined_strategy)
# if not use sharding, adapt amp/clip, for remain parallelism.
# cast --> amp --> clip --> opt
......@@ -778,8 +779,12 @@ class ShardingOptimizer(MetaOptimizerBase):
shard_allredue_vars) >= 1:
insert_sync_comm_ops(block, self._segments[-1]._end_idx,
self.dp_ring_id, shard_allredue_vars)
insert_allreduce_ops(block, self._segments[-1]._end_idx,
self.dp_ring_id, shard_allredue_vars)
insert_allreduce_ops(
block,
self._segments[-1]._end_idx,
self.dp_ring_id,
shard_allredue_vars,
user_defined_strategy=self.user_defined_strategy)
# gradient merge
elif self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1:
self.create_persistable_gradients_and_insert_merge_ops(
......@@ -896,8 +901,12 @@ class ShardingOptimizer(MetaOptimizerBase):
if self.gradient_merge_mode != "sharding_gm" or self._gradient_merge_acc_step <= 1:
if self.hybrid_dp and self.hybrid_dp_mode == "sharding_hybrid_dp" and len(
shard_allredue_vars) >= 1:
insert_allreduce_ops(block, segment._start_idx,
self.dp_ring_id, shard_allredue_vars)
insert_allreduce_ops(
block,
segment._start_idx,
self.dp_ring_id,
shard_allredue_vars,
user_defined_strategy=self.user_defined_strategy)
insert_sync_comm_ops(block, segment._start_idx,
self.sharding_ring_id, allreduce_vars)
# gradient merge
......
......@@ -586,6 +586,36 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer):
self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002'])
def test_sharding_dp_with_allreduce_fuse(self):
train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program(
)
avg_cost, _ = self.net(train_prog, startup_prog)
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.sharding = True
strategy.sharding_configs = {
"sharding_segment_strategy": "segment_broadcast_MB",
"segment_broadcast_MB": 0.1,
"segment_anchors": None,
"sharding_degree": 2,
"dp_degree": 2,
"hybrid_dp": True,
"gradient_merge_acc_step": 1,
"mp_degree": 1
}
strategy.fuse_all_reduce_ops = True
strategy.fuse_grad_size_in_MB = 2
self.optimizer(avg_cost, strategy, train_prog, startup_prog)
main_prog_ops = train_prog.global_block().ops
main_prog_op_types = [op.type for op in main_prog_ops]
assert 'c_allreduce_sum' in main_prog_op_types
assert 'coalesce_tensor' in main_prog_op_types
for op in main_prog_ops:
if op.type == 'c_allreduce_sum':
assert 'FusedOutput' in op.input_arg_names[0]
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册