未验证 提交 1bb2c68a 编写于 作者: C caozhou 提交者: GitHub

Add mcmc of planner, of update cost model and relaunch (#38177)

* add planner

* add planner

* add cost model update

* add relaunch updation

* update process_group

* fix error

* add unitest

* update unitest

* update cost model

* avoid api problem
上级 6f439e5a
...@@ -11,12 +11,17 @@ ...@@ -11,12 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np
import json import json
import queue import queue
import copy import copy
from enum import Enum from enum import Enum
import numpy as np
import paddle import paddle
from paddle.fluid import core
from paddle.distributed.fleet.meta_optimizers.common import OpRole
SUCC = 0 # successor SUCC = 0 # successor
PRED = 1 # predecessor PRED = 1 # predecessor
...@@ -121,6 +126,10 @@ class TensorCostNode(CostNode): ...@@ -121,6 +126,10 @@ class TensorCostNode(CostNode):
batch_size=None, batch_size=None,
shared_node_id=None): shared_node_id=None):
super(TensorCostNode, self).__init__(node, node_type, id) super(TensorCostNode, self).__init__(node, node_type, id)
if node.name == "create_py_reader_0" or node.name == "double_buffer_0":
self.shape = [2, 2]
self.dtype = paddle.float32
else:
self.shape = node.shape self.shape = node.shape
self.dtype = node.dtype self.dtype = node.dtype
self.dtype_factor = 1 self.dtype_factor = 1
...@@ -130,9 +139,10 @@ class TensorCostNode(CostNode): ...@@ -130,9 +139,10 @@ class TensorCostNode(CostNode):
self.dtype_factor *= 4 self.dtype_factor *= 4
elif node.dtype == paddle.int64: elif node.dtype == paddle.int64:
self.dtype_factor *= 8 self.dtype_factor *= 8
elif node.dtype == paddle.uint8:
self.dtype_factor = 1
else: else:
raise NotImplementedError("{} not counted".format(node.dtype)) raise NotImplementedError("{} not counted".format(node.dtype))
self.batch_size = None self.batch_size = None
if batch_size is not None: if batch_size is not None:
self.batch_size = batch_size self.batch_size = batch_size
...@@ -155,9 +165,9 @@ class CompOpCostNode(CostNode): ...@@ -155,9 +165,9 @@ class CompOpCostNode(CostNode):
def init_comp_cost(self, cost_data): def init_comp_cost(self, cost_data):
# TODO: improve fluid.CostModel for more specific cost_data # TODO: improve fluid.CostModel for more specific cost_data
op_name = self.node.type op_id = self.node.desc.id()
if op_name in cost_data.keys(): if op_id in cost_data.keys():
self.cost = cost_data[op_name] self.cost = cost_data[op_id]
else: else:
self.cost = 0.0 self.cost = 0.0
...@@ -215,8 +225,17 @@ class CostModel(object): ...@@ -215,8 +225,17 @@ class CostModel(object):
program.blocks) == 1, "Program more than 1 block not supported." program.blocks) == 1, "Program more than 1 block not supported."
block = program.blocks[0] block = program.blocks[0]
var_id = "lod_tensor_blocking_queue_0"
new_var = program.global_block().create_var(
name=var_id,
dtype=paddle.float32,
type=core.VarDesc.VarType.LOD_TENSOR)
nodes[var_id] = TensorCostNode(new_var, CostNodeType.VARIABLE,
"lod_tensor_blocking_queue_0")
for var in block.vars.values(): for var in block.vars.values():
var_id = var.name var_id = var.name
# if var.name == "create_py_reader_0" or var.name == "double_buffer_0":
# continue
nodes[var_id] = TensorCostNode(var, CostNodeType.VARIABLE, var_id) nodes[var_id] = TensorCostNode(var, CostNodeType.VARIABLE, var_id)
graph[var_id] = [[], []] graph[var_id] = [[], []]
...@@ -225,7 +244,10 @@ class CostModel(object): ...@@ -225,7 +244,10 @@ class CostModel(object):
if op.type.startswith('c_') or op.type.startswith( if op.type.startswith('c_') or op.type.startswith(
'send') or op.type.startswith('recv'): 'send') or op.type.startswith('recv'):
is_bwd = False is_bwd = False
if op.type.startswith('c_'): if op.type.startswith(
'c_'
) and op.type != "c_sync_calc_stream" and not op.type.startswith(
'c_embedding'):
ring_id = op.attr('ring_id') ring_id = op.attr('ring_id')
if ring_id not in self.ring2rank: if ring_id not in self.ring2rank:
self.ring2rank[ring_id] = set() self.ring2rank[ring_id] = set()
...@@ -238,7 +260,8 @@ class CostModel(object): ...@@ -238,7 +260,8 @@ class CostModel(object):
op_node = CommOpCostNode(op, CostNodeType.COMMUNICATION, op_id, op_node = CommOpCostNode(op, CostNodeType.COMMUNICATION, op_id,
is_bwd) is_bwd)
else: else:
is_bwd = '_grad' in op.type is_bwd = (int(op.attr('op_role')) == int(OpRole.Backward)
) or "@GRAD" in op.input_arg_names
is_optim = 'LearningRate' in op.input_names is_optim = 'LearningRate' in op.input_names
op_node = CompOpCostNode(op, CostNodeType.COMPUTATION, op_id, op_node = CompOpCostNode(op, CostNodeType.COMPUTATION, op_id,
is_bwd, is_optim) is_bwd, is_optim)
...@@ -258,6 +281,7 @@ class CostModel(object): ...@@ -258,6 +281,7 @@ class CostModel(object):
comm_input_shape = var_node.shape comm_input_shape = var_node.shape
except: except:
continue continue
for i in range(len(op.output_names)): for i in range(len(op.output_names)):
try: try:
var_id = op.output(op.output_names[i])[0] var_id = op.output(op.output_names[i])[0]
...@@ -361,7 +385,9 @@ class CostModel(object): ...@@ -361,7 +385,9 @@ class CostModel(object):
for sub_idx in range(self.total_rank): for sub_idx in range(self.total_rank):
for node_id, edges in self.op_graph[sub_idx].items(): for node_id, edges in self.op_graph[sub_idx].items():
node = self.nodes[sub_idx][node_id] node = self.nodes[sub_idx][node_id]
if node_id.startswith('c_'): if node_id.startswith('c_') and not node.id.startswith(
"c_sync_calc_stream") and not node.id.startswith(
'c_embedding'):
ring_id = node.node.attr('ring_id') ring_id = node.node.attr('ring_id')
node.set_ranks(list(self.ring2rank[ring_id])) node.set_ranks(list(self.ring2rank[ring_id]))
node.init_comm_cost(self.cluster) node.init_comm_cost(self.cluster)
...@@ -454,31 +480,52 @@ class CostModel(object): ...@@ -454,31 +480,52 @@ class CostModel(object):
# delete edges and add new edges # delete edges and add new edges
succ = None succ = None
runtime_graph[merged_node_id][SUCC] = copy.deepcopy(edges[SUCC]) try:
runtime_graph[merged_node_id][SUCC] = copy.deepcopy(edges[
SUCC])
if len(runtime_graph[pred_id][SUCC]) > 1: if len(runtime_graph[pred_id][SUCC]) > 1:
# predecessor has more than 1 successor # predecessor has more than 1 successor
# the merged_node is to inherit the rest of its successors # the merged_node is to inherit the rest of its successors
succ = runtime_graph[pred_id][SUCC] succ = runtime_graph[pred_id][SUCC]
succ.remove(node_id) succ.remove(node_id)
runtime_graph[merged_node_id][SUCC] += succ runtime_graph[merged_node_id][SUCC] += succ
runtime_graph[merged_node_id][PRED] = runtime_graph[pred_id][ runtime_graph[merged_node_id][PRED] = runtime_graph[
PRED] pred_id][PRED]
except:
pass
try:
for i in runtime_graph[pred_id][PRED]: for i in runtime_graph[pred_id][PRED]:
try:
runtime_graph[i][SUCC].remove(pred_id) runtime_graph[i][SUCC].remove(pred_id)
except:
continue
runtime_graph[i][SUCC].append(merged_node_id) runtime_graph[i][SUCC].append(merged_node_id)
except:
pass
try:
for i in edges[SUCC]: for i in edges[SUCC]:
runtime_graph[i][PRED].remove(node_id) runtime_graph[i][PRED].remove(node_id)
runtime_graph[i][PRED].append(merged_node_id) runtime_graph[i][PRED].append(merged_node_id)
except:
pass
if succ is not None: if succ is not None:
for i in succ: for i in succ:
try:
runtime_graph[i][PRED].remove(pred_id) runtime_graph[i][PRED].remove(pred_id)
except:
continue
runtime_graph[i][PRED].append(merged_node_id) runtime_graph[i][PRED].append(merged_node_id)
runtime_graph.pop(node_id) runtime_graph.pop(node_id)
try:
runtime_graph.pop(pred_id) runtime_graph.pop(pred_id)
except:
continue
reduct_cnt += 1 reduct_cnt += 1
self.eliminate_multi_edges(runtime_graph) self.eliminate_multi_edges(runtime_graph)
break
return reduct_cnt # the number of nodes that have been reduced return reduct_cnt # the number of nodes that have been reduced
def _merge_branch(self, nodes, runtime_graph, is_bwd=False): def _merge_branch(self, nodes, runtime_graph, is_bwd=False):
...@@ -496,7 +543,10 @@ class CostModel(object): ...@@ -496,7 +543,10 @@ class CostModel(object):
succ_to_elim = [] succ_to_elim = []
for succ_id in succ_nodes_id: for succ_id in succ_nodes_id:
for succ_2_id in succ_nodes_id: for succ_2_id in succ_nodes_id:
try:
tmp = runtime_graph[succ_2_id][SUCC] tmp = runtime_graph[succ_2_id][SUCC]
except:
continue
if succ_id in tmp: if succ_id in tmp:
succ_to_elim.append(succ_id) succ_to_elim.append(succ_id)
break break
...@@ -506,16 +556,22 @@ class CostModel(object): ...@@ -506,16 +556,22 @@ class CostModel(object):
reduct_cnt += 1 reduct_cnt += 1
to_merge = True to_merge = True
if len(edges[SUCC]) < 1 or len(runtime_graph[edges[SUCC][0]][ try:
SUCC]) < 1: if len(edges[SUCC]) < 1 or len(runtime_graph[edges[SUCC][0]]
[SUCC]) < 1:
continue
except:
continue continue
end_node_id = runtime_graph[edges[SUCC][0]][SUCC][0] end_node_id = runtime_graph[edges[SUCC][0]][SUCC][0]
for i in succ_nodes_id: for i in succ_nodes_id:
try:
if len(runtime_graph[i][SUCC]) != 1 or \ if len(runtime_graph[i][SUCC]) != 1 or \
runtime_graph[i][SUCC][0] != end_node_id: runtime_graph[i][SUCC][0] != end_node_id:
to_merge = False # if branches has different end node, we don't merge them to_merge = False # if branches has different end node, we don't merge them
break break
if to_merge: except:
continue
if to_merge and len(succ_nodes_id) > 1:
to_merge_node_list = [nodes[i] for i in succ_nodes_id] to_merge_node_list = [nodes[i] for i in succ_nodes_id]
merged_node_id, merged_node = self._merge_node( merged_node_id, merged_node = self._merge_node(
to_merge_node_list, merge_type='branch', nodes=nodes) to_merge_node_list, merge_type='branch', nodes=nodes)
...@@ -529,9 +585,13 @@ class CostModel(object): ...@@ -529,9 +585,13 @@ class CostModel(object):
runtime_graph[end_node_id][PRED] = [merged_node_id] runtime_graph[end_node_id][PRED] = [merged_node_id]
runtime_graph[node_id][SUCC] = [merged_node_id] runtime_graph[node_id][SUCC] = [merged_node_id]
try:
for i in succ_nodes_id: for i in succ_nodes_id:
runtime_graph.pop(i) runtime_graph.pop(i)
reduct_cnt += len(to_merge_node_list) - 1 reduct_cnt += len(to_merge_node_list) - 1
break
except:
pass
return reduct_cnt return reduct_cnt
def get_runtime_cost(self): def get_runtime_cost(self):
...@@ -615,7 +675,7 @@ class CostModel(object): ...@@ -615,7 +675,7 @@ class CostModel(object):
return static_mem, cur_mem, top_mem return static_mem, cur_mem, top_mem
def get_pipeline_time(self): def get_pipeline_time(self):
if self.total_rank <= 1: if self.pp2rank is None:
return self.fwd_time[0] + self.bwd_time[0] + self.optim_time[0] return self.fwd_time[0] + self.bwd_time[0] + self.optim_time[0]
else: else:
return self._simulate_pipeline() return self._simulate_pipeline()
......
...@@ -118,11 +118,11 @@ def get_comm_volume(comm_op, src_rank, tgt_rank): ...@@ -118,11 +118,11 @@ def get_comm_volume(comm_op, src_rank, tgt_rank):
return comm_volume return comm_volume
def analyze_comm_requirements_from_op(op, rank): def analyze_comm_requirements_from_op(op, rank, g_process_group_map):
comm_requirements_to_ranks = {} comm_requirements_to_ranks = {}
if is_collective_comm_op(op): if is_collective_comm_op(op):
process_group_id = op.attr("ring_id") process_group_id = op.attr("ring_id")
process_group = get_process_group(process_group_id) process_group = get_process_group(process_group_id, g_process_group_map)
if rank not in process_group.ranks: if rank not in process_group.ranks:
return comm_requirements_to_ranks return comm_requirements_to_ranks
for tgt_rank in process_group.ranks: for tgt_rank in process_group.ranks:
...@@ -142,7 +142,9 @@ def analyze_comm_requirements_from_op(op, rank): ...@@ -142,7 +142,9 @@ def analyze_comm_requirements_from_op(op, rank):
return comm_requirements_to_ranks return comm_requirements_to_ranks
def analyze_requirements_for_program(program, rank): def analyze_requirements_for_program(src_info, rank):
program = src_info[0]
g_process_group_map = src_info[1]
resource_requirements = {} resource_requirements = {}
comm_requirements_to_ranks = {} comm_requirements_to_ranks = {}
# only support device_type and only support GPU for now # only support device_type and only support GPU for now
...@@ -150,7 +152,7 @@ def analyze_requirements_for_program(program, rank): ...@@ -150,7 +152,7 @@ def analyze_requirements_for_program(program, rank):
for block in program.blocks: for block in program.blocks:
for op in block.ops: for op in block.ops:
cur_comm_requirements_to_ranks = analyze_comm_requirements_from_op( cur_comm_requirements_to_ranks = analyze_comm_requirements_from_op(
op, rank) op, rank, g_process_group_map)
for tgt_rank, link_info in cur_comm_requirements_to_ranks.items(): for tgt_rank, link_info in cur_comm_requirements_to_ranks.items():
if tgt_rank in comm_requirements_to_ranks: if tgt_rank in comm_requirements_to_ranks:
comm_requirements_to_ranks[tgt_rank][ comm_requirements_to_ranks[tgt_rank][
...@@ -164,9 +166,9 @@ def analyze_requirements_for_program(program, rank): ...@@ -164,9 +166,9 @@ def analyze_requirements_for_program(program, rank):
def build_process_graph(distributed_program): def build_process_graph(distributed_program):
graph = Graph() graph = Graph()
for src_rank, src_program in distributed_program.items(): for src_rank, src_info in distributed_program.items():
resource_requirements, comm_requirements_to_ranks = analyze_requirements_for_program( resource_requirements, comm_requirements_to_ranks = analyze_requirements_for_program(
src_program, src_rank) src_info, src_rank)
graph.add_node(src_rank, resource_requirements=resource_requirements) graph.add_node(src_rank, resource_requirements=resource_requirements)
for tgt_rank, comm_requirements in comm_requirements_to_ranks.items(): for tgt_rank, comm_requirements in comm_requirements_to_ranks.items():
graph.add_edge( graph.add_edge(
......
...@@ -308,6 +308,8 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -308,6 +308,8 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
assert len(x_dims_mapping) >= len( assert len(x_dims_mapping) >= len(
y_dims_mapping), "now just support x dims > y dims" y_dims_mapping), "now just support x dims > y dims"
if len(y_dims_mapping) != 2:
return False
if len(x_dims_mapping) == len(y_dims_mapping) and len( if len(x_dims_mapping) == len(y_dims_mapping) and len(
x_dims_mapping) == 4: x_dims_mapping) == 4:
if x_dims_mapping[:2] != y_dims_mapping[:2]: if x_dims_mapping[:2] != y_dims_mapping[:2]:
...@@ -602,6 +604,8 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -602,6 +604,8 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
# for gpt2, x dims > y dims, this is a temporary solution # for gpt2, x dims > y dims, this is a temporary solution
assert len(x_dims_mapping) >= len( assert len(x_dims_mapping) >= len(
y_dims_mapping), "now just support x dims > y dims" y_dims_mapping), "now just support x dims > y dims"
if len(y_dims_mapping) != 2:
return False
if len(x_dims_mapping) == len(y_dims_mapping) and len( if len(x_dims_mapping) == len(y_dims_mapping) and len(
x_dims_mapping) == 4: x_dims_mapping) == 4:
if x_dims_mapping[:2] != y_dims_mapping[:2]: if x_dims_mapping[:2] != y_dims_mapping[:2]:
...@@ -889,6 +893,8 @@ class DistributedMatmulImpl2(DistributedOperatorImpl): ...@@ -889,6 +893,8 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
y_dims_mapping y_dims_mapping
), "now just support x dims > y dims,but x:{0} and y:{1}".format( ), "now just support x dims > y dims,but x:{0} and y:{1}".format(
x_dims_mapping, y_dims_mapping) x_dims_mapping, y_dims_mapping)
if len(y_dims_mapping) != 2:
return False
if len(x_dims_mapping) == len(y_dims_mapping) and len( if len(x_dims_mapping) == len(y_dims_mapping) and len(
x_dims_mapping) == 4: x_dims_mapping) == 4:
if x_dims_mapping[:2] != y_dims_mapping[:2]: if x_dims_mapping[:2] != y_dims_mapping[:2]:
...@@ -1010,6 +1016,8 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -1010,6 +1016,8 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
return False return False
assert len(x_dims_mapping) >= len( assert len(x_dims_mapping) >= len(
y_dims_mapping), "now just support x dims > y dims" y_dims_mapping), "now just support x dims > y dims"
if len(y_dims_mapping) != 2:
return False
if len(x_dims_mapping) == len(y_dims_mapping) and len( if len(x_dims_mapping) == len(y_dims_mapping) and len(
x_dims_mapping) == 4: x_dims_mapping) == 4:
if x_dims_mapping[:2] != y_dims_mapping[:2]: if x_dims_mapping[:2] != y_dims_mapping[:2]:
...@@ -1297,6 +1305,8 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -1297,6 +1305,8 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
assert len(x_dims_mapping) >= len( assert len(x_dims_mapping) >= len(
y_dims_mapping), "now just support x dims > y dims" y_dims_mapping), "now just support x dims > y dims"
if len(y_dims_mapping) != 2:
return False
if len(x_dims_mapping) == len(y_dims_mapping) and len( if len(x_dims_mapping) == len(y_dims_mapping) and len(
x_dims_mapping) == 4: x_dims_mapping) == 4:
if x_dims_mapping[:2] != y_dims_mapping[:2]: if x_dims_mapping[:2] != y_dims_mapping[:2]:
...@@ -1583,7 +1593,8 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl): ...@@ -1583,7 +1593,8 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
y_dims_mapping y_dims_mapping
), "now just support x dims > y dims,but x:{0} and y:{1}".format( ), "now just support x dims > y dims,but x:{0} and y:{1}".format(
x_dims_mapping, y_dims_mapping) x_dims_mapping, y_dims_mapping)
if len(y_dims_mapping) != 2:
return False
if len(x_dims_mapping) == len(y_dims_mapping) and len( if len(x_dims_mapping) == len(y_dims_mapping) and len(
x_dims_mapping) == 4: x_dims_mapping) == 4:
if x_dims_mapping[:2] != y_dims_mapping[:2]: if x_dims_mapping[:2] != y_dims_mapping[:2]:
......
...@@ -20,6 +20,9 @@ import copy ...@@ -20,6 +20,9 @@ import copy
import pathlib import pathlib
import subprocess import subprocess
import logging import logging
import pickle
import time
import paddle import paddle
from paddle.distributed.utils import get_logger from paddle.distributed.utils import get_logger
from paddle.distributed.fleet import cloud_utils from paddle.distributed.fleet import cloud_utils
...@@ -30,13 +33,18 @@ from .dist_context import set_default_distributed_context ...@@ -30,13 +33,18 @@ from .dist_context import set_default_distributed_context
from .completion import complete_annotation, complete_backward_annotation from .completion import complete_annotation, complete_backward_annotation
from .partitioner import Partitioner from .partitioner import Partitioner
from .process_group import get_all_process_groups from .process_group import get_all_process_groups
from .process_group import get_process_group
from .process_group import get_world_process_groups from .process_group import get_world_process_groups
from .process_group import _g_process_group_map, ProcessGroup
from .utils import make_data_unshard from .utils import make_data_unshard
from .utils import set_grad_var_shape from .utils import set_grad_var_shape
from .reshard import reshard from .utils import SerialProgramInfo
from .reshard import reshard, HAS_SENT, HAS_RECV, HAS_ALLGATHER
from .cluster import Cluster from .cluster import Cluster
from .mapper import mapping from .mapper import mapping
# from .auto_search import auto_search from .dist_op import DistributedOperator
from .dist_tensor import DistributedTensor
from .planner import Planner
_logger = get_logger(logging.INFO) _logger = get_logger(logging.INFO)
...@@ -82,12 +90,20 @@ class AutoParallelizer: ...@@ -82,12 +90,20 @@ class AutoParallelizer:
if suffix in attr_name: if suffix in attr_name:
op._remove_attr(attr_name) op._remove_attr(attr_name)
def _get_dist_program(self, dist_context, rank): def _get_dist_program(self, rank, dist_context=None, relaunch_phase=False):
completed_main_program = None
if dist_context is None:
# Annotation completion # Annotation completion
self._dist_context = DistributedContext()
_logger.info("Start annotation dist attr.")
completed_main_program = complete_annotation(self._main_program, completed_main_program = complete_annotation(self._main_program,
dist_context) self._dist_context)
else:
completed_main_program = self._main_program
self._dist_context = copy.deepcopy(dist_context)
# Logical partition # Logical partition
partitioner = Partitioner(self._dist_strategy, dist_context, rank) partitioner = Partitioner(self._dist_strategy, self._dist_context, rank)
dist_main_prog, dist_startup_prog = partitioner.transpile_forward( dist_main_prog, dist_startup_prog = partitioner.transpile_forward(
completed_main_program, self._startup_program) completed_main_program, self._startup_program)
dist_params_grads = partitioner.apply_backward( dist_params_grads = partitioner.apply_backward(
...@@ -97,11 +113,21 @@ class AutoParallelizer: ...@@ -97,11 +113,21 @@ class AutoParallelizer:
copy.deepcopy(self._optimizer), dist_params_grads, dist_main_prog, copy.deepcopy(self._optimizer), dist_params_grads, dist_main_prog,
dist_startup_prog) dist_startup_prog)
make_data_unshard(dist_main_prog, dist_startup_prog, dist_context) set_grad_var_shape(dist_main_prog, self._dist_context)
reshard(dist_main_prog, dist_startup_prog, rank, dist_context) make_data_unshard(dist_main_prog, dist_startup_prog, self._dist_context)
return dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog reshard(dist_main_prog, dist_startup_prog, rank, self._dist_context)
g_process_group_map = None
if not relaunch_phase:
g_process_group_map = copy.deepcopy(_g_process_group_map)
HAS_SENT.clear()
HAS_RECV.clear()
HAS_ALLGATHER.clear()
_g_process_group_map.clear()
_g_process_group_map[0] = ProcessGroup(0, [])
return dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog, g_process_group_map
def parallelize(self, def parallelize(self,
loss, loss,
...@@ -121,11 +147,51 @@ class AutoParallelizer: ...@@ -121,11 +147,51 @@ class AutoParallelizer:
"The cluster must not be none when using auto mapping." "The cluster must not be none when using auto mapping."
dist_programs = {} dist_programs = {}
world_process_group = get_world_process_groups() world_process_group = get_world_process_groups()
dist_context = None
# auto search
if self._dist_strategy.auto_search:
logging.info("Start searching dist attr.")
serial_program_info = SerialProgramInfo(
self._main_program, self._startup_program, self._loss,
self._optimizer, self._cluster)
planner = Planner(
serial_program_info,
algorithm_config={"name": "mcmc",
"max_search_times": 5})
dist_context, _ = planner.search()
logging.info("End searching dist attr.")
# serialize the dist context by planner
if dist_context is not None:
logging.info("Start serialize searched dist attr")
cwd = pathlib.Path().resolve()
searched_dist_context_path = os.path.join(
cwd, f"searched_dist_context_{time.time()}.pkl")
saved_dist_context = {}
ops_dist_attr = {}
tensors_dist_attr = {}
for key, dist_op in dist_context._dist_ops_for_program.items():
ops_dist_attr[key] = dist_op.dist_attr
for key, dist_tensor in dist_context._dist_tensors_for_program.items(
):
tensors_dist_attr[key] = dist_tensor.dist_attr
saved_dist_context["ops_dist_attr"] = ops_dist_attr
saved_dist_context["tensors_dist_attr"] = tensors_dist_attr
saved_dist_context[
"process_meshes"] = dist_context._process_meshes
with open(searched_dist_context_path,
"wb") as dist_context_file:
pickle.dump(saved_dist_context, dist_context_file)
os.environ[
'PADDLE_SEARCHED_DIST_CONTEXT_PATH'] = searched_dist_context_path
logging.info(
f"End serialize searched dist attr to {searched_dist_context_path}"
)
for rank in world_process_group.ranks: for rank in world_process_group.ranks:
dist_context = DistributedContext() dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog, g_process_group_map = self._get_dist_program(
dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog = self._get_dist_program( rank, dist_context)
dist_context, rank) dist_programs[rank] = [dist_main_prog, g_process_group_map]
dist_programs[rank] = dist_main_prog
# Do the mapping between the distributed program graph and the cluster graph # Do the mapping between the distributed program graph and the cluster graph
rank_mapping_dict = mapping(dist_programs, self._cluster) rank_mapping_dict = mapping(dist_programs, self._cluster)
...@@ -162,9 +228,64 @@ class AutoParallelizer: ...@@ -162,9 +228,64 @@ class AutoParallelizer:
else: else:
# Parallelization after the mapping pass # Parallelization after the mapping pass
rank = paddle.distributed.get_rank() rank = paddle.distributed.get_rank()
dist_context = None
searched_dist_context_path = os.getenv(
"PADDLE_SEARCHED_DIST_CONTEXT_PATH", None)
if searched_dist_context_path is not None:
with open(searched_dist_context_path,
"rb") as dist_context_file:
saved_dist_context = pickle.load(dist_context_file)
dist_context = DistributedContext()
for op in self._main_program.global_block().ops:
dist_attr = saved_dist_context["ops_dist_attr"][
op.desc.id()]
dist_op = DistributedOperator(op, dist_attr)
dist_context.add_dist_op_for_program(dist_op)
vars = self._main_program.global_block().vars
for var in vars.values():
dist_attr = saved_dist_context["tensors_dist_attr"][
var.desc.id()]
dist_tensor = DistributedTensor(var, dist_attr)
dist_context.add_dist_tensor_for_program(dist_tensor)
dist_context._process_meshes = saved_dist_context[
"process_meshes"]
else:
if self._dist_strategy.auto_search:
serial_program_info = SerialProgramInfo(
self._main_program,
self._startup_program,
self._loss,
self._optimizer,
cluster=self._cluster)
planner = Planner(
serial_program_info,
algorithm_config={
"name": "mcmc",
"max_search_times": 5
})
dist_context, _ = planner.search()
# rebuild g_process_group
if dist_context is not None:
pg0 = get_process_group(0)
for process_mesh in dist_context._process_meshes:
pg0.add_ranks(process_mesh.processes)
dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog, _ = self._get_dist_program(
rank, dist_context, relaunch_phase=True)
dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog = self._get_dist_program( # NOTE: This is a trick to fix hang in pipeline mode when dist context is searched by planner
self._dist_context, rank) if self._dist_strategy.auto_search:
is_pipeline = False
for op in dist_main_prog.global_block().ops:
if op.type == "send_v2" or op.type == "recv_v2":
is_pipeline = True
break
if is_pipeline:
with paddle.static.program_guard(dist_main_prog):
paddle.distributed.barrier()
# Traverse different rank programs and traverse each op of them, # Traverse different rank programs and traverse each op of them,
# instantiate communication by process_mapping. # instantiate communication by process_mapping.
......
...@@ -25,9 +25,12 @@ def get_all_process_groups(): ...@@ -25,9 +25,12 @@ def get_all_process_groups():
return _g_process_group_map.values() return _g_process_group_map.values()
def get_process_group(group_id): def get_process_group(group_id, g_process_group_map=None):
global _g_process_group_map global _g_process_group_map
return _g_process_group_map.get(group_id, None) return _g_process_group_map.get(
group_id,
None) if g_process_group_map is None else g_process_group_map.get(
group_id, None)
def get_world_process_groups(): def get_world_process_groups():
......
...@@ -19,9 +19,11 @@ import threading ...@@ -19,9 +19,11 @@ import threading
import numpy as np import numpy as np
import warnings import warnings
import logging import logging
from functools import reduce
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.framework.io import _to_LodTensor from paddle.framework.io import _to_LodTensor
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.fluid.io import is_parameter, is_belong_to_optimizer from paddle.fluid.io import is_parameter, is_belong_to_optimizer
...@@ -1258,3 +1260,95 @@ class SerialProgramInfo: ...@@ -1258,3 +1260,95 @@ class SerialProgramInfo:
@property @property
def cluster(self): def cluster(self):
return self._cluster return self._cluster
def get_standalone_cost_data(distributed_programs):
def _compute_runtime(op_cost, op, vars):
runtime = 0
try:
runtime = float(op_cost["op_time"])
except:
return runtime
op_config = op_cost["config"]
total_static_input_size = 0
total_actual_input_size = 0
parsed_info = op_config.split("\n")
variable = "(Variable)"
for info in parsed_info:
variable = "(Variable)" if "(Variable)" in info else "(list<Variable>"
if variable in info:
arg_name_lower = info[:info.find(variable) - 1]
shape_left_boundary = info.find("[")
shape_right_boundary = info.find("]")
assert shape_left_boundary > 0 and shape_right_boundary > 0 and shape_right_boundary > shape_left_boundary, "Get shape failed."
shape = info[shape_left_boundary + 1:
shape_right_boundary].split(",")
shape = list(map(lambda x: int(x.strip()), shape))
dtype_factor = 1
total_static_input_size += reduce(lambda x, y: x * y, shape)
# print(arg_name_lower)
if op.type == "c_embedding":
arg_name_lower = "w" if arg_name_lower == "weight" else "ids"
for arg_name in op.input_names:
if arg_name.lower() == arg_name_lower:
for var_name in op.input(arg_name):
var = vars[var_name]
total_actual_input_size += reduce(
lambda x, y: x * y, var.shape)
break
assert total_static_input_size > 0 and total_actual_input_size > 0, "Get input size failed."
actual_runtime = total_actual_input_size / total_static_input_size * runtime
return actual_runtime
cost_model = paddle.cost_model.CostModel()
cost_model.static_cost_data()
DEFAULT_MULTIPLE = 2
OP_NAME_MAPPING = {
"c_embedding": "embedding",
"matmul_v2": "matmul",
"transpose2": "transpose",
"reshape2": "reshape",
"unsqueeze2": "unsqueeze",
"reduce_sum": "sum",
"elementwise_div": "divide"
}
standalone_cost_data = []
not_enum_ops = ["create_py_reader", "create_double_buffer_reader", "read"]
for distributed_program in distributed_programs:
cost_data = {}
vars = distributed_program.global_block().vars
for op in distributed_program.global_block().ops:
runtime = 0
if op.type in not_enum_ops:
cost_data[op.desc.id()] = runtime
continue
dtype = str(vars[op.input_arg_names[0]]
.dtype) if op.input_arg_names else "float32"
if int(op.attr('op_role')) == int(OpRole.Backward):
if "_grad" in op.type:
forward_op_name = op.type[:-5]
if forward_op_name in OP_NAME_MAPPING.keys():
forward_op_name = OP_NAME_MAPPING[forward_op_name]
op_cost = cost_model.get_static_op_time(
forward_op_name, forward=False, dtype=dtype)
if op_cost:
runtime = _compute_runtime(op_cost, op, vars)
else:
op_cost = cost_model.get_static_op_time(
forward_op_name, dtype=dtype)
if op_cost:
runtime = 2 * _compute_runtime(op_cost, op, vars)
elif int(op.attr('op_role')) == int(OpRole.Forward):
op_name = OP_NAME_MAPPING[
op.type] if op.type in OP_NAME_MAPPING.keys() else op.type
op_cost = cost_model.get_static_op_time(op_name)
if op_cost:
runtime = _compute_runtime(op_cost, op, vars)
cost_data[op.desc.id()] = runtime
standalone_cost_data.append(cost_data)
return standalone_cost_data
...@@ -1436,7 +1436,7 @@ class Fleet(object): ...@@ -1436,7 +1436,7 @@ class Fleet(object):
context["role_maker"] = self._role_maker context["role_maker"] = self._role_maker
# Use the auto-parallel's routines instead # Use the auto-parallel's routines instead
if self._user_defined_strategy.semi_auto: if self._user_defined_strategy.semi_auto or self._user_defined_strategy.auto_search:
from ...auto_parallel.parallelizer import AutoParallelizer from ...auto_parallel.parallelizer import AutoParallelizer
auto_parallelizer = AutoParallelizer(self) auto_parallelizer = AutoParallelizer(self)
optimize_ops, params_grads, dist_startup_prog, dist_main_prog = auto_parallelizer.parallelize( optimize_ops, params_grads, dist_startup_prog, dist_main_prog = auto_parallelizer.parallelize(
...@@ -1586,13 +1586,13 @@ class Fleet(object): ...@@ -1586,13 +1586,13 @@ class Fleet(object):
] ]
param_grads_fp16 = [ param_grads_fp16 = [
param._grad_ivar() for param in optimizer._parameter_list param._grad_ivar() for param in optimizer._parameter_list
if (param._grad_ivar() is not None) and if (param._grad_ivar() is not None) and (param._grad_ivar(
(param._grad_ivar().dtype == core.VarDesc.VarType.FP16) ).dtype == core.VarDesc.VarType.FP16)
] ]
param_grads_fp32 = [ param_grads_fp32 = [
param._grad_ivar() for param in optimizer._parameter_list param._grad_ivar() for param in optimizer._parameter_list
if (param._grad_ivar() is not None) and if (param._grad_ivar() is not None) and (param._grad_ivar(
(param._grad_ivar().dtype == core.VarDesc.VarType.FP32) ).dtype == core.VarDesc.VarType.FP32)
] ]
temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool)) temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool))
temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool)) temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool))
......
...@@ -3,4 +3,6 @@ ...@@ -3,4 +3,6 @@
if(WITH_DISTRIBUTE AND WITH_GPU) if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_auto_parallel_relaunch MODULES test_auto_parallel_relaunch ENVS ${dist_ENVS}) py_test_modules(test_auto_parallel_relaunch MODULES test_auto_parallel_relaunch ENVS ${dist_ENVS})
set_tests_properties(test_auto_parallel_relaunch PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120) set_tests_properties(test_auto_parallel_relaunch PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120)
py_test_modules(test_relaunch_with_planner MODULES test_relaunch_with_planner ENVS ${dist_ENVS})
set_tests_properties(test_relaunch_with_planner PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120)
endif() endif()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.static as static
from paddle.distributed import fleet
def train():
from auto_parallel_relaunch_model import mlp_pretrain_forward
from auto_parallel_relaunch_model import batch_generator_creator
dist_strategy = fleet.DistributedStrategy()
# init parallel optimizer
dist_strategy.auto_search = True
fleet.init(is_collective=True, strategy=dist_strategy)
train_program = static.Program()
start_program = static.Program()
loss, train_program, start_program, loader = mlp_pretrain_forward(
train_program, start_program)
optimizer = paddle.fluid.optimizer.AdamOptimizer(
learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
optimizer = fleet.distributed_optimizer(optimizer)
_, _, distributed_startup_program, distributed_main_program = optimizer.minimize(
loss, start_program)
places = static.cuda_places()
loader.set_batch_generator(batch_generator_creator(), places=places)
exe = paddle.static.Executor(places[0])
exe.run(distributed_startup_program)
for data in loader():
exe.run(distributed_main_program, feed=data)
if __name__ == "__main__":
train()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import os
import sys
import json
import shutil
import subprocess
from paddle.distributed.fleet.launch_utils import run_with_coverage
class TestPlannerReLaunch(unittest.TestCase):
def test_relaunch_with_planner(self):
from test_auto_parallel_relaunch import cluster_json
file_dir = os.path.dirname(os.path.abspath(__file__))
cluster_json_path = os.path.join(file_dir, "auto_parallel_cluster.json")
cluster_json_object = json.loads(cluster_json)
with open(cluster_json_path, "w") as cluster_json_file:
json.dump(cluster_json_object, cluster_json_file)
launch_model_path = os.path.join(
file_dir, "auto_parallel_relaunch_with_planner.py")
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
else:
coverage_args = []
cmd = [sys.executable, "-u"] + coverage_args + [
"-m", "launch", "--cluster_topo_path", cluster_json_path,
"--enable_auto_mapping", "True", launch_model_path
]
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
# Remove unnecessary files
if os.path.exists(cluster_json_path):
os.remove(cluster_json_path)
rank_mapping_json_path = os.path.join(file_dir,
"auto_parallel_rank_mapping.json")
if os.path.exists(rank_mapping_json_path):
os.remove(rank_mapping_json_path)
log_path = os.path.join(file_dir, "log")
if os.path.exists(log_path):
shutil.rmtree(log_path)
if __name__ == "__main__":
unittest.main()
...@@ -187,10 +187,10 @@ def check_empty_program_runtime(cost): ...@@ -187,10 +187,10 @@ def check_empty_program_runtime(cost):
def check_empty_program_memory(cost): def check_empty_program_memory(cost):
for mem in cost.peak_mem: for mem in cost.peak_mem:
if mem > 0: if mem > 1:
return False return False
for mem in cost.static_mem: for mem in cost.static_mem:
if mem > 0: if mem > 1:
return False return False
return True return True
......
...@@ -529,7 +529,7 @@ class TestAutoParallelMapper(unittest.TestCase): ...@@ -529,7 +529,7 @@ class TestAutoParallelMapper(unittest.TestCase):
train_program, startup_program, dist_context, rank_id) train_program, startup_program, dist_context, rank_id)
# if rank_id == 0: # if rank_id == 0:
# print_program_with_dist_attr(dist_train_program, dist_context) # print_program_with_dist_attr(dist_train_program, dist_context)
dist_programs[rank_id] = dist_train_program dist_programs[rank_id] = [dist_train_program, None]
rank_mapping = mapping(dist_programs, cluster) rank_mapping = mapping(dist_programs, cluster)
......
...@@ -174,7 +174,7 @@ class Testcompatible(unittest.TestCase): ...@@ -174,7 +174,7 @@ class Testcompatible(unittest.TestCase):
op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1, -1]) op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1, -1])
op_dist_attr.set_input_dims_mapping(Y, [-1, -1, -1, -1]) op_dist_attr.set_input_dims_mapping(Y, [-1, -1, -1, -1])
op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1, -1]) op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1, -1])
self.assertTrue(impls[2].is_auto_compatible( self.assertFalse(impls[2].is_auto_compatible(
DistributedOperator(op, op_dist_attr))) DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(Y, [0, -1, -1, -1]) op_dist_attr.set_input_dims_mapping(Y, [0, -1, -1, -1])
self.assertFalse(impls[2].is_auto_compatible( self.assertFalse(impls[2].is_auto_compatible(
...@@ -261,7 +261,7 @@ class Testcompatible(unittest.TestCase): ...@@ -261,7 +261,7 @@ class Testcompatible(unittest.TestCase):
op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1, 1]) op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1, 1])
op_dist_attr.set_input_dims_mapping(Y, [-1, -1, 1, -1]) op_dist_attr.set_input_dims_mapping(Y, [-1, -1, 1, -1])
op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1, -1]) op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1, -1])
self.assertTrue(impls[1].is_auto_compatible( self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr))) DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(Y, [0, -1, -1, -1]) op_dist_attr.set_input_dims_mapping(Y, [0, -1, -1, -1])
self.assertFalse(impls[1].is_auto_compatible( self.assertFalse(impls[1].is_auto_compatible(
...@@ -362,7 +362,7 @@ class Testcompatible(unittest.TestCase): ...@@ -362,7 +362,7 @@ class Testcompatible(unittest.TestCase):
op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1, -1]) op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1, -1])
op_dist_attr.set_input_dims_mapping(Y, [-1, -1, -1, 1]) op_dist_attr.set_input_dims_mapping(Y, [-1, -1, -1, 1])
op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1, 1]) op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1, 1])
self.assertTrue(impls[0].is_auto_compatible( self.assertFalse(impls[0].is_auto_compatible(
DistributedOperator(op, op_dist_attr))) DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(out, [0, -1, -1, 1]) op_dist_attr.set_output_dims_mapping(out, [0, -1, -1, 1])
self.assertFalse(impls[0].is_auto_compatible( self.assertFalse(impls[0].is_auto_compatible(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册