未验证 提交 6a15d407 编写于 作者: C caozhou 提交者: GitHub

[Auto Paralle]Add reshard cost and update estimator (#45118)

* update reshard cost and cost estimator

* add unittest

* add dropout cost

* fix import error

* fix reshard code style error

* improve unittest coverage
上级 933db9d4
......@@ -148,6 +148,25 @@ class ConcatOpCost(CompOpCost):
return 0
@register_op_cost
class DropoutOpCost(CompOpCost):
OP_TYPE = "dropout"
def __init__(self, op=None, op_desc=None, cluster=None):
super(DropoutOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class ElementwiseAddOpCost(CompOpCost):
OP_TYPE = "elementwise_add"
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,26 +12,56 @@
# See the License for the specific language governing permissions and
# limitations under the License
from collections import OrderedDict
from functools import reduce
import paddle
import paddle.fluid.core as core
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from .base_cost import Cost
from ..operators.common import get_distributed_operator_impl_container
from ..dist_tensor import DistributedTensor
class CostEstimator:
_sepical_op_type = ["fused_attention", "fused_feedforward"]
def __init__(self,
program,
cluster=None,
dist_context=None,
mode="modeling"):
cluster,
mode="modeling",
rank=None,
loop_count=10):
self._program = program
self._cluster = cluster
self._dist_context = dist_context
self._check_mode(mode)
self._mode = mode
self._global_cost = None
self._local_cost = {}
self._rank = rank if rank is not None else paddle.distributed.get_rank()
self._loop_count = loop_count
self._global_cost = Cost()
self._local_cost_mapping = {}
self._detailed_cost = OrderedDict(
) # {`op_id`: {"reshard": [], "dist_op": [], "local_cost": local_cost}}}
self._bubble_time_mapping = {}
self._ordered_ops = []
@property
def loop_count(self):
return self._loop_count
@property
def detailed_cost(self):
return self._detailed_cost
@property
def program(self):
return self._program
@property
def rank(self):
return self._rank
@property
def dist_context(self):
return self._dist_context
......@@ -46,25 +76,337 @@ class CostEstimator:
@property
def global_cost(self):
max_time = 0
memory = 0
flops = 0
for rank in self._local_cost_mapping:
cost = self._local_cost_mapping[rank]
if cost.time > max_time:
max_time = cost.time
memory += cost.memory
flops += cost.flops
self._global_cost.time = max_time
self._global_cost.memory = memory
self._global_cost.flops = flops
return self._global_cost
@property
def local_cost(self):
return self._local_cost
def get_op_cost(self):
return 0
def local_cost(self, rank=None):
rank = self.rank if rank is None else rank
if rank not in self._local_cost_mapping:
self._local_cost_mapping[rank] = Cost()
def get_tensor_cost(self):
return 0
return self._local_cost_mapping[rank]
def get_global_cost(self):
return 0
def get_local_cost(self, rank=None):
return 0
def local_bubble_time(self, rank=None):
rank = self.rank if rank is None else rank
return self._bubble_time_mapping[rank]
def _check_mode(self, mode):
if mode not in ["modeling", "profiling"]:
raise ValueError(
"Just support modeling and profiling, but got {}".format(mode))
def _is_special_var_name(self, var_name):
special_var_name = ["lod_tensor_blocking_queue_0"]
if var_name in special_var_name:
return True
return False
def _estimate_core(self, dist_context, resharder, block):
from ..reshard import get_var_with_recursion
ops = block.ops
loop_count = None
if block.desc.id != self.program.global_block().desc.id:
loop_count = self.loop_count
else:
loop_count = 1
for i in range(loop_count):
for op in ops:
self._detailed_cost[op.desc.id()] = OrderedDict()
# if in the while sub block, the detail of cost is the last cost
detail = self._detailed_cost[op.desc.id()]
detail["reshard_cost"] = OrderedDict() #
detail["dist_op_cost"] = []
if int(op.attr('op_role')) == int(OpRole.Optimize):
continue
if op.type in [
"create_py_reader", "create_double_buffer_reader",
"read"
]:
continue
# NOTE: It does not support nested loop and just supports while op when op has sub block now.
if op.type == "while":
while_block = self.program.blocks[op.attr("sub_block").id]
self._estimate_core(dist_context, resharder, while_block)
continue
for var_name in op.input_arg_names:
if self._is_special_var_name(var_name):
continue
var = get_var_with_recursion(var_name, block, self.program)
reshard_cost = resharder.get_cost(op, var, self.cluster)
# calc reshard cost
if reshard_cost is not None:
detail["reshard_cost"][var_name] = reshard_cost
comm_costs = reshard_cost[0]
local_comp_cost = reshard_cost[1]
for comm_cost in comm_costs:
# time is cumulative in global cost and local cost, but memory and flops just are cumulative in global cost.
# comm sync
for item in comm_cost:
group_ranks, cost = item
max_time = None
cost_time = {}
for rank in group_ranks:
rank_cost = self.local_cost(rank)
cost_time[rank] = rank_cost.time
if max_time is None:
max_time = rank_cost.time
else:
if max_time < rank_cost.time:
max_time = rank_cost.time
for rank in group_ranks:
self.local_cost(
rank).time = max_time + cost.time
if rank not in self._bubble_time_mapping:
self._bubble_time_mapping[rank] = 0
self._bubble_time_mapping[rank] += (
max_time - cost_time[rank])
for rank in local_comp_cost:
for comp_cost in local_comp_cost[rank]:
self.local_cost(rank).time += comp_cost.time
# calc dist op cost
dist_op = dist_context.get_dist_op_for_program(op)
op_dist_attr = dist_op.dist_attr
processes = op_dist_attr.process_mesh.processes
container = get_distributed_operator_impl_container(
op_dist_attr.impl_type)
dist_impl = container.impls[op_dist_attr.impl_idx]
dist_op_cost = dist_impl.calc_cost(op.attr('op_role'), dist_op,
dist_context, self.cluster)
detail["dist_op_cost"] = dist_op_cost
if dist_op_cost is None:
assert dist_op.serial_op.type in CostEstimator._sepical_op_type
continue
for item in dist_op_cost:
if isinstance(item, list):
# comm sync
for comm_op_cost in item:
max_time = None
cost_time = {}
group_ranks = comm_op_cost.group_ranks
for rank in comm_op_cost.group_ranks:
rank_cost = self.local_cost(rank)
cost_time[rank] = rank_cost.time
if max_time is None:
max_time = rank_cost.time
else:
if max_time < rank_cost.time:
max_time = rank_cost.time
for rank in group_ranks:
self.local_cost(
rank).time = max_time + comm_op_cost.time
if rank not in self._bubble_time_mapping:
self._bubble_time_mapping[rank] = 0
self._bubble_time_mapping[rank] += (
max_time - cost_time[rank])
elif isinstance(item, dict):
# op just one
for rank in processes:
# dp+pp+mp
if rank not in item:
continue
self.local_cost(rank).time += item[rank].time
def prepare(self):
self._global_cost = Cost()
self._local_cost_mapping = {}
self._detailed_cost = OrderedDict()
self._bubble_time_mapping = {}
def _calculate_bytes(self, sizes, dtype):
if sizes:
total_count = reduce(lambda x, y: x * y, sizes)
else:
total_count = 0
if dtype == paddle.float64 or dtype == paddle.int64:
dtype_factor = 8
elif dtype == paddle.float32 or dtype == paddle.int32:
dtype_factor = 4
elif dtype == paddle.float16 or dtype == paddle.bfloat16 \
or dtype == paddle.int16:
dtype_factor = 2
elif dtype == paddle.int8 or dtype == paddle.uint8:
dtype_factor = 1
else:
dtype_factor = 8
memory = total_count * dtype_factor
return memory
def _estimate_max_memory_by_dist_op(self, dist_context):
# This estimation will be improved, now reshard and inplace are not considered.
# Persist var is not free.
def _convert_pm_and_dm_to_str(process_mesh, dims_mapping):
processes = ",".join([str(x) for x in process_mesh.processes])
topology = ",".join([str(x) for x in process_mesh.topology])
dims_mapping = ",".join([str(x) for x in dims_mapping])
result = processes + topology + dims_mapping
return result
memories = {}
max_memories = {}
var_info = {
} # var_name: [[process_mesh, dims_mapping], [id]], [[process_mesh, dims_mapping], [id]]}
for block in self.program.blocks:
for op in block.ops:
self._ordered_ops.append([op.desc.id(), op])
self._ordered_ops.sort(key=lambda x: x[0])
for op_id, op in self._ordered_ops:
dist_op = dist_context.get_dist_op_for_program(op)
process_mesh = dist_op.dist_attr.process_mesh
for var_name in op.input_arg_names:
input_dims_mapping = dist_op.dist_attr.get_input_dims_mapping(
var_name)
if var_name not in var_info:
var_info[var_name] = {}
key = _convert_pm_and_dm_to_str(process_mesh,
input_dims_mapping)
if key not in var_info[var_name]:
var_info[var_name][key] = {}
# it is even partition now
if "memory" not in var_info[var_name][key]:
var = dist_op.get_serial_input(var_name)
global_sizes = var.shape
dtype = var.dtype
sizes = DistributedTensor.get_local_sizes(
global_sizes, input_dims_mapping, process_mesh.topology,
process_mesh.processes)
var_info[var_name][key]["memory"] = self._calculate_bytes(
sizes, dtype)
if "position" not in var_info[var_name][key]:
var_info[var_name][key]["position"] = []
var_info[var_name][key]["position"].append(op_id)
for var_name in op.output_arg_names:
output_dims_mapping = dist_op.dist_attr.get_output_dims_mapping(
var_name)
if var_name not in var_info:
var_info[var_name] = {}
key = _convert_pm_and_dm_to_str(process_mesh,
output_dims_mapping)
if key not in var_info[var_name]:
var_info[var_name][key] = {}
if "memory" not in var_info[var_name][key]:
var = dist_op.get_serial_output(var_name)
global_sizes = var.shape
dtype = var.dtype
sizes = DistributedTensor.get_local_sizes(
global_sizes, output_dims_mapping,
process_mesh.topology, process_mesh.processes)
var_info[var_name][key]["memory"] = self._calculate_bytes(
sizes, dtype)
if "position" not in var_info[var_name][key]:
var_info[var_name][key]["position"] = []
var_info[var_name][key]["position"].append(op_id)
has_used_vars = set()
for op_id, op in self._ordered_ops:
can_free_memories = {}
can_free_vars = set()
dist_op = dist_context.get_dist_op_for_program(op)
process_mesh = dist_op.dist_attr.process_mesh
for var_name in op.input_arg_names:
input_dims_mapping = dist_op.dist_attr.get_input_dims_mapping(
var_name)
key = _convert_pm_and_dm_to_str(process_mesh,
input_dims_mapping)
has_used_var = var_name + key
var = dist_op.get_serial_input(var_name)
# not used
if var_name + key not in has_used_vars:
has_used_vars.add(has_used_var)
for process in process_mesh.processes:
if process not in memories:
memories[process] = 0
memories[process] += var_info[var_name][key]["memory"]
# used
else:
if op_id == var_info[var_name][key]["position"][-1]:
if has_used_var not in can_free_vars:
can_free_vars.add(has_used_var)
if not var.persistable:
for process in process_mesh.processes:
if process not in can_free_memories:
can_free_memories[process] = 0
can_free_memories[process] += var_info[
var_name][key]["memory"]
for var_name in op.output_arg_names:
output_dims_mapping = dist_op.dist_attr.get_output_dims_mapping(
var_name)
key = _convert_pm_and_dm_to_str(process_mesh,
output_dims_mapping)
has_used_var = var_name + key
var = dist_op.get_serial_output(var_name)
# not used
if var_name + key not in has_used_vars:
has_used_vars.add(has_used_var)
for process in process_mesh.processes:
if process not in memories:
memories[process] = 0
memories[process] += var_info[var_name][key]["memory"]
# used
else:
if op_id == var_info[var_name][key]["position"][-1]:
if has_used_var not in can_free_vars:
can_free_vars.add(has_used_var)
if not var.persistable:
for process in process_mesh.processes:
if process not in can_free_memories:
can_free_memories[process] = 0
can_free_memories[process] += var_info[
var_name][key]["memory"]
# calc peak memory
for process in memories:
if process not in max_memories:
max_memories[process] = memories[process]
else:
if memories[process] > max_memories[process]:
max_memories[process] = memories[process]
# free memory
for process in can_free_memories:
if process in memories:
memories[process] -= can_free_memories[process]
# Calculate the max memory in all ranks
max_memory = max(max_memories.values())
return max_memory
def estimate(self, dist_context, resharder=None):
self.prepare()
from ..reshard import Resharder
resharder = Resharder(self.program, None, self.rank, dist_context,
[]) if resharder is None else resharder
block = self.program.global_block()
self._estimate_core(dist_context, resharder, block)
return self.global_cost
......@@ -34,7 +34,8 @@ from ..process_group import new_process_group
from ..utils import _get_comm_group, _get_idx_in_axis, _get_corresponding_rank
from ..cost import build_comp_desc_from_dist_op, build_comm_desc_from_dist_op
from ..cost import build_comm_costs_from_descs, build_comp_costs_from_descs, build_dp_costs
from ..cost import EmbeddingOpCost, EmbeddingGradOpCost, AllreduceSumOpCost, IdentityOpCost
from ..cost import EmbeddingOpCost, EmbeddingGradOpCost
from paddle.distributed.auto_parallel.cost.comm_op_cost import AllreduceSumOpCost, IdentityOpCost
class DistributedEmbedding(DistributedOperatorImplContainer):
......
......@@ -32,7 +32,7 @@ from .dist_default import DistributedDefaultImpl0
from ..cost import FillConstantBatchSizeLikeOpCost
from ..cost import build_comp_desc_from_dist_op, build_dp_costs
from ..cost import build_comp_costs_from_descs
from ..cost import AllreduceSumOpCost
from paddle.distributed.auto_parallel.cost.comm_op_cost import AllreduceSumOpCost
class DistributedFillConstantBatchSizeLike(DistributedOperatorImplContainer):
......
......@@ -39,8 +39,9 @@ from ..utils import _get_comm_group, _get_corresponding_rank
from .dist_default import DistributedDefaultImpl0
from ..cost import build_comp_desc_from_dist_op, build_comm_desc_from_dist_op, build_dp_costs
from ..cost import build_comm_costs_from_descs, build_comp_costs_from_descs
from ..cost import MatmulV2OpCost, MatmulOpCost, MulOpCost, IdentityOpCost, AllreduceSumOpCost
from ..cost import MatmulV2OpCost, MatmulOpCost, MulOpCost
from ..cost import MatmulV2GradOpCost, MatmulGradOpCost, MulGradOpCost
from paddle.distributed.auto_parallel.cost.comm_op_cost import AllreduceSumOpCost, IdentityOpCost
def copy_op_with_new_input_output(ctx, block, src_op, **kwargs):
......
......@@ -24,11 +24,12 @@ from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
from .dist_default import DistributedDefaultImpl0
from ..cost import AllreduceSumOpCost, _g_op_cost_factory
from ..cost import _g_op_cost_factory
from ..cost import build_comp_desc_from_dist_op, build_dp_costs
from ..cost import build_comp_costs_from_descs
from ..cost import SoftmaxOpCost, SoftmaxGradOpCost
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.distributed.auto_parallel.cost.comm_op_cost import AllreduceSumOpCost
class DistributedSoftmax(DistributedOperatorImplContainer):
......
......@@ -24,10 +24,11 @@ from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
from .dist_default import DistributedDefaultImpl0
from ..cost import AllreduceSumOpCost, Transpose2OpCost, Transpose2GradOpCost
from ..cost import Transpose2OpCost, Transpose2GradOpCost
from ..cost import build_comp_desc_from_dist_op, build_comm_desc_from_dist_op, build_dp_costs
from ..cost import build_comp_costs_from_descs
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.distributed.auto_parallel.cost.comm_op_cost import AllreduceSumOpCost
class DistributedTranspose2(DistributedOperatorImplContainer):
......
......@@ -2065,3 +2065,209 @@ class Resharder:
# reset some variable when remove operation ended
Resharder.while_block_info = {}
def get_cost(self, op, tensor, cluster):
# NOTE: The program should be the serial_program which is not been parted
global _g_special_ops
not_supported_op_type = _g_special_ops + ["while"]
reshard_op_cost = None
if op.type in not_supported_op_type:
return reshard_op_cost
else:
tensor_name = tensor.name
if tensor_name == "lod_tensor_blocking_queue_0":
return reshard_op_cost
else:
dist_tensor = self.dist_context.get_dist_tensor_for_program(
tensor)
# simplified processing: ignore union process mesh and output reshard
dist_op = self.dist_context.get_dist_op_for_program(op)
dims_mapping = dist_op.dist_attr.get_input_dims_mapping(
tensor.name)
process_mesh = dist_op.dist_attr.process_mesh
dist_attr = [process_mesh, dims_mapping]
if dist_tensor is not None and self.need_reshard(
dist_tensor, dist_attr):
if tensor_name not in self._has_resharded:
self._has_resharded[tensor_name] = [dist_op]
else:
for item in self._has_resharded[tensor_name]:
item_dist_attr = item.dist_attr
item_dims_mapping = item_dist_attr.get_input_dims_mapping(
tensor_name)
item_process_mesh = item_dist_attr.process_mesh
if dims_mapping == item_dims_mapping and item_process_mesh == process_mesh:
return reshard_op_cost
self._has_resharded[tensor_name].append(dist_op)
reshard_op_desc = self.find_op_desc_seq(dist_tensor,
dist_attr,
serial=True)
dtype = dist_tensor.serial_tensor.dtype
reshard_op_cost = self.parse_op_desc_for_cost(
reshard_op_desc, dtype, cluster)
return reshard_op_cost
def _concat_partitions_for_cost(self, partition_tensor_list,
partition_index, dtype, rank_id,
local_rank_comp_cost, cluster):
if not partition_tensor_list:
partition_tensor_list.append(partition_index)
else:
i = 0
has_concat = False
while i < len(partition_tensor_list):
concat_axis, first_order, new_partition = Resharder.compute_concat_info(
partition_tensor_list[i], partition_index)
if concat_axis != -1:
has_concat = True
concat_desc = {}
concat_desc["op"] = "concat"
concat_desc["attrs"] = {"axis": concat_axis}
if first_order == 0:
concat_desc["inputs"] = {
"X": [(dtype, partition_tensor_list[i]),
(dtype, partition_index)]
}
else:
concat_desc["inputs"] = {
"X": [(dtype, partition_index),
(dtype, partition_tensor_list[i])]
}
partition_tensor_list.pop(i)
if rank_id not in local_rank_comp_cost:
local_rank_comp_cost[rank_id] = []
local_rank_comp_cost[rank_id].append(
ConcatOpCost(op_desc=concat_desc, cluster=cluster))
self._concat_partitions_for_cost(partition_tensor_list,
new_partition, dtype,
rank_id,
local_rank_comp_cost,
cluster)
break
i += 1
if not has_concat:
partition_tensor_list.append(partition_index)
def parse_op_desc_for_cost(self, reshard_op_desc, dtype, cluster):
def _get_idx(comm_ranks, group_ranks):
res, is_the_same = None, False
idx = 0
while idx < len(comm_ranks):
if comm_ranks[idx] == set(group_ranks):
is_the_same = True
for rank in group_ranks:
if rank in comm_ranks[idx]:
res = idx
comm_ranks[idx].add(rank)
if res is None:
idx += 1
else:
break
return res, is_the_same
comm_context = CommContext(cluster)
# run communication op before computation op
# TODO: Communication cost is not calculated when the var has been transfered by the same group in the past
comm_costs = []
comm_ranks = []
local_rank_comp_cost = {}
for key in reshard_op_desc:
partition_tensor_list = []
op_desc_list = reshard_op_desc[key]
for op_desc in op_desc_list:
if isinstance(op_desc, SendOpDesc):
group_ranks = [key, op_desc.dst]
shape = op_desc.shape
send_desc = build_comm_desc("send_v2", group_ranks, dtype,
shape)
idx, is_the_same = _get_idx(comm_ranks, group_ranks)
if idx is None:
comm_costs.append([
(group_ranks,
SendOpCost(op_desc=send_desc,
comm_context=comm_context))
])
comm_ranks.append(set(group_ranks))
else:
if not is_the_same:
comm_costs[idx].append(
(group_ranks,
SendOpCost(op_desc=send_desc,
comm_context=comm_context)))
elif isinstance(op_desc, AllGatherOpDesc):
# NOTE: fill_const and other unnecessary op is not calculated because those cost is very small
group_ranks = op_desc.group
shape = op_desc.shape
allgather_desc = build_comm_desc("c_allgather", group_ranks,
dtype, shape)
split_inputs_shape = []
for idx, dim in enumerate(shape):
if idx == 0:
split_inputs_shape.append(dim * len(group_ranks))
else:
split_inputs_shape.append(dim)
idx, is_the_same = _get_idx(comm_ranks, group_ranks)
if idx is None:
comm_costs.append([
(group_ranks,
AllgatherOpCost(op_desc=allgather_desc,
comm_context=comm_context))
])
comm_ranks.append(set(group_ranks))
else:
if not is_the_same:
comm_costs[idx].append(
(group_ranks,
AllgatherOpCost(op_desc=allgather_desc,
comm_context=comm_context)))
# calc the split op cost
if key not in local_rank_comp_cost:
local_rank_comp_cost[key] = []
split_desc = {}
split_desc["op"] = "split"
split_desc["inputs"] = {
"inputs": [(dtype, split_inputs_shape)]
}
split_desc["attrs"] = {"num": len(group_ranks), "axis": 0}
local_rank_comp_cost[key].append(
SplitOpCost(op_desc=split_desc, cluster=cluster))
elif isinstance(op_desc, ConcatOpDesc):
partition_index_list = op_desc._partition_index_list
for idx, partion_idex in enumerate(partition_index_list):
self._concat_partitions_for_cost(
partition_tensor_list, partion_idex, dtype, key,
local_rank_comp_cost, cluster)
elif isinstance(op_desc, SliceOpDesc):
if key not in local_rank_comp_cost:
local_rank_comp_cost[key] = []
assert len(
partition_tensor_list) == 1 or not partition_tensor_list
to_slice_tensor_shape = []
if len(partition_tensor_list) == 1:
for item in partition_tensor_list[0]:
to_slice_tensor_shape.append(item[1] - item[0])
else:
to_slice_tensor_shape = op_desc.shape
slice_desc = {}
slice_desc["op"] = "slice"
infer_flags = list(1 for i in range(len(op_desc.axes)))
slice_desc["attrs"] = {
"axes": op_desc.axes,
"starts": op_desc.starts,
"ends": op_desc.ends,
"infer_flags": infer_flags
}
slice_desc["inputs"] = {
"Input": [(dtype, to_slice_tensor_shape)]
}
local_rank_comp_cost[key].append(
SliceOpCost(op_desc=slice_desc, cluster=cluster))
res = (comm_costs, local_rank_comp_cost)
return res
......@@ -15,6 +15,9 @@
import paddle
import paddle.static as static
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.cost import CostEstimator
from paddle.distributed.auto_parallel.cluster import Cluster
from paddle.distributed.auto_parallel.dist_context import get_default_distributed_context
def train():
......@@ -39,6 +42,30 @@ def train():
_, _, distributed_startup_program, distributed_main_program = optimizer.minimize(
loss, start_program)
# add cost estimator
dist_context = get_default_distributed_context()
cluster = Cluster()
for op in train_program.global_block().ops:
dist_op = dist_context.get_dist_op_for_program(op)
for var_name in op.input_arg_names:
dims_mapping = dist_op.dist_attr.get_input_dims_mapping(var_name)
if dims_mapping is None:
dist_op.dist_attr.set_input_dims_mapping(
var_name, [
-1 for i in range(
len(train_program.global_block().vars[var_name].
shape))
])
cluster.gen_default_config_cluster(device_count=2)
cost_estimator = CostEstimator(train_program, cluster)
global_cost = cost_estimator.estimate(dist_context)
max_memory = cost_estimator._estimate_max_memory_by_dist_op(dist_context)
# test cache
global_cost = cost_estimator.estimate(dist_context)
max_memory = cost_estimator._estimate_max_memory_by_dist_op(dist_context)
assert global_cost.time > 0
assert max_memory > 0
places = static.cuda_places()
loader.set_batch_generator(batch_generator_creator(), places=places)
exe = paddle.static.Executor(places[0])
......
......@@ -19,6 +19,7 @@ import tempfile
import paddle
import paddle.distributed.auto_parallel.cost as cost_model
from paddle.distributed.auto_parallel.cost.base_cost import build_comp_desc_from_op
from paddle.distributed.auto_parallel.cost.base_cost import build_comp_desc_str_for_predict
from paddle.distributed.auto_parallel.cost.base_cost import calc_time_by_modeling
......
......@@ -29,6 +29,8 @@ from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import Resharder
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.cost import CostEstimator
from paddle.distributed.auto_parallel.cluster import Cluster
paddle.enable_static()
_global_parallel_strategy = "dp_mp_pp"
......@@ -196,6 +198,21 @@ class TestMLPReshard(unittest.TestCase):
rank_id = 2
dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog(
train_program, startup_program, dist_context, rank_id)
# test estimator
cluster = Cluster()
cluster.gen_default_config_cluster(device_count=8)
cost_estimator = CostEstimator(train_program, cluster)
global_cost = cost_estimator.estimate(dist_context)
max_memory = cost_estimator._estimate_max_memory_by_dist_op(
dist_context)
# test cache
global_cost = cost_estimator.estimate(dist_context)
max_memory = cost_estimator._estimate_max_memory_by_dist_op(
dist_context)
assert global_cost.time > 0
assert max_memory > 0
resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
dist_context, dist_params_grads)
resharder.reshard()
......
......@@ -29,6 +29,8 @@ from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import Resharder
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.cost import CostEstimator
from paddle.distributed.auto_parallel.cluster import Cluster
paddle.enable_static()
_global_parallel_strategy = "mp_pp"
......@@ -247,7 +249,7 @@ class TestMLPReshard(unittest.TestCase):
def test_allgather(self):
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
process_mesh = auto.ProcessMesh(mesh=[0, 3])
process_mesh = auto.ProcessMesh(mesh=[0, 1])
with static.program_guard(train_program, startup_program):
x = paddle.static.data(name="x", shape=[4, 4], dtype='float32')
x = auto.shard_tensor(x,
......@@ -284,6 +286,21 @@ class TestMLPReshard(unittest.TestCase):
dist_context.block_state.parse_forward_blocks(complete_train_program)
partitioned_main_prog, partitioned_startup_prog, partitioned_params_grads = partitioner.partition(
complete_train_program, startup_program, [])
# test estimator
cluster = Cluster()
cluster.gen_default_config_cluster(device_count=2)
cost_estimator = CostEstimator(train_program, cluster)
global_cost = cost_estimator.estimate(dist_context)
max_memory = cost_estimator._estimate_max_memory_by_dist_op(
dist_context)
# test cache
global_cost = cost_estimator.estimate(dist_context)
max_memory = cost_estimator._estimate_max_memory_by_dist_op(
dist_context)
assert global_cost.time > 0
assert max_memory > 0
resharder = Resharder(partitioned_main_prog, partitioned_startup_prog,
rank_id, dist_context, partitioned_params_grads)
resharder.reshard()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册