未验证 提交 b0dff05d 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Do the physical mapping between the process graph and the cluster graph (#37094)

* [Auto Parallel]  Add the unified cluster representation

* [Auto Parallel] Add the graph class for physical mapping

* [Auto Parallel] Add the simple physical mapper

* Set the timeout of the mapper

* Merge the upstream develop unittests cmake files

* Fix a bug of the process group

* Remove mapper unittest from platforms which is not GPU

* Move the instantiation of process group after resharding

* Add the local id for devices

* Update the rank mapping format

* Add some comments

* Remove the related files about mapping

* Update the unittest for auto mapping

* Remove unused rank_mapping unittest

* Improve the unittest coverage

* Improve the unittest coverage
上级 87e65a99
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import operator
import functools
import json
import paddle
from collections import deque
from .graph import Node
from .graph import Edge
from .graph import Graph
from .cluster import DeviceType
from .process_group import get_process_group
def is_collective_comm_op(op):
comm_list = [
"c_allreduce_sum", "c_allreduce_min", "c_allreduce_max",
"c_allreduce_prod", "c_reduce_sum", "c_reduce_min", "c_reduce_max",
"c_reduce_prod", "c_broadcast", "c_allgather"
]
if op.type in comm_list:
return True
else:
return False
def is_p2p_comm_op(op):
comm_list = ["send_v2", "recv_v2"]
if op.type in comm_list:
return True
else:
return False
def get_dtype_bytes(dtype):
num_bytes = 0
if dtype == paddle.float64:
num_bytes = 8
elif dtype == paddle.float32:
num_bytes = 4
elif dtype == paddle.float16:
num_bytes = 2
elif dtype == paddle.bfloat16:
num_bytes = 2
elif dtype == paddle.int64:
num_bytes = 8
elif dtype == paddle.int32:
num_bytes = 4
elif dtype == paddle.int16:
num_bytes = 2
elif dtype == paddle.int8:
num_bytes = 1
elif dtype == paddle.uint8:
num_bytes = 1
else:
raise ValueError("Unrecognized dtype {}.".format(dtype))
return num_bytes
def get_comm_volume(comm_op, src_rank, tgt_rank):
comm_volume = None
if src_rank == tgt_rank:
return comm_volume
comm_op_type = comm_op.type
if comm_op_type != "recv_v2":
tensor_name = comm_op.input_arg_names[0]
else:
tensor_name = comm_op.output_arg_names[0]
tensor = comm_op.block._find_var_recursive(tensor_name)
assert tensor is not None
tensor_shape = tensor.shape
# Skip the batch dim
new_tensor_shape = []
for val in tensor_shape:
if val == -1:
print("Warning: -1 in the tensor shape.")
new_tensor_shape.append(1)
else:
new_tensor_shape.append(val)
tensor_size = functools.reduce(operator.mul, new_tensor_shape, 1)
tensor_bytes = tensor_size * get_dtype_bytes(tensor.dtype)
if "c_allreduce" in comm_op_type:
comm_volume = 2 * tensor_bytes
elif "c_allgather" in comm_op_type:
comm_volume = tensor_bytes
elif "c_broadcast" in comm_op_type:
if comm_op.attr("root") == src_rank:
comm_volume = tensor_bytes
else:
comm_volume = None
elif "c_reduce" in comm_op_type:
if comm_op.attr("root_id") == src_rank:
comm_volume = None
else:
comm_volume = tensor_bytes
elif "send_v2" in comm_op_type:
if comm_op.attr("peer") == tgt_rank:
comm_volume = tensor_bytes
else:
comm_volume = None
elif "recv_v2" in comm_op_type:
comm_volume = None
else:
raise ValueError("Unrecognized communication operator.")
return comm_volume
def analyze_comm_requirements_from_op(op, rank):
comm_requirements_to_ranks = {}
if is_collective_comm_op(op):
process_group_id = op.attr("ring_id")
process_group = get_process_group(process_group_id)
if rank not in process_group.ranks:
return comm_requirements_to_ranks
for tgt_rank in process_group.ranks:
comm_volume = get_comm_volume(op, rank, tgt_rank)
if comm_volume is not None:
comm_requirements_to_ranks[tgt_rank] = {}
comm_requirements_to_ranks[tgt_rank][
"comm_volume"] = comm_volume
elif is_p2p_comm_op(op):
tgt_rank = op.attr("peer")
comm_volume = get_comm_volume(op, rank, tgt_rank)
if comm_volume is not None:
comm_requirements_to_ranks[tgt_rank] = {}
comm_requirements_to_ranks[tgt_rank]["comm_volume"] = comm_volume
else:
comm_requirements_to_ranks = {}
return comm_requirements_to_ranks
def analyze_requirements_for_program(program, rank):
resource_requirements = {}
comm_requirements_to_ranks = {}
# only support device_type and only support GPU for now
resource_requirements["device_type"] = DeviceType.GPU
for block in program.blocks:
for op in block.ops:
cur_comm_requirements_to_ranks = analyze_comm_requirements_from_op(
op, rank)
for tgt_rank, link_info in cur_comm_requirements_to_ranks.items():
if tgt_rank in comm_requirements_to_ranks:
comm_requirements_to_ranks[tgt_rank][
"comm_volume"] += link_info["comm_volume"]
else:
comm_requirements_to_ranks[tgt_rank] = {}
comm_requirements_to_ranks[tgt_rank][
"comm_volume"] = link_info["comm_volume"]
return resource_requirements, comm_requirements_to_ranks
def build_process_graph(distributed_program):
graph = Graph()
for src_rank, src_program in distributed_program.items():
resource_requirements, comm_requirements_to_ranks = analyze_requirements_for_program(
src_program, src_rank)
graph.add_node(src_rank, resource_requirements=resource_requirements)
for tgt_rank, comm_requirements in comm_requirements_to_ranks.items():
graph.add_edge(
src_rank, tgt_rank, comm_requirements=comm_requirements)
return graph
def build_cluster_graph(cluster):
graph = Graph()
for machine in cluster.machines.values():
for device in machine.devices.values():
graph.add_node(device.global_id, device=device)
for link in machine.links.values():
graph.add_edge(
link.source.global_id, link.target.global_id, link=link)
return graph
def mapping(distributed_program, cluster):
# A very simple mapping algorithm only for GPUs.
# Here we assume one process will be mapped to one GPU.
# In the future, more mapping configurations and algorithms will be supported.
process_graph = build_process_graph(distributed_program)
cluster_graph = build_cluster_graph(cluster)
for cur_rank_node in process_graph:
cur_rank_node["visited"] = False
for cur_device_node in cluster_graph:
cur_device_node["occupied"] = False
def sort_by_comm_volume(rank_edge):
return rank_edge["comm_requirements"]["comm_volume"]
def sort_by_comm_bandwidth(device_edge):
return device_edge["link"].bandwidth
def select_unvisited_rank_node(rank_node_list):
selected_rank_node = None
for rank_node in rank_node_list:
if rank_node["visited"] is False:
selected_rank_node = rank_node
return selected_rank_node
queue = deque()
root_rank_node = select_unvisited_rank_node(
list(process_graph.nodes.values()))
while root_rank_node is not None:
queue.append(root_rank_node)
while queue:
cur_rank_node = queue.popleft()
if cur_rank_node["visited"]:
continue
device_type = cur_rank_node["resource_requirements"]["device_type"]
cur_device_node = None
for device_node in cluster_graph.nodes.values():
if (device_node["device"].type == device_type) and (
not device_node["occupied"]):
device_node["occupied"] = True
cur_rank_node["visited"] = True
cur_rank_node["device"] = device_node["device"]
cur_device_node = device_node
break
assert cur_device_node, "Cannot find a device to satisfy the requirement."
nbr_rank_edges = []
for nbr_rank_node_id, nbr_rank_edge in process_graph.adjs[
cur_rank_node.id].items():
assert nbr_rank_edge.src_id == cur_rank_node.id and nbr_rank_edge.tgt_id == nbr_rank_node_id
queue.append(process_graph.nodes[nbr_rank_node_id])
nbr_rank_edges.append(nbr_rank_edge)
nbr_rank_edges.sort(key=sort_by_comm_volume)
nbr_device_edges = []
for nbr_device_edge in cluster_graph.adjs[
cur_device_node.id].values():
nbr_device_edges.append(nbr_device_edge)
nbr_device_edges.sort(key=sort_by_comm_bandwidth)
for nbr_rank_edge in nbr_rank_edges:
src_rank_node = process_graph.nodes[nbr_rank_edge.src_id][
"visited"]
if src_rank_node:
continue
device_type = src_rank_node["resource_requirements"][
"device_type"]
nbr_rank_node = process_graph.nodes[nbr_rank_edge.tgt_id]
for nbr_device_edge in nbr_device_edges:
nbr_device_node = cluster_graph.nodes[
nbr_device_edge.tgt_id]
if (nbr_device_node["device"].type == device_type) and (
not nbr_device_node["occupied"]):
nbr_device_node["occupied"] = True
nbr_rank_node["visited"] = True
nbr_rank_node["device"] = nbr_device_node["device"]
break
root_rank_node = select_unvisited_rank_node(
list(process_graph.nodes.values()))
rank_mapping = {}
for rank, rank_node in process_graph.nodes.items():
device = rank_node["device"]
machine = device.machine
if machine.id in rank_mapping:
rank_mapping[machine.id]["hostname"] = machine.hostname
rank_mapping[machine.id]["addr"] = machine.addr
rank_mapping[machine.id]["port"] = machine.port
if rank not in rank_mapping[machine.id]["ranks"]:
rank_mapping[machine.id]["ranks"][rank] = []
rank_mapping[machine.id]["ranks"][rank].append(device.local_id)
else:
rank_mapping[machine.id]["ranks"][rank].append(device.local_id)
else:
rank_mapping[machine.id] = {}
rank_mapping[machine.id]["hostname"] = machine.hostname
rank_mapping[machine.id]["addr"] = machine.addr
rank_mapping[machine.id]["port"] = machine.port
rank_mapping[machine.id]["ranks"] = {}
rank_mapping[machine.id]["ranks"][rank] = []
rank_mapping[machine.id]["ranks"][rank].append(device.local_id)
for machine_mapping in rank_mapping.values():
for rank_devices in machine_mapping["ranks"].values():
rank_devices.sort()
return rank_mapping
......@@ -144,6 +144,7 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32)
LIST(REMOVE_ITEM TEST_OPS test_disable_signal_handler)
LIST(REMOVE_ITEM TEST_OPS test_fleet_executor)
LIST(REMOVE_ITEM TEST_OPS test_fleet_executor_multi_devices)
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_mapper)
LIST(REMOVE_ITEM TEST_OPS test_fleet_executor_task_node)
endif()
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import os
import json
import collections
import math
import numpy as np
import paddle
import paddle.nn as nn
import paddle.fluid as fluid
import paddle.nn.functional as F
import paddle.tensor as tensor
import paddle.utils as utils
import paddle.static as static
from paddle.fluid import core
from paddle.fluid import layers
from paddle.fluid.framework import in_dygraph_mode
from paddle.nn.layer.transformer import _convert_param_attr_to_list
from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer
from paddle.distributed import fleet
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import reshard
from paddle.distributed.auto_parallel.process_group import get_all_process_groups
from paddle.distributed.auto_parallel.process_group import new_process_group
from paddle.distributed.auto_parallel.cluster import Cluster
from paddle.distributed.auto_parallel.cluster import DeviceType
from paddle.distributed.auto_parallel.cluster import LinkType
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.mapper import build_process_graph
from paddle.distributed.auto_parallel.mapper import build_cluster_graph
from paddle.distributed.auto_parallel.mapper import mapping
from paddle.distributed.auto_parallel.mapper import get_dtype_bytes
from paddle.distributed.auto_parallel.mapper import get_comm_volume
paddle.enable_static()
_global_parallel_strategy = None
_global_process_mesh = None
_global_num_stages = None
cluster_json = """
{
"machines": [
{
"hostname": "machine0",
"addr": "0.0.0.1",
"port": "768",
"devices": [
{
"global_id": 0,
"local_id": 0,
"type": "GPU",
"model": "A100-SXM4-40GB",
"sp_gflops": 19500,
"dp_gflops": 9700,
"memory": 40
},
{
"global_id": 1,
"local_id": 1,
"type": "GPU",
"model": "A100-SXM4-40GB",
"sp_gflops": 19500,
"dp_gflops": 9700,
"memory": 40
},
{
"global_id": 2,
"local_id": 2,
"type": "GPU",
"model": "A100-SXM4-40GB",
"sp_gflops": 19500,
"dp_gflops": 9700,
"memory": 40
},
{
"global_id": 3,
"local_id": 3,
"type": "GPU",
"model": "A100-SXM4-40GB",
"sp_gflops": 19500,
"dp_gflops": 9700,
"memory": 40
},
{
"global_id": 4,
"local_id": 0,
"type": "NIC"
}
],
"links": [
{
"source_global_id": 0,
"target_global_id": 1,
"type": "NVL",
"bandwidth": 42
},
{
"source_global_id": 0,
"target_global_id": 2,
"type": "NVL",
"bandwidth": 42
},
{
"source_global_id": 0,
"target_global_id": 3,
"type": "NVL",
"bandwidth": 42
},
{
"source_global_id": 0,
"target_global_id": 4,
"type": "PHB",
"bandwidth": 12
},
{
"source_global_id": 1,
"target_global_id": 0,
"type": "NVL",
"bandwidth": 42
},
{
"source_global_id": 1,
"target_global_id": 2,
"type": "NVL",
"bandwidth": 42
},
{
"source_global_id": 1,
"target_global_id": 3,
"type": "NVL",
"bandwidth": 42
},
{
"source_global_id": 1,
"target_global_id": 4,
"type": "PHB",
"bandwidth": 12
},
{
"source_global_id": 2,
"target_global_id": 0,
"type": "NVL",
"bandwidth": 42
},
{
"source_global_id": 2,
"target_global_id": 1,
"type": "NVL",
"bandwidth": 42
},
{
"source_global_id": 2,
"target_global_id": 3,
"type": "NVL",
"bandwidth": 42
},
{
"source_global_id": 2,
"target_global_id": 4,
"type": "PHB",
"bandwidth": 12
},
{
"source_global_id": 3,
"target_global_id": 0,
"type": "NVL",
"bandwidth": 42
},
{
"source_global_id": 3,
"target_global_id": 1,
"type": "NVL",
"bandwidth": 42
},
{
"source_global_id": 3,
"target_global_id": 2,
"type": "NVL",
"bandwidth": 42
},
{
"source_global_id": 3,
"target_global_id": 4,
"type": "PHB",
"bandwidth": 12
},
{
"source_global_id": 4,
"target_global_id": 9,
"type": "NET",
"bandwidth": 1
}
]
},
{
"hostname": "machine1",
"addr": "0.0.0.2",
"port": "768",
"devices": [
{
"global_id": 5,
"local_id": 0,
"type": "GPU",
"model": "Tesla V100-SXM2-32GB",
"sp_gflops": 15700,
"dp_gflops": 7800,
"memory": 32
},
{
"global_id": 6,
"local_id": 1,
"type": "GPU",
"model": "Tesla V100-SXM2-32GB",
"sp_gflops": 15700,
"dp_gflops": 7800,
"memory": 32
},
{
"global_id": 7,
"local_id": 2,
"type": "GPU",
"model": "Tesla V100-SXM2-32GB",
"sp_gflops": 15700,
"dp_gflops": 7800,
"memory": 32
},
{
"global_id": 8,
"local_id": 3,
"type": "GPU",
"model": "Tesla V100-SXM2-32GB",
"sp_gflops": 15700,
"dp_gflops": 7800,
"memory": 32
},
{
"global_id": 9,
"local_id": 0,
"type": "NIC"
}
],
"links": [
{
"source_global_id": 5,
"target_global_id": 6,
"type": "NVL",
"bandwidth": 42
},
{
"source_global_id": 5,
"target_global_id": 7,
"type": "NVL",
"bandwidth": 42
},
{
"source_global_id": 5,
"target_global_id": 8,
"type": "NVL",
"bandwidth": 42
},
{
"source_global_id": 5,
"target_global_id": 9,
"type": "PHB",
"bandwidth": 12
},
{
"source_global_id": 6,
"target_global_id": 5,
"type": "NVL",
"bandwidth": 42
},
{
"source_global_id": 6,
"target_global_id": 7,
"type": "NVL",
"bandwidth": 42
},
{
"source_global_id": 6,
"target_global_id": 8,
"type": "NVL",
"bandwidth": 42
},
{
"source_global_id": 6,
"target_global_id": 9,
"type": "PHB",
"bandwidth": 12
},
{
"source_global_id": 7,
"target_global_id": 5,
"type": "NVL",
"bandwidth": 42
},
{
"source_global_id": 7,
"target_global_id": 6,
"type": "NVL",
"bandwidth": 42
},
{
"source_global_id": 7,
"target_global_id": 8,
"type": "NVL",
"bandwidth": 42
},
{
"source_global_id": 7,
"target_global_id": 9,
"type": "PHB",
"bandwidth": 12
},
{
"source_global_id": 8,
"target_global_id": 5,
"type": "NVL",
"bandwidth": 42
},
{
"source_global_id": 8,
"target_global_id": 6,
"type": "NVL",
"bandwidth": 42
},
{
"source_global_id": 8,
"target_global_id": 7,
"type": "NVL",
"bandwidth": 42
},
{
"source_global_id": 8,
"target_global_id": 9,
"type": "PHB",
"bandwidth": 12
},
{
"source_global_id": 9,
"target_global_id": 4,
"type": "NET",
"bandwidth": 1
}
]
}
]
}
"""
class MLPLayer(nn.Layer):
def __init__(self,
hidden_size=64,
intermediate_size=4 * 64,
initializer_range=0.02):
super(MLPLayer, self).__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
np.random.seed(2021)
arr0 = np.random.normal(0, 0.02, size=(d_model, dim_feedforward))
arr1 = np.random.normal(0, 0.02, size=(dim_feedforward, d_model))
arr2 = np.random.normal(0, 0.02, size=(d_model, dim_feedforward))
arr3 = np.random.normal(0, 0.02, size=(dim_feedforward, d_model))
weight_attr0 = paddle.ParamAttr(initializer=NumpyArrayInitializer(arr0))
weight_attr1 = paddle.ParamAttr(initializer=NumpyArrayInitializer(arr1))
weight_attr2 = paddle.ParamAttr(initializer=NumpyArrayInitializer(arr2))
weight_attr3 = paddle.ParamAttr(initializer=NumpyArrayInitializer(arr3))
bias_attr = None
self.linear0 = nn.Linear(
d_model, dim_feedforward, weight_attr0, bias_attr=bias_attr)
self.linear1 = nn.Linear(
dim_feedforward, d_model, weight_attr1, bias_attr=bias_attr)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
self.linear2 = nn.Linear(
d_model, dim_feedforward, weight_attr2, bias_attr=bias_attr)
self.linear3 = nn.Linear(
dim_feedforward, d_model, weight_attr3, bias_attr=bias_attr)
def forward(self, input):
if _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor(
self.linear0.weight,
dist_attr={
"process_mesh": _global_process_mesh[0],
"dims_mapping": [-1, 1]
})
auto.shard_tensor(
self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh[0],
"dims_mapping": [1, -1]
})
auto.shard_tensor(
self.linear2.weight,
dist_attr={
"process_mesh": _global_process_mesh[1],
"dims_mapping": [-1, 1]
})
auto.shard_tensor(
self.linear3.weight,
dist_attr={
"process_mesh": _global_process_mesh[1],
"dims_mapping": [1, -1]
})
out = self.norm(input)
out = self.linear0(out)
out = F.gelu(out, approximate=True)
out = self.linear1(out)
out = self.linear2(out)
out = F.gelu(out, approximate=True)
out = self.linear3(out)
return out
def mlp_forward(train_program, start_program):
with static.program_guard(train_program,start_program), \
utils.unique_name.guard():
batch_size = 4
hidden_size = 64
input = static.data(
name="input", shape=[batch_size, hidden_size], dtype='float32')
label = static.data(
name="label", shape=[batch_size, 1], dtype='float32')
if _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor(
input,
dist_attr={
"process_mesh": _global_process_mesh[0],
"dims_mapping": [0, -1]
})
mlp = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
initializer_range=0.02)
predict = mlp(input)
error_cost = paddle.nn.functional.square_error_cost(predict, label)
loss = paddle.mean(error_cost)
return loss, train_program, start_program
def get_dist_prog(train_program, startup_program, dist_context, rank_id):
loss, train_program, startup_program = mlp_forward(train_program,
startup_program)
dist_strategy = fleet.DistributedStrategy()
# auto completion
complete_train_program = auto.complete_annotation(train_program,
dist_context)
partitioner = Partitioner(dist_strategy, dist_context, rank_id)
# logical partition
dist_train_program, dist_startup_prog = partitioner.transpile_forward(
complete_train_program, startup_program)
dist_params_grads = partitioner.apply_backward(
loss, complete_train_program, startup_program, dist_train_program,
dist_startup_prog)
optimizer = paddle.fluid.optimizer.AdamOptimizer()
opt_ops = partitioner.apply_optimize(optimizer, dist_params_grads,
dist_train_program, dist_startup_prog)
reshard(dist_train_program, dist_startup_prog, rank_id, dist_context)
return dist_train_program, dist_startup_prog
def is_in_machine(device_local_id, machine):
for device in machine.devices.values():
if device_local_id == device.local_id:
return True
return False
def get_device_local_ids(machine):
local_ids = []
for device in machine.devices.values():
local_ids.append[device.local_id]
return local_ids
class TestAutoParallelMapper(unittest.TestCase):
def test_mapper_dp_mp_pp(self):
cluster_json_file = ""
cluster_json_object = json.loads(cluster_json)
with open("./auto_parallel_cluster.json", "w") as cluster_json_file:
json.dump(cluster_json_object, cluster_json_file)
cluster = Cluster()
cluster.build_from_file("./auto_parallel_cluster.json")
os.remove("./auto_parallel_cluster.json")
global _global_parallel_strategy
_global_parallel_strategy = "dp_mp_pp"
global _global_num_stages
_global_num_stages = 2
global _global_process_mesh
_global_process_mesh = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]]
processes = [0, 1, 2, 3, 4, 5, 6, 7]
dist_programs = {}
for rank_id in processes:
train_program = static.Program()
startup_program = static.Program()
dist_context = DistributedContext()
dist_train_program, dist_startup_prog = get_dist_prog(
train_program, startup_program, dist_context, rank_id)
# if rank_id == 0:
# print_program_with_dist_attr(dist_train_program, dist_context)
dist_programs[rank_id] = dist_train_program
rank_mapping = mapping(dist_programs, cluster)
all_mapped_ranks = set()
for machine_id, machine_mapping in rank_mapping.items():
machine = cluster.machines[machine_id]
machine_mapped_ranks = set()
machine_mapped_device_local_ids = set()
for rank, device_ids in machine_mapping["ranks"].items():
# Only allow one process to one device mapping
self.assertEqual(len(device_ids), 1)
self.assertTrue(is_in_machine(device_ids[0], machine))
machine_mapped_ranks.add(rank)
machine_mapped_device_local_ids.add(device_ids[0])
self.assertEqual(
len(machine_mapped_ranks), len(machine_mapped_device_local_ids))
all_mapped_ranks.update(machine_mapped_ranks)
self.assertEqual(set(processes), all_mapped_ranks)
def test_mapper_misc(self):
self.assertEqual(get_dtype_bytes(paddle.float64), 8)
self.assertEqual(get_dtype_bytes(paddle.float32), 4)
self.assertEqual(get_dtype_bytes(paddle.float16), 2)
self.assertEqual(get_dtype_bytes(paddle.bfloat16), 2)
self.assertEqual(get_dtype_bytes(paddle.int64), 8)
self.assertEqual(get_dtype_bytes(paddle.int32), 4)
self.assertEqual(get_dtype_bytes(paddle.int16), 2)
self.assertEqual(get_dtype_bytes(paddle.int8), 1)
self.assertEqual(get_dtype_bytes(paddle.uint8), 1)
self.assertRaises(ValueError, get_dtype_bytes, "unknown type")
train_program = static.Program()
startup_program = static.Program()
ring_id = 0
root_id = 0
nranks = 2
with fluid.program_guard(train_program, startup_program):
input = layers.data(name="input", shape=[10, 10], dtype='float32')
output = train_program.current_block().create_var(
name="outofbroadcast",
dtype='float32',
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False)
broadcast_op = train_program.global_block().append_op(
type="c_broadcast",
inputs={'X': input},
attrs={'ring_id': ring_id,
'root': root_id},
outputs={'Out': output})
self.assertEqual(get_comm_volume(broadcast_op, 0, 1), 400)
self.assertEqual(get_comm_volume(broadcast_op, 1, 0), None)
allgather_op = train_program.global_block().append_op(
type="c_allgather",
inputs={'X': input},
attrs={'ring_id': ring_id,
'nranks': nranks},
outputs={'Out': output})
self.assertEqual(get_comm_volume(allgather_op, 0, 1), 400)
self.assertEqual(get_comm_volume(allgather_op, 0, 0), None)
reduce_op = train_program.global_block().append_op(
type="c_reduce_sum",
inputs={'X': input},
attrs={'ring_id': ring_id,
'root_id': root_id},
outputs={'Out': output})
self.assertEqual(get_comm_volume(reduce_op, 0, 1), None)
self.assertEqual(get_comm_volume(reduce_op, 1, 0), 400)
cast_op = train_program.global_block().append_op(
type="cast",
inputs={"X": input},
outputs={"Out": output},
attrs={
"in_dtype": fluid.core.VarDesc.VarType.FP32,
"out_dtype": fluid.core.VarDesc.VarType.FP32
})
self.assertRaises(ValueError, get_comm_volume, cast_op, 0, 1)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册