未验证 提交 6bbe92a1 编写于 作者: L LiYuRio 提交者: GitHub

Optimize initialize time by decrease the number of pp group (#53559)

* use global group to pass meta

* use batch isend irecv

* add partial send/recv

* remove communication group

* remove p2p on npu and xpu

* remove virtual pp ut
上级 8c74ffc0
...@@ -294,11 +294,6 @@ class HybridCommunicateGroup: ...@@ -294,11 +294,6 @@ class HybridCommunicateGroup:
def _set_p2p_group(self): def _set_p2p_group(self):
comm_lists = self._topo.get_comm_list('pipe') comm_lists = self._topo.get_comm_list('pipe')
self.send_next_group = None
self.send_prev_group = None
self.recv_next_group = None
self.recv_prev_group = None
for comm_ranks in comm_lists: for comm_ranks in comm_lists:
assert len(comm_ranks) == self._pp_degree assert len(comm_ranks) == self._pp_degree
for idx, rank in enumerate(comm_ranks): for idx, rank in enumerate(comm_ranks):
...@@ -310,28 +305,6 @@ class HybridCommunicateGroup: ...@@ -310,28 +305,6 @@ class HybridCommunicateGroup:
self.next_rank = next_rank self.next_rank = next_rank
self.prev_rank = prev_rank self.prev_rank = prev_rank
next_group = paddle.distributed.new_group(
ranks=[curr_rank, next_rank]
)
if self.global_rank == curr_rank:
self.send_next_group = next_group
elif self.global_rank == next_rank:
self.recv_prev_group = next_group
prev_group = paddle.distributed.new_group(
ranks=[prev_rank, curr_rank]
)
if self.global_rank == curr_rank:
self.send_prev_group = prev_group
elif self.global_rank == prev_rank:
self.recv_next_group = prev_group
assert self.send_next_group is not None
assert self.send_prev_group is not None
assert self.recv_next_group is not None
assert self.recv_prev_group is not None
def topology(self): def topology(self):
return self._topo return self._topo
...@@ -384,12 +357,7 @@ class HybridCommunicateGroup: ...@@ -384,12 +357,7 @@ class HybridCommunicateGroup:
return self._pp_comm_group return self._pp_comm_group
def get_p2p_groups(self): def get_p2p_groups(self):
return ( return None
self.send_next_group,
self.send_prev_group,
self.recv_next_group,
self.recv_prev_group,
)
# sharding parallel message: # sharding parallel message:
def _get_sharding_parallel_id(self): def _get_sharding_parallel_id(self):
......
...@@ -10,7 +10,6 @@ ...@@ -10,7 +10,6 @@
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
import warnings
import paddle import paddle
from paddle import framework from paddle import framework
...@@ -629,10 +628,9 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -629,10 +628,9 @@ class PipelineParallelWithInterleave(PipelineParallel):
def __init__(self, layers, hcg, strategy): def __init__(self, layers, hcg, strategy):
super().__init__(layers=layers, hcg=hcg, strategy=strategy) super().__init__(layers=layers, hcg=hcg, strategy=strategy)
assert layers.get_num_virtual_stages() > 1 assert layers.get_num_virtual_stages() > 1
if self.num_stages <= 2: assert (
warnings.warn( self.num_stages > 2
"Deprecate warning! In the near future the virtual pp will only available when pp degree > 2." ), "virtual pipeline must run under pp degree > 2"
)
assert ( assert (
framework.in_dynamic_mode() framework.in_dynamic_mode()
), "virtual pipeline stage with interleave only support eager dygraph mode" ), "virtual pipeline stage with interleave only support eager dygraph mode"
......
...@@ -19,17 +19,20 @@ from legacy_test.test_parallel_dygraph_dataparallel import TestMultipleGpus ...@@ -19,17 +19,20 @@ from legacy_test.test_parallel_dygraph_dataparallel import TestMultipleGpus
class TestHybridPipeParallelWithVirtualStage(TestMultipleGpus): class TestHybridPipeParallelWithVirtualStage(TestMultipleGpus):
def test_hybrid_parallel_pp_layer_with_virtual_stage(self): def test_hybrid_parallel_pp_layer_with_virtual_stage(self):
self.run_mnist_2gpu('hybrid_parallel_pp_layer_with_virtual_stage.py') # self.run_mnist_2gpu('hybrid_parallel_pp_layer_with_virtual_stage.py')
pass
def test_hybrid_parallel_pp_transformer_with_virtual_stage(self): def test_hybrid_parallel_pp_transformer_with_virtual_stage(self):
self.run_mnist_2gpu( # self.run_mnist_2gpu(
'hybrid_parallel_pp_transformer_with_virtual_stage.py' # 'hybrid_parallel_pp_transformer_with_virtual_stage.py'
) # )
pass
def test_hybrid_parallel_save_load_with_virtual_stage(self): def test_hybrid_parallel_save_load_with_virtual_stage(self):
self.run_mnist_2gpu( # self.run_mnist_2gpu(
'hybrid_parallel_pp_save_load_with_virtual_stage.py' # 'hybrid_parallel_pp_save_load_with_virtual_stage.py'
) # )
pass
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册