未验证 提交 6bbe92a1 编写于 作者: L LiYuRio 提交者: GitHub

Optimize initialize time by decrease the number of pp group (#53559)

* use global group to pass meta

* use batch isend irecv

* add partial send/recv

* remove communication group

* remove p2p on npu and xpu

* remove virtual pp ut
上级 8c74ffc0
......@@ -294,11 +294,6 @@ class HybridCommunicateGroup:
def _set_p2p_group(self):
comm_lists = self._topo.get_comm_list('pipe')
self.send_next_group = None
self.send_prev_group = None
self.recv_next_group = None
self.recv_prev_group = None
for comm_ranks in comm_lists:
assert len(comm_ranks) == self._pp_degree
for idx, rank in enumerate(comm_ranks):
......@@ -310,28 +305,6 @@ class HybridCommunicateGroup:
self.next_rank = next_rank
self.prev_rank = prev_rank
next_group = paddle.distributed.new_group(
ranks=[curr_rank, next_rank]
)
if self.global_rank == curr_rank:
self.send_next_group = next_group
elif self.global_rank == next_rank:
self.recv_prev_group = next_group
prev_group = paddle.distributed.new_group(
ranks=[prev_rank, curr_rank]
)
if self.global_rank == curr_rank:
self.send_prev_group = prev_group
elif self.global_rank == prev_rank:
self.recv_next_group = prev_group
assert self.send_next_group is not None
assert self.send_prev_group is not None
assert self.recv_next_group is not None
assert self.recv_prev_group is not None
def topology(self):
return self._topo
......@@ -384,12 +357,7 @@ class HybridCommunicateGroup:
return self._pp_comm_group
def get_p2p_groups(self):
return (
self.send_next_group,
self.send_prev_group,
self.recv_next_group,
self.recv_prev_group,
)
return None
# sharding parallel message:
def _get_sharding_parallel_id(self):
......
......@@ -10,7 +10,6 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import warnings
import paddle
from paddle import framework
......@@ -629,10 +628,9 @@ class PipelineParallelWithInterleave(PipelineParallel):
def __init__(self, layers, hcg, strategy):
super().__init__(layers=layers, hcg=hcg, strategy=strategy)
assert layers.get_num_virtual_stages() > 1
if self.num_stages <= 2:
warnings.warn(
"Deprecate warning! In the near future the virtual pp will only available when pp degree > 2."
)
assert (
self.num_stages > 2
), "virtual pipeline must run under pp degree > 2"
assert (
framework.in_dynamic_mode()
), "virtual pipeline stage with interleave only support eager dygraph mode"
......
......@@ -19,17 +19,20 @@ from legacy_test.test_parallel_dygraph_dataparallel import TestMultipleGpus
class TestHybridPipeParallelWithVirtualStage(TestMultipleGpus):
def test_hybrid_parallel_pp_layer_with_virtual_stage(self):
self.run_mnist_2gpu('hybrid_parallel_pp_layer_with_virtual_stage.py')
# self.run_mnist_2gpu('hybrid_parallel_pp_layer_with_virtual_stage.py')
pass
def test_hybrid_parallel_pp_transformer_with_virtual_stage(self):
self.run_mnist_2gpu(
'hybrid_parallel_pp_transformer_with_virtual_stage.py'
)
# self.run_mnist_2gpu(
# 'hybrid_parallel_pp_transformer_with_virtual_stage.py'
# )
pass
def test_hybrid_parallel_save_load_with_virtual_stage(self):
self.run_mnist_2gpu(
'hybrid_parallel_pp_save_load_with_virtual_stage.py'
)
# self.run_mnist_2gpu(
# 'hybrid_parallel_pp_save_load_with_virtual_stage.py'
# )
pass
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册