diff --git a/python/paddle/distributed/fleet/__init__.py b/python/paddle/distributed/fleet/__init__.py index 271114c09cb95f9fc104678ccfb1128098cac507..947950f38ec48cff71af55f2aa11abc9785d5691 100755 --- a/python/paddle/distributed/fleet/__init__.py +++ b/python/paddle/distributed/fleet/__init__.py @@ -104,4 +104,5 @@ get_log_level_code = log_util.get_log_level_code get_log_level_name = log_util.get_log_level_name save_cache_table = fleet.save_cache_table perf_test = fleet.perf_test +monitor_perf = fleet.monitor_perf from .. import auto_parallel as auto diff --git a/python/paddle/distributed/fleet/fleet.py b/python/paddle/distributed/fleet/fleet.py index bbe99ebda709ea2d08e9e37c37acbdf7f1731f7c..854a41e1e94f05e594533ad054c7d4c05b933520 100755 --- a/python/paddle/distributed/fleet/fleet.py +++ b/python/paddle/distributed/fleet/fleet.py @@ -382,80 +382,245 @@ class Fleet: ) return self - def perf_test(self, round=50): - # test allreduce perf - def allreduce_test(iteration, x, group): - paddle.distributed.barrier() - paddle.device.cuda.synchronize() - start_t = time.time() - for _ in range(iteration): - paddle.distributed.all_reduce(x, group=group) - paddle.device.cuda.synchronize() - end_t = time.time() - return (end_t - start_t) / iteration - - # test reduce perf - def reduce_test(iteration, x, group): - paddle.distributed.barrier() - paddle.device.cuda.synchronize() - start_t = time.time() - for _ in range(iteration): - # TODO: shuffle dst - paddle.distributed.reduce(x, dst=min(group.ranks), group=group) - paddle.device.cuda.synchronize() - end_t = time.time() - return (end_t - start_t) / iteration - - # test broadcast perf - def broadcast_test(iteration, x, group): - paddle.distributed.barrier() - paddle.device.cuda.synchronize() - start_t = time.time() - for _ in range(iteration): - # TODO: shuffle src - paddle.distributed.broadcast( - x, src=min(group.ranks), group=group - ) - paddle.device.cuda.synchronize() - end_t = time.time() - return (end_t - start_t) / iteration + # test allreduce perf + def allreduce_test( + self, + iteration, + x, + group, + allreduce_size, + allreduce_thres_time, + warmup=False, + ): + if group is None or group.nranks <= 1: + logger.warning("allreduce_test is invalid, group invalid!") + return + paddle.distributed.barrier() + paddle.device.cuda.synchronize() + start_t = time.time() + for _ in range(iteration): + paddle.distributed.all_reduce(x, group=group) + paddle.device.cuda.synchronize() + end_t = time.time() + ret = (end_t - start_t) / iteration + if warmup: + return + logger.info( + f"[AllReduceTest] nbytes {allreduce_size}B test result: {ret} s/iter" + ) + if allreduce_thres_time > -1 and ret > allreduce_thres_time: + logger.warning( + f"[Perf Warnning] AllReduce Test Timeout! {ret} > {allreduce_thres_time}" + ) + + # test reduce perf + def reduce_test(self, iteration, x, group, reduce_size, reduce_thres_time): + if group is None or group.nranks <= 1: + logger.warning("reduce_test is invalid, group invalid!") + return + paddle.distributed.barrier() + paddle.device.cuda.synchronize() + start_t = time.time() + for _ in range(iteration): + paddle.distributed.reduce(x, dst=min(group.ranks), group=group) + paddle.device.cuda.synchronize() + end_t = time.time() + ret = (end_t - start_t) / iteration + logger.info( + f"[ReduceTest] nbytes {reduce_size}B test result: {ret} s/iter" + ) + if reduce_thres_time > -1 and ret > reduce_thres_time: + logger.warning( + f"[Perf Warnning] Reduce Test Timeout! {ret} > {reduce_thres_time}" + ) + + # test broadcast perf + def broadcast_test( + self, iteration, x, group, broadcast_size, broadcast_thres_time + ): + if group is None or group.nranks <= 1: + logger.warning("broadcast_test is invalid, group invalid!") + return + paddle.distributed.barrier() + paddle.device.cuda.synchronize() + start_t = time.time() + for _ in range(iteration): + paddle.distributed.broadcast(x, src=min(group.ranks), group=group) + paddle.device.cuda.synchronize() + end_t = time.time() + ret = (end_t - start_t) / iteration + logger.info( + f"[BroadcastTest] nbytes {broadcast_size}B test result: {ret} s/iter" + ) + if broadcast_thres_time > -1 and ret > broadcast_thres_time: + logger.warning( + f"[Perf Warnning] Broadcast Test Timeout! {ret} > {broadcast_thres_time}" + ) + + # test allgather perf + def allgather_test( + self, iteration, x, group, allgather_size, allgather_thres_time + ): + if group is None or group.nranks <= 1: + logger.warning("allgather_test is invalid, group invalid!") + return + paddle.distributed.barrier() + paddle.device.cuda.synchronize() + start_t = time.time() + for _ in range(iteration): + tmp = [] + paddle.distributed.all_gather(tmp, x, group=group) + paddle.device.cuda.synchronize() + end_t = time.time() + ret = (end_t - start_t) / iteration + logger.info( + f"[AllgatherTest] nbytes {allgather_size}B test result: {ret} s/iter" + ) + if allgather_thres_time > -1 and ret > allgather_thres_time: + logger.warning( + f"[Perf Warnning] Allgather Test Timeout! {ret} > {allgather_thres_time}" + ) + + # test reduce_scatter perf + def reduce_scatter_test( + self, + iteration, + x, + group, + reduce_scatter_size, + reduce_scatter_thres_time, + ): + if group is None or group.nranks <= 1: + logger.warning("reduce_scatter_test is invalid, group invalid!") + return + paddle.distributed.barrier() + paddle.device.cuda.synchronize() + parallelism = group.nranks + output_shape = x.shape + if x.shape[0] % parallelism != 0: + logger.warning( + f"the shape of input[{x.shape[0]}] can't be divided exactly by reduce_scatter parallelism[{parallelism}], test stopped!" + ) + return + output_shape[0] = output_shape[0] // parallelism + output = paddle.empty(shape=output_shape, dtype=x.dtype) + start_t = time.time() + for _ in range(iteration): + paddle.distributed.stream.reduce_scatter( + output, + x, + op=paddle.distributed.ReduceOp.SUM, + group=group, + sync_op=True, + ) + paddle.device.cuda.synchronize() + end_t = time.time() + ret = (end_t - start_t) / iteration + logger.info( + f"[ReduceScatterTest] nbytes {reduce_scatter_size}B test result: {ret} s/iter" + ) + if reduce_scatter_thres_time > -1 and ret > reduce_scatter_thres_time: + logger.warning( + f"[Perf Warnning] ReduceScatter Test Timeout! {ret} > {reduce_scatter_thres_time}" + ) + + def perf_test(self, round=50, test_comm=[], context={}, hcg=None): + if hcg is None: + hcg = self.get_hybrid_communicate_group() - hcg = self.get_hybrid_communicate_group() dp_group = hcg.get_data_parallel_group() sharding_group = hcg.get_sharding_parallel_group() + mp_group = hcg.get_model_parallel_group() test_group = None if dp_group.nranks > 1: test_group = dp_group elif sharding_group.nranks > 1: test_group = sharding_group - else: - logger.warning( - f"hcg created with dp_degree: {dp_group.nranks} and sharding_degree: {sharding_group.nranks}, skipping perf test..." - ) - return + # test 1M ~ 1G nbytes = 1 << 20 # 1048576(1MB) final_nbytes = 1 << 30 # 1073741824(1GB) dtype = paddle.float32 + + # run once when test specific package size. + test_specific_size = False + for k, st in context.items(): + if st[0] > 0: + test_specific_size = True + break + + if test_specific_size: + test_comm = list(context.keys()) + + if len(test_comm) == 0: + return + while nbytes <= final_nbytes: x = paddle.zeros([nbytes // 4], dtype=dtype) # warmup - allreduce_test(iteration=10, x=x, group=test_group) - # test-allreduce - ret = allreduce_test(iteration=round, x=x, group=test_group) - logger.info( - f"[AllReduceTest] nbytes {nbytes}B test result: {ret} s/iter" + self.allreduce_test(10, x, test_group, nbytes, -1, warmup=True) + + allreduce_size, allreduce_thres_time = context.get( + "allreduce", [nbytes, -1] + ) + reduce_size, reduce_thres_time = context.get("reduce", [nbytes, -1]) + broadcast_size, broadcast_thres_time = context.get( + "broadcast", [nbytes, -1] ) - ret = reduce_test(iteration=round, x=x, group=test_group) - logger.info( - f"[ReduceTest] nbytes {nbytes}B test result: {ret} s/iter" + allgather_size, allgather_thres_time = context.get( + "allgather", [nbytes, -1] ) - ret = broadcast_test(iteration=round, x=x, group=test_group) - logger.info( - f"[BroadcastTest] nbytes {nbytes}B test result: {ret} s/iter" + reduce_scatter_size, reduce_scatter_thres_time = context.get( + "reduce_scatter", [nbytes, -1] ) + + # inter machines + if "allreduce" in test_comm: + x = paddle.zeros([allreduce_size // 4], dtype=dtype) + self.allreduce_test( + round, x, test_group, allreduce_size, allreduce_thres_time + ) + + if "reduce" in test_comm: + x = paddle.zeros([reduce_size // 4], dtype=dtype) + self.reduce_test( + round, x, test_group, reduce_size, reduce_thres_time + ) + + if "broadcast" in test_comm: + x = paddle.zeros([broadcast_size // 4], dtype=dtype) + self.broadcast_test( + round, x, test_group, broadcast_size, broadcast_thres_time + ) + + # intra machines + if "allgather" in test_comm: + x = paddle.zeros([allgather_size // 4], dtype=dtype) + self.allgather_test( + round, x, mp_group, allgather_size, allgather_thres_time + ) + + if "reduce_scatter" in test_comm: + x = paddle.zeros([reduce_scatter_size // 4], dtype=dtype) + self.reduce_scatter_test( + round, + x, + mp_group, + reduce_scatter_size, + reduce_scatter_thres_time, + ) + + # run once when test specific package size. + if test_specific_size: + break + nbytes = nbytes << 1 + def monitor_perf(self, comm_type, round=50, size_and_time={}, hcg=None): + for size, time_thres in size_and_time.items(): + context = {comm_type: [size, time_thres]} + self.perf_test(round=round, context=context, hcg=hcg) + def _init_hybrid_parallel_env(self): """initialize the hybrid environment""" self.hybrid_configs = self._user_defined_strategy.hybrid_configs diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_perf_test.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_perf_test.py index ddb09d72d07035a4df1de139a9bd7de3a3e9a174..57ea6d3673647e9aa82df843cf631aaba1466f7c 100644 --- a/python/paddle/fluid/tests/unittests/hybrid_parallel_perf_test.py +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_perf_test.py @@ -60,6 +60,106 @@ class TestDistDPTraning(unittest.TestCase): def test_communication_perf(self): fleet.perf_test(round=1) + # test comm type in test_comm(list), scan package from 1M to 1G + fleet.perf_test( + round=1, + test_comm=[ + "allreduce", + "reduce", + "broadcast", + "allgather", + "reduce_scatter", + ], + ) + # context: {comm_type:[size, time]} + # only test allreduce for package(1024B) and time threshold(0.00000001s), + # and test allgather for package(8192B) and time threshold(2s), + fleet.perf_test( + round=30, + test_comm=[ + "allreduce", + "reduce", + "broadcast", + "allgather", + "reduce_scatter", + ], + context={ + "allreduce": [1024, 0.00000001], + "reduce": [1024, 0.00000001], + "broadcast": [1024, 0.00000001], + "allgather": [8192, 2], + }, + ) + # test allreduce for specific size and time. + fleet.monitor_perf( + "allreduce", + round=50, + size_and_time={1024: 0.00000001, 4096: 0.01, 8192: 2}, + ) + + +class TestDistMPTraning(unittest.TestCase): + def setUp(self): + strategy = fleet.DistributedStrategy() + self.model_parallel_size = 2 + self.data_parallel_size = 1 + self.pipeline_parallel_size = 1 + strategy.hybrid_configs = { + "dp_degree": self.data_parallel_size, + "mp_degree": self.model_parallel_size, + "pp_degree": self.pipeline_parallel_size, + } + strategy.pipeline_configs = { + "accumulate_steps": batch_size // micro_batch_size, + "micro_batch_size": micro_batch_size, + } + fleet.init(is_collective=True, strategy=strategy) + from paddle.distributed.fleet.base.topology import ( + CommunicateTopology, + HybridCommunicateGroup, + ) + + topo = CommunicateTopology( + hybrid_group_names=["data", "pipe", "sharding", "model"], + dims=[1, 1, 1, 2], + ) + self.hcg = HybridCommunicateGroup(topo) + + def build_optimizer(self, model): + scheduler = paddle.optimizer.lr.PiecewiseDecay( + boundaries=[2], values=[0.001, 0.002], verbose=True + ) + optimizer = paddle.optimizer.SGD( + learning_rate=scheduler, parameters=model.parameters() + ) + return scheduler, optimizer + + def test_communication_perf(self): + # test comm type in test_comm(list), scan package from 1M to 1G + fleet.perf_test( + round=1, + test_comm=["allreduce", "allgather", "reduce_scatter"], + hcg=self.hcg, + ) + # context: {comm_type:[size, time]} + # only test reduce for package(1024B) and time threshold(1s), + # and test allgather for package(8192B) and time threshold(0.00000002s), + fleet.perf_test( + round=100000, + context={ + "reduce": [1024, 1], + "allgather": [8192, 0.00000002], + "reduce_scatter": [8192, 0.00000002], + }, + hcg=self.hcg, + ) + # test allgather for specific size and time. + fleet.monitor_perf( + "allgather", + round=50, + size_and_time={1024: 1, 4096: 0.01, 8192: 0.00000002}, + hcg=self.hcg, + ) if __name__ == "__main__":