未验证 提交 54344964 编写于 作者: J JZ-LIANG 提交者: GitHub

4D Hybrid Parallelism (#32134)

上级 6e65fe02
......@@ -134,8 +134,17 @@ class ShardingOptimizer(MetaOptimizerBase):
self.pp_degree,
self.dp_degree, )
self.hybrid_dp = self.user_defined_strategy.sharding_configs[
"hybrid_dp"]
# FIXME (JZ-LIANG) deprecated hybrid_dp
if self.user_defined_strategy.sharding_configs["hybrid_dp"]:
logging.warning(
"[hybrid_dp] API setting is deprecated. Now when dp_degree >= 2, its will be in hybrid dp mode automatically"
)
assert self.dp_degree >= 1
if self.dp_degree > 1:
self.hybrid_dp = True
else:
self.hybrid_dp = False
# NOTE (JZ-LIANG)
# there 2 kind of modes for gradient-merge and hybrid-dp in mixed parallism [sharding] and [pipeline].
# we distinguish this two modes since the gm/hybrid-dp related allreduce should be insert in different place according different mode to have best performance:
......
......@@ -50,6 +50,38 @@ class TestFleetMetaOptimizer(unittest.TestCase):
strategy = paddle.distributed.fleet.DistributedStrategy()
return avg_cost, strategy
def pp_net(self, main_prog, startup_prog, pp_degree=2):
def fc_block(input_x):
fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh')
fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh')
fc_3 = paddle.fluid.layers.fc(input=fc_2, size=64, act='tanh')
return fc_3
with fluid.program_guard(main_prog, startup_prog):
with fluid.unique_name.guard():
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role)
with fluid.device_guard("gpu:0"):
input_x = paddle.fluid.layers.data(
name="x", shape=[32], dtype='float32')
input_y = paddle.fluid.layers.data(
name="y", shape=[1], dtype='int64')
for stage_idx in range(pp_degree):
with fluid.device_guard("gpu:" + str(stage_idx)):
input_x = fc_block(input_x)
with fluid.device_guard("gpu:" + str(pp_degree - 1)):
prediction = paddle.fluid.layers.fc(input=[input_x],
size=2,
act='softmax')
cost = paddle.fluid.layers.cross_entropy(
input=prediction, label=input_y)
avg_cost = paddle.fluid.layers.mean(x=cost)
strategy = paddle.distributed.fleet.DistributedStrategy()
return avg_cost, strategy
def optimizer(self,
loss,
strategy,
......
......@@ -298,6 +298,13 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer):
os.environ[
"PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001,127.0.0.1:36002,127.0.0.1:36003,127.0.0.1:36004"
# pre-assigned ring id
self.mp_ring_id = 0
self.sharding_ring_id = 1
self.dp_ring_id = 2
self.global_ring_id = 3
self.pp_ring_id = 20
def test_sharding_with_mp(self):
# NOTE(JZ-LIANG) MP parallelism need user to build model with MP API
train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program(
......@@ -323,7 +330,7 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer):
op.desc.attr("ring_id") for op in startup_prog_ops
if op.type == "c_comm_init"
]
self.assertIn(0, created_ring_ids)
self.assertIn(self.mp_ring_id, created_ring_ids)
# check correctness of MP group
sharding_group_waiting_port = None
......@@ -368,7 +375,7 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer):
op.desc.attr("ring_id") for op in startup_prog_ops
if op.type == "c_comm_init"
]
self.assertIn(2, created_ring_ids)
self.assertIn(self.dp_ring_id, created_ring_ids)
# check correctness of sharding group
sharding_group_waiting_port = None
......@@ -437,7 +444,7 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer):
op.desc.attr("ring_id") for op in startup_prog_ops
if op.type == "c_comm_init"
]
self.assertIn(2, created_ring_ids)
self.assertIn(self.dp_ring_id, created_ring_ids)
# check correctness of sharding group
sharding_group_waiting_port = None
......@@ -460,56 +467,19 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer):
fw_bw_ops = [op.type for op in train_prog.blocks[0].ops]
opt_ops = [op.type for op in train_prog.blocks[2].ops]
self.assertEqual(fw_bw_ops, [
'fill_constant',
'fill_constant',
'fill_constant',
'c_sync_calc_stream',
'c_broadcast',
'c_broadcast',
'c_broadcast',
'c_broadcast',
'c_broadcast',
'c_broadcast',
'c_sync_comm_stream',
'mul',
'elementwise_add',
'tanh',
'mul',
'elementwise_add',
'tanh',
'mul',
'elementwise_add',
'softmax',
'cross_entropy2',
'mean',
'fill_constant',
'scale',
'mean_grad',
'cross_entropy_grad2',
'softmax_grad',
'elementwise_add_grad',
'mul_grad',
'tanh_grad',
'elementwise_add_grad',
'mul_grad',
'tanh_grad',
'elementwise_add_grad',
'mul_grad',
'c_sync_calc_stream',
'c_reduce_sum',
'c_reduce_sum',
'c_reduce_sum',
'c_reduce_sum',
'c_reduce_sum',
'c_reduce_sum',
'c_sync_comm_stream',
'elementwise_add',
'elementwise_add',
'elementwise_add',
'increment',
'elementwise_mod',
'equal',
'conditional_block',
'fill_constant', 'fill_constant', 'fill_constant',
'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream',
'mul', 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh',
'mul', 'elementwise_add', 'softmax', 'cross_entropy2', 'mean',
'fill_constant', 'scale', 'mean_grad', 'cross_entropy_grad2',
'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream',
'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum',
'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream',
'elementwise_add', 'elementwise_add', 'elementwise_add',
'increment', 'elementwise_mod', 'equal', 'conditional_block'
])
self.assertEqual(opt_ops, [
'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', 'scale',
......@@ -524,6 +494,93 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer):
scale_ = float(op.desc.attr("scale"))
self.assertEqual(scale_, 0.25)
def test_sharding_with_pp(self):
train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program(
)
avg_cost, strategy = self.pp_net(train_prog, startup_prog)
strategy.sharding = True
strategy.sharding_configs = {
"sharding_segment_strategy": "segment_broadcast_MB",
"segment_broadcast_MB": 0.1,
"sharding_degree": 2,
"hybrid_dp": False,
"gradient_merge_acc_step": 4,
"mp_degree": 1,
"pp_degree": 2
}
strategy.pipeline = True
strategy.pipeline_configs = {
"schedule_mode": "1F1B",
"micro_batch_size": 2,
"accumulate_steps": 4,
}
self.optimizer(avg_cost, strategy, train_prog, startup_prog)
train_prog = train_prog._pipeline_opt['section_program']
startup_prog = startup_prog._pipeline_opt['startup_program']
startup_prog_ops = startup_prog.global_block().ops
main_prog_ops = train_prog.global_block().ops
# check program
startup_prog_op_types = [op.type for op in startup_prog_ops]
main_prog_op_types = [op.type for op in main_prog_ops]
self.assertEqual(startup_prog_op_types, [
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'uniform_random',
'fill_constant', 'uniform_random', 'fill_constant', 'c_gen_nccl_id',
'c_comm_init', 'fill_constant', 'c_allreduce_sum', 'c_gen_nccl_id',
'c_comm_init', 'fill_constant', 'c_allreduce_sum', 'c_gen_nccl_id',
'c_comm_init', 'c_gen_nccl_id', 'c_comm_init'
])
self.assertEqual(main_prog_op_types, [
'fill_constant', 'fill_constant', 'fill_constant',
'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_sync_comm_stream', 'recv_v2', 'mul',
'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh', 'mul',
'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'softmax',
'cross_entropy2', 'mean', 'fill_constant', 'scale', 'scale',
'mean_grad', 'cross_entropy_grad2', 'softmax_grad',
'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', 'send_v2',
'c_sync_calc_stream', 'c_reduce_sum', 'c_reduce_sum',
'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum',
'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream',
'c_sync_comm_stream', 'fill_constant', 'sum', 'fill_constant',
'sum', 'fill_constant', 'sum', 'fill_constant', 'sum',
'fill_constant', 'sum', 'momentum', 'momentum', 'momentum',
'momentum', 'momentum'
])
# should has ring id for pp
created_ring_ids = [
op.desc.attr("ring_id") for op in startup_prog_ops
if op.type == "c_comm_init"
]
self.assertIn(self.sharding_ring_id, created_ring_ids)
self.assertIn(self.pp_ring_id, created_ring_ids)
# check correctness of pp group
sharding_group_waiting_port = None
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "nccl_id_1":
sharding_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(sharding_group_waiting_ports, ['127.0.0.1:36003'])
# check correctness of sharding group
sharding_group_waiting_port = None
for op in startup_prog_ops:
if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[
0] == "nccl_id_2":
dp_group_waiting_ports = op.desc.attr("other_endpoints")
self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002'])
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册