diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index a83ae226a9df1eeec1239881028893278412c44c..2c4ad33c361e01abcca66c33008826a028f8c354 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -134,8 +134,17 @@ class ShardingOptimizer(MetaOptimizerBase): self.pp_degree, self.dp_degree, ) - self.hybrid_dp = self.user_defined_strategy.sharding_configs[ - "hybrid_dp"] + # FIXME (JZ-LIANG) deprecated hybrid_dp + if self.user_defined_strategy.sharding_configs["hybrid_dp"]: + logging.warning( + "[hybrid_dp] API setting is deprecated. Now when dp_degree >= 2, its will be in hybrid dp mode automatically" + ) + assert self.dp_degree >= 1 + if self.dp_degree > 1: + self.hybrid_dp = True + else: + self.hybrid_dp = False + # NOTE (JZ-LIANG) # there 2 kind of modes for gradient-merge and hybrid-dp in mixed parallism [sharding] and [pipeline]. # we distinguish this two modes since the gm/hybrid-dp related allreduce should be insert in different place according different mode to have best performance: diff --git a/python/paddle/fluid/tests/unittests/fleet_meta_optimizer_base.py b/python/paddle/fluid/tests/unittests/fleet_meta_optimizer_base.py index 549975f5d3f0f41c4959e4921b736209bd7c3757..730fa4ca60d31ea03428e44cd40096ea9fd16bb4 100755 --- a/python/paddle/fluid/tests/unittests/fleet_meta_optimizer_base.py +++ b/python/paddle/fluid/tests/unittests/fleet_meta_optimizer_base.py @@ -50,6 +50,38 @@ class TestFleetMetaOptimizer(unittest.TestCase): strategy = paddle.distributed.fleet.DistributedStrategy() return avg_cost, strategy + def pp_net(self, main_prog, startup_prog, pp_degree=2): + def fc_block(input_x): + fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh') + fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh') + fc_3 = paddle.fluid.layers.fc(input=fc_2, size=64, act='tanh') + return fc_3 + + with fluid.program_guard(main_prog, startup_prog): + with fluid.unique_name.guard(): + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + with fluid.device_guard("gpu:0"): + input_x = paddle.fluid.layers.data( + name="x", shape=[32], dtype='float32') + input_y = paddle.fluid.layers.data( + name="y", shape=[1], dtype='int64') + + for stage_idx in range(pp_degree): + with fluid.device_guard("gpu:" + str(stage_idx)): + input_x = fc_block(input_x) + + with fluid.device_guard("gpu:" + str(pp_degree - 1)): + prediction = paddle.fluid.layers.fc(input=[input_x], + size=2, + act='softmax') + cost = paddle.fluid.layers.cross_entropy( + input=prediction, label=input_y) + avg_cost = paddle.fluid.layers.mean(x=cost) + + strategy = paddle.distributed.fleet.DistributedStrategy() + return avg_cost, strategy + def optimizer(self, loss, strategy, diff --git a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py index f28bf89ff5c30b86e7af7be1a0dd7d79416d6c98..4d1e936558abf726d0a221f24821c79be68398b9 100755 --- a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py @@ -298,6 +298,13 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer): os.environ[ "PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001,127.0.0.1:36002,127.0.0.1:36003,127.0.0.1:36004" + # pre-assigned ring id + self.mp_ring_id = 0 + self.sharding_ring_id = 1 + self.dp_ring_id = 2 + self.global_ring_id = 3 + self.pp_ring_id = 20 + def test_sharding_with_mp(self): # NOTE(JZ-LIANG) MP parallelism need user to build model with MP API train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program( @@ -323,7 +330,7 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer): op.desc.attr("ring_id") for op in startup_prog_ops if op.type == "c_comm_init" ] - self.assertIn(0, created_ring_ids) + self.assertIn(self.mp_ring_id, created_ring_ids) # check correctness of MP group sharding_group_waiting_port = None @@ -368,7 +375,7 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer): op.desc.attr("ring_id") for op in startup_prog_ops if op.type == "c_comm_init" ] - self.assertIn(2, created_ring_ids) + self.assertIn(self.dp_ring_id, created_ring_ids) # check correctness of sharding group sharding_group_waiting_port = None @@ -437,7 +444,7 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer): op.desc.attr("ring_id") for op in startup_prog_ops if op.type == "c_comm_init" ] - self.assertIn(2, created_ring_ids) + self.assertIn(self.dp_ring_id, created_ring_ids) # check correctness of sharding group sharding_group_waiting_port = None @@ -460,56 +467,19 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer): fw_bw_ops = [op.type for op in train_prog.blocks[0].ops] opt_ops = [op.type for op in train_prog.blocks[2].ops] self.assertEqual(fw_bw_ops, [ - 'fill_constant', - 'fill_constant', - 'fill_constant', - 'c_sync_calc_stream', - 'c_broadcast', - 'c_broadcast', - 'c_broadcast', - 'c_broadcast', - 'c_broadcast', - 'c_broadcast', - 'c_sync_comm_stream', - 'mul', - 'elementwise_add', - 'tanh', - 'mul', - 'elementwise_add', - 'tanh', - 'mul', - 'elementwise_add', - 'softmax', - 'cross_entropy2', - 'mean', - 'fill_constant', - 'scale', - 'mean_grad', - 'cross_entropy_grad2', - 'softmax_grad', - 'elementwise_add_grad', - 'mul_grad', - 'tanh_grad', - 'elementwise_add_grad', - 'mul_grad', - 'tanh_grad', - 'elementwise_add_grad', - 'mul_grad', - 'c_sync_calc_stream', - 'c_reduce_sum', - 'c_reduce_sum', - 'c_reduce_sum', - 'c_reduce_sum', - 'c_reduce_sum', - 'c_reduce_sum', - 'c_sync_comm_stream', - 'elementwise_add', - 'elementwise_add', - 'elementwise_add', - 'increment', - 'elementwise_mod', - 'equal', - 'conditional_block', + 'fill_constant', 'fill_constant', 'fill_constant', + 'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast', + 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream', + 'mul', 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh', + 'mul', 'elementwise_add', 'softmax', 'cross_entropy2', 'mean', + 'fill_constant', 'scale', 'mean_grad', 'cross_entropy_grad2', + 'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad', + 'elementwise_add_grad', 'mul_grad', 'tanh_grad', + 'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', + 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', + 'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream', + 'elementwise_add', 'elementwise_add', 'elementwise_add', + 'increment', 'elementwise_mod', 'equal', 'conditional_block' ]) self.assertEqual(opt_ops, [ 'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', 'scale', @@ -524,6 +494,93 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer): scale_ = float(op.desc.attr("scale")) self.assertEqual(scale_, 0.25) + def test_sharding_with_pp(self): + train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program( + ) + avg_cost, strategy = self.pp_net(train_prog, startup_prog) + strategy.sharding = True + strategy.sharding_configs = { + "sharding_segment_strategy": "segment_broadcast_MB", + "segment_broadcast_MB": 0.1, + "sharding_degree": 2, + "hybrid_dp": False, + "gradient_merge_acc_step": 4, + "mp_degree": 1, + "pp_degree": 2 + } + strategy.pipeline = True + strategy.pipeline_configs = { + "schedule_mode": "1F1B", + "micro_batch_size": 2, + "accumulate_steps": 4, + } + self.optimizer(avg_cost, strategy, train_prog, startup_prog) + train_prog = train_prog._pipeline_opt['section_program'] + startup_prog = startup_prog._pipeline_opt['startup_program'] + + startup_prog_ops = startup_prog.global_block().ops + main_prog_ops = train_prog.global_block().ops + + # check program + startup_prog_op_types = [op.type for op in startup_prog_ops] + main_prog_op_types = [op.type for op in main_prog_ops] + self.assertEqual(startup_prog_op_types, [ + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'fill_constant', 'uniform_random', + 'fill_constant', 'uniform_random', 'fill_constant', 'c_gen_nccl_id', + 'c_comm_init', 'fill_constant', 'c_allreduce_sum', 'c_gen_nccl_id', + 'c_comm_init', 'fill_constant', 'c_allreduce_sum', 'c_gen_nccl_id', + 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init' + ]) + + self.assertEqual(main_prog_op_types, [ + 'fill_constant', 'fill_constant', 'fill_constant', + 'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast', + 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', + 'c_broadcast', 'c_sync_comm_stream', 'recv_v2', 'mul', + 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh', 'mul', + 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'softmax', + 'cross_entropy2', 'mean', 'fill_constant', 'scale', 'scale', + 'mean_grad', 'cross_entropy_grad2', 'softmax_grad', + 'elementwise_add_grad', 'mul_grad', 'tanh_grad', + 'elementwise_add_grad', 'mul_grad', 'tanh_grad', + 'elementwise_add_grad', 'mul_grad', 'tanh_grad', + 'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', 'send_v2', + 'c_sync_calc_stream', 'c_reduce_sum', 'c_reduce_sum', + 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', + 'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream', + 'c_sync_comm_stream', 'fill_constant', 'sum', 'fill_constant', + 'sum', 'fill_constant', 'sum', 'fill_constant', 'sum', + 'fill_constant', 'sum', 'momentum', 'momentum', 'momentum', + 'momentum', 'momentum' + ]) + + # should has ring id for pp + created_ring_ids = [ + op.desc.attr("ring_id") for op in startup_prog_ops + if op.type == "c_comm_init" + ] + self.assertIn(self.sharding_ring_id, created_ring_ids) + self.assertIn(self.pp_ring_id, created_ring_ids) + + # check correctness of pp group + sharding_group_waiting_port = None + for op in startup_prog_ops: + if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[ + 0] == "nccl_id_1": + sharding_group_waiting_ports = op.desc.attr("other_endpoints") + + self.assertEqual(sharding_group_waiting_ports, ['127.0.0.1:36003']) + + # check correctness of sharding group + sharding_group_waiting_port = None + for op in startup_prog_ops: + if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[ + 0] == "nccl_id_2": + dp_group_waiting_ports = op.desc.attr("other_endpoints") + + self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002']) + if __name__ == "__main__": unittest.main()