未验证 提交 48faf638 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Add the graph class for the process and cluster (#37482)

* [Auto Parallel]  Add the unified cluster representation

* [Auto Parallel] Add the graph class for physical mapping

* [Auto Parallel] Add the simple physical mapper

* Set the timeout of the mapper

* Merge the upstream develop unittests cmake files

* Fix a bug of the process group

* Remove mapper unittest from platforms which is not GPU

* Move the instantiation of process group after resharding

* Add the local id for devices

* Update the rank mapping format

* Add some comments

* Remove the related files about mapping

* Remove unused rank_mapping unittest

* Improve the unittest coverage
上级 e7bda1dd
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
class Node:
def __init__(self, id, **attrs):
# Each node must has a unique id
self._id = id
# Attributes for Node
self._attrs = {}
self._attrs.update(attrs)
@property
def id(self):
return self._id
@property
def attrs(self):
return self._attrs
def __getitem__(self, attr_name):
return self._attrs[attr_name]
def __setitem__(self, attr_name, attr_value):
self._attrs[attr_name] = attr_value
def __contains__(self, attr_name):
try:
return attr_name in self._attrs
except TypeError:
return False
def __str__(self):
str = "(id: {}, attrs: {})".format(self.id, self.attrs)
return str
class Edge:
def __init__(self, src_id, tgt_id, **attrs):
# The id of source node in an Edge
self._src_id = src_id
# The id of target node in an Edge
self._tgt_id = tgt_id
# Attributes for Edge
self._attrs = {}
self._attrs.update(attrs)
@property
def src_id(self):
return self._src_id
@property
def tgt_id(self):
return self._tgt_id
@property
def attrs(self):
return self._attrs
def __getitem__(self, attr_name):
return self._attrs[attr_name]
def __setitem__(self, attr_name, attr_value):
self._attrs[attr_name] = attr_value
def __contains__(self, attr_name):
try:
return attr_name in self._attrs
except TypeError:
return False
def __str__(self):
str = ""
str += "(src_id: {}, tgt_id: {}, attrs: {})".format(
self.src_id, self.tgt_id, self._attrs)
return str
class Graph:
def __init__(self, **attrs):
# _nodes is dict for storing the nodes of the graph.
# The key of this dict is the node id.
self._nodes = {}
# _adjs is a dict of dict for storing the adjacency of the graph.
# The key of the outer dict is the node id of the source node and
# the key of the inner dict is the node id of the target node.
self._adjs = {}
# Attributes for Graph
self._attrs = {}
self._attrs.update(attrs)
@property
def nodes(self):
return self._nodes
@property
def attrs(self):
return self._attrs
@property
def adjs(self):
return self._adjs
def add_node(self, node_id, **attrs):
if node_id is None:
raise ValueError("None cannot be a node")
if node_id not in self._nodes:
node = Node(node_id, **attrs)
self._nodes[node_id] = node
self._adjs[node_id] = {}
else:
self._nodes[node_id].attrs.update(attrs)
def add_edge(self, src_id, tgt_id, **attrs):
# add nodes
if src_id is None:
raise ValueError("None cannot be a node")
if tgt_id is None:
raise ValueError("None cannot be a node")
if src_id not in self._nodes:
src_node = Node(src_id)
self._nodes[src_id] = src_node
self._adjs[src_id] = {}
if tgt_id not in self._nodes:
tgt_node = Node(tgt_id)
self._nodes[tgt_id] = tgt_node
self._adjs[tgt_id] = {}
# add the edge
edge = Edge(src_id, tgt_id, **attrs)
self._adjs[src_id][tgt_id] = edge
def __len__(self):
return len(self._nodes)
def __iter__(self):
return iter(self._nodes.values())
def __getitem__(self, node_id):
# Return the adjacency of a node
return self._adjs[node_id]
def __contains__(self, node_id):
# Check whether a node in the graph
try:
return node_id in self._nodes
except TypeError:
return False
def __str__(self):
str = ""
str += "**************Nodes**************\n"
for node_id in self.nodes:
str += "{}\n".format(self.nodes[node_id])
str += "**************Edges**************\n"
for src_id in self.adjs:
str += "--------------{}--------------\n".format(src_id)
for idx, tgt_id in enumerate(self.adjs[src_id]):
str += "{}\n".format(self.adjs[src_id][tgt_id])
return str
...@@ -106,7 +106,7 @@ class AutoParallelizer: ...@@ -106,7 +106,7 @@ class AutoParallelizer:
# instantiate communication by process_mapping. # instantiate communication by process_mapping.
all_process_groups = get_all_process_groups() all_process_groups = get_all_process_groups()
for process_group in all_process_groups: for process_group in all_process_groups:
if rank not in process_group._ranks: if rank not in process_group.ranks:
continue continue
process_group.instantiate() process_group.instantiate()
......
...@@ -19,6 +19,8 @@ from ..collective import _new_ring_id ...@@ -19,6 +19,8 @@ from ..collective import _new_ring_id
from ...fluid.framework import in_dygraph_mode from ...fluid.framework import in_dygraph_mode
from ...fluid.layers.tensor import fill_constant from ...fluid.layers.tensor import fill_constant
# Note that Process group 0 is reserved for representing all ranks.
# At the begining, group 0 is empty and new ranks will be added automatically.
_g_process_group_map = {} _g_process_group_map = {}
...@@ -27,25 +29,27 @@ def get_all_process_groups(): ...@@ -27,25 +29,27 @@ def get_all_process_groups():
return _g_process_group_map.values() return _g_process_group_map.values()
def get_process_group(group_id):
global _g_process_group_map
return _g_process_group_map.get(group_id, None)
def new_process_group(ranks): def new_process_group(ranks):
global _g_process_group_map global _g_process_group_map
if not _g_process_group_map: # A key constructed from ranks is used for avoiding duplication
genv = _get_global_env() new_key = ''.join(map(str, sorted(ranks)))
_g_process_group_map["global_group"] = ProcessGroup( for pg_id, pg in _g_process_group_map.items():
0, list(range(genv.world_size))) cur_key = ''.join(map(str, sorted(pg.ranks)))
# A key constructed from ranks is used in the global process group map if pg_id != 0 and new_key == cur_key:
key = ''.join(map(str, sorted(ranks))) return pg
if key not in _g_process_group_map: # If not matching the existing one, construt a new process group
num_groups = len(_g_process_group_map) num_groups = len(_g_process_group_map)
# Note: our process group may interfere with the original implementation # Note: our process group may interfere with the original implementation
# so the created group id should start from the original _new_ring_id() # so the created group id should start from the original _new_ring_id()
group_id = _new_ring_id() + num_groups + 1 group_id = _new_ring_id() + num_groups + 1
pg = ProcessGroup(group_id, ranks) new_pg = ProcessGroup(group_id, ranks)
_g_process_group_map[key] = pg _g_process_group_map[group_id] = new_pg
return pg return new_pg
else:
pg = _g_process_group_map[key]
return pg
# This implementation refers to lots of Paddle/python/paddle/distributed/collective.py, # This implementation refers to lots of Paddle/python/paddle/distributed/collective.py,
...@@ -56,22 +60,40 @@ def new_process_group(ranks): ...@@ -56,22 +60,40 @@ def new_process_group(ranks):
# handle the communication implementation choice. # handle the communication implementation choice.
class ProcessGroup: class ProcessGroup:
def __init__(self, group_id, ranks): def __init__(self, group_id, ranks):
if group_id == 0 and get_process_group(0) is not None:
assert group_id != 0, "Process group id 0 is reserved for all ranks."
self._group_id = group_id self._group_id = group_id
self._ranks = sorted(ranks) self._ranks = sorted(ranks)
self._nranks = len(self._ranks) # Add the current ranks into group 0
if group_id != 0:
global _g_process_group_map
_g_process_group_map[0].add_ranks(ranks)
self._is_instantiate = False self._is_instantiate = False
@property @property
def id(self): def id(self):
return self._group_id return self._group_id
# @property @property
# def key(self): def ranks(self):
# return ''.join(map(str, sorted(self._ranks))) return self._ranks
@property
def nranks(self):
return len(self._ranks)
def add_ranks(self, new_ranks):
if set(new_ranks) <= set(self.ranks):
return
else:
assert self.is_instantiate() == False, \
"Cannot add new ranks after instantiating the process group"
self._ranks.extend(new_ranks)
self._ranks = sorted(list(set(self.ranks)))
def local_rank(self, global_rank): def local_rank(self, global_rank):
if global_rank in self._ranks: if global_rank in self.ranks:
return self._ranks.index(global_rank) return self.ranks.index(global_rank)
else: else:
assert False, \ assert False, \
"Rank {} doesn't belong to this group".format(global_rank) "Rank {} doesn't belong to this group".format(global_rank)
...@@ -86,12 +108,12 @@ class ProcessGroup: ...@@ -86,12 +108,12 @@ class ProcessGroup:
genv = _get_global_env() genv = _get_global_env()
global_rank = genv.rank global_rank = genv.rank
if self._nranks >= 2: if self.nranks >= 2:
strategy = core.ParallelStrategy() strategy = core.ParallelStrategy()
strategy.nranks = self._nranks strategy.nranks = self.nranks
strategy.local_rank = self.local_rank(global_rank) strategy.local_rank = self.local_rank(global_rank)
strategy.trainer_endpoints = [ strategy.trainer_endpoints = [
genv.trainer_endpoints[i] for i in self._ranks genv.trainer_endpoints[i] for i in self.ranks
] ]
strategy.current_endpoint = genv.current_endpoint strategy.current_endpoint = genv.current_endpoint
strategy.nrings = 1 strategy.nrings = 1
...@@ -113,7 +135,20 @@ class ProcessGroup: ...@@ -113,7 +135,20 @@ class ProcessGroup:
self._is_instantiate = True self._is_instantiate = True
# def __eq__(self, other):
# if not isinstance(other, ProcessGroup):
# return False
# if self.id != other.id:
# return False
# return True
# def __ne__(self, other):
# return not self.__eq__(other)
def __str__(self): def __str__(self):
string = "id: {}, nranks: {}, ranks: {}.".format( string = "id: {}, nranks: {}, ranks: {}.".format(
self.id, self._nranks, ", ".join(map(str, self._ranks))) self.id, self.nranks, ", ".join(map(str, self.ranks)))
return string return string
_g_process_group_map[0] = ProcessGroup(0, [])
...@@ -93,9 +93,14 @@ class ProcessMesh(object): ...@@ -93,9 +93,14 @@ class ProcessMesh(object):
self._topology = _get_nested_list_shape(mesh) self._topology = _get_nested_list_shape(mesh)
self._processes = processes self._processes = processes
# Store all process meshes
from .dist_context import get_default_distributed_context from .dist_context import get_default_distributed_context
default_dist_cxt = get_default_distributed_context() default_dist_cxt = get_default_distributed_context()
default_dist_cxt.add_process_mesh(self) default_dist_cxt.add_process_mesh(self)
# Add new processes to process group 0
from .process_group import get_process_group
pg0 = get_process_group(0)
pg0.add_ranks(self.processes)
@property @property
def topology(self): def topology(self):
......
...@@ -627,13 +627,13 @@ def _insert_allgather_op(block, idx, tensor, ranks): ...@@ -627,13 +627,13 @@ def _insert_allgather_op(block, idx, tensor, ranks):
attrs={ attrs={
'ring_id': group.id, 'ring_id': group.id,
'use_calc_stream': True, 'use_calc_stream': True,
'nranks': group._nranks 'nranks': group.nranks
}) })
idx_offset += 1 idx_offset += 1
# insert split op # insert split op
split_out = _insert_split_op(block, idx + idx_offset, allgather_out, split_out = _insert_split_op(block, idx + idx_offset, allgather_out,
group._nranks) group.nranks)
idx_offset += 1 idx_offset += 1
tensor_list.extend(split_out) tensor_list.extend(split_out)
return tensor_list, idx_offset return tensor_list, idx_offset
...@@ -665,14 +665,6 @@ def _concat_partitions_with_op(partition_tensor_list, tensor, partition_index, ...@@ -665,14 +665,6 @@ def _concat_partitions_with_op(partition_tensor_list, tensor, partition_index,
partition_tensor_list.append((tensor, partition_index)) partition_tensor_list.append((tensor, partition_index))
def _init_comm_for_send_recv():
if not _g_process_group_map:
genv = _get_global_env()
_g_process_group_map["global_group"] = ProcessGroup(
0, list(range(genv.world_size)))
_g_process_group_map["global_group"].instantiate()
HAS_SENT = {} HAS_SENT = {}
HAS_RECV = {} HAS_RECV = {}
HAS_ALLGATHER = {} HAS_ALLGATHER = {}
...@@ -726,7 +718,6 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op, ...@@ -726,7 +718,6 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
assert tensor_list, "The result of parsing allgather op should not be None." assert tensor_list, "The result of parsing allgather op should not be None."
elif isinstance(op_desc, SendOpDesc): elif isinstance(op_desc, SendOpDesc):
_init_comm_for_send_recv()
if var_name not in HAS_SENT.keys(): if var_name not in HAS_SENT.keys():
HAS_SENT[var_name] = [] HAS_SENT[var_name] = []
if op_desc.dst not in HAS_SENT[var_name]: if op_desc.dst not in HAS_SENT[var_name]:
...@@ -735,7 +726,6 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op, ...@@ -735,7 +726,6 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
HAS_SENT[var_name].append(op_desc.dst) HAS_SENT[var_name].append(op_desc.dst)
elif isinstance(op_desc, RecvOpDesc): elif isinstance(op_desc, RecvOpDesc):
_init_comm_for_send_recv()
if var_name not in HAS_RECV.keys(): if var_name not in HAS_RECV.keys():
HAS_RECV[var_name] = {} HAS_RECV[var_name] = {}
if op_desc.src not in HAS_RECV[var_name].keys(): if op_desc.src not in HAS_RECV[var_name].keys():
......
...@@ -61,7 +61,6 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_run_random_port) ...@@ -61,7 +61,6 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_run_random_port)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_launch_async) list(APPEND MIXED_DIST_TEST_OPS test_fleet_launch_async)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_launch_cloud) list(APPEND MIXED_DIST_TEST_OPS test_fleet_launch_cloud)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_launch_ascend) list(APPEND MIXED_DIST_TEST_OPS test_fleet_launch_ascend)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_launch_rank_mapping)
list(APPEND MIXED_DIST_TEST_OPS test_ascend_group) list(APPEND MIXED_DIST_TEST_OPS test_ascend_group)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_launch_nproc) list(APPEND MIXED_DIST_TEST_OPS test_fleet_launch_nproc)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_api_input) list(APPEND MIXED_DIST_TEST_OPS test_fleet_api_input)
...@@ -669,7 +668,6 @@ if(WITH_DISTRIBUTE) ...@@ -669,7 +668,6 @@ if(WITH_DISTRIBUTE)
bash_test_modules(test_fleet_launch_async START_BASH test_fleet_launch_async.sh ENVS PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR}) bash_test_modules(test_fleet_launch_async START_BASH test_fleet_launch_async.sh ENVS PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR})
bash_test_modules(test_fleet_launch_cloud START_BASH test_fleet_launch_cloud.sh ENVS PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR}) bash_test_modules(test_fleet_launch_cloud START_BASH test_fleet_launch_cloud.sh ENVS PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR})
bash_test_modules(test_fleet_launch_nproc START_BASH test_fleet_launch_nproc.sh ENVS PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR}) bash_test_modules(test_fleet_launch_nproc START_BASH test_fleet_launch_nproc.sh ENVS PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR})
bash_test_modules(test_fleet_launch_rank_mapping START_BASH test_fleet_launch_rank_mapping.sh ENVS PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR})
if(WITH_ASCEND OR WITH_ASCEND_CL) if(WITH_ASCEND OR WITH_ASCEND_CL)
bash_test_modules(test_fleet_launch_ascend START_BASH test_fleet_launch_ascend.sh ENVS PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR}) bash_test_modules(test_fleet_launch_ascend START_BASH test_fleet_launch_ascend.sh ENVS PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR})
bash_test_modules(test_ascend_group START_BASH test_ascend_group.sh ENVS PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR}) bash_test_modules(test_ascend_group START_BASH test_ascend_group.sh ENVS PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR})
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import os
import json
from paddle.distributed.auto_parallel.graph import Node
from paddle.distributed.auto_parallel.graph import Edge
from paddle.distributed.auto_parallel.graph import Graph
class TestAutoParallelGraph(unittest.TestCase):
def test_graph(self):
graph = Graph(name="foo")
self.assertEqual(graph.attrs["name"], "foo")
graph.add_node(1, weight=0)
# Overide the existing node attribute
graph.add_node(1, weight=1)
graph.add_node(2, weight=2)
graph.add_node(3, weight=3)
node = graph.nodes[1]
node["info"] = "is a node"
self.assertTrue(node.id, 1)
self.assertTrue("weight" in node)
self.assertTrue("info" in node)
for node_attr in node.attrs:
self.assertTrue(node_attr in ["weight", "info"])
self.assertTrue(1 in graph)
self.assertTrue(2 in graph)
self.assertTrue(3 in graph)
self.assertEqual(len(graph), 3)
self.assertEqual(graph.nodes[1].id, 1)
self.assertEqual(graph.nodes[2].id, 2)
self.assertEqual(graph.nodes[3].id, 3)
for node in graph:
if node.id == 1:
self.assertEqual(node["weight"], 1)
if node.id == 2:
self.assertEqual(node["weight"], 2)
if node.id == 3:
self.assertEqual(node["weight"], 3)
graph.add_edge(1, 2, weight=0.1)
graph.add_edge(1, 3, weight=0.2)
graph.add_edge(2, 3, weight=0.3)
graph.add_edge(4, 5, weight=0.4)
edge = graph[1][2]
edge["info"] = "is a edge"
self.assertTrue(edge.src_id, 1)
self.assertTrue(edge.tgt_id, 2)
self.assertTrue("weight" in edge)
self.assertTrue("info" in edge)
for edge_attr in edge.attrs:
self.assertTrue(edge_attr in ["weight", "info"])
self.assertEqual(graph[1][2]["weight"], 0.1)
self.assertEqual(graph[1][3]["weight"], 0.2)
self.assertEqual(graph[2][3]["weight"], 0.3)
self.assertEqual(graph[4][5]["weight"], 0.4)
str = "{}".format(graph)
self.assertIsNotNone(str)
self.assertRaises(TypeError, 6 in graph)
self.assertRaises(TypeError, "unkown_attr" in graph.nodes[1])
self.assertRaises(TypeError, "unkown_attr" in graph[1][2])
self.assertRaises(ValueError, graph.add_node, None)
self.assertRaises(ValueError, graph.add_edge, 3, None)
self.assertRaises(ValueError, graph.add_edge, None, 3)
if __name__ == '__main__':
unittest.main()
...@@ -166,8 +166,6 @@ def get_dist_prog_with_parallelizer(train_program, startup_program, ...@@ -166,8 +166,6 @@ def get_dist_prog_with_parallelizer(train_program, startup_program,
grad_clip=None) grad_clip=None)
optimizer = fleet.distributed_optimizer(optimizer) optimizer = fleet.distributed_optimizer(optimizer)
# fake a comm group
pg = new_process_group([3, 4])
_, _, distributed_startup_program, distributed_main_program = optimizer.minimize( _, _, distributed_startup_program, distributed_main_program = optimizer.minimize(
loss, startup_program) loss, startup_program)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册