未验证 提交 51c414b6 编写于 作者: T Tian 提交者: GitHub

add perf test api to fleet (#54856)

上级 21fa0346
......@@ -103,4 +103,5 @@ set_log_level = log_util.set_log_level
get_log_level_code = log_util.get_log_level_code
get_log_level_name = log_util.get_log_level_name
save_cache_table = fleet.save_cache_table
perf_test = fleet.perf_test
from .. import auto_parallel as auto
......@@ -14,6 +14,7 @@
import copy
import os
import time
import paddle
from paddle.fluid import compiler
......@@ -192,7 +193,6 @@ class Fleet:
log_level (Integer, String, optional): A ``Integer`` or ``String`` Variable determining how hight
the logging level is. Default is "INFO".
Returns:
None
......@@ -382,6 +382,80 @@ class Fleet:
)
return self
def perf_test(self, round=50):
# test allreduce perf
def allreduce_test(iteration, x, group):
paddle.distributed.barrier()
paddle.device.cuda.synchronize()
start_t = time.time()
for _ in range(iteration):
paddle.distributed.all_reduce(x, group=group)
paddle.device.cuda.synchronize()
end_t = time.time()
return (end_t - start_t) / iteration
# test reduce perf
def reduce_test(iteration, x, group):
paddle.distributed.barrier()
paddle.device.cuda.synchronize()
start_t = time.time()
for _ in range(iteration):
# TODO: shuffle dst
paddle.distributed.reduce(x, dst=min(group.ranks), group=group)
paddle.device.cuda.synchronize()
end_t = time.time()
return (end_t - start_t) / iteration
# test broadcast perf
def broadcast_test(iteration, x, group):
paddle.distributed.barrier()
paddle.device.cuda.synchronize()
start_t = time.time()
for _ in range(iteration):
# TODO: shuffle src
paddle.distributed.broadcast(
x, src=min(group.ranks), group=group
)
paddle.device.cuda.synchronize()
end_t = time.time()
return (end_t - start_t) / iteration
hcg = self.get_hybrid_communicate_group()
dp_group = hcg.get_data_parallel_group()
sharding_group = hcg.get_sharding_parallel_group()
test_group = None
if dp_group.nranks > 1:
test_group = dp_group
elif sharding_group.nranks > 1:
test_group = sharding_group
else:
logger.warning(
f"hcg created with dp_degree: {dp_group.nranks} and sharding_degree: {sharding_group.nranks}, skipping perf test..."
)
return
# test 1M ~ 1G
nbytes = 1 << 20 # 1048576(1MB)
final_nbytes = 1 << 30 # 1073741824(1GB)
dtype = paddle.float32
while nbytes <= final_nbytes:
x = paddle.zeros([nbytes // 4], dtype=dtype)
# warmup
allreduce_test(iteration=10, x=x, group=test_group)
# test-allreduce
ret = allreduce_test(iteration=round, x=x, group=test_group)
logger.info(
f"[AllReduceTest] nbytes {nbytes}B test result: {ret} s/iter"
)
ret = reduce_test(iteration=round, x=x, group=test_group)
logger.info(
f"[ReduceTest] nbytes {nbytes}B test result: {ret} s/iter"
)
ret = broadcast_test(iteration=round, x=x, group=test_group)
logger.info(
f"[BroadcastTest] nbytes {nbytes}B test result: {ret} s/iter"
)
nbytes = nbytes << 1
def _init_hybrid_parallel_env(self):
"""initialize the hybrid environment"""
self.hybrid_configs = self._user_defined_strategy.hybrid_configs
......
......@@ -87,6 +87,7 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32)
list(REMOVE_ITEM TEST_OPS test_c_comm_init_all_op)
list(REMOVE_ITEM TEST_OPS test_c_embedding_op)
list(REMOVE_ITEM TEST_OPS test_pipeline_parallel)
list(REMOVE_ITEM TEST_OPS test_fleet_perf_test)
list(REMOVE_ITEM TEST_OPS test_memcpy_op)
list(REMOVE_ITEM TEST_OPS test_raw_program_optimizer)
list(REMOVE_ITEM TEST_OPS test_fleet_gradient_scale)
......@@ -1062,11 +1063,13 @@ if((WITH_ROCM OR WITH_GPU) AND NOT WIN32)
PROPERTIES TIMEOUT 120)
set_tests_properties(test_pipeline_parallel PROPERTIES LABELS
"RUN_TYPE=DIST")
set_tests_properties(test_fleet_perf_test PROPERTIES LABELS "RUN_TYPE=DIST")
set_tests_properties(test_reducescatter PROPERTIES TIMEOUT 120)
set_tests_properties(test_allgather PROPERTIES TIMEOUT 120)
endif()
set_tests_properties(test_paddle_multiprocessing PROPERTIES TIMEOUT 120)
set_tests_properties(test_pipeline_parallel PROPERTIES TIMEOUT 120)
set_tests_properties(test_fleet_perf_test PROPERTIES TIMEOUT 120)
endif()
if(WITH_GPU OR WITH_ROCM)
set_tests_properties(test_rank_attention_op PROPERTIES TIMEOUT 120)
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import unittest
import numpy as np
import paddle
from paddle.distributed import fleet
def set_random_seed(seed, dp_id, rank_id):
"""Set random seed for reproducability."""
random.seed(seed)
np.random.seed(seed + dp_id)
paddle.seed(seed + dp_id)
batch_size = 4
micro_batch_size = 2
class TestDistDPTraning(unittest.TestCase):
def setUp(self):
strategy = fleet.DistributedStrategy()
self.model_parallel_size = 1
self.data_parallel_size = 2
self.pipeline_parallel_size = 1
strategy.hybrid_configs = {
"dp_degree": self.data_parallel_size,
"mp_degree": self.model_parallel_size,
"pp_degree": self.pipeline_parallel_size,
}
strategy.pipeline_configs = {
"accumulate_steps": batch_size // micro_batch_size,
"micro_batch_size": micro_batch_size,
}
fleet.init(is_collective=True, strategy=strategy)
def build_optimizer(self, model):
scheduler = paddle.optimizer.lr.PiecewiseDecay(
boundaries=[2], values=[0.001, 0.002], verbose=True
)
optimizer = paddle.optimizer.SGD(
learning_rate=scheduler, parameters=model.parameters()
)
return scheduler, optimizer
def test_communication_perf(self):
fleet.perf_test(round=1)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from test_parallel_dygraph_dataparallel import TestMultipleGpus
class TestFleetPerfTest(TestMultipleGpus):
def test_fleet_perf_test(self):
self.run_mnist_2gpu('hybrid_parallel_perf_test.py')
if __name__ == "__main__":
unittest.main()
......@@ -538,6 +538,7 @@ HIGH_PARALLEL_JOB_NEW = [
'test_dist_fleet_ps3',
'test_dist_mnist_pg',
'test_pipeline_parallel',
'test_fleet_perf_test',
'test_dist_fleet_ps5',
'test_dist_fleet_sparse_embedding_ctr',
'test_collective_broadcast_api',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册