未验证 提交 1bb2c68a 编写于 作者: C caozhou 提交者: GitHub

Add mcmc of planner, of update cost model and relaunch (#38177)

* add planner

* add planner

* add cost model update

* add relaunch updation

* update process_group

* fix error

* add unitest

* update unitest

* update cost model

* avoid api problem
上级 6f439e5a
......@@ -11,12 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import json
import queue
import copy
from enum import Enum
import numpy as np
import paddle
from paddle.fluid import core
from paddle.distributed.fleet.meta_optimizers.common import OpRole
SUCC = 0 # successor
PRED = 1 # predecessor
......@@ -121,6 +126,10 @@ class TensorCostNode(CostNode):
batch_size=None,
shared_node_id=None):
super(TensorCostNode, self).__init__(node, node_type, id)
if node.name == "create_py_reader_0" or node.name == "double_buffer_0":
self.shape = [2, 2]
self.dtype = paddle.float32
else:
self.shape = node.shape
self.dtype = node.dtype
self.dtype_factor = 1
......@@ -130,9 +139,10 @@ class TensorCostNode(CostNode):
self.dtype_factor *= 4
elif node.dtype == paddle.int64:
self.dtype_factor *= 8
elif node.dtype == paddle.uint8:
self.dtype_factor = 1
else:
raise NotImplementedError("{} not counted".format(node.dtype))
self.batch_size = None
if batch_size is not None:
self.batch_size = batch_size
......@@ -155,9 +165,9 @@ class CompOpCostNode(CostNode):
def init_comp_cost(self, cost_data):
# TODO: improve fluid.CostModel for more specific cost_data
op_name = self.node.type
if op_name in cost_data.keys():
self.cost = cost_data[op_name]
op_id = self.node.desc.id()
if op_id in cost_data.keys():
self.cost = cost_data[op_id]
else:
self.cost = 0.0
......@@ -215,8 +225,17 @@ class CostModel(object):
program.blocks) == 1, "Program more than 1 block not supported."
block = program.blocks[0]
var_id = "lod_tensor_blocking_queue_0"
new_var = program.global_block().create_var(
name=var_id,
dtype=paddle.float32,
type=core.VarDesc.VarType.LOD_TENSOR)
nodes[var_id] = TensorCostNode(new_var, CostNodeType.VARIABLE,
"lod_tensor_blocking_queue_0")
for var in block.vars.values():
var_id = var.name
# if var.name == "create_py_reader_0" or var.name == "double_buffer_0":
# continue
nodes[var_id] = TensorCostNode(var, CostNodeType.VARIABLE, var_id)
graph[var_id] = [[], []]
......@@ -225,7 +244,10 @@ class CostModel(object):
if op.type.startswith('c_') or op.type.startswith(
'send') or op.type.startswith('recv'):
is_bwd = False
if op.type.startswith('c_'):
if op.type.startswith(
'c_'
) and op.type != "c_sync_calc_stream" and not op.type.startswith(
'c_embedding'):
ring_id = op.attr('ring_id')
if ring_id not in self.ring2rank:
self.ring2rank[ring_id] = set()
......@@ -238,7 +260,8 @@ class CostModel(object):
op_node = CommOpCostNode(op, CostNodeType.COMMUNICATION, op_id,
is_bwd)
else:
is_bwd = '_grad' in op.type
is_bwd = (int(op.attr('op_role')) == int(OpRole.Backward)
) or "@GRAD" in op.input_arg_names
is_optim = 'LearningRate' in op.input_names
op_node = CompOpCostNode(op, CostNodeType.COMPUTATION, op_id,
is_bwd, is_optim)
......@@ -258,6 +281,7 @@ class CostModel(object):
comm_input_shape = var_node.shape
except:
continue
for i in range(len(op.output_names)):
try:
var_id = op.output(op.output_names[i])[0]
......@@ -361,7 +385,9 @@ class CostModel(object):
for sub_idx in range(self.total_rank):
for node_id, edges in self.op_graph[sub_idx].items():
node = self.nodes[sub_idx][node_id]
if node_id.startswith('c_'):
if node_id.startswith('c_') and not node.id.startswith(
"c_sync_calc_stream") and not node.id.startswith(
'c_embedding'):
ring_id = node.node.attr('ring_id')
node.set_ranks(list(self.ring2rank[ring_id]))
node.init_comm_cost(self.cluster)
......@@ -454,31 +480,52 @@ class CostModel(object):
# delete edges and add new edges
succ = None
runtime_graph[merged_node_id][SUCC] = copy.deepcopy(edges[SUCC])
try:
runtime_graph[merged_node_id][SUCC] = copy.deepcopy(edges[
SUCC])
if len(runtime_graph[pred_id][SUCC]) > 1:
# predecessor has more than 1 successor
# the merged_node is to inherit the rest of its successors
succ = runtime_graph[pred_id][SUCC]
succ.remove(node_id)
runtime_graph[merged_node_id][SUCC] += succ
runtime_graph[merged_node_id][PRED] = runtime_graph[pred_id][
PRED]
runtime_graph[merged_node_id][PRED] = runtime_graph[
pred_id][PRED]
except:
pass
try:
for i in runtime_graph[pred_id][PRED]:
try:
runtime_graph[i][SUCC].remove(pred_id)
except:
continue
runtime_graph[i][SUCC].append(merged_node_id)
except:
pass
try:
for i in edges[SUCC]:
runtime_graph[i][PRED].remove(node_id)
runtime_graph[i][PRED].append(merged_node_id)
except:
pass
if succ is not None:
for i in succ:
try:
runtime_graph[i][PRED].remove(pred_id)
except:
continue
runtime_graph[i][PRED].append(merged_node_id)
runtime_graph.pop(node_id)
try:
runtime_graph.pop(pred_id)
except:
continue
reduct_cnt += 1
self.eliminate_multi_edges(runtime_graph)
break
return reduct_cnt # the number of nodes that have been reduced
def _merge_branch(self, nodes, runtime_graph, is_bwd=False):
......@@ -496,7 +543,10 @@ class CostModel(object):
succ_to_elim = []
for succ_id in succ_nodes_id:
for succ_2_id in succ_nodes_id:
try:
tmp = runtime_graph[succ_2_id][SUCC]
except:
continue
if succ_id in tmp:
succ_to_elim.append(succ_id)
break
......@@ -506,16 +556,22 @@ class CostModel(object):
reduct_cnt += 1
to_merge = True
if len(edges[SUCC]) < 1 or len(runtime_graph[edges[SUCC][0]][
SUCC]) < 1:
try:
if len(edges[SUCC]) < 1 or len(runtime_graph[edges[SUCC][0]]
[SUCC]) < 1:
continue
except:
continue
end_node_id = runtime_graph[edges[SUCC][0]][SUCC][0]
for i in succ_nodes_id:
try:
if len(runtime_graph[i][SUCC]) != 1 or \
runtime_graph[i][SUCC][0] != end_node_id:
to_merge = False # if branches has different end node, we don't merge them
break
if to_merge:
except:
continue
if to_merge and len(succ_nodes_id) > 1:
to_merge_node_list = [nodes[i] for i in succ_nodes_id]
merged_node_id, merged_node = self._merge_node(
to_merge_node_list, merge_type='branch', nodes=nodes)
......@@ -529,9 +585,13 @@ class CostModel(object):
runtime_graph[end_node_id][PRED] = [merged_node_id]
runtime_graph[node_id][SUCC] = [merged_node_id]
try:
for i in succ_nodes_id:
runtime_graph.pop(i)
reduct_cnt += len(to_merge_node_list) - 1
break
except:
pass
return reduct_cnt
def get_runtime_cost(self):
......@@ -615,7 +675,7 @@ class CostModel(object):
return static_mem, cur_mem, top_mem
def get_pipeline_time(self):
if self.total_rank <= 1:
if self.pp2rank is None:
return self.fwd_time[0] + self.bwd_time[0] + self.optim_time[0]
else:
return self._simulate_pipeline()
......
......@@ -118,11 +118,11 @@ def get_comm_volume(comm_op, src_rank, tgt_rank):
return comm_volume
def analyze_comm_requirements_from_op(op, rank):
def analyze_comm_requirements_from_op(op, rank, g_process_group_map):
comm_requirements_to_ranks = {}
if is_collective_comm_op(op):
process_group_id = op.attr("ring_id")
process_group = get_process_group(process_group_id)
process_group = get_process_group(process_group_id, g_process_group_map)
if rank not in process_group.ranks:
return comm_requirements_to_ranks
for tgt_rank in process_group.ranks:
......@@ -142,7 +142,9 @@ def analyze_comm_requirements_from_op(op, rank):
return comm_requirements_to_ranks
def analyze_requirements_for_program(program, rank):
def analyze_requirements_for_program(src_info, rank):
program = src_info[0]
g_process_group_map = src_info[1]
resource_requirements = {}
comm_requirements_to_ranks = {}
# only support device_type and only support GPU for now
......@@ -150,7 +152,7 @@ def analyze_requirements_for_program(program, rank):
for block in program.blocks:
for op in block.ops:
cur_comm_requirements_to_ranks = analyze_comm_requirements_from_op(
op, rank)
op, rank, g_process_group_map)
for tgt_rank, link_info in cur_comm_requirements_to_ranks.items():
if tgt_rank in comm_requirements_to_ranks:
comm_requirements_to_ranks[tgt_rank][
......@@ -164,9 +166,9 @@ def analyze_requirements_for_program(program, rank):
def build_process_graph(distributed_program):
graph = Graph()
for src_rank, src_program in distributed_program.items():
for src_rank, src_info in distributed_program.items():
resource_requirements, comm_requirements_to_ranks = analyze_requirements_for_program(
src_program, src_rank)
src_info, src_rank)
graph.add_node(src_rank, resource_requirements=resource_requirements)
for tgt_rank, comm_requirements in comm_requirements_to_ranks.items():
graph.add_edge(
......
......@@ -308,6 +308,8 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
assert len(x_dims_mapping) >= len(
y_dims_mapping), "now just support x dims > y dims"
if len(y_dims_mapping) != 2:
return False
if len(x_dims_mapping) == len(y_dims_mapping) and len(
x_dims_mapping) == 4:
if x_dims_mapping[:2] != y_dims_mapping[:2]:
......@@ -602,6 +604,8 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
# for gpt2, x dims > y dims, this is a temporary solution
assert len(x_dims_mapping) >= len(
y_dims_mapping), "now just support x dims > y dims"
if len(y_dims_mapping) != 2:
return False
if len(x_dims_mapping) == len(y_dims_mapping) and len(
x_dims_mapping) == 4:
if x_dims_mapping[:2] != y_dims_mapping[:2]:
......@@ -889,6 +893,8 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
y_dims_mapping
), "now just support x dims > y dims,but x:{0} and y:{1}".format(
x_dims_mapping, y_dims_mapping)
if len(y_dims_mapping) != 2:
return False
if len(x_dims_mapping) == len(y_dims_mapping) and len(
x_dims_mapping) == 4:
if x_dims_mapping[:2] != y_dims_mapping[:2]:
......@@ -1010,6 +1016,8 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
return False
assert len(x_dims_mapping) >= len(
y_dims_mapping), "now just support x dims > y dims"
if len(y_dims_mapping) != 2:
return False
if len(x_dims_mapping) == len(y_dims_mapping) and len(
x_dims_mapping) == 4:
if x_dims_mapping[:2] != y_dims_mapping[:2]:
......@@ -1297,6 +1305,8 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
assert len(x_dims_mapping) >= len(
y_dims_mapping), "now just support x dims > y dims"
if len(y_dims_mapping) != 2:
return False
if len(x_dims_mapping) == len(y_dims_mapping) and len(
x_dims_mapping) == 4:
if x_dims_mapping[:2] != y_dims_mapping[:2]:
......@@ -1583,7 +1593,8 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
y_dims_mapping
), "now just support x dims > y dims,but x:{0} and y:{1}".format(
x_dims_mapping, y_dims_mapping)
if len(y_dims_mapping) != 2:
return False
if len(x_dims_mapping) == len(y_dims_mapping) and len(
x_dims_mapping) == 4:
if x_dims_mapping[:2] != y_dims_mapping[:2]:
......
......@@ -20,6 +20,9 @@ import copy
import pathlib
import subprocess
import logging
import pickle
import time
import paddle
from paddle.distributed.utils import get_logger
from paddle.distributed.fleet import cloud_utils
......@@ -30,13 +33,18 @@ from .dist_context import set_default_distributed_context
from .completion import complete_annotation, complete_backward_annotation
from .partitioner import Partitioner
from .process_group import get_all_process_groups
from .process_group import get_process_group
from .process_group import get_world_process_groups
from .process_group import _g_process_group_map, ProcessGroup
from .utils import make_data_unshard
from .utils import set_grad_var_shape
from .reshard import reshard
from .utils import SerialProgramInfo
from .reshard import reshard, HAS_SENT, HAS_RECV, HAS_ALLGATHER
from .cluster import Cluster
from .mapper import mapping
# from .auto_search import auto_search
from .dist_op import DistributedOperator
from .dist_tensor import DistributedTensor
from .planner import Planner
_logger = get_logger(logging.INFO)
......@@ -82,12 +90,20 @@ class AutoParallelizer:
if suffix in attr_name:
op._remove_attr(attr_name)
def _get_dist_program(self, dist_context, rank):
def _get_dist_program(self, rank, dist_context=None, relaunch_phase=False):
completed_main_program = None
if dist_context is None:
# Annotation completion
self._dist_context = DistributedContext()
_logger.info("Start annotation dist attr.")
completed_main_program = complete_annotation(self._main_program,
dist_context)
self._dist_context)
else:
completed_main_program = self._main_program
self._dist_context = copy.deepcopy(dist_context)
# Logical partition
partitioner = Partitioner(self._dist_strategy, dist_context, rank)
partitioner = Partitioner(self._dist_strategy, self._dist_context, rank)
dist_main_prog, dist_startup_prog = partitioner.transpile_forward(
completed_main_program, self._startup_program)
dist_params_grads = partitioner.apply_backward(
......@@ -97,11 +113,21 @@ class AutoParallelizer:
copy.deepcopy(self._optimizer), dist_params_grads, dist_main_prog,
dist_startup_prog)
make_data_unshard(dist_main_prog, dist_startup_prog, dist_context)
set_grad_var_shape(dist_main_prog, self._dist_context)
reshard(dist_main_prog, dist_startup_prog, rank, dist_context)
make_data_unshard(dist_main_prog, dist_startup_prog, self._dist_context)
return dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog
reshard(dist_main_prog, dist_startup_prog, rank, self._dist_context)
g_process_group_map = None
if not relaunch_phase:
g_process_group_map = copy.deepcopy(_g_process_group_map)
HAS_SENT.clear()
HAS_RECV.clear()
HAS_ALLGATHER.clear()
_g_process_group_map.clear()
_g_process_group_map[0] = ProcessGroup(0, [])
return dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog, g_process_group_map
def parallelize(self,
loss,
......@@ -121,11 +147,51 @@ class AutoParallelizer:
"The cluster must not be none when using auto mapping."
dist_programs = {}
world_process_group = get_world_process_groups()
dist_context = None
# auto search
if self._dist_strategy.auto_search:
logging.info("Start searching dist attr.")
serial_program_info = SerialProgramInfo(
self._main_program, self._startup_program, self._loss,
self._optimizer, self._cluster)
planner = Planner(
serial_program_info,
algorithm_config={"name": "mcmc",
"max_search_times": 5})
dist_context, _ = planner.search()
logging.info("End searching dist attr.")
# serialize the dist context by planner
if dist_context is not None:
logging.info("Start serialize searched dist attr")
cwd = pathlib.Path().resolve()
searched_dist_context_path = os.path.join(
cwd, f"searched_dist_context_{time.time()}.pkl")
saved_dist_context = {}
ops_dist_attr = {}
tensors_dist_attr = {}
for key, dist_op in dist_context._dist_ops_for_program.items():
ops_dist_attr[key] = dist_op.dist_attr
for key, dist_tensor in dist_context._dist_tensors_for_program.items(
):
tensors_dist_attr[key] = dist_tensor.dist_attr
saved_dist_context["ops_dist_attr"] = ops_dist_attr
saved_dist_context["tensors_dist_attr"] = tensors_dist_attr
saved_dist_context[
"process_meshes"] = dist_context._process_meshes
with open(searched_dist_context_path,
"wb") as dist_context_file:
pickle.dump(saved_dist_context, dist_context_file)
os.environ[
'PADDLE_SEARCHED_DIST_CONTEXT_PATH'] = searched_dist_context_path
logging.info(
f"End serialize searched dist attr to {searched_dist_context_path}"
)
for rank in world_process_group.ranks:
dist_context = DistributedContext()
dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog = self._get_dist_program(
dist_context, rank)
dist_programs[rank] = dist_main_prog
dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog, g_process_group_map = self._get_dist_program(
rank, dist_context)
dist_programs[rank] = [dist_main_prog, g_process_group_map]
# Do the mapping between the distributed program graph and the cluster graph
rank_mapping_dict = mapping(dist_programs, self._cluster)
......@@ -162,9 +228,64 @@ class AutoParallelizer:
else:
# Parallelization after the mapping pass
rank = paddle.distributed.get_rank()
dist_context = None
searched_dist_context_path = os.getenv(
"PADDLE_SEARCHED_DIST_CONTEXT_PATH", None)
if searched_dist_context_path is not None:
with open(searched_dist_context_path,
"rb") as dist_context_file:
saved_dist_context = pickle.load(dist_context_file)
dist_context = DistributedContext()
for op in self._main_program.global_block().ops:
dist_attr = saved_dist_context["ops_dist_attr"][
op.desc.id()]
dist_op = DistributedOperator(op, dist_attr)
dist_context.add_dist_op_for_program(dist_op)
vars = self._main_program.global_block().vars
for var in vars.values():
dist_attr = saved_dist_context["tensors_dist_attr"][
var.desc.id()]
dist_tensor = DistributedTensor(var, dist_attr)
dist_context.add_dist_tensor_for_program(dist_tensor)
dist_context._process_meshes = saved_dist_context[
"process_meshes"]
else:
if self._dist_strategy.auto_search:
serial_program_info = SerialProgramInfo(
self._main_program,
self._startup_program,
self._loss,
self._optimizer,
cluster=self._cluster)
planner = Planner(
serial_program_info,
algorithm_config={
"name": "mcmc",
"max_search_times": 5
})
dist_context, _ = planner.search()
# rebuild g_process_group
if dist_context is not None:
pg0 = get_process_group(0)
for process_mesh in dist_context._process_meshes:
pg0.add_ranks(process_mesh.processes)
dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog, _ = self._get_dist_program(
rank, dist_context, relaunch_phase=True)
dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog = self._get_dist_program(
self._dist_context, rank)
# NOTE: This is a trick to fix hang in pipeline mode when dist context is searched by planner
if self._dist_strategy.auto_search:
is_pipeline = False
for op in dist_main_prog.global_block().ops:
if op.type == "send_v2" or op.type == "recv_v2":
is_pipeline = True
break
if is_pipeline:
with paddle.static.program_guard(dist_main_prog):
paddle.distributed.barrier()
# Traverse different rank programs and traverse each op of them,
# instantiate communication by process_mapping.
......
......@@ -32,6 +32,7 @@ from .completion import is_elementwise_like_op
from .operators.common import get_distributed_operator_impl_container
from .utils import update_op_dims_mapping_by_default_dist_impl
from .utils import update_op_dims_mapping_by_elementwise_like_dist_impl
from .utils import get_all_distributed_main_program
from .dist_context import DistributedContext, DistributedOperatorContext
from .dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute
......@@ -370,3 +371,499 @@ class PlanSpace:
)] = [op_valid_dist_attrs, pipeline_stage]
return valid_dist_attr_dict, pipeline_process_meshes, global_process_mesh
class SearchAlgorithm:
def __init__(self, name):
self._name = name
@property
def name(self):
self.name = name
def search(self):
raise NotImplementedError("Please Implement this method in subclass.")
class MCMC(SearchAlgorithm):
def __init__(self, serial_program_info, max_search_times=5):
super(MCMC, self).__init__("mcmc")
self._serial_program_info = serial_program_info
self._max_search_times = max_search_times
@property
def serial_program_info(self):
return self._serial_program_info
@property
def max_search_times(self):
return self._max_search_times
def make_special_op_unshard(self, op, ops, vars, dist_context,
valid_dist_attr_dict):
if op.type == "softmax_with_cross_entropy":
for var_name in op.input_arg_names:
dims_mapping = dist_context.get_op_dist_attr_for_program(
op).get_input_dims_mapping(var_name)
if dims_mapping != dist_context.get_tensor_dist_attr_for_program(
vars[var_name]).dims_mapping:
has_changed = False
for search_op in ops:
if var_name in search_op.output_arg_names:
op_dist_attr_list = valid_dist_attr_dict[
search_op.desc.id()][0]
for op_dist_attr in op_dist_attr_list:
if op_dist_attr.get_output_dims_mapping(
var_name) == dims_mapping:
dist_context.set_op_dist_attr_for_program(
search_op, op_dist_attr)
tensor_dist_attr = TensorDistributedAttribute(
)
tensor_dist_attr.process_mesh = op_dist_attr.process_mesh
tensor_dist_attr.dims_mapping = op_dist_attr.get_output_dims_mapping(
var_name)
dist_context.set_tensor_dist_attr_for_program(
vars[var_name], tensor_dist_attr)
has_changed = True
break
if has_changed:
break
if not has_changed:
raise ValueError(
"Change softmax_with_cross_entropy dist attr failed")
def init_program(self, valid_dist_attr_dict, program,
pipeline_process_meshes, global_process_mesh):
ops = program.global_block().ops
vars = program.global_block().vars
new_dist_context = DistributedContext()
for op in ops:
op_valid_dist_attr_list = valid_dist_attr_dict[op.desc.id()][0]
random_op_dist_attr = np.random.randint(
len(op_valid_dist_attr_list))
init_op_dist_attr = op_valid_dist_attr_list[random_op_dist_attr]
new_dist_context.set_op_dist_attr_for_program(op, init_op_dist_attr)
for var_name in op.input_arg_names:
if var_name == "lod_tensor_blocking_queue_0":
continue
if new_dist_context.get_tensor_dist_attr_for_program(vars[
var_name]) is None:
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.process_mesh = init_op_dist_attr.process_mesh
tensor_dist_attr.dims_mapping = init_op_dist_attr.get_input_dims_mapping(
var_name)
new_dist_context.set_tensor_dist_attr_for_program(
vars[var_name], tensor_dist_attr)
for var_name in op.output_arg_names:
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.process_mesh = init_op_dist_attr.process_mesh
tensor_dist_attr.dims_mapping = init_op_dist_attr.get_output_dims_mapping(
var_name)
new_dist_context.set_tensor_dist_attr_for_program(
vars[var_name], tensor_dist_attr)
# NOTE: this is a temporary solution to make softmax_with_cross_entropy unshard
self.make_special_op_unshard(op, ops, vars, new_dist_context,
valid_dist_attr_dict)
# add process meshes to distributed context
if global_process_mesh is not None:
new_dist_context.add_process_mesh(global_process_mesh)
elif pipeline_process_meshes is not None:
for process_mesh in pipeline_process_meshes:
new_dist_context.add_process_mesh(process_mesh)
return new_dist_context
def estimate_searched_strategy_cost(self,
dist_context,
pipeline_process_meshes=None):
cost = None
# get all distributed programs
all_dist_main_program = get_all_distributed_main_program(
self.serial_program_info, dist_context)
pipeline_config = [
process_mesh.processes for process_mesh in pipeline_process_meshes
] if pipeline_process_meshes is not None else None
microbatch_size = 1
for program in all_dist_main_program:
searched_batch_size = False
for var in program.list_vars():
if var.is_data and "@RESHARD" in var.name:
microbatch_size = var.shape[0]
searched_batch_size = True
break
if searched_batch_size:
break
from .utils import get_standalone_cost_data
standalone_cost_data = get_standalone_cost_data(all_dist_main_program)
# cost model does not support cluster argument
cost = estimate_cost(
all_dist_main_program,
cluster=None,
pipeline_config=pipeline_config,
standalone_cost_data=standalone_cost_data,
batch_size=microbatch_size)
return cost
def set_tensor_dist_attr(self, op, op_dist_attr, vars, dist_context):
# set output tensor distributed attribute
for var_name in op.output_arg_names:
process_mesh = op_dist_attr.process_mesh
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.process_mesh = process_mesh
tensor_dist_attr.dims_mapping = op_dist_attr.get_output_dims_mapping(
var_name)
dist_context.set_tensor_dist_attr_for_program(vars[var_name],
tensor_dist_attr)
# set input tensor distributed attribute if input is data or parameter
for var_name in op.input_arg_names:
if vars[var_name].is_parameter or vars[var_name].is_data:
process_mesh = op_dist_attr.process_mesh
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.process_mesh = process_mesh
tensor_dist_attr.dims_mapping = op_dist_attr.get_input_dims_mapping(
var_name)
dist_context.set_tensor_dist_attr_for_program(vars[var_name],
tensor_dist_attr)
def change_process_mesh(self, op, changed_process_mesh, vars, dist_context):
dist_context.get_op_dist_attr_for_program(
op).process_mesh = changed_process_mesh
for var_name in op.output_arg_names:
dist_context.get_tensor_dist_attr_for_program(vars[
var_name]).process_mesh = changed_process_mesh
for var_name in op.input_arg_names:
if vars[var_name].is_parameter or vars[var_name].is_data:
dist_context.get_tensor_dist_attr_for_program(vars[
var_name]).process_mesh = changed_process_mesh
def search_once(self,
program,
valid_dist_attr_dict,
dist_context,
pipeline_process_meshes=None):
raw_ops = program.global_block().ops
ops = []
for op in raw_ops:
if op.type not in PlanSpace.not_enum_ops:
ops.append(op)
assert ops, "The ops of program have no distributed attributes."
vars = program.global_block().vars
new_dist_context = copy.deepcopy(dist_context)
new_dist_context._dist_op_context = DistributedOperatorContext()
new_valid_dist_attr_dict = None
random_selected_op_idx = np.random.randint(len(ops))
selected_op = ops[random_selected_op_idx]
op_valid_dist_attr_list = valid_dist_attr_dict[selected_op.desc.id()][0]
pipeline_stage = valid_dist_attr_dict[selected_op.desc.id()][1]
random_selected_dist_attr_idx = np.random.randint(
len(op_valid_dist_attr_list))
selected_op_dist_attr = copy.deepcopy(op_valid_dist_attr_list[
random_selected_dist_attr_idx])
start_idx = ops[0].desc.id()
if pipeline_stage > -1:
# in pipeline mode, the above phase just select a dims mapping
# 0 represents not changed, 1 represents to be the same with before stage, 2 represents to be the same with the latter stage
new_valid_dist_attr_dict = copy.deepcopy(valid_dist_attr_dict)
changed_mode = np.random.randint(3)
if changed_mode == 0:
# not change the process mesh, just change dims mapping
new_dist_context.set_op_dist_attr_for_program(
selected_op, selected_op_dist_attr)
self.set_tensor_dist_attr(selected_op, selected_op_dist_attr,
vars, new_dist_context)
elif changed_mode == 1:
changed_stage = pipeline_stage - 1
if changed_stage == -1 or random_selected_op_idx == len(ops) - 1 or \
(random_selected_op_idx + 1 == len(ops) - 1 and new_valid_dist_attr_dict[ops[random_selected_op_idx + 1].desc.id()][1] == pipeline_stage + 1 ):
new_dist_context.set_op_dist_attr_for_program(
selected_op, selected_op_dist_attr)
self.set_tensor_dist_attr(selected_op,
selected_op_dist_attr, vars,
new_dist_context)
else:
selected_op_process_mesh = pipeline_process_meshes[
pipeline_stage]
next_op_id = ops[random_selected_op_idx + 1].desc.id()
if new_valid_dist_attr_dict[next_op_id][
1] == pipeline_stage + 1 and random_selected_op_idx + 1 != len(
ops) - 1:
new_valid_dist_attr_dict[next_op_id][1] = pipeline_stage
for op_dist_attr in new_valid_dist_attr_dict[
next_op_id][0]:
op_dist_attr.process_mesh = selected_op_process_mesh
# set next op dist attr in the discontext and output/input tensor process mesh
self.change_process_mesh(
ops[random_selected_op_idx + 1],
selected_op_process_mesh, vars, new_dist_context)
# change the selected op stage and output dist attr
new_valid_dist_attr_dict[selected_op.desc.id()][
1] = changed_stage
new_process_mesh = pipeline_process_meshes[changed_stage]
selected_op_dist_attr.process_mesh = new_process_mesh
for op_dist_attr in new_valid_dist_attr_dict[
selected_op.desc.id()][0]:
op_dist_attr.process_mesh = new_process_mesh
new_dist_context.set_op_dist_attr_for_program(
selected_op, selected_op_dist_attr)
self.set_tensor_dist_attr(selected_op,
selected_op_dist_attr, vars,
new_dist_context)
# change the pre op stage
for idx in range(random_selected_op_idx - 1, -1, -1):
stage = new_valid_dist_attr_dict[ops[idx].desc.id()][1]
valid_dist_attr_list = new_valid_dist_attr_dict[ops[
idx].desc.id()][0]
new_process_mesh = pipeline_process_meshes[
changed_stage]
if stage == changed_stage + 1:
new_valid_dist_attr_dict[ops[idx].desc.id()][
1] = changed_stage
for op_dist_attr in valid_dist_attr_list:
op_dist_attr.process_mesh = new_process_mesh
new_dist_context.get_op_dist_attr_for_program(ops[
idx]).process_mesh = new_process_mesh
# change process mesh of the output and input tensor
self.change_process_mesh(ops[idx], new_process_mesh,
vars, new_dist_context)
else:
break
else:
changed_stage = pipeline_stage + 1
if changed_stage == len(
pipeline_process_meshes) or random_selected_op_idx == 0 or \
(new_valid_dist_attr_dict[ops[random_selected_op_idx - 1].desc.id()][1] == pipeline_stage - 1 and (random_selected_op_idx == 1)):
new_dist_context.set_op_dist_attr_for_program(
selected_op, selected_op_dist_attr)
self.set_tensor_dist_attr(selected_op,
selected_op_dist_attr, vars,
new_dist_context)
else:
selected_op_process_mesh = pipeline_process_meshes[
pipeline_stage]
pre_op_id = ops[random_selected_op_idx - 1].desc.id()
if new_valid_dist_attr_dict[pre_op_id][
1] == pipeline_stage - 1 and random_selected_op_idx != 1:
new_valid_dist_attr_dict[pre_op_id][1] = pipeline_stage
for op_dist_attr in new_valid_dist_attr_dict[pre_op_id][
0]:
op_dist_attr.process_mesh = selected_op_process_mesh
# set pre op dist attr in the discontext and output tensor process mesh
self.change_process_mesh(
ops[random_selected_op_idx - 1],
selected_op_process_mesh, vars, new_dist_context)
# change the selected op stage and output tensor dist attr
new_valid_dist_attr_dict[selected_op.desc.id()][
1] = changed_stage
new_process_mesh = pipeline_process_meshes[changed_stage]
selected_op_dist_attr.process_mesh = new_process_mesh
for op_dist_attr in new_valid_dist_attr_dict[
selected_op.desc.id()][0]:
op_dist_attr.process_mesh = new_process_mesh
new_dist_context.set_op_dist_attr_for_program(
selected_op, selected_op_dist_attr)
self.set_tensor_dist_attr(selected_op,
selected_op_dist_attr, vars,
new_dist_context)
# change the next op stage
for idx in range(random_selected_op_idx + 1, len(ops)):
stage = new_valid_dist_attr_dict[ops[idx].desc.id()][1]
valid_dist_attr_list = new_valid_dist_attr_dict[ops[
idx].desc.id()][0]
new_process_mesh = pipeline_process_meshes[
changed_stage]
if stage == changed_stage - 1:
new_valid_dist_attr_dict[ops[idx].desc.id()][
1] = changed_stage
for op_dist_attr in valid_dist_attr_list:
op_dist_attr.process_mesh = new_process_mesh
new_dist_context.get_op_dist_attr_for_program(ops[
idx]).process_mesh = new_process_mesh
# change the output tensor dist attr
self.change_process_mesh(ops[idx], new_process_mesh,
vars, new_dist_context)
else:
break
else:
new_dist_context.set_op_dist_attr_for_program(selected_op,
selected_op_dist_attr)
self.set_tensor_dist_attr(selected_op, selected_op_dist_attr, vars,
new_dist_context)
for op in ops:
# make softmax_with_cross_entropy unshard
if op.type == "softmax_with_cross_entropy":
self.make_special_op_unshard(op, ops, vars, new_dist_context,
valid_dist_attr_dict)
break
if new_valid_dist_attr_dict is None:
return valid_dist_attr_dict, new_dist_context
else:
return new_valid_dist_attr_dict, new_dist_context
def _search_core(self,
valid_dist_attr_dict,
init_dist_context,
pipeline_process_meshes=None):
times = 0
best_dist_context = init_dist_context
cost = self.estimate_searched_strategy_cost(
init_dist_context, pipeline_process_meshes).runtime
min_cost = cost
while times < self.max_search_times:
times += 1
new_dist_context = self.search_once(
self.serial_program_info.train_program, valid_dist_attr_dict,
best_dist_context, pipeline_process_meshes)[1]
cur_cost = self.estimate_searched_strategy_cost(
new_dist_context, pipeline_process_meshes).runtime
if (min_cost - cur_cost) > 0:
best_dist_context = copy.deepcopy(new_dist_context)
min_cost = cur_cost
times = 0
return best_dist_context, min_cost
def search(self):
logging.info("Start MCMC searching.")
start_time = time.time()
train_program = self.serial_program_info.train_program
cluster = self.serial_program_info.cluster
processes = paddle.distributed.get_world_size(
) if cluster is None else len(cluster.get_all_devices("GPU"))
assert processes > 0, "Get process failed."
process_mesh_topology_list = PlanSpace.enum_process_mesh_topology(
processes)
searched_dist_context = None
min_cost = None
searched_pipeline_dist_context = None
pipeline_min_cost = None
for process_mesh_topology in process_mesh_topology_list:
logging.info(
"MCMC search: search process mesh {} with pipeline mode.".
format(process_mesh_topology))
valid_dist_attr_dict, pipeline_process_meshes, global_process_mesh = PlanSpace.enum_valid_dist_attr_for_program(
train_program, process_mesh_topology, True)
init_dist_context = self.init_program(
valid_dist_attr_dict, train_program, pipeline_process_meshes,
global_process_mesh)
best_dist_context, cost = self._search_core(valid_dist_attr_dict,
init_dist_context,
pipeline_process_meshes)
logging.info(
"MCMC search: the min cost is {} in the process mesh {} with pipeline mode.".
format(cost, process_mesh_topology))
best_dist_context._dist_op_context = DistributedOperatorContext()
pipeline_min_cost = cost if pipeline_min_cost is None else pipeline_min_cost
searched_pipeline_dist_context = best_dist_context if searched_pipeline_dist_context is None else searched_pipeline_dist_context
if pipeline_min_cost > cost:
searched_pipeline_dist_context = best_dist_context
pipeline_min_cost = cost
searched_non_pipeline_dist_context = None
non_pipeline_min_cost = None
for process_mesh_topology in process_mesh_topology_list:
# if process_mesh_topology shape is 3, include pipeline mode by default
if len(process_mesh_topology) == 3:
continue
logging.info(
"MCMC search: search process mesh {} without pipeline mode.".
format(process_mesh_topology))
valid_dist_attr_dict, pipeline_process_meshes, global_process_mesh = PlanSpace.enum_valid_dist_attr_for_program(
train_program, process_mesh_topology, False)
init_dist_context = self.init_program(
valid_dist_attr_dict, train_program, pipeline_process_meshes,
global_process_mesh)
best_dist_context, cost = self._search_core(valid_dist_attr_dict,
init_dist_context,
pipeline_process_meshes)
logging.info(
"MCMC search: the min cost is {} in the process mesh {} without pipeline mode.".
format(cost, process_mesh_topology))
best_dist_context._dist_op_context = DistributedOperatorContext()
non_pipeline_min_cost = cost if non_pipeline_min_cost is None else non_pipeline_min_cost
searched_non_pipeline_dist_context = best_dist_context if searched_non_pipeline_dist_context is None else searched_non_pipeline_dist_context
if non_pipeline_min_cost > cost:
searched_non_pipeline_dist_context = best_dist_context
non_pipeline_min_cost = cost
if non_pipeline_min_cost > pipeline_min_cost:
searched_dist_context = searched_pipeline_dist_context
min_cost = pipeline_min_cost
logging.info(
"Better set FLAGS_benchmark=1 to avoid hang problem in the pipeline mode."
)
else:
searched_dist_context = searched_non_pipeline_dist_context
min_cost = non_pipeline_min_cost
# rebuild g_process_group
pg0 = get_process_group(0)
for process_mesh in searched_dist_context._process_meshes:
pg0.add_ranks(process_mesh.processes)
end_time = time.time()
logging.info(
"End MCMC searching: the min cost is {} and the search time is {}s.".
format(min_cost, end_time - start_time))
return searched_dist_context, min_cost
class Planner:
def __init__(self, serial_program_info, algorithm_config=None):
self._serial_program_info = serial_program_info
self._algorithm_config = algorithm_config
self._algorithm_searcher = self.create_algorithm_searcher(
algorithm_config)
@property
def serial_program_info(self):
return self._serial_program_info
@property
def algorithm_config(self):
return self._algorithm_config
@property
def algorithm_searcher(self):
return self._algorithm_searcher
def create_algorithm_searcher(self, algorithm_config):
name = algorithm_config.get("name", None)
assert name is not None, "Invalid algorithm config."
algorithm_searcher = None
if name == "mcmc":
# NOTE: Only GPU clusters are supported now.
max_search_times = algorithm_config.get("max_search_times", None)
algorithm_searcher = MCMC(
self.serial_program_info,
max_search_times) if max_search_times is not None else MCMC(
self.serial_program_info)
else:
raise NotImplementedError(
"Other search algorithms have not been supported now.")
return algorithm_searcher
def search(self):
return self.algorithm_searcher.search()
......@@ -25,9 +25,12 @@ def get_all_process_groups():
return _g_process_group_map.values()
def get_process_group(group_id):
def get_process_group(group_id, g_process_group_map=None):
global _g_process_group_map
return _g_process_group_map.get(group_id, None)
return _g_process_group_map.get(
group_id,
None) if g_process_group_map is None else g_process_group_map.get(
group_id, None)
def get_world_process_groups():
......
......@@ -19,9 +19,11 @@ import threading
import numpy as np
import warnings
import logging
from functools import reduce
import paddle.fluid.core as core
from paddle.framework.io import _to_LodTensor
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.fluid.io import is_parameter, is_belong_to_optimizer
......@@ -1258,3 +1260,95 @@ class SerialProgramInfo:
@property
def cluster(self):
return self._cluster
def get_standalone_cost_data(distributed_programs):
def _compute_runtime(op_cost, op, vars):
runtime = 0
try:
runtime = float(op_cost["op_time"])
except:
return runtime
op_config = op_cost["config"]
total_static_input_size = 0
total_actual_input_size = 0
parsed_info = op_config.split("\n")
variable = "(Variable)"
for info in parsed_info:
variable = "(Variable)" if "(Variable)" in info else "(list<Variable>"
if variable in info:
arg_name_lower = info[:info.find(variable) - 1]
shape_left_boundary = info.find("[")
shape_right_boundary = info.find("]")
assert shape_left_boundary > 0 and shape_right_boundary > 0 and shape_right_boundary > shape_left_boundary, "Get shape failed."
shape = info[shape_left_boundary + 1:
shape_right_boundary].split(",")
shape = list(map(lambda x: int(x.strip()), shape))
dtype_factor = 1
total_static_input_size += reduce(lambda x, y: x * y, shape)
# print(arg_name_lower)
if op.type == "c_embedding":
arg_name_lower = "w" if arg_name_lower == "weight" else "ids"
for arg_name in op.input_names:
if arg_name.lower() == arg_name_lower:
for var_name in op.input(arg_name):
var = vars[var_name]
total_actual_input_size += reduce(
lambda x, y: x * y, var.shape)
break
assert total_static_input_size > 0 and total_actual_input_size > 0, "Get input size failed."
actual_runtime = total_actual_input_size / total_static_input_size * runtime
return actual_runtime
cost_model = paddle.cost_model.CostModel()
cost_model.static_cost_data()
DEFAULT_MULTIPLE = 2
OP_NAME_MAPPING = {
"c_embedding": "embedding",
"matmul_v2": "matmul",
"transpose2": "transpose",
"reshape2": "reshape",
"unsqueeze2": "unsqueeze",
"reduce_sum": "sum",
"elementwise_div": "divide"
}
standalone_cost_data = []
not_enum_ops = ["create_py_reader", "create_double_buffer_reader", "read"]
for distributed_program in distributed_programs:
cost_data = {}
vars = distributed_program.global_block().vars
for op in distributed_program.global_block().ops:
runtime = 0
if op.type in not_enum_ops:
cost_data[op.desc.id()] = runtime
continue
dtype = str(vars[op.input_arg_names[0]]
.dtype) if op.input_arg_names else "float32"
if int(op.attr('op_role')) == int(OpRole.Backward):
if "_grad" in op.type:
forward_op_name = op.type[:-5]
if forward_op_name in OP_NAME_MAPPING.keys():
forward_op_name = OP_NAME_MAPPING[forward_op_name]
op_cost = cost_model.get_static_op_time(
forward_op_name, forward=False, dtype=dtype)
if op_cost:
runtime = _compute_runtime(op_cost, op, vars)
else:
op_cost = cost_model.get_static_op_time(
forward_op_name, dtype=dtype)
if op_cost:
runtime = 2 * _compute_runtime(op_cost, op, vars)
elif int(op.attr('op_role')) == int(OpRole.Forward):
op_name = OP_NAME_MAPPING[
op.type] if op.type in OP_NAME_MAPPING.keys() else op.type
op_cost = cost_model.get_static_op_time(op_name)
if op_cost:
runtime = _compute_runtime(op_cost, op, vars)
cost_data[op.desc.id()] = runtime
standalone_cost_data.append(cost_data)
return standalone_cost_data
......@@ -1436,7 +1436,7 @@ class Fleet(object):
context["role_maker"] = self._role_maker
# Use the auto-parallel's routines instead
if self._user_defined_strategy.semi_auto:
if self._user_defined_strategy.semi_auto or self._user_defined_strategy.auto_search:
from ...auto_parallel.parallelizer import AutoParallelizer
auto_parallelizer = AutoParallelizer(self)
optimize_ops, params_grads, dist_startup_prog, dist_main_prog = auto_parallelizer.parallelize(
......@@ -1586,13 +1586,13 @@ class Fleet(object):
]
param_grads_fp16 = [
param._grad_ivar() for param in optimizer._parameter_list
if (param._grad_ivar() is not None) and
(param._grad_ivar().dtype == core.VarDesc.VarType.FP16)
if (param._grad_ivar() is not None) and (param._grad_ivar(
).dtype == core.VarDesc.VarType.FP16)
]
param_grads_fp32 = [
param._grad_ivar() for param in optimizer._parameter_list
if (param._grad_ivar() is not None) and
(param._grad_ivar().dtype == core.VarDesc.VarType.FP32)
if (param._grad_ivar() is not None) and (param._grad_ivar(
).dtype == core.VarDesc.VarType.FP32)
]
temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool))
temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool))
......
......@@ -3,4 +3,6 @@
if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_auto_parallel_relaunch MODULES test_auto_parallel_relaunch ENVS ${dist_ENVS})
set_tests_properties(test_auto_parallel_relaunch PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120)
py_test_modules(test_relaunch_with_planner MODULES test_relaunch_with_planner ENVS ${dist_ENVS})
set_tests_properties(test_relaunch_with_planner PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120)
endif()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.static as static
from paddle.distributed import fleet
def train():
from auto_parallel_relaunch_model import mlp_pretrain_forward
from auto_parallel_relaunch_model import batch_generator_creator
dist_strategy = fleet.DistributedStrategy()
# init parallel optimizer
dist_strategy.auto_search = True
fleet.init(is_collective=True, strategy=dist_strategy)
train_program = static.Program()
start_program = static.Program()
loss, train_program, start_program, loader = mlp_pretrain_forward(
train_program, start_program)
optimizer = paddle.fluid.optimizer.AdamOptimizer(
learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
optimizer = fleet.distributed_optimizer(optimizer)
_, _, distributed_startup_program, distributed_main_program = optimizer.minimize(
loss, start_program)
places = static.cuda_places()
loader.set_batch_generator(batch_generator_creator(), places=places)
exe = paddle.static.Executor(places[0])
exe.run(distributed_startup_program)
for data in loader():
exe.run(distributed_main_program, feed=data)
if __name__ == "__main__":
train()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import os
import sys
import json
import shutil
import subprocess
from paddle.distributed.fleet.launch_utils import run_with_coverage
class TestPlannerReLaunch(unittest.TestCase):
def test_relaunch_with_planner(self):
from test_auto_parallel_relaunch import cluster_json
file_dir = os.path.dirname(os.path.abspath(__file__))
cluster_json_path = os.path.join(file_dir, "auto_parallel_cluster.json")
cluster_json_object = json.loads(cluster_json)
with open(cluster_json_path, "w") as cluster_json_file:
json.dump(cluster_json_object, cluster_json_file)
launch_model_path = os.path.join(
file_dir, "auto_parallel_relaunch_with_planner.py")
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
else:
coverage_args = []
cmd = [sys.executable, "-u"] + coverage_args + [
"-m", "launch", "--cluster_topo_path", cluster_json_path,
"--enable_auto_mapping", "True", launch_model_path
]
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
# Remove unnecessary files
if os.path.exists(cluster_json_path):
os.remove(cluster_json_path)
rank_mapping_json_path = os.path.join(file_dir,
"auto_parallel_rank_mapping.json")
if os.path.exists(rank_mapping_json_path):
os.remove(rank_mapping_json_path)
log_path = os.path.join(file_dir, "log")
if os.path.exists(log_path):
shutil.rmtree(log_path)
if __name__ == "__main__":
unittest.main()
......@@ -187,10 +187,10 @@ def check_empty_program_runtime(cost):
def check_empty_program_memory(cost):
for mem in cost.peak_mem:
if mem > 0:
if mem > 1:
return False
for mem in cost.static_mem:
if mem > 0:
if mem > 1:
return False
return True
......
......@@ -529,7 +529,7 @@ class TestAutoParallelMapper(unittest.TestCase):
train_program, startup_program, dist_context, rank_id)
# if rank_id == 0:
# print_program_with_dist_attr(dist_train_program, dist_context)
dist_programs[rank_id] = dist_train_program
dist_programs[rank_id] = [dist_train_program, None]
rank_mapping = mapping(dist_programs, cluster)
......
......@@ -174,7 +174,7 @@ class Testcompatible(unittest.TestCase):
op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1, -1])
op_dist_attr.set_input_dims_mapping(Y, [-1, -1, -1, -1])
op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1, -1])
self.assertTrue(impls[2].is_auto_compatible(
self.assertFalse(impls[2].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(Y, [0, -1, -1, -1])
self.assertFalse(impls[2].is_auto_compatible(
......@@ -261,7 +261,7 @@ class Testcompatible(unittest.TestCase):
op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1, 1])
op_dist_attr.set_input_dims_mapping(Y, [-1, -1, 1, -1])
op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1, -1])
self.assertTrue(impls[1].is_auto_compatible(
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(Y, [0, -1, -1, -1])
self.assertFalse(impls[1].is_auto_compatible(
......@@ -362,7 +362,7 @@ class Testcompatible(unittest.TestCase):
op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1, -1])
op_dist_attr.set_input_dims_mapping(Y, [-1, -1, -1, 1])
op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1, 1])
self.assertTrue(impls[0].is_auto_compatible(
self.assertFalse(impls[0].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(out, [0, -1, -1, 1])
self.assertFalse(impls[0].is_auto_compatible(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册