未验证 提交 1bb2c68a 编写于 作者: C caozhou 提交者: GitHub

Add mcmc of planner, of update cost model and relaunch (#38177)

* add planner

* add planner

* add cost model update

* add relaunch updation

* update process_group

* fix error

* add unitest

* update unitest

* update cost model

* avoid api problem
上级 6f439e5a
...@@ -11,12 +11,17 @@ ...@@ -11,12 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np
import json import json
import queue import queue
import copy import copy
from enum import Enum from enum import Enum
import numpy as np
import paddle import paddle
from paddle.fluid import core
from paddle.distributed.fleet.meta_optimizers.common import OpRole
SUCC = 0 # successor SUCC = 0 # successor
PRED = 1 # predecessor PRED = 1 # predecessor
...@@ -121,6 +126,10 @@ class TensorCostNode(CostNode): ...@@ -121,6 +126,10 @@ class TensorCostNode(CostNode):
batch_size=None, batch_size=None,
shared_node_id=None): shared_node_id=None):
super(TensorCostNode, self).__init__(node, node_type, id) super(TensorCostNode, self).__init__(node, node_type, id)
if node.name == "create_py_reader_0" or node.name == "double_buffer_0":
self.shape = [2, 2]
self.dtype = paddle.float32
else:
self.shape = node.shape self.shape = node.shape
self.dtype = node.dtype self.dtype = node.dtype
self.dtype_factor = 1 self.dtype_factor = 1
...@@ -130,9 +139,10 @@ class TensorCostNode(CostNode): ...@@ -130,9 +139,10 @@ class TensorCostNode(CostNode):
self.dtype_factor *= 4 self.dtype_factor *= 4
elif node.dtype == paddle.int64: elif node.dtype == paddle.int64:
self.dtype_factor *= 8 self.dtype_factor *= 8
elif node.dtype == paddle.uint8:
self.dtype_factor = 1
else: else:
raise NotImplementedError("{} not counted".format(node.dtype)) raise NotImplementedError("{} not counted".format(node.dtype))
self.batch_size = None self.batch_size = None
if batch_size is not None: if batch_size is not None:
self.batch_size = batch_size self.batch_size = batch_size
...@@ -155,9 +165,9 @@ class CompOpCostNode(CostNode): ...@@ -155,9 +165,9 @@ class CompOpCostNode(CostNode):
def init_comp_cost(self, cost_data): def init_comp_cost(self, cost_data):
# TODO: improve fluid.CostModel for more specific cost_data # TODO: improve fluid.CostModel for more specific cost_data
op_name = self.node.type op_id = self.node.desc.id()
if op_name in cost_data.keys(): if op_id in cost_data.keys():
self.cost = cost_data[op_name] self.cost = cost_data[op_id]
else: else:
self.cost = 0.0 self.cost = 0.0
...@@ -215,8 +225,17 @@ class CostModel(object): ...@@ -215,8 +225,17 @@ class CostModel(object):
program.blocks) == 1, "Program more than 1 block not supported." program.blocks) == 1, "Program more than 1 block not supported."
block = program.blocks[0] block = program.blocks[0]
var_id = "lod_tensor_blocking_queue_0"
new_var = program.global_block().create_var(
name=var_id,
dtype=paddle.float32,
type=core.VarDesc.VarType.LOD_TENSOR)
nodes[var_id] = TensorCostNode(new_var, CostNodeType.VARIABLE,
"lod_tensor_blocking_queue_0")
for var in block.vars.values(): for var in block.vars.values():
var_id = var.name var_id = var.name
# if var.name == "create_py_reader_0" or var.name == "double_buffer_0":
# continue
nodes[var_id] = TensorCostNode(var, CostNodeType.VARIABLE, var_id) nodes[var_id] = TensorCostNode(var, CostNodeType.VARIABLE, var_id)
graph[var_id] = [[], []] graph[var_id] = [[], []]
...@@ -225,7 +244,10 @@ class CostModel(object): ...@@ -225,7 +244,10 @@ class CostModel(object):
if op.type.startswith('c_') or op.type.startswith( if op.type.startswith('c_') or op.type.startswith(
'send') or op.type.startswith('recv'): 'send') or op.type.startswith('recv'):
is_bwd = False is_bwd = False
if op.type.startswith('c_'): if op.type.startswith(
'c_'
) and op.type != "c_sync_calc_stream" and not op.type.startswith(
'c_embedding'):
ring_id = op.attr('ring_id') ring_id = op.attr('ring_id')
if ring_id not in self.ring2rank: if ring_id not in self.ring2rank:
self.ring2rank[ring_id] = set() self.ring2rank[ring_id] = set()
...@@ -238,7 +260,8 @@ class CostModel(object): ...@@ -238,7 +260,8 @@ class CostModel(object):
op_node = CommOpCostNode(op, CostNodeType.COMMUNICATION, op_id, op_node = CommOpCostNode(op, CostNodeType.COMMUNICATION, op_id,
is_bwd) is_bwd)
else: else:
is_bwd = '_grad' in op.type is_bwd = (int(op.attr('op_role')) == int(OpRole.Backward)
) or "@GRAD" in op.input_arg_names
is_optim = 'LearningRate' in op.input_names is_optim = 'LearningRate' in op.input_names
op_node = CompOpCostNode(op, CostNodeType.COMPUTATION, op_id, op_node = CompOpCostNode(op, CostNodeType.COMPUTATION, op_id,
is_bwd, is_optim) is_bwd, is_optim)
...@@ -258,6 +281,7 @@ class CostModel(object): ...@@ -258,6 +281,7 @@ class CostModel(object):
comm_input_shape = var_node.shape comm_input_shape = var_node.shape
except: except:
continue continue
for i in range(len(op.output_names)): for i in range(len(op.output_names)):
try: try:
var_id = op.output(op.output_names[i])[0] var_id = op.output(op.output_names[i])[0]
...@@ -361,7 +385,9 @@ class CostModel(object): ...@@ -361,7 +385,9 @@ class CostModel(object):
for sub_idx in range(self.total_rank): for sub_idx in range(self.total_rank):
for node_id, edges in self.op_graph[sub_idx].items(): for node_id, edges in self.op_graph[sub_idx].items():
node = self.nodes[sub_idx][node_id] node = self.nodes[sub_idx][node_id]
if node_id.startswith('c_'): if node_id.startswith('c_') and not node.id.startswith(
"c_sync_calc_stream") and not node.id.startswith(
'c_embedding'):
ring_id = node.node.attr('ring_id') ring_id = node.node.attr('ring_id')
node.set_ranks(list(self.ring2rank[ring_id])) node.set_ranks(list(self.ring2rank[ring_id]))
node.init_comm_cost(self.cluster) node.init_comm_cost(self.cluster)
...@@ -454,31 +480,52 @@ class CostModel(object): ...@@ -454,31 +480,52 @@ class CostModel(object):
# delete edges and add new edges # delete edges and add new edges
succ = None succ = None
runtime_graph[merged_node_id][SUCC] = copy.deepcopy(edges[SUCC]) try:
runtime_graph[merged_node_id][SUCC] = copy.deepcopy(edges[
SUCC])
if len(runtime_graph[pred_id][SUCC]) > 1: if len(runtime_graph[pred_id][SUCC]) > 1:
# predecessor has more than 1 successor # predecessor has more than 1 successor
# the merged_node is to inherit the rest of its successors # the merged_node is to inherit the rest of its successors
succ = runtime_graph[pred_id][SUCC] succ = runtime_graph[pred_id][SUCC]
succ.remove(node_id) succ.remove(node_id)
runtime_graph[merged_node_id][SUCC] += succ runtime_graph[merged_node_id][SUCC] += succ
runtime_graph[merged_node_id][PRED] = runtime_graph[pred_id][ runtime_graph[merged_node_id][PRED] = runtime_graph[
PRED] pred_id][PRED]
except:
pass
try:
for i in runtime_graph[pred_id][PRED]: for i in runtime_graph[pred_id][PRED]:
try:
runtime_graph[i][SUCC].remove(pred_id) runtime_graph[i][SUCC].remove(pred_id)
except:
continue
runtime_graph[i][SUCC].append(merged_node_id) runtime_graph[i][SUCC].append(merged_node_id)
except:
pass
try:
for i in edges[SUCC]: for i in edges[SUCC]:
runtime_graph[i][PRED].remove(node_id) runtime_graph[i][PRED].remove(node_id)
runtime_graph[i][PRED].append(merged_node_id) runtime_graph[i][PRED].append(merged_node_id)
except:
pass
if succ is not None: if succ is not None:
for i in succ: for i in succ:
try:
runtime_graph[i][PRED].remove(pred_id) runtime_graph[i][PRED].remove(pred_id)
except:
continue
runtime_graph[i][PRED].append(merged_node_id) runtime_graph[i][PRED].append(merged_node_id)
runtime_graph.pop(node_id) runtime_graph.pop(node_id)
try:
runtime_graph.pop(pred_id) runtime_graph.pop(pred_id)
except:
continue
reduct_cnt += 1 reduct_cnt += 1
self.eliminate_multi_edges(runtime_graph) self.eliminate_multi_edges(runtime_graph)
break
return reduct_cnt # the number of nodes that have been reduced return reduct_cnt # the number of nodes that have been reduced
def _merge_branch(self, nodes, runtime_graph, is_bwd=False): def _merge_branch(self, nodes, runtime_graph, is_bwd=False):
...@@ -496,7 +543,10 @@ class CostModel(object): ...@@ -496,7 +543,10 @@ class CostModel(object):
succ_to_elim = [] succ_to_elim = []
for succ_id in succ_nodes_id: for succ_id in succ_nodes_id:
for succ_2_id in succ_nodes_id: for succ_2_id in succ_nodes_id:
try:
tmp = runtime_graph[succ_2_id][SUCC] tmp = runtime_graph[succ_2_id][SUCC]
except:
continue
if succ_id in tmp: if succ_id in tmp:
succ_to_elim.append(succ_id) succ_to_elim.append(succ_id)
break break
...@@ -506,16 +556,22 @@ class CostModel(object): ...@@ -506,16 +556,22 @@ class CostModel(object):
reduct_cnt += 1 reduct_cnt += 1
to_merge = True to_merge = True
if len(edges[SUCC]) < 1 or len(runtime_graph[edges[SUCC][0]][ try:
SUCC]) < 1: if len(edges[SUCC]) < 1 or len(runtime_graph[edges[SUCC][0]]
[SUCC]) < 1:
continue
except:
continue continue
end_node_id = runtime_graph[edges[SUCC][0]][SUCC][0] end_node_id = runtime_graph[edges[SUCC][0]][SUCC][0]
for i in succ_nodes_id: for i in succ_nodes_id:
try:
if len(runtime_graph[i][SUCC]) != 1 or \ if len(runtime_graph[i][SUCC]) != 1 or \
runtime_graph[i][SUCC][0] != end_node_id: runtime_graph[i][SUCC][0] != end_node_id:
to_merge = False # if branches has different end node, we don't merge them to_merge = False # if branches has different end node, we don't merge them
break break
if to_merge: except:
continue
if to_merge and len(succ_nodes_id) > 1:
to_merge_node_list = [nodes[i] for i in succ_nodes_id] to_merge_node_list = [nodes[i] for i in succ_nodes_id]
merged_node_id, merged_node = self._merge_node( merged_node_id, merged_node = self._merge_node(
to_merge_node_list, merge_type='branch', nodes=nodes) to_merge_node_list, merge_type='branch', nodes=nodes)
...@@ -529,9 +585,13 @@ class CostModel(object): ...@@ -529,9 +585,13 @@ class CostModel(object):
runtime_graph[end_node_id][PRED] = [merged_node_id] runtime_graph[end_node_id][PRED] = [merged_node_id]
runtime_graph[node_id][SUCC] = [merged_node_id] runtime_graph[node_id][SUCC] = [merged_node_id]
try:
for i in succ_nodes_id: for i in succ_nodes_id:
runtime_graph.pop(i) runtime_graph.pop(i)
reduct_cnt += len(to_merge_node_list) - 1 reduct_cnt += len(to_merge_node_list) - 1
break
except:
pass
return reduct_cnt return reduct_cnt
def get_runtime_cost(self): def get_runtime_cost(self):
...@@ -615,7 +675,7 @@ class CostModel(object): ...@@ -615,7 +675,7 @@ class CostModel(object):
return static_mem, cur_mem, top_mem return static_mem, cur_mem, top_mem
def get_pipeline_time(self): def get_pipeline_time(self):
if self.total_rank <= 1: if self.pp2rank is None:
return self.fwd_time[0] + self.bwd_time[0] + self.optim_time[0] return self.fwd_time[0] + self.bwd_time[0] + self.optim_time[0]
else: else:
return self._simulate_pipeline() return self._simulate_pipeline()
......
...@@ -118,11 +118,11 @@ def get_comm_volume(comm_op, src_rank, tgt_rank): ...@@ -118,11 +118,11 @@ def get_comm_volume(comm_op, src_rank, tgt_rank):
return comm_volume return comm_volume
def analyze_comm_requirements_from_op(op, rank): def analyze_comm_requirements_from_op(op, rank, g_process_group_map):
comm_requirements_to_ranks = {} comm_requirements_to_ranks = {}
if is_collective_comm_op(op): if is_collective_comm_op(op):
process_group_id = op.attr("ring_id") process_group_id = op.attr("ring_id")
process_group = get_process_group(process_group_id) process_group = get_process_group(process_group_id, g_process_group_map)
if rank not in process_group.ranks: if rank not in process_group.ranks:
return comm_requirements_to_ranks return comm_requirements_to_ranks
for tgt_rank in process_group.ranks: for tgt_rank in process_group.ranks:
...@@ -142,7 +142,9 @@ def analyze_comm_requirements_from_op(op, rank): ...@@ -142,7 +142,9 @@ def analyze_comm_requirements_from_op(op, rank):
return comm_requirements_to_ranks return comm_requirements_to_ranks
def analyze_requirements_for_program(program, rank): def analyze_requirements_for_program(src_info, rank):
program = src_info[0]
g_process_group_map = src_info[1]
resource_requirements = {} resource_requirements = {}
comm_requirements_to_ranks = {} comm_requirements_to_ranks = {}
# only support device_type and only support GPU for now # only support device_type and only support GPU for now
...@@ -150,7 +152,7 @@ def analyze_requirements_for_program(program, rank): ...@@ -150,7 +152,7 @@ def analyze_requirements_for_program(program, rank):
for block in program.blocks: for block in program.blocks:
for op in block.ops: for op in block.ops:
cur_comm_requirements_to_ranks = analyze_comm_requirements_from_op( cur_comm_requirements_to_ranks = analyze_comm_requirements_from_op(
op, rank) op, rank, g_process_group_map)
for tgt_rank, link_info in cur_comm_requirements_to_ranks.items(): for tgt_rank, link_info in cur_comm_requirements_to_ranks.items():
if tgt_rank in comm_requirements_to_ranks: if tgt_rank in comm_requirements_to_ranks:
comm_requirements_to_ranks[tgt_rank][ comm_requirements_to_ranks[tgt_rank][
...@@ -164,9 +166,9 @@ def analyze_requirements_for_program(program, rank): ...@@ -164,9 +166,9 @@ def analyze_requirements_for_program(program, rank):
def build_process_graph(distributed_program): def build_process_graph(distributed_program):
graph = Graph() graph = Graph()
for src_rank, src_program in distributed_program.items(): for src_rank, src_info in distributed_program.items():
resource_requirements, comm_requirements_to_ranks = analyze_requirements_for_program( resource_requirements, comm_requirements_to_ranks = analyze_requirements_for_program(
src_program, src_rank) src_info, src_rank)
graph.add_node(src_rank, resource_requirements=resource_requirements) graph.add_node(src_rank, resource_requirements=resource_requirements)
for tgt_rank, comm_requirements in comm_requirements_to_ranks.items(): for tgt_rank, comm_requirements in comm_requirements_to_ranks.items():
graph.add_edge( graph.add_edge(
......
...@@ -308,6 +308,8 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -308,6 +308,8 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
assert len(x_dims_mapping) >= len( assert len(x_dims_mapping) >= len(
y_dims_mapping), "now just support x dims > y dims" y_dims_mapping), "now just support x dims > y dims"
if len(y_dims_mapping) != 2:
return False
if len(x_dims_mapping) == len(y_dims_mapping) and len( if len(x_dims_mapping) == len(y_dims_mapping) and len(
x_dims_mapping) == 4: x_dims_mapping) == 4:
if x_dims_mapping[:2] != y_dims_mapping[:2]: if x_dims_mapping[:2] != y_dims_mapping[:2]:
...@@ -602,6 +604,8 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -602,6 +604,8 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
# for gpt2, x dims > y dims, this is a temporary solution # for gpt2, x dims > y dims, this is a temporary solution
assert len(x_dims_mapping) >= len( assert len(x_dims_mapping) >= len(
y_dims_mapping), "now just support x dims > y dims" y_dims_mapping), "now just support x dims > y dims"
if len(y_dims_mapping) != 2:
return False
if len(x_dims_mapping) == len(y_dims_mapping) and len( if len(x_dims_mapping) == len(y_dims_mapping) and len(
x_dims_mapping) == 4: x_dims_mapping) == 4:
if x_dims_mapping[:2] != y_dims_mapping[:2]: if x_dims_mapping[:2] != y_dims_mapping[:2]:
...@@ -889,6 +893,8 @@ class DistributedMatmulImpl2(DistributedOperatorImpl): ...@@ -889,6 +893,8 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
y_dims_mapping y_dims_mapping
), "now just support x dims > y dims,but x:{0} and y:{1}".format( ), "now just support x dims > y dims,but x:{0} and y:{1}".format(
x_dims_mapping, y_dims_mapping) x_dims_mapping, y_dims_mapping)
if len(y_dims_mapping) != 2:
return False
if len(x_dims_mapping) == len(y_dims_mapping) and len( if len(x_dims_mapping) == len(y_dims_mapping) and len(
x_dims_mapping) == 4: x_dims_mapping) == 4:
if x_dims_mapping[:2] != y_dims_mapping[:2]: if x_dims_mapping[:2] != y_dims_mapping[:2]:
...@@ -1010,6 +1016,8 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -1010,6 +1016,8 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
return False return False
assert len(x_dims_mapping) >= len( assert len(x_dims_mapping) >= len(
y_dims_mapping), "now just support x dims > y dims" y_dims_mapping), "now just support x dims > y dims"
if len(y_dims_mapping) != 2:
return False
if len(x_dims_mapping) == len(y_dims_mapping) and len( if len(x_dims_mapping) == len(y_dims_mapping) and len(
x_dims_mapping) == 4: x_dims_mapping) == 4:
if x_dims_mapping[:2] != y_dims_mapping[:2]: if x_dims_mapping[:2] != y_dims_mapping[:2]:
...@@ -1297,6 +1305,8 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -1297,6 +1305,8 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
assert len(x_dims_mapping) >= len( assert len(x_dims_mapping) >= len(
y_dims_mapping), "now just support x dims > y dims" y_dims_mapping), "now just support x dims > y dims"
if len(y_dims_mapping) != 2:
return False
if len(x_dims_mapping) == len(y_dims_mapping) and len( if len(x_dims_mapping) == len(y_dims_mapping) and len(
x_dims_mapping) == 4: x_dims_mapping) == 4:
if x_dims_mapping[:2] != y_dims_mapping[:2]: if x_dims_mapping[:2] != y_dims_mapping[:2]:
...@@ -1583,7 +1593,8 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl): ...@@ -1583,7 +1593,8 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
y_dims_mapping y_dims_mapping
), "now just support x dims > y dims,but x:{0} and y:{1}".format( ), "now just support x dims > y dims,but x:{0} and y:{1}".format(
x_dims_mapping, y_dims_mapping) x_dims_mapping, y_dims_mapping)
if len(y_dims_mapping) != 2:
return False
if len(x_dims_mapping) == len(y_dims_mapping) and len( if len(x_dims_mapping) == len(y_dims_mapping) and len(
x_dims_mapping) == 4: x_dims_mapping) == 4:
if x_dims_mapping[:2] != y_dims_mapping[:2]: if x_dims_mapping[:2] != y_dims_mapping[:2]:
......
...@@ -20,6 +20,9 @@ import copy ...@@ -20,6 +20,9 @@ import copy
import pathlib import pathlib
import subprocess import subprocess
import logging import logging
import pickle
import time
import paddle import paddle
from paddle.distributed.utils import get_logger from paddle.distributed.utils import get_logger
from paddle.distributed.fleet import cloud_utils from paddle.distributed.fleet import cloud_utils
...@@ -30,13 +33,18 @@ from .dist_context import set_default_distributed_context ...@@ -30,13 +33,18 @@ from .dist_context import set_default_distributed_context
from .completion import complete_annotation, complete_backward_annotation from .completion import complete_annotation, complete_backward_annotation
from .partitioner import Partitioner from .partitioner import Partitioner
from .process_group import get_all_process_groups from .process_group import get_all_process_groups
from .process_group import get_process_group
from .process_group import get_world_process_groups from .process_group import get_world_process_groups
from .process_group import _g_process_group_map, ProcessGroup
from .utils import make_data_unshard from .utils import make_data_unshard
from .utils import set_grad_var_shape from .utils import set_grad_var_shape
from .reshard import reshard from .utils import SerialProgramInfo
from .reshard import reshard, HAS_SENT, HAS_RECV, HAS_ALLGATHER
from .cluster import Cluster from .cluster import Cluster
from .mapper import mapping from .mapper import mapping
# from .auto_search import auto_search from .dist_op import DistributedOperator
from .dist_tensor import DistributedTensor
from .planner import Planner
_logger = get_logger(logging.INFO) _logger = get_logger(logging.INFO)
...@@ -82,12 +90,20 @@ class AutoParallelizer: ...@@ -82,12 +90,20 @@ class AutoParallelizer:
if suffix in attr_name: if suffix in attr_name:
op._remove_attr(attr_name) op._remove_attr(attr_name)
def _get_dist_program(self, dist_context, rank): def _get_dist_program(self, rank, dist_context=None, relaunch_phase=False):
completed_main_program = None
if dist_context is None:
# Annotation completion # Annotation completion
self._dist_context = DistributedContext()
_logger.info("Start annotation dist attr.")
completed_main_program = complete_annotation(self._main_program, completed_main_program = complete_annotation(self._main_program,
dist_context) self._dist_context)
else:
completed_main_program = self._main_program
self._dist_context = copy.deepcopy(dist_context)
# Logical partition # Logical partition
partitioner = Partitioner(self._dist_strategy, dist_context, rank) partitioner = Partitioner(self._dist_strategy, self._dist_context, rank)
dist_main_prog, dist_startup_prog = partitioner.transpile_forward( dist_main_prog, dist_startup_prog = partitioner.transpile_forward(
completed_main_program, self._startup_program) completed_main_program, self._startup_program)
dist_params_grads = partitioner.apply_backward( dist_params_grads = partitioner.apply_backward(
...@@ -97,11 +113,21 @@ class AutoParallelizer: ...@@ -97,11 +113,21 @@ class AutoParallelizer:
copy.deepcopy(self._optimizer), dist_params_grads, dist_main_prog, copy.deepcopy(self._optimizer), dist_params_grads, dist_main_prog,
dist_startup_prog) dist_startup_prog)
make_data_unshard(dist_main_prog, dist_startup_prog, dist_context) set_grad_var_shape(dist_main_prog, self._dist_context)
reshard(dist_main_prog, dist_startup_prog, rank, dist_context) make_data_unshard(dist_main_prog, dist_startup_prog, self._dist_context)
return dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog reshard(dist_main_prog, dist_startup_prog, rank, self._dist_context)
g_process_group_map = None
if not relaunch_phase:
g_process_group_map = copy.deepcopy(_g_process_group_map)
HAS_SENT.clear()
HAS_RECV.clear()
HAS_ALLGATHER.clear()
_g_process_group_map.clear()
_g_process_group_map[0] = ProcessGroup(0, [])
return dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog, g_process_group_map
def parallelize(self, def parallelize(self,
loss, loss,
...@@ -121,11 +147,51 @@ class AutoParallelizer: ...@@ -121,11 +147,51 @@ class AutoParallelizer:
"The cluster must not be none when using auto mapping." "The cluster must not be none when using auto mapping."
dist_programs = {} dist_programs = {}
world_process_group = get_world_process_groups() world_process_group = get_world_process_groups()
dist_context = None
# auto search
if self._dist_strategy.auto_search:
logging.info("Start searching dist attr.")
serial_program_info = SerialProgramInfo(
self._main_program, self._startup_program, self._loss,
self._optimizer, self._cluster)
planner = Planner(
serial_program_info,
algorithm_config={"name": "mcmc",
"max_search_times": 5})
dist_context, _ = planner.search()
logging.info("End searching dist attr.")
# serialize the dist context by planner
if dist_context is not None:
logging.info("Start serialize searched dist attr")
cwd = pathlib.Path().resolve()
searched_dist_context_path = os.path.join(
cwd, f"searched_dist_context_{time.time()}.pkl")
saved_dist_context = {}
ops_dist_attr = {}
tensors_dist_attr = {}
for key, dist_op in dist_context._dist_ops_for_program.items():
ops_dist_attr[key] = dist_op.dist_attr
for key, dist_tensor in dist_context._dist_tensors_for_program.items(
):
tensors_dist_attr[key] = dist_tensor.dist_attr
saved_dist_context["ops_dist_attr"] = ops_dist_attr
saved_dist_context["tensors_dist_attr"] = tensors_dist_attr
saved_dist_context[
"process_meshes"] = dist_context._process_meshes
with open(searched_dist_context_path,
"wb") as dist_context_file:
pickle.dump(saved_dist_context, dist_context_file)
os.environ[
'PADDLE_SEARCHED_DIST_CONTEXT_PATH'] = searched_dist_context_path
logging.info(
f"End serialize searched dist attr to {searched_dist_context_path}"
)
for rank in world_process_group.ranks: for rank in world_process_group.ranks:
dist_context = DistributedContext() dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog, g_process_group_map = self._get_dist_program(
dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog = self._get_dist_program( rank, dist_context)
dist_context, rank) dist_programs[rank] = [dist_main_prog, g_process_group_map]
dist_programs[rank] = dist_main_prog
# Do the mapping between the distributed program graph and the cluster graph # Do the mapping between the distributed program graph and the cluster graph
rank_mapping_dict = mapping(dist_programs, self._cluster) rank_mapping_dict = mapping(dist_programs, self._cluster)
...@@ -162,9 +228,64 @@ class AutoParallelizer: ...@@ -162,9 +228,64 @@ class AutoParallelizer:
else: else:
# Parallelization after the mapping pass # Parallelization after the mapping pass
rank = paddle.distributed.get_rank() rank = paddle.distributed.get_rank()
dist_context = None
searched_dist_context_path = os.getenv(
"PADDLE_SEARCHED_DIST_CONTEXT_PATH", None)
if searched_dist_context_path is not None:
with open(searched_dist_context_path,
"rb") as dist_context_file:
saved_dist_context = pickle.load(dist_context_file)
dist_context = DistributedContext()
for op in self._main_program.global_block().ops:
dist_attr = saved_dist_context["ops_dist_attr"][
op.desc.id()]
dist_op = DistributedOperator(op, dist_attr)
dist_context.add_dist_op_for_program(dist_op)
vars = self._main_program.global_block().vars
for var in vars.values():
dist_attr = saved_dist_context["tensors_dist_attr"][
var.desc.id()]
dist_tensor = DistributedTensor(var, dist_attr)
dist_context.add_dist_tensor_for_program(dist_tensor)
dist_context._process_meshes = saved_dist_context[
"process_meshes"]
else:
if self._dist_strategy.auto_search:
serial_program_info = SerialProgramInfo(
self._main_program,
self._startup_program,
self._loss,
self._optimizer,
cluster=self._cluster)
planner = Planner(
serial_program_info,
algorithm_config={
"name": "mcmc",
"max_search_times": 5
})
dist_context, _ = planner.search()
# rebuild g_process_group
if dist_context is not None:
pg0 = get_process_group(0)
for process_mesh in dist_context._process_meshes:
pg0.add_ranks(process_mesh.processes)
dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog, _ = self._get_dist_program(
rank, dist_context, relaunch_phase=True)
dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog = self._get_dist_program( # NOTE: This is a trick to fix hang in pipeline mode when dist context is searched by planner
self._dist_context, rank) if self._dist_strategy.auto_search:
is_pipeline = False
for op in dist_main_prog.global_block().ops:
if op.type == "send_v2" or op.type == "recv_v2":
is_pipeline = True
break
if is_pipeline:
with paddle.static.program_guard(dist_main_prog):
paddle.distributed.barrier()
# Traverse different rank programs and traverse each op of them, # Traverse different rank programs and traverse each op of them,
# instantiate communication by process_mapping. # instantiate communication by process_mapping.
......
...@@ -32,6 +32,7 @@ from .completion import is_elementwise_like_op ...@@ -32,6 +32,7 @@ from .completion import is_elementwise_like_op
from .operators.common import get_distributed_operator_impl_container from .operators.common import get_distributed_operator_impl_container
from .utils import update_op_dims_mapping_by_default_dist_impl from .utils import update_op_dims_mapping_by_default_dist_impl
from .utils import update_op_dims_mapping_by_elementwise_like_dist_impl from .utils import update_op_dims_mapping_by_elementwise_like_dist_impl
from .utils import get_all_distributed_main_program
from .dist_context import DistributedContext, DistributedOperatorContext from .dist_context import DistributedContext, DistributedOperatorContext
from .dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute from .dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute
...@@ -370,3 +371,499 @@ class PlanSpace: ...@@ -370,3 +371,499 @@ class PlanSpace:
)] = [op_valid_dist_attrs, pipeline_stage] )] = [op_valid_dist_attrs, pipeline_stage]
return valid_dist_attr_dict, pipeline_process_meshes, global_process_mesh return valid_dist_attr_dict, pipeline_process_meshes, global_process_mesh
class SearchAlgorithm:
def __init__(self, name):
self._name = name
@property
def name(self):
self.name = name
def search(self):
raise NotImplementedError("Please Implement this method in subclass.")
class MCMC(SearchAlgorithm):
def __init__(self, serial_program_info, max_search_times=5):
super(MCMC, self).__init__("mcmc")
self._serial_program_info = serial_program_info
self._max_search_times = max_search_times
@property
def serial_program_info(self):
return self._serial_program_info
@property
def max_search_times(self):
return self._max_search_times
def make_special_op_unshard(self, op, ops, vars, dist_context,
valid_dist_attr_dict):
if op.type == "softmax_with_cross_entropy":
for var_name in op.input_arg_names:
dims_mapping = dist_context.get_op_dist_attr_for_program(
op).get_input_dims_mapping(var_name)
if dims_mapping != dist_context.get_tensor_dist_attr_for_program(
vars[var_name]).dims_mapping:
has_changed = False
for search_op in ops:
if var_name in search_op.output_arg_names:
op_dist_attr_list = valid_dist_attr_dict[
search_op.desc.id()][0]
for op_dist_attr in op_dist_attr_list:
if op_dist_attr.get_output_dims_mapping(
var_name) == dims_mapping:
dist_context.set_op_dist_attr_for_program(
search_op, op_dist_attr)
tensor_dist_attr = TensorDistributedAttribute(
)
tensor_dist_attr.process_mesh = op_dist_attr.process_mesh
tensor_dist_attr.dims_mapping = op_dist_attr.get_output_dims_mapping(
var_name)
dist_context.set_tensor_dist_attr_for_program(
vars[var_name], tensor_dist_attr)
has_changed = True
break
if has_changed:
break
if not has_changed:
raise ValueError(
"Change softmax_with_cross_entropy dist attr failed")
def init_program(self, valid_dist_attr_dict, program,
pipeline_process_meshes, global_process_mesh):
ops = program.global_block().ops
vars = program.global_block().vars
new_dist_context = DistributedContext()
for op in ops:
op_valid_dist_attr_list = valid_dist_attr_dict[op.desc.id()][0]
random_op_dist_attr = np.random.randint(
len(op_valid_dist_attr_list))
init_op_dist_attr = op_valid_dist_attr_list[random_op_dist_attr]
new_dist_context.set_op_dist_attr_for_program(op, init_op_dist_attr)
for var_name in op.input_arg_names:
if var_name == "lod_tensor_blocking_queue_0":
continue
if new_dist_context.get_tensor_dist_attr_for_program(vars[
var_name]) is None:
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.process_mesh = init_op_dist_attr.process_mesh
tensor_dist_attr.dims_mapping = init_op_dist_attr.get_input_dims_mapping(
var_name)
new_dist_context.set_tensor_dist_attr_for_program(
vars[var_name], tensor_dist_attr)
for var_name in op.output_arg_names:
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.process_mesh = init_op_dist_attr.process_mesh
tensor_dist_attr.dims_mapping = init_op_dist_attr.get_output_dims_mapping(
var_name)
new_dist_context.set_tensor_dist_attr_for_program(
vars[var_name], tensor_dist_attr)
# NOTE: this is a temporary solution to make softmax_with_cross_entropy unshard
self.make_special_op_unshard(op, ops, vars, new_dist_context,
valid_dist_attr_dict)
# add process meshes to distributed context
if global_process_mesh is not None:
new_dist_context.add_process_mesh(global_process_mesh)
elif pipeline_process_meshes is not None:
for process_mesh in pipeline_process_meshes:
new_dist_context.add_process_mesh(process_mesh)
return new_dist_context
def estimate_searched_strategy_cost(self,
dist_context,
pipeline_process_meshes=None):
cost = None
# get all distributed programs
all_dist_main_program = get_all_distributed_main_program(
self.serial_program_info, dist_context)
pipeline_config = [
process_mesh.processes for process_mesh in pipeline_process_meshes
] if pipeline_process_meshes is not None else None
microbatch_size = 1
for program in all_dist_main_program:
searched_batch_size = False
for var in program.list_vars():
if var.is_data and "@RESHARD" in var.name:
microbatch_size = var.shape[0]
searched_batch_size = True
break
if searched_batch_size:
break
from .utils import get_standalone_cost_data
standalone_cost_data = get_standalone_cost_data(all_dist_main_program)
# cost model does not support cluster argument
cost = estimate_cost(
all_dist_main_program,
cluster=None,
pipeline_config=pipeline_config,
standalone_cost_data=standalone_cost_data,
batch_size=microbatch_size)
return cost
def set_tensor_dist_attr(self, op, op_dist_attr, vars, dist_context):
# set output tensor distributed attribute
for var_name in op.output_arg_names:
process_mesh = op_dist_attr.process_mesh
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.process_mesh = process_mesh
tensor_dist_attr.dims_mapping = op_dist_attr.get_output_dims_mapping(
var_name)
dist_context.set_tensor_dist_attr_for_program(vars[var_name],
tensor_dist_attr)
# set input tensor distributed attribute if input is data or parameter
for var_name in op.input_arg_names:
if vars[var_name].is_parameter or vars[var_name].is_data:
process_mesh = op_dist_attr.process_mesh
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.process_mesh = process_mesh
tensor_dist_attr.dims_mapping = op_dist_attr.get_input_dims_mapping(
var_name)
dist_context.set_tensor_dist_attr_for_program(vars[var_name],
tensor_dist_attr)
def change_process_mesh(self, op, changed_process_mesh, vars, dist_context):
dist_context.get_op_dist_attr_for_program(
op).process_mesh = changed_process_mesh
for var_name in op.output_arg_names:
dist_context.get_tensor_dist_attr_for_program(vars[
var_name]).process_mesh = changed_process_mesh
for var_name in op.input_arg_names:
if vars[var_name].is_parameter or vars[var_name].is_data:
dist_context.get_tensor_dist_attr_for_program(vars[
var_name]).process_mesh = changed_process_mesh
def search_once(self,
program,
valid_dist_attr_dict,
dist_context,
pipeline_process_meshes=None):
raw_ops = program.global_block().ops
ops = []
for op in raw_ops:
if op.type not in PlanSpace.not_enum_ops:
ops.append(op)
assert ops, "The ops of program have no distributed attributes."
vars = program.global_block().vars
new_dist_context = copy.deepcopy(dist_context)
new_dist_context._dist_op_context = DistributedOperatorContext()
new_valid_dist_attr_dict = None
random_selected_op_idx = np.random.randint(len(ops))
selected_op = ops[random_selected_op_idx]
op_valid_dist_attr_list = valid_dist_attr_dict[selected_op.desc.id()][0]
pipeline_stage = valid_dist_attr_dict[selected_op.desc.id()][1]
random_selected_dist_attr_idx = np.random.randint(
len(op_valid_dist_attr_list))
selected_op_dist_attr = copy.deepcopy(op_valid_dist_attr_list[
random_selected_dist_attr_idx])
start_idx = ops[0].desc.id()
if pipeline_stage > -1:
# in pipeline mode, the above phase just select a dims mapping
# 0 represents not changed, 1 represents to be the same with before stage, 2 represents to be the same with the latter stage
new_valid_dist_attr_dict = copy.deepcopy(valid_dist_attr_dict)
changed_mode = np.random.randint(3)
if changed_mode == 0:
# not change the process mesh, just change dims mapping
new_dist_context.set_op_dist_attr_for_program(
selected_op, selected_op_dist_attr)
self.set_tensor_dist_attr(selected_op, selected_op_dist_attr,
vars, new_dist_context)
elif changed_mode == 1:
changed_stage = pipeline_stage - 1
if changed_stage == -1 or random_selected_op_idx == len(ops) - 1 or \
(random_selected_op_idx + 1 == len(ops) - 1 and new_valid_dist_attr_dict[ops[random_selected_op_idx + 1].desc.id()][1] == pipeline_stage + 1 ):
new_dist_context.set_op_dist_attr_for_program(
selected_op, selected_op_dist_attr)
self.set_tensor_dist_attr(selected_op,
selected_op_dist_attr, vars,
new_dist_context)
else:
selected_op_process_mesh = pipeline_process_meshes[
pipeline_stage]
next_op_id = ops[random_selected_op_idx + 1].desc.id()
if new_valid_dist_attr_dict[next_op_id][
1] == pipeline_stage + 1 and random_selected_op_idx + 1 != len(
ops) - 1:
new_valid_dist_attr_dict[next_op_id][1] = pipeline_stage
for op_dist_attr in new_valid_dist_attr_dict[
next_op_id][0]:
op_dist_attr.process_mesh = selected_op_process_mesh
# set next op dist attr in the discontext and output/input tensor process mesh
self.change_process_mesh(
ops[random_selected_op_idx + 1],
selected_op_process_mesh, vars, new_dist_context)
# change the selected op stage and output dist attr
new_valid_dist_attr_dict[selected_op.desc.id()][
1] = changed_stage
new_process_mesh = pipeline_process_meshes[changed_stage]
selected_op_dist_attr.process_mesh = new_process_mesh
for op_dist_attr in new_valid_dist_attr_dict[
selected_op.desc.id()][0]:
op_dist_attr.process_mesh = new_process_mesh
new_dist_context.set_op_dist_attr_for_program(
selected_op, selected_op_dist_attr)
self.set_tensor_dist_attr(selected_op,
selected_op_dist_attr, vars,
new_dist_context)
# change the pre op stage
for idx in range(random_selected_op_idx - 1, -1, -1):
stage = new_valid_dist_attr_dict[ops[idx].desc.id()][1]
valid_dist_attr_list = new_valid_dist_attr_dict[ops[
idx].desc.id()][0]
new_process_mesh = pipeline_process_meshes[
changed_stage]
if stage == changed_stage + 1:
new_valid_dist_attr_dict[ops[idx].desc.id()][
1] = changed_stage
for op_dist_attr in valid_dist_attr_list:
op_dist_attr.process_mesh = new_process_mesh
new_dist_context.get_op_dist_attr_for_program(ops[
idx]).process_mesh = new_process_mesh
# change process mesh of the output and input tensor
self.change_process_mesh(ops[idx], new_process_mesh,
vars, new_dist_context)
else:
break
else:
changed_stage = pipeline_stage + 1
if changed_stage == len(
pipeline_process_meshes) or random_selected_op_idx == 0 or \
(new_valid_dist_attr_dict[ops[random_selected_op_idx - 1].desc.id()][1] == pipeline_stage - 1 and (random_selected_op_idx == 1)):
new_dist_context.set_op_dist_attr_for_program(
selected_op, selected_op_dist_attr)
self.set_tensor_dist_attr(selected_op,
selected_op_dist_attr, vars,
new_dist_context)
else:
selected_op_process_mesh = pipeline_process_meshes[
pipeline_stage]
pre_op_id = ops[random_selected_op_idx - 1].desc.id()
if new_valid_dist_attr_dict[pre_op_id][
1] == pipeline_stage - 1 and random_selected_op_idx != 1:
new_valid_dist_attr_dict[pre_op_id][1] = pipeline_stage
for op_dist_attr in new_valid_dist_attr_dict[pre_op_id][
0]:
op_dist_attr.process_mesh = selected_op_process_mesh
# set pre op dist attr in the discontext and output tensor process mesh
self.change_process_mesh(
ops[random_selected_op_idx - 1],
selected_op_process_mesh, vars, new_dist_context)
# change the selected op stage and output tensor dist attr
new_valid_dist_attr_dict[selected_op.desc.id()][
1] = changed_stage
new_process_mesh = pipeline_process_meshes[changed_stage]
selected_op_dist_attr.process_mesh = new_process_mesh
for op_dist_attr in new_valid_dist_attr_dict[
selected_op.desc.id()][0]:
op_dist_attr.process_mesh = new_process_mesh
new_dist_context.set_op_dist_attr_for_program(
selected_op, selected_op_dist_attr)
self.set_tensor_dist_attr(selected_op,
selected_op_dist_attr, vars,
new_dist_context)
# change the next op stage
for idx in range(random_selected_op_idx + 1, len(ops)):
stage = new_valid_dist_attr_dict[ops[idx].desc.id()][1]
valid_dist_attr_list = new_valid_dist_attr_dict[ops[
idx].desc.id()][0]
new_process_mesh = pipeline_process_meshes[
changed_stage]
if stage == changed_stage - 1:
new_valid_dist_attr_dict[ops[idx].desc.id()][
1] = changed_stage
for op_dist_attr in valid_dist_attr_list:
op_dist_attr.process_mesh = new_process_mesh
new_dist_context.get_op_dist_attr_for_program(ops[
idx]).process_mesh = new_process_mesh
# change the output tensor dist attr
self.change_process_mesh(ops[idx], new_process_mesh,
vars, new_dist_context)
else:
break
else:
new_dist_context.set_op_dist_attr_for_program(selected_op,
selected_op_dist_attr)
self.set_tensor_dist_attr(selected_op, selected_op_dist_attr, vars,
new_dist_context)
for op in ops:
# make softmax_with_cross_entropy unshard
if op.type == "softmax_with_cross_entropy":
self.make_special_op_unshard(op, ops, vars, new_dist_context,
valid_dist_attr_dict)
break
if new_valid_dist_attr_dict is None:
return valid_dist_attr_dict, new_dist_context
else:
return new_valid_dist_attr_dict, new_dist_context
def _search_core(self,
valid_dist_attr_dict,
init_dist_context,
pipeline_process_meshes=None):
times = 0
best_dist_context = init_dist_context
cost = self.estimate_searched_strategy_cost(
init_dist_context, pipeline_process_meshes).runtime
min_cost = cost
while times < self.max_search_times:
times += 1
new_dist_context = self.search_once(
self.serial_program_info.train_program, valid_dist_attr_dict,
best_dist_context, pipeline_process_meshes)[1]
cur_cost = self.estimate_searched_strategy_cost(
new_dist_context, pipeline_process_meshes).runtime
if (min_cost - cur_cost) > 0:
best_dist_context = copy.deepcopy(new_dist_context)
min_cost = cur_cost
times = 0
return best_dist_context, min_cost
def search(self):
logging.info("Start MCMC searching.")
start_time = time.time()
train_program = self.serial_program_info.train_program
cluster = self.serial_program_info.cluster
processes = paddle.distributed.get_world_size(
) if cluster is None else len(cluster.get_all_devices("GPU"))
assert processes > 0, "Get process failed."
process_mesh_topology_list = PlanSpace.enum_process_mesh_topology(
processes)
searched_dist_context = None
min_cost = None
searched_pipeline_dist_context = None
pipeline_min_cost = None
for process_mesh_topology in process_mesh_topology_list:
logging.info(
"MCMC search: search process mesh {} with pipeline mode.".
format(process_mesh_topology))
valid_dist_attr_dict, pipeline_process_meshes, global_process_mesh = PlanSpace.enum_valid_dist_attr_for_program(
train_program, process_mesh_topology, True)
init_dist_context = self.init_program(
valid_dist_attr_dict, train_program, pipeline_process_meshes,
global_process_mesh)
best_dist_context, cost = self._search_core(valid_dist_attr_dict,
init_dist_context,
pipeline_process_meshes)
logging.info(
"MCMC search: the min cost is {} in the process mesh {} with pipeline mode.".
format(cost, process_mesh_topology))
best_dist_context._dist_op_context = DistributedOperatorContext()
pipeline_min_cost = cost if pipeline_min_cost is None else pipeline_min_cost
searched_pipeline_dist_context = best_dist_context if searched_pipeline_dist_context is None else searched_pipeline_dist_context
if pipeline_min_cost > cost:
searched_pipeline_dist_context = best_dist_context
pipeline_min_cost = cost
searched_non_pipeline_dist_context = None
non_pipeline_min_cost = None
for process_mesh_topology in process_mesh_topology_list:
# if process_mesh_topology shape is 3, include pipeline mode by default
if len(process_mesh_topology) == 3:
continue
logging.info(
"MCMC search: search process mesh {} without pipeline mode.".
format(process_mesh_topology))
valid_dist_attr_dict, pipeline_process_meshes, global_process_mesh = PlanSpace.enum_valid_dist_attr_for_program(
train_program, process_mesh_topology, False)
init_dist_context = self.init_program(
valid_dist_attr_dict, train_program, pipeline_process_meshes,
global_process_mesh)
best_dist_context, cost = self._search_core(valid_dist_attr_dict,
init_dist_context,
pipeline_process_meshes)
logging.info(
"MCMC search: the min cost is {} in the process mesh {} without pipeline mode.".
format(cost, process_mesh_topology))
best_dist_context._dist_op_context = DistributedOperatorContext()
non_pipeline_min_cost = cost if non_pipeline_min_cost is None else non_pipeline_min_cost
searched_non_pipeline_dist_context = best_dist_context if searched_non_pipeline_dist_context is None else searched_non_pipeline_dist_context
if non_pipeline_min_cost > cost:
searched_non_pipeline_dist_context = best_dist_context
non_pipeline_min_cost = cost
if non_pipeline_min_cost > pipeline_min_cost:
searched_dist_context = searched_pipeline_dist_context
min_cost = pipeline_min_cost
logging.info(
"Better set FLAGS_benchmark=1 to avoid hang problem in the pipeline mode."
)
else:
searched_dist_context = searched_non_pipeline_dist_context
min_cost = non_pipeline_min_cost
# rebuild g_process_group
pg0 = get_process_group(0)
for process_mesh in searched_dist_context._process_meshes:
pg0.add_ranks(process_mesh.processes)
end_time = time.time()
logging.info(
"End MCMC searching: the min cost is {} and the search time is {}s.".
format(min_cost, end_time - start_time))
return searched_dist_context, min_cost
class Planner:
def __init__(self, serial_program_info, algorithm_config=None):
self._serial_program_info = serial_program_info
self._algorithm_config = algorithm_config
self._algorithm_searcher = self.create_algorithm_searcher(
algorithm_config)
@property
def serial_program_info(self):
return self._serial_program_info
@property
def algorithm_config(self):
return self._algorithm_config
@property
def algorithm_searcher(self):
return self._algorithm_searcher
def create_algorithm_searcher(self, algorithm_config):
name = algorithm_config.get("name", None)
assert name is not None, "Invalid algorithm config."
algorithm_searcher = None
if name == "mcmc":
# NOTE: Only GPU clusters are supported now.
max_search_times = algorithm_config.get("max_search_times", None)
algorithm_searcher = MCMC(
self.serial_program_info,
max_search_times) if max_search_times is not None else MCMC(
self.serial_program_info)
else:
raise NotImplementedError(
"Other search algorithms have not been supported now.")
return algorithm_searcher
def search(self):
return self.algorithm_searcher.search()
...@@ -25,9 +25,12 @@ def get_all_process_groups(): ...@@ -25,9 +25,12 @@ def get_all_process_groups():
return _g_process_group_map.values() return _g_process_group_map.values()
def get_process_group(group_id): def get_process_group(group_id, g_process_group_map=None):
global _g_process_group_map global _g_process_group_map
return _g_process_group_map.get(group_id, None) return _g_process_group_map.get(
group_id,
None) if g_process_group_map is None else g_process_group_map.get(
group_id, None)
def get_world_process_groups(): def get_world_process_groups():
......
...@@ -19,9 +19,11 @@ import threading ...@@ -19,9 +19,11 @@ import threading
import numpy as np import numpy as np
import warnings import warnings
import logging import logging
from functools import reduce
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.framework.io import _to_LodTensor from paddle.framework.io import _to_LodTensor
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.fluid.io import is_parameter, is_belong_to_optimizer from paddle.fluid.io import is_parameter, is_belong_to_optimizer
...@@ -1258,3 +1260,95 @@ class SerialProgramInfo: ...@@ -1258,3 +1260,95 @@ class SerialProgramInfo:
@property @property
def cluster(self): def cluster(self):
return self._cluster return self._cluster
def get_standalone_cost_data(distributed_programs):
def _compute_runtime(op_cost, op, vars):
runtime = 0
try:
runtime = float(op_cost["op_time"])
except:
return runtime
op_config = op_cost["config"]
total_static_input_size = 0
total_actual_input_size = 0
parsed_info = op_config.split("\n")
variable = "(Variable)"
for info in parsed_info:
variable = "(Variable)" if "(Variable)" in info else "(list<Variable>"
if variable in info:
arg_name_lower = info[:info.find(variable) - 1]
shape_left_boundary = info.find("[")
shape_right_boundary = info.find("]")
assert shape_left_boundary > 0 and shape_right_boundary > 0 and shape_right_boundary > shape_left_boundary, "Get shape failed."
shape = info[shape_left_boundary + 1:
shape_right_boundary].split(",")
shape = list(map(lambda x: int(x.strip()), shape))
dtype_factor = 1
total_static_input_size += reduce(lambda x, y: x * y, shape)
# print(arg_name_lower)
if op.type == "c_embedding":
arg_name_lower = "w" if arg_name_lower == "weight" else "ids"
for arg_name in op.input_names:
if arg_name.lower() == arg_name_lower:
for var_name in op.input(arg_name):
var = vars[var_name]
total_actual_input_size += reduce(
lambda x, y: x * y, var.shape)
break
assert total_static_input_size > 0 and total_actual_input_size > 0, "Get input size failed."
actual_runtime = total_actual_input_size / total_static_input_size * runtime
return actual_runtime
cost_model = paddle.cost_model.CostModel()
cost_model.static_cost_data()
DEFAULT_MULTIPLE = 2
OP_NAME_MAPPING = {
"c_embedding": "embedding",
"matmul_v2": "matmul",
"transpose2": "transpose",
"reshape2": "reshape",
"unsqueeze2": "unsqueeze",
"reduce_sum": "sum",
"elementwise_div": "divide"
}
standalone_cost_data = []
not_enum_ops = ["create_py_reader", "create_double_buffer_reader", "read"]
for distributed_program in distributed_programs:
cost_data = {}
vars = distributed_program.global_block().vars
for op in distributed_program.global_block().ops:
runtime = 0
if op.type in not_enum_ops:
cost_data[op.desc.id()] = runtime
continue
dtype = str(vars[op.input_arg_names[0]]
.dtype) if op.input_arg_names else "float32"
if int(op.attr('op_role')) == int(OpRole.Backward):
if "_grad" in op.type:
forward_op_name = op.type[:-5]
if forward_op_name in OP_NAME_MAPPING.keys():
forward_op_name = OP_NAME_MAPPING[forward_op_name]
op_cost = cost_model.get_static_op_time(
forward_op_name, forward=False, dtype=dtype)
if op_cost:
runtime = _compute_runtime(op_cost, op, vars)
else:
op_cost = cost_model.get_static_op_time(
forward_op_name, dtype=dtype)
if op_cost:
runtime = 2 * _compute_runtime(op_cost, op, vars)
elif int(op.attr('op_role')) == int(OpRole.Forward):
op_name = OP_NAME_MAPPING[
op.type] if op.type in OP_NAME_MAPPING.keys() else op.type
op_cost = cost_model.get_static_op_time(op_name)
if op_cost:
runtime = _compute_runtime(op_cost, op, vars)
cost_data[op.desc.id()] = runtime
standalone_cost_data.append(cost_data)
return standalone_cost_data
...@@ -1436,7 +1436,7 @@ class Fleet(object): ...@@ -1436,7 +1436,7 @@ class Fleet(object):
context["role_maker"] = self._role_maker context["role_maker"] = self._role_maker
# Use the auto-parallel's routines instead # Use the auto-parallel's routines instead
if self._user_defined_strategy.semi_auto: if self._user_defined_strategy.semi_auto or self._user_defined_strategy.auto_search:
from ...auto_parallel.parallelizer import AutoParallelizer from ...auto_parallel.parallelizer import AutoParallelizer
auto_parallelizer = AutoParallelizer(self) auto_parallelizer = AutoParallelizer(self)
optimize_ops, params_grads, dist_startup_prog, dist_main_prog = auto_parallelizer.parallelize( optimize_ops, params_grads, dist_startup_prog, dist_main_prog = auto_parallelizer.parallelize(
...@@ -1586,13 +1586,13 @@ class Fleet(object): ...@@ -1586,13 +1586,13 @@ class Fleet(object):
] ]
param_grads_fp16 = [ param_grads_fp16 = [
param._grad_ivar() for param in optimizer._parameter_list param._grad_ivar() for param in optimizer._parameter_list
if (param._grad_ivar() is not None) and if (param._grad_ivar() is not None) and (param._grad_ivar(
(param._grad_ivar().dtype == core.VarDesc.VarType.FP16) ).dtype == core.VarDesc.VarType.FP16)
] ]
param_grads_fp32 = [ param_grads_fp32 = [
param._grad_ivar() for param in optimizer._parameter_list param._grad_ivar() for param in optimizer._parameter_list
if (param._grad_ivar() is not None) and if (param._grad_ivar() is not None) and (param._grad_ivar(
(param._grad_ivar().dtype == core.VarDesc.VarType.FP32) ).dtype == core.VarDesc.VarType.FP32)
] ]
temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool)) temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool))
temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool)) temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool))
......
...@@ -3,4 +3,6 @@ ...@@ -3,4 +3,6 @@
if(WITH_DISTRIBUTE AND WITH_GPU) if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_auto_parallel_relaunch MODULES test_auto_parallel_relaunch ENVS ${dist_ENVS}) py_test_modules(test_auto_parallel_relaunch MODULES test_auto_parallel_relaunch ENVS ${dist_ENVS})
set_tests_properties(test_auto_parallel_relaunch PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120) set_tests_properties(test_auto_parallel_relaunch PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120)
py_test_modules(test_relaunch_with_planner MODULES test_relaunch_with_planner ENVS ${dist_ENVS})
set_tests_properties(test_relaunch_with_planner PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120)
endif() endif()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.static as static
from paddle.distributed import fleet
def train():
from auto_parallel_relaunch_model import mlp_pretrain_forward
from auto_parallel_relaunch_model import batch_generator_creator
dist_strategy = fleet.DistributedStrategy()
# init parallel optimizer
dist_strategy.auto_search = True
fleet.init(is_collective=True, strategy=dist_strategy)
train_program = static.Program()
start_program = static.Program()
loss, train_program, start_program, loader = mlp_pretrain_forward(
train_program, start_program)
optimizer = paddle.fluid.optimizer.AdamOptimizer(
learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
optimizer = fleet.distributed_optimizer(optimizer)
_, _, distributed_startup_program, distributed_main_program = optimizer.minimize(
loss, start_program)
places = static.cuda_places()
loader.set_batch_generator(batch_generator_creator(), places=places)
exe = paddle.static.Executor(places[0])
exe.run(distributed_startup_program)
for data in loader():
exe.run(distributed_main_program, feed=data)
if __name__ == "__main__":
train()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import os
import sys
import json
import shutil
import subprocess
from paddle.distributed.fleet.launch_utils import run_with_coverage
class TestPlannerReLaunch(unittest.TestCase):
def test_relaunch_with_planner(self):
from test_auto_parallel_relaunch import cluster_json
file_dir = os.path.dirname(os.path.abspath(__file__))
cluster_json_path = os.path.join(file_dir, "auto_parallel_cluster.json")
cluster_json_object = json.loads(cluster_json)
with open(cluster_json_path, "w") as cluster_json_file:
json.dump(cluster_json_object, cluster_json_file)
launch_model_path = os.path.join(
file_dir, "auto_parallel_relaunch_with_planner.py")
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
else:
coverage_args = []
cmd = [sys.executable, "-u"] + coverage_args + [
"-m", "launch", "--cluster_topo_path", cluster_json_path,
"--enable_auto_mapping", "True", launch_model_path
]
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
# Remove unnecessary files
if os.path.exists(cluster_json_path):
os.remove(cluster_json_path)
rank_mapping_json_path = os.path.join(file_dir,
"auto_parallel_rank_mapping.json")
if os.path.exists(rank_mapping_json_path):
os.remove(rank_mapping_json_path)
log_path = os.path.join(file_dir, "log")
if os.path.exists(log_path):
shutil.rmtree(log_path)
if __name__ == "__main__":
unittest.main()
...@@ -187,10 +187,10 @@ def check_empty_program_runtime(cost): ...@@ -187,10 +187,10 @@ def check_empty_program_runtime(cost):
def check_empty_program_memory(cost): def check_empty_program_memory(cost):
for mem in cost.peak_mem: for mem in cost.peak_mem:
if mem > 0: if mem > 1:
return False return False
for mem in cost.static_mem: for mem in cost.static_mem:
if mem > 0: if mem > 1:
return False return False
return True return True
......
...@@ -529,7 +529,7 @@ class TestAutoParallelMapper(unittest.TestCase): ...@@ -529,7 +529,7 @@ class TestAutoParallelMapper(unittest.TestCase):
train_program, startup_program, dist_context, rank_id) train_program, startup_program, dist_context, rank_id)
# if rank_id == 0: # if rank_id == 0:
# print_program_with_dist_attr(dist_train_program, dist_context) # print_program_with_dist_attr(dist_train_program, dist_context)
dist_programs[rank_id] = dist_train_program dist_programs[rank_id] = [dist_train_program, None]
rank_mapping = mapping(dist_programs, cluster) rank_mapping = mapping(dist_programs, cluster)
......
...@@ -174,7 +174,7 @@ class Testcompatible(unittest.TestCase): ...@@ -174,7 +174,7 @@ class Testcompatible(unittest.TestCase):
op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1, -1]) op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1, -1])
op_dist_attr.set_input_dims_mapping(Y, [-1, -1, -1, -1]) op_dist_attr.set_input_dims_mapping(Y, [-1, -1, -1, -1])
op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1, -1]) op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1, -1])
self.assertTrue(impls[2].is_auto_compatible( self.assertFalse(impls[2].is_auto_compatible(
DistributedOperator(op, op_dist_attr))) DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(Y, [0, -1, -1, -1]) op_dist_attr.set_input_dims_mapping(Y, [0, -1, -1, -1])
self.assertFalse(impls[2].is_auto_compatible( self.assertFalse(impls[2].is_auto_compatible(
...@@ -261,7 +261,7 @@ class Testcompatible(unittest.TestCase): ...@@ -261,7 +261,7 @@ class Testcompatible(unittest.TestCase):
op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1, 1]) op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1, 1])
op_dist_attr.set_input_dims_mapping(Y, [-1, -1, 1, -1]) op_dist_attr.set_input_dims_mapping(Y, [-1, -1, 1, -1])
op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1, -1]) op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1, -1])
self.assertTrue(impls[1].is_auto_compatible( self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr))) DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(Y, [0, -1, -1, -1]) op_dist_attr.set_input_dims_mapping(Y, [0, -1, -1, -1])
self.assertFalse(impls[1].is_auto_compatible( self.assertFalse(impls[1].is_auto_compatible(
...@@ -362,7 +362,7 @@ class Testcompatible(unittest.TestCase): ...@@ -362,7 +362,7 @@ class Testcompatible(unittest.TestCase):
op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1, -1]) op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1, -1])
op_dist_attr.set_input_dims_mapping(Y, [-1, -1, -1, 1]) op_dist_attr.set_input_dims_mapping(Y, [-1, -1, -1, 1])
op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1, 1]) op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1, 1])
self.assertTrue(impls[0].is_auto_compatible( self.assertFalse(impls[0].is_auto_compatible(
DistributedOperator(op, op_dist_attr))) DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(out, [0, -1, -1, 1]) op_dist_attr.set_output_dims_mapping(out, [0, -1, -1, 1])
self.assertFalse(impls[0].is_auto_compatible( self.assertFalse(impls[0].is_auto_compatible(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册