未验证 提交 a573a7ed 编写于 作者: Y YipZLF 提交者: GitHub

Add auto parallel cost model and unittests (#36363)

* Add auto parallel cost model and unittests

* Fixed code styles.

* Fixed bugs and codes style

* fixed typo

* Improved code style: object encapsulation.

* Fixed codes.

* Refractored estimate_cost

* Fixed typo
上级 1d5746bd
......@@ -21,5 +21,6 @@ from .interface import ProcessMesh # noqa: F401
from .completion import complete_annotation # noqa: F401
from .completion import complete_backward_annotation # noqa: F401
from .reshard import reshard # noqa: F401
from .cost_model import estimate_cost
__all__ = []
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import json
import queue
import copy
from enum import Enum
import paddle
SUCC = 0 # successor
PRED = 1 # predecessor
class CostNodeType(Enum):
DEFAULT = 0
COMPUTATION = 1
COMMUNICATION = 2
VARIABLE = 3
MERGED = 4
NOP = 5
class Cost(object):
def __init__(self):
self.runtime = None
self.static_mem = None
self.peak_mem = None
class CostModelMode(Enum):
DEFAULT = 0
BENCHMARKING = 1 # costs based on trial runs
ANALYSIS = 2 # costs based on analysis
MIXED = 3
class CostNode(object):
def __init__(self, node, node_type, id=None):
self.id = id
self.node = node
self.type = node_type
self._cost = 0
self.is_optim = False
self.is_bwd = False
@property
def cost(self):
return self._cost
@cost.setter
def cost(self, cost):
if cost < 0:
raise ValueError('Cost must be above 0.')
self._cost = cost
class MergedOpsCostNode(CostNode):
def __init__(self, node_type, id=None, base_node_list=None, is_bwd=False):
super(MergedOpsCostNode, self).__init__(None, node_type, id)
self.node_list = base_node_list
self.is_bwd = is_bwd
class CommOpCostNode(CostNode):
def __init__(self,
node,
node_type,
id=None,
comm_node_list=None,
is_bwd=False):
super(CommOpCostNode, self).__init__(node, node_type, id)
self.node_list = comm_node_list
self.ranks = []
self.comm_type = node.type
self.is_bwd = is_bwd
def set_ranks(self, ranks):
self.ranks = ranks
def set_shapes(self, input_shape, output_shape):
self.input_shape = input_shape
self.output_shape = output_shape
def init_comm_cost(self, cluster=None):
# ref: https://github.com/NVIDIA/nccl-tests/blob/master/doc/PERFORMANCE.md
# should get from `cluster`
BANDWIDTH = 32 * 1024 / 1000 # MB/ms, V100 PCIe
num_ranks = len(self.ranks)
comm_volumn = np.prod(self.input_shape) * 4
if 'allreduce' in self.comm_type:
self._cost = comm_volumn / (BANDWIDTH * num_ranks /
(2 * (num_ranks - 1)))
elif 'gather' in self.comm_type:
self._cost = comm_volumn / (BANDWIDTH * num_ranks / (num_ranks - 1))
elif 'broadcast' in self.comm_type:
self._cost = comm_volumn / BANDWIDTH
elif 'send' in self.comm_type or 'recv' in self.comm_type:
self._cost = comm_volumn / BANDWIDTH
else:
self._cost = 0
class TensorCostNode(CostNode):
def __init__(self,
node,
node_type,
id=None,
base_node_list=None,
batch_size=None,
shared_node_id=None):
super(TensorCostNode, self).__init__(node, node_type, id)
self.shape = node.shape
self.dtype = node.dtype
self.dtype_factor = 1
self.persistable = None
self.shared_node_id = shared_node_id
if self.dtype == paddle.float32 or node.dtype == paddle.int32:
self.dtype_factor *= 4
elif node.dtype == paddle.int64:
self.dtype_factor *= 8
else:
raise NotImplementedError("{} not counted".format(v.node.dtype))
self.batch_size = None
if batch_size is not None:
self.batch_size = batch_size
def get_size(self):
p = 1
for i in self.node.shape:
if i == -1: # deal with placeholder
assert self.batch_size is not None, "Batch size not decided."
i = self.batch_size
p *= i
return p
class CompOpCostNode(CostNode):
def __init__(self, node, node_type, id=None, is_bwd=False, is_optim=False):
super(CompOpCostNode, self).__init__(node, node_type, id)
self.is_bwd = is_bwd
self.is_optim = is_optim
def init_comp_cost(self, cost_data):
# TODO: improve fluid.CostModel for more specific cost_data
op_name = self.node.type
if op_name in cost_data.keys():
self.cost = cost_data[op_name]
else:
self.cost = 0.0
class PipeEvent(object):
def __init__(self, stage_id, event_name, duration, start_time=-1):
self.stage_id = stage_id
self.name = event_name
self.duration = duration
self.s_time = start_time
self.e_time = -1
class CostModel(object):
def __init__(self,
mode=CostModelMode.BENCHMARKING,
cluster=None,
batch_size=1,
microbatch_num=1,
opcall_overhead=0,
standalone_cost_data=None,
pipeline_config=None):
self.mode = mode
# parameters
self.opcall_overhead = opcall_overhead
self.batch_size = batch_size
self.microbatch_num = microbatch_num
self.nodes = {} # name -> node
self.origin_graph = {} # original graph
self.op_graph = {} # op graph (no variables nodes)
self.runtime_graph = {} # runtime graph, for simulation
self.cluster = cluster
self.cost_data = standalone_cost_data
self.pp2rank = pipeline_config
if self.pp2rank is not None:
self.rank2pp = {}
for stage_idx, ranks in enumerate(self.pp2rank):
for rank in ranks:
self.rank2pp[rank] = stage_idx
else:
self.rank2pp = None
self.ring2rank = {}
self.fwd_time = []
self.bwd_time = []
self.optim_time = []
def _parse_sub_program(self, program, nodes, graph, cost_data, sub_idx):
assert len(
program.blocks) == 1, "Program more than 1 block not supported."
block = program.blocks[0]
for var in block.vars.values():
var_id = var.name
nodes[var_id] = TensorCostNode(var, CostNodeType.VARIABLE, var_id)
graph[var_id] = [[], []]
for op in block.ops:
op_id = op.type + "_" + str(op.idx)
if op.type.startswith('c_') or op.type.startswith(
'send') or op.type.startswith('recv'):
is_bwd = False
if op.type.startswith('c_'):
ring_id = op.attr('ring_id')
if ring_id not in self.ring2rank:
self.ring2rank[ring_id] = set()
self.ring2rank[ring_id].add(sub_idx)
is_bwd = '@GRAD' in op.output('Out')[0]
elif op.type.startswith('recv'):
is_bwd = '@GRAD' in op.output('Out')[0]
elif op.type.startswith('send'):
is_bwd = '@GRAD' in op.input('X')[0]
op_node = CommOpCostNode(op, CostNodeType.COMMUNICATION, op_id,
is_bwd)
else:
is_bwd = '_grad' in op.type
is_optim = 'LearningRate' in op.input_names
op_node = CompOpCostNode(op, CostNodeType.COMPUTATION, op_id,
is_bwd, is_optim)
op_node.init_comp_cost(cost_data)
nodes[op_id] = op_node
graph[op_id] = [[], []]
comm_input_shape = [0]
comm_output_shape = [0]
for i in range(len(op.input_names)):
try:
var_id = op.input(op.input_names[i])[0]
var_node = nodes[var_id]
graph[op_id][PRED].append(var_node.id)
graph[var_id][SUCC].append(op_node.id)
comm_input_shape = var_node.shape
except:
continue
for i in range(len(op.output_names)):
try:
var_id = op.output(op.output_names[i])[0]
var_node = nodes[var_id]
graph[op_id][SUCC].append(var_node.id)
graph[var_id][PRED].append(op_node.id)
comm_output_shape = var_node.shape
except:
continue
if op_node.type == CostNodeType.COMMUNICATION:
op_node.set_shapes(comm_input_shape, comm_output_shape)
# resolve hazard: rename the r/w hazard variable nodes to ensure self.origin_graph is a DAG
new_var_dict = {}
for node_id, node in nodes.items():
if node.type == CostNodeType.VARIABLE and node.node.persistable:
write_op_cnt = 0
for pred_id in graph[node_id][PRED]:
pred = nodes[pred_id]
if pred.type == CostNodeType.COMPUTATION and (
pred_id in graph[node_id][SUCC]):
graph[pred_id][SUCC].remove(node_id)
graph[node_id][PRED].remove(pred_id)
write_op_cnt += 1
new_var_id = node_id + '_write_{}'.format(write_op_cnt)
new_var = TensorCostNode(
node.node,
CostNodeType.VARIABLE,
new_var_id,
shared_node_id=node_id)
graph[new_var_id] = [[], []]
graph[pred_id][SUCC].append(new_var_id)
graph[new_var_id][PRED].append(pred_id)
new_var_dict[new_var_id] = new_var
for k, v in new_var_dict.items():
nodes[k] = v
return nodes
def parse_program(self, distributed_program):
self.distributed_program = distributed_program
self.total_rank = len(self.distributed_program)
sub_prog_cnt = len(distributed_program)
self.nodes = [] * sub_prog_cnt
self.origin_graph = [] * sub_prog_cnt # original graph
self.op_graph = [] * sub_prog_cnt # op graph (no variables nodes)
self.runtime_graph = [] * sub_prog_cnt # runtime graph, for simulation
for sub_idx, sub_prog in enumerate(distributed_program):
self.nodes.append({})
self.origin_graph.append({})
self.op_graph.append({})
self.runtime_graph.append({})
self._parse_sub_program(
sub_prog, self.nodes[sub_idx], self.origin_graph[sub_idx],
self.cost_data[0 if self.rank2pp is None else self.rank2pp[
sub_idx]], sub_idx)
return self.nodes
def _find_succ_op(self, node_id, sub_idx=0):
succ_ops_id = []
for succ_id in self.origin_graph[sub_idx][node_id][SUCC]:
succ = self.nodes[sub_idx][succ_id]
if succ.type == CostNodeType.COMMUNICATION or \
succ.type == CostNodeType.COMPUTATION:
succ_ops_id.append(succ_id)
elif succ.type == CostNodeType.VARIABLE:
succ_ops_id = succ_ops_id + self._find_succ_op(succ_id, sub_idx)
else:
raise NotImplementedError(
'This type of node not supported yet:{}'.format(succ.type))
return succ_ops_id
def build_op_graph(self):
for sub_idx in range(self.total_rank):
op_nodes_id = []
for node_id, node in self.nodes[sub_idx].items():
if node.type == CostNodeType.VARIABLE:
continue
self.op_graph[sub_idx][node_id] = [[], []]
op_nodes_id.append(node_id)
for op_id in op_nodes_id:
succ_nodes_id = self._find_succ_op(op_id, sub_idx)
self.op_graph[sub_idx][op_id][SUCC] = succ_nodes_id
for succ_id in succ_nodes_id:
self.op_graph[sub_idx][succ_id][PRED].append(op_id)
def build_runtime_graph(self):
self.runtime_graph = copy.deepcopy(self.op_graph)
def eliminate_multi_edges(self, graph=None):
for node_id, edges in graph.items():
graph[node_id][PRED] = list(set(edges[PRED]))
graph[node_id][SUCC] = list(set(edges[SUCC]))
def merge_comm(self):
for sub_idx in range(self.total_rank):
for node_id, edges in self.op_graph[sub_idx].items():
node = self.nodes[sub_idx][node_id]
if node_id.startswith('c_'):
ring_id = node.node.attr('ring_id')
node.set_ranks(list(self.ring2rank[ring_id]))
node.init_comm_cost(self.cluster)
elif node_id.startswith('send') or node_id.startswith('recv'):
peer_rank = node.node.attr('peer')
node.set_ranks([sub_idx, peer_rank])
node.init_comm_cost(self.cluster)
else:
pass # Not communication op
def _merge_node(self, to_merge_node_list, merge_type='linear', nodes=None):
nodes_list = []
node_cost = 0
for node in to_merge_node_list:
if isinstance(node, MergedOpsCostNode):
nodes_list += node.node_list
else:
nodes_list.append(node.id)
if merge_type == 'linear':
node_cost += node.cost
elif merge_type == 'branch':
node_cost = max(node_cost, node.cost)
else:
raise NotImplementedError(
'This type of merging is not supported:{}'.format(
merge_type))
merged_node_id = 'merged_' + str(len(nodes))
is_bwd = to_merge_node_list[0].is_bwd
merged_node = MergedOpsCostNode(
CostNodeType.MERGED,
id=merged_node_id,
base_node_list=nodes_list,
is_bwd=is_bwd)
merged_node.cost = node_cost
return merged_node_id, merged_node
def merge_linear(self):
'''
This method does the following:
If X depends on Y only, they must be run sequentially.
[ e.g. A ->- C ->- D D and E depends on C only.]
[ B ->-/ \->- E C depends on A and B. ]
We merge X and Y into a new node and sum up their cost time.
'''
cnt = 0
for sub_idx in range(self.total_rank):
cnt += self._merge_linear(
self.nodes[sub_idx], self.runtime_graph[sub_idx], is_bwd=False)
cnt += self._merge_linear(
self.nodes[sub_idx], self.runtime_graph[sub_idx], is_bwd=True)
return cnt
def merge_branch(self):
'''
This method does the following:
If a node has more than one successor, there is *branch*.
[ e.g. A ->- B ->- D ]
[ \->- C ->- / , B and C can be run at the same time ]
case 1: if B or C is null (or D is directly dependent on A),
it's equivalent to A->C->D or A->B->D, fall back to self.merge_linear
case 2: if both B and C are some op,
merged_cost = max(cost(B), cost(C))
'''
cnt = 0
for sub_idx in range(self.total_rank):
cnt += self._merge_branch(
self.nodes[sub_idx], self.runtime_graph[sub_idx], is_bwd=False)
cnt += self._merge_branch(
self.nodes[sub_idx], self.runtime_graph[sub_idx], is_bwd=True)
return cnt
def _merge_linear(self, nodes, runtime_graph, is_bwd=False):
reduct_cnt = 0
rt_nodes_id = list(runtime_graph.keys())
for node_id in rt_nodes_id:
if node_id not in runtime_graph.keys():
continue
node = nodes[node_id]
if not is_bwd == node.is_bwd or node.is_optim:
continue
edges = runtime_graph[node_id]
ind = len(edges[PRED]) # in_degree
if ind == 1: # only depend on one node
pred_id = edges[PRED][0]
pred = nodes[pred_id]
merged_node_id, merged_node = self._merge_node(
[node, pred], merge_type='linear', nodes=nodes)
nodes[merged_node_id] = merged_node
runtime_graph[merged_node_id] = [[], []]
# delete edges and add new edges
succ = None
runtime_graph[merged_node_id][SUCC] = copy.deepcopy(edges[SUCC])
if len(runtime_graph[pred_id][SUCC]) > 1:
# predecessor has more than 1 successor
# the merged_node is to inherit the rest of its successors
succ = runtime_graph[pred_id][SUCC]
succ.remove(node_id)
runtime_graph[merged_node_id][SUCC] += succ
runtime_graph[merged_node_id][PRED] = runtime_graph[pred_id][
PRED]
for i in runtime_graph[pred_id][PRED]:
runtime_graph[i][SUCC].remove(pred_id)
runtime_graph[i][SUCC].append(merged_node_id)
for i in edges[SUCC]:
runtime_graph[i][PRED].remove(node_id)
runtime_graph[i][PRED].append(merged_node_id)
if succ is not None:
for i in succ:
runtime_graph[i][PRED].remove(pred_id)
runtime_graph[i][PRED].append(merged_node_id)
runtime_graph.pop(node_id)
runtime_graph.pop(pred_id)
reduct_cnt += 1
self.eliminate_multi_edges(runtime_graph)
return reduct_cnt # the number of nodes that have been reduced
def _merge_branch(self, nodes, runtime_graph, is_bwd=False):
reduct_cnt = 0
rt_nodes_id = list(runtime_graph.keys())
for node_id in rt_nodes_id:
node = nodes[node_id]
if not is_bwd == node.is_bwd or node.is_optim:
continue
edges = runtime_graph[node_id]
outd = len(edges[SUCC]) # out_degree
if outd > 1: # branch out
succ_nodes_id = edges[SUCC]
succ_to_elim = []
for succ_id in succ_nodes_id:
for succ_2_id in succ_nodes_id:
tmp = runtime_graph[succ_2_id][SUCC]
if succ_id in tmp:
succ_to_elim.append(succ_id)
break
for id in succ_to_elim:
edges[SUCC].remove(id)
runtime_graph[id][PRED].remove(node_id)
reduct_cnt += 1
to_merge = True
if len(edges[SUCC]) < 1 or len(runtime_graph[edges[SUCC][0]][
SUCC]) < 1:
continue
end_node_id = runtime_graph[edges[SUCC][0]][SUCC][0]
for i in succ_nodes_id:
if len(runtime_graph[i][SUCC]) != 1 or \
runtime_graph[i][SUCC][0] != end_node_id:
to_merge = False # if branches has different end node, we don't merge them
break
if to_merge:
to_merge_node_list = [nodes[i] for i in succ_nodes_id]
merged_node_id, merged_node = self._merge_node(
to_merge_node_list, merge_type='branch', nodes=nodes)
nodes[merged_node_id] = merged_node
runtime_graph[merged_node_id] = [[], []]
# delete edges and add new edges
runtime_graph[merged_node_id][SUCC] = [end_node_id]
runtime_graph[merged_node_id][PRED] = edges[PRED]
runtime_graph[end_node_id][PRED] = [merged_node_id]
runtime_graph[node_id][SUCC] = [merged_node_id]
for i in succ_nodes_id:
runtime_graph.pop(i)
reduct_cnt += len(to_merge_node_list) - 1
return reduct_cnt
def get_runtime_cost(self):
def get_node_cost(node):
node_cost = node.cost + self.opcall_overhead
if isinstance(node, MergedOpsCostNode):
for it in node.node_list:
node_cost += self.opcall_overhead
return node_cost
for sub_idx in range(self.total_rank):
fwd_cost = 0
bwd_cost = 0
optim_cost = 0
for node_id in self.runtime_graph[sub_idx].keys():
node = self.nodes[sub_idx][node_id]
if node.is_optim:
optim_cost += get_node_cost(node)
elif node.is_bwd:
bwd_cost += get_node_cost(node)
else:
fwd_cost += get_node_cost(node)
self.fwd_time.append(fwd_cost)
self.bwd_time.append(bwd_cost)
self.optim_time.append(optim_cost)
return self.fwd_time, self.bwd_time, self.optim_time
def get_mem(self):
static_list = []
top_list = []
for sub_idx in range(self.total_rank):
static_mem, cur_mem, top_mem = self._simulate_mem(
self.nodes[sub_idx], self.origin_graph[sub_idx])
static_list.append(static_mem)
top_list.append(top_mem)
return static_list, top_list
def _simulate_mem(self, nodes, origin_graph):
q = queue.Queue(1024)
sim_graph = copy.deepcopy(origin_graph)
for node_id, node in nodes.items():
if len(sim_graph[node_id][PRED]) == 0:
q.put(node_id)
q.put('nop')
cur_mem = 0
top_mem = -1
static_mem = 0
while not q.empty():
node_id = q.get()
node = None
size = 0
if node_id == 'nop':
top_mem = max(cur_mem, top_mem)
if q.empty():
break
else:
q.put(node_id)
continue
else:
node = nodes[node_id]
if node.type == CostNodeType.VARIABLE:
size = node.get_size()
if node.node.persistable:
static_mem += size
cur_mem += size
edges = sim_graph[node_id]
if not (node.type == CostNodeType.VARIABLE and
node.node.persistable):
for succ_id in edges[SUCC]:
sim_graph[succ_id][PRED].remove(node_id)
if len(sim_graph[succ_id][PRED]) == 0:
q.put(succ_id)
for pred_id in edges[PRED]:
pred = nodes
if pred.type == CostNodeType.VARIABLE:
sim_graph[pred_id][SUCC].remove(node_id)
if len(sim_graph[pred_id][
SUCC]) == 0 and not pred.node.persistable:
cur_mem -= pred.get_size()
return static_mem, cur_mem, top_mem
def get_pipeline_time(self):
if self.total_rank <= 1:
return self.fwd_time[0] + self.bwd_time[0] + self.optim_time[0]
else:
return self._simulate_pipeline()
def _simulate_pipeline(self):
stage_num = len(self.pp2rank)
event_list = []
global_time = [0] * stage_num
total_time = 0
fwd_cnt = list(range(stage_num, 0, -1))
bwd_cnt = [self.microbatch_num] * stage_num
q = queue.Queue(1024)
for i in range(self.microbatch_num):
q.put(PipeEvent(0, 'fwd', self.fwd_time[0]))
while not q.empty():
e = q.get()
stid = e.stage_id
if e.name == 'fwd':
if fwd_cnt[stid] > 0:
e.s_time = max(global_time[stid], e.s_time)
e.e_time = e.s_time + e.duration
event_list.append(e)
if stid != stage_num - 1:
q.put(
PipeEvent(
stid + 1,
'fwd',
self.fwd_time[stid + 1],
start_time=e.e_time))
else:
q.put(
PipeEvent(
stid,
'bwd',
self.bwd_time[stid],
start_time=e.e_time))
fwd_cnt[stid] -= 1
global_time[stid] = e.e_time
else:
q.put(e)
elif e.name == 'bwd':
e.s_time = max(global_time[stid], e.s_time)
e.e_time = e.s_time + e.duration
event_list.append(e)
if stid != 0:
q.put(
PipeEvent(
stid - 1,
'bwd',
self.bwd_time[stid - 1],
start_time=e.e_time))
fwd_cnt[stid] += 1
bwd_cnt[stid] -= 1
if bwd_cnt[stid] == 0:
q.put(
PipeEvent(
stid,
'optim',
self.optim_time[stid],
start_time=e.e_time))
global_time[stid] = e.e_time
elif e.name == 'optim':
e.s_time = max(global_time[stid], e.s_time)
e.e_time = e.s_time + e.duration
event_list.append(e)
global_time[stid] = e.e_time
else:
raise NotImplementedError(
'This type of pipe event is not supported yet.{}'.format(
e.name))
for t in global_time:
total_time = max(total_time, t)
return total_time
def get_cost(self):
cost = Cost()
static_mem, peak_mem = self.get_mem()
cost.static_mem = static_mem
cost.peak_mem = peak_mem
self.merge_comm()
while True:
cnt = 0
cnt += self.merge_linear()
cnt += self.merge_branch()
if cnt == 0: # can't be further merged
break
self.get_runtime_cost()
cost.runtime = self.get_pipeline_time()
return cost
def init(self, distributed_program):
self.parse_program(distributed_program)
self.build_op_graph()
for sub_idx in range(self.total_rank):
self.eliminate_multi_edges(self.op_graph[sub_idx])
self.build_runtime_graph()
def estimate_cost(distributed_program, cluster, pipeline_config,
standalone_cost_data, batch_size):
"""
Estimated cost from distributed program, cluster model and distributed settings.
Args:
distributed_program(list): list of paddle programs
cluster(Cluster): cluster model
standalone_cost_data(CostData): cost data given by paddle.core
batch_size(int): batch size of the training workload
pipeline_config(list): configuration of pipeline stage allocation
"""
# the following line is left for now, cluster model will be involved in the future
assert cluster is None, "For now, cluster remains None"
cm_ctx = CostModel(
cluster=cluster,
batch_size=batch_size,
standalone_cost_data=standalone_cost_data,
pipeline_config=pipeline_config)
cm_ctx.init(distributed_program)
cost = cm_ctx.get_cost()
return cost
......@@ -91,6 +91,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard)
list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_serial)
list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_mppp)
list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_dpmppp)
list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_cost_model)
foreach(TEST_OP ${MIXED_DIST_TEST_OPS})
list(REMOVE_ITEM TEST_OPS ${TEST_OP})
endforeach()
......@@ -234,6 +235,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM))
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_serial)
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_mppp)
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_dpmppp)
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_cost_model)
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_data_unshard)
elseif(WITH_GPU)
if (${CUDNN_VERSION} VERSION_LESS 7100)
......@@ -608,6 +610,7 @@ if(WITH_DISTRIBUTE)
py_test_modules(test_auto_parallel_reshard_serial MODULES test_auto_parallel_reshard_serial ENVS ${dist_ENVS})
py_test_modules(test_auto_parallel_reshard_mppp MODULES test_auto_parallel_reshard_mppp ENVS ${dist_ENVS})
py_test_modules(test_auto_parallel_reshard_dpmppp MODULES test_auto_parallel_reshard_dpmppp ENVS ${dist_ENVS})
py_test_modules(test_auto_parallel_cost_model MODULES test_auto_parallel_cost_model ENVS ${dist_ENVS})
endif(NOT WIN32)
endif(NOT APPLE)
if(WITH_DGC)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle
import paddle.nn as nn
import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.context import DistributedContext
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.completion import complete_backward_annotation
from paddle.distributed.auto_parallel.reshard import reshard
from paddle.distributed.auto_parallel.cost_model import estimate_cost
import paddle.fluid.core as core
paddle.enable_static()
_global_parallel_strategy = "dp_mp_pp"
ROOT_MESH = auto.ProcessMesh([[[0, 1], [4, 5]], [[2, 3], [6, 7]]])
_global_process_mesh = auto.ProcessMesh(
[[[0, 1], [4, 5]], [[2, 3], [6, 7]]], parent=ROOT_MESH)
PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]], parent=ROOT_MESH)
PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]], parent=ROOT_MESH)
NUM_RANKS = 8
STAGE_0_CNT = 5
STAGE_1_CNT = 10
pp_cfg = [[0, 1, 4, 5], [2, 3, 6, 7]]
device = "gpu" if core.is_compiled_with_cuda() else "cpu"
class MLPLayer(nn.Layer):
def __init__(self,
hidden_size=256,
intermediate_size=4 * 256,
initializer_range=0.02,
is_distributed=True):
super(MLPLayer, self).__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range))
bias_attr = None
self.linear0 = nn.Linear(
d_model, dim_feedforward, weight_attr, bias_attr=bias_attr)
self.linear1 = nn.Linear(
dim_feedforward, d_model, weight_attr, bias_attr=bias_attr)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
self.is_distributed = is_distributed
def forward(self, input):
if self.is_distributed:
auto.shard_tensor(
self.linear0.weight, PP_MESH_0, dim_mapping=[-1, 1])
auto.shard_tensor(
self.linear1.weight, PP_MESH_1, dim_mapping=[1, -1])
out = self.norm(input)
out = self.linear0(out)
out = F.gelu(out, approximate=True)
out = self.linear1(out)
return out
def get_single_node_data():
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
loss, train_program, startup_program = mlp_forward(
train_program, startup_program, is_distributed=False)
cost_model = core.CostModel()
cost_data = cost_model.profile_measure(train_program, startup_program,
device, ["time"])
op_name2cost = [{}, {}]
for idx, op in enumerate(train_program.blocks[0].ops):
if idx <= STAGE_0_CNT:
op_name2cost[0][op.type] = cost_data.get_op_time_ms(idx)
elif idx <= STAGE_1_CNT:
op_name2cost[1][op.type] = cost_data.get_op_time_ms(idx)
return op_name2cost
def mlp_forward(train_program, start_program, is_distributed=True):
with static.program_guard(train_program,
start_program), utils.unique_name.guard():
batch_size = 4
hidden_size = 256
sequence_len = 128
if is_distributed:
input = static.data(
name="input", shape=[batch_size, hidden_size], dtype='float32')
label = static.data(
name="label", shape=[batch_size, 1], dtype='float32')
else:
input = paddle.ones(
name="input", shape=[batch_size, hidden_size], dtype='float32')
label = paddle.ones(
name="label", shape=[batch_size, 1], dtype='float32')
if is_distributed:
auto.shard_tensor(input, PP_MESH_0, dim_mapping=[0, -1])
auto.shard_tensor(label, PP_MESH_1, dim_mapping=[0, -1])
mlp = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
initializer_range=0.02,
is_distributed=is_distributed)
predict = mlp(input)
error_cost = paddle.nn.functional.square_error_cost(predict, label)
loss = paddle.mean(error_cost)
return loss, train_program, start_program
def get_dist_prog(train_program, startup_program, dist_context, rank_id):
global _global_process_mesh
dist_context.set_process_mesh(_global_process_mesh)
loss, train_program, startup_program = mlp_forward(train_program,
startup_program)
# auto completion
complete_train_program = auto.complete_annotation(train_program,
dist_context)
dist_strategy = fleet.DistributedStrategy()
dist_main_prog = []
dist_startup_prog = []
for rank_id in range(NUM_RANKS):
partitioner = Partitioner(dist_strategy, dist_context, rank_id)
# logical partition
auto_parallel_main_prog, auto_parallel_startup_prog = partitioner.transpile_forward(
complete_train_program, startup_program)
dist_params_grads = partitioner.apply_backward(
loss, complete_train_program, startup_program,
auto_parallel_main_prog, auto_parallel_startup_prog)
optimizer = paddle.fluid.optimizer.AdamOptimizer()
opt_ops = partitioner.apply_optimize(optimizer, dist_params_grads,
auto_parallel_main_prog,
auto_parallel_startup_prog)
dist_main_prog.append(auto_parallel_main_prog)
dist_startup_prog.append(auto_parallel_startup_prog)
return dist_main_prog, dist_startup_prog
def check_runtime_estimation(cost):
return cost.runtime > 0
def check_memory_estimation(cost):
for i in range(NUM_RANKS):
if cost.static_mem[i] <= 0 or cost.peak_mem[i] <= 0:
return False
if cost.static_mem[i] > cost.peak_mem[i]:
return False
return True
def check_empty_program_runtime(cost):
return cost.runtime == 0
def check_empty_program_memory(cost):
for mem in cost.peak_mem:
if mem > 0:
return False
for mem in cost.static_mem:
if mem > 0:
return False
return True
class TestCostModel(unittest.TestCase):
def test_empty_program_cost_model(self):
empty_program = paddle.static.Program()
startup_program = paddle.static.Program()
standalone_cost_data = [{}]
empty_pp_cfg = None
cluster = None
cost = estimate_cost(
[empty_program],
cluster=cluster,
pipeline_config=empty_pp_cfg,
standalone_cost_data=standalone_cost_data,
batch_size=1)
self.assertTrue(check_empty_program_runtime(cost))
self.assertTrue(check_empty_program_memory(cost))
def test_auto_parallel_cost_model(self):
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
dist_context = DistributedContext()
standalone_cost_data = get_single_node_data()
distributed_program, dist_startup_prog = get_dist_prog(
train_program, startup_program, dist_context, 0)
for rank_id in range(NUM_RANKS):
complete_backward_annotation(distributed_program[rank_id],
dist_context)
reshard(distributed_program[rank_id], dist_startup_prog[rank_id],
rank_id, dist_context)
cluster = None
cost = estimate_cost(
distributed_program,
cluster=cluster,
pipeline_config=pp_cfg,
standalone_cost_data=standalone_cost_data,
batch_size=4)
self.assertTrue(check_runtime_estimation(cost))
self.assertTrue(check_memory_estimation(cost))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册