From f9f32b4114088781ac4b8aa69de92748cb447da4 Mon Sep 17 00:00:00 2001 From: caozhou <48191911+Caozhou1995@users.noreply.github.com> Date: Tue, 14 Dec 2021 19:43:14 +0800 Subject: [PATCH] add PlanFilter, PlanSpace of auto planner (#37858) * update Planner * update unitest * update PlanSpace * update PlanSpace * modify set_grad_var_shape * update code style --- .../distributed/auto_parallel/cluster.py | 8 + .../distributed/auto_parallel/planner.py | 372 ++++++++++++++++++ .../paddle/distributed/auto_parallel/utils.py | 90 ++++- .../unittests/test_auto_parallel_cluster.py | 4 + .../unittests/test_auto_parallel_searcher.py | 89 ++++- 5 files changed, 553 insertions(+), 10 deletions(-) create mode 100644 python/paddle/distributed/auto_parallel/planner.py diff --git a/python/paddle/distributed/auto_parallel/cluster.py b/python/paddle/distributed/auto_parallel/cluster.py index d65612fc6e9..ef05ff7c746 100644 --- a/python/paddle/distributed/auto_parallel/cluster.py +++ b/python/paddle/distributed/auto_parallel/cluster.py @@ -351,6 +351,14 @@ class Cluster: self._num_machines += 1 return cur_machine_id + def get_all_devices(self, device_type): + devices = [] + for machine in self.machines.values(): + for device in machine.devices.values(): + if device.type == DeviceType[device_type]: + devices.append(device) + return devices + def __str__(self): str = "" for machine in self.machines.values(): diff --git a/python/paddle/distributed/auto_parallel/planner.py b/python/paddle/distributed/auto_parallel/planner.py new file mode 100644 index 00000000000..8d1ecf9deaa --- /dev/null +++ b/python/paddle/distributed/auto_parallel/planner.py @@ -0,0 +1,372 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import time +import random +import logging +from functools import reduce +from itertools import chain, product +from collections import OrderedDict + +import numpy as np + +import paddle +import paddle.distributed.auto_parallel as auto +from .cost_model import estimate_cost +from .dist_op import DistributedOperator +from .process_group import _g_process_group_map +from .process_group import ProcessGroup, get_process_group +from .completion import is_elementwise_like_op +from .operators.common import get_distributed_operator_impl_container +from .utils import update_op_dims_mapping_by_default_dist_impl +from .utils import update_op_dims_mapping_by_elementwise_like_dist_impl +from .dist_context import DistributedContext, DistributedOperatorContext +from .dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute + +paddle.enable_static() +paddle.seed(123) +random.seed(123) +np.random.seed(123) + + +class PlanFilter: + @staticmethod + def check_dims_mapping_for_tensor(process_mesh_topology, tensor_shape, + dims_mapping): + valid = True + assert len(tensor_shape) == len(dims_mapping) + + for idx, dim_mapping in enumerate(dims_mapping): + if dim_mapping != -1: + if tensor_shape[idx] % process_mesh_topology[ + dim_mapping] != 0 or dims_mapping.count( + dim_mapping) > 1: + valid = False + if dim_mapping != -1 and process_mesh_topology[0] == 1: + valid = False + + return valid + + @staticmethod + def check_dims_mapping_for_op(op, op_dist_attr, vars): + process_mesh = op_dist_attr.process_mesh + assert process_mesh is not None, "The process mesh should not be None." + for var_name in op.input_arg_names: + dims_mapping = op_dist_attr.get_input_dims_mapping(var_name) + if not PlanFilter.check_dims_mapping_for_tensor( + process_mesh.topology, vars[var_name].shape, dims_mapping): + return False + if vars[var_name].is_data and len(dims_mapping) > 1: + for dim in dims_mapping[1:]: + if dim != -1: + return False + + for var_name in op.output_arg_names: + dims_mapping = op_dist_attr.get_output_dims_mapping(var_name) + if not PlanFilter.check_dims_mapping_for_tensor( + process_mesh.topology, vars[var_name].shape, dims_mapping): + return False + + return True + + @staticmethod + def check_dims_mapping_for_special_op(op, op_dist_attr, vars): + if op.type == "layer_norm": + bias_dims_mapping = op_dist_attr.get_input_dims_mapping( + op.input("Bias")[0]) + scale_dims_mapping = op_dist_attr.get_input_dims_mapping( + op.input("Scale")[0]) + x_dims_mapping = op_dist_attr.get_input_dims_mapping( + op.input("X")[0]) + mean_dims_mapping = op_dist_attr.get_output_dims_mapping( + op.output("Mean")[0]) + variance_dims_mapping = op_dist_attr.get_output_dims_mapping( + op.output("Variance")[0]) + y_dims_mapping = op_dist_attr.get_output_dims_mapping( + op.output("Y")[0]) + if x_dims_mapping != y_dims_mapping: + return False + + if scale_dims_mapping[0] != x_dims_mapping[-1]: + return False + + if bias_dims_mapping[0] != y_dims_mapping[-1]: + return False + + if mean_dims_mapping[0] != x_dims_mapping[0]: + return False + + if variance_dims_mapping[0] != x_dims_mapping[0]: + return False + + return True + + +class PlanSpace: + not_enum_ops = ["create_py_reader", "create_double_buffer_reader", "read"] + special_vars = [ + "lod_tensor_blocking_queue_0", "create_py_reader_0", "double_buffer_0" + ] + + @staticmethod + def _enum_dims_mapping(process_mesh_topology, visited, path, depth, res, + tensor_shape): + """Enumerate dims mapping of tensor by the given process_mesh_topology""" + nums = list(range(-1, len(process_mesh_topology))) + if depth == len(tensor_shape): + valid = True + for idx, item in enumerate(path): + if item != -1: + if tensor_shape[idx] % process_mesh_topology[ + item] != 0 or path.count(item) > 1: + valid = False + if valid: + res.append(copy.deepcopy(path)) + return + + for i in range(len(nums)): + if not visited[i]: + if i != 0: + visited[i] = True + path.append(nums[i]) + PlanSpace._enum_dims_mapping(process_mesh_topology, visited, + path, depth + 1, res, tensor_shape) + visited[i] = False + path.pop() + + @staticmethod + def enum_process_mesh_topology(processes): + """Enumerate all process meshes with the given processes.""" + assert processes >= 1, "The processes must be number and greater than 0." + # compute divisors + divisors = [] + for i in range(1, processes + 1): + if processes % i == 0: + divisors.append(i) + + # compute valid process mesh + results = [] + for i in range(len(divisors) - 1, 0, -1): + result = [] + result.append(divisors[i]) + if i == len(divisors) - 1: + results.append(copy.deepcopy(result)) + continue + + j = 1 + while j < len(divisors): + if len(result) == 1: + result.append(divisors[j]) + elif len(result) == 2: + if processes % (result[0] * result[1]) == 0: + if processes // (result[0] * result[1]) == 1: + results.append(copy.deepcopy(result)) + break + else: + result.append(processes // (result[0] * result[1])) + results.append(copy.deepcopy(result)) + result.pop(-1) + result.pop(-1) + j += 1 + else: + if result[0] * result[1] < processes: + result.pop(-1) + j += 1 + else: + break + return results + + @staticmethod + def _enum_valid_dist_attr_for_op(program, op, process_mesh): + """Enumerate the valid distributed attribute for op based on the given process mesh.""" + vars = program.global_block().vars + dims_mapping_dict = OrderedDict() + op_valid_dist_attrs = [] + dist_op_impl_container = get_distributed_operator_impl_container( + op.type) + + # enumerate all valid dims mapping of tensor when process mesh given + for var_name in chain(op.input_arg_names, op.output_arg_names): + visited = [ + False + for _ in range( + len(list(range(-1, len(process_mesh.topology))))) + ] + depth = 0 + path = [] + dims_mapping_list = [] + PlanSpace._enum_dims_mapping(process_mesh.topology, visited, path, + depth, dims_mapping_list, + vars[var_name].shape) + dims_mapping_dict[var_name] = copy.deepcopy(dims_mapping_list) + + # compose dims mapping + composed_dims_mapping_list = list( + product( + *[dims_mapping_dict[key] for key in dims_mapping_dict.keys()])) + for composed_dims_mapping in composed_dims_mapping_list: + op_dist_attr = OperatorDistributedAttribute() + op_dist_attr.process_mesh = process_mesh + var_names = list(dims_mapping_dict.keys()) + + for idx, dims_mapping in enumerate(composed_dims_mapping): + if var_names[idx] in op.input_arg_names: + op_dist_attr.set_input_dims_mapping(var_names[idx], + dims_mapping) + elif var_names[idx] in op.output_arg_names: + op_dist_attr.set_output_dims_mapping(var_names[idx], + dims_mapping) + else: + raise ValueError( + "The {varname} is not input or output of op {op}.". + format( + varname='var_names[idx]', op='op')) + + dist_op = DistributedOperator(op, op_dist_attr) + if dist_op_impl_container is None: + if is_elementwise_like_op(op.type): + changed = True + valid = True + try: + changed = update_op_dims_mapping_by_elementwise_like_dist_impl( + dist_op) + except Exception as e: + valid = False + if valid and not changed: + if PlanFilter.check_dims_mapping_for_op( + op, dist_op.dist_attr, vars + ) and PlanFilter.check_dims_mapping_for_special_op( + op, dist_op.dist_attr, vars): + dist_op.dist_attr.impl_idx = -1 + op_valid_dist_attrs.append(dist_op.dist_attr) + continue + else: + changed = True + valid = True + try: + changed = update_op_dims_mapping_by_default_dist_impl( + dist_op) + except Exception as e: + valid = False + if valid and not changed: + if PlanFilter.check_dims_mapping_for_op( + op, dist_op.dist_attr, vars + ) and PlanFilter.check_dims_mapping_for_special_op( + op, dist_op.dist_attr, vars): + dist_op.dist_attr.impl_idx = -2 + op_valid_dist_attrs.append(dist_op.dist_attr) + continue + + # if op has distributed implements, find all valid dist attr of this op + impls = dist_op_impl_container.get_impls() + for idx, impl in enumerate(impls): + if impl.is_auto_compatible(dist_op): + if PlanFilter.check_dims_mapping_for_op( + op, dist_op.dist_attr, vars): + dist_op.dist_attr.impl_idx = idx + op_valid_dist_attrs.append(dist_op.dist_attr) + + # set default dist attr for some special ops whose distributed attributes can not be enumerated + if not op_valid_dist_attrs: + op_dist_attr = OperatorDistributedAttribute() + op_dist_attr.process_mesh = process_mesh + dist_op = DistributedOperator(op, op_dist_attr) + for var_name in op.input_arg_names: + op_dist_attr.set_input_dims_mapping( + vars[var_name], [-1 for i in vars[var_name].shape]) + for var_name in op.output_arg_names: + op_dist_attr.set_output_dims_mapping( + vars[var_name], [-1 for i in vars[var_name].shape]) + dist_op.dist_attr.impl_idx = -1 + op_valid_dist_attrs.append(dist_op.dist_attr) + + return op_valid_dist_attrs + + @staticmethod + def enum_valid_dist_attr_for_program(program, + process_mesh_topology, + is_pipeline=False): + """Enumerate valid distributed attributes for all ops in program.""" + valid_dist_attr_dict = OrderedDict() + ops = program.global_block().ops + vars = program.global_block().vars + + processes = reduce(lambda x, y: x * y, process_mesh_topology) + global_group = [i for i in range(processes)] + global_process_mesh = None + pipeline_process_meshes = None + + # in the pipeline mode, there are some process meshes + if is_pipeline: + pipeline_stages = process_mesh_topology[-1] + op_count_per_stage = len(ops) // pipeline_stages + if len(process_mesh_topology) > 1: + process_mesh_shape = process_mesh_topology[:-1] + per_process_mesh_group = processes // pipeline_stages + pipeline_process_meshes = [auto.ProcessMesh(mesh=np.array(global_group[i*per_process_mesh_group: \ + (i+1)*per_process_mesh_group]).reshape(process_mesh_shape).tolist()) for i in range(pipeline_stages)] + elif len(process_mesh_topology) == 1: + pipeline_process_meshes = [ + auto.ProcessMesh(mesh=[i]) for i in range(pipeline_stages) + ] + else: + if len(process_mesh_topology) > 1: + global_process_mesh = auto.ProcessMesh(mesh=np.array( + global_group).reshape(process_mesh_topology).tolist()) + else: + global_process_mesh = auto.ProcessMesh(mesh=global_group) + + # enumerate valid distributed attribute for each op in the program + for idx, op in enumerate(ops): + op_valid_dist_attrs = None + op_process_mesh = global_process_mesh + pipeline_stage = -1 + if pipeline_process_meshes is not None: + pipeline_stage = idx // op_count_per_stage if idx // op_count_per_stage < len( + pipeline_process_meshes) else idx // op_count_per_stage - 1 + if pipeline_stage >= len(pipeline_process_meshes): + pipeline_stage = len(pipeline_process_meshes) - 1 + op_process_mesh = pipeline_process_meshes[pipeline_stage] + + if op.type in PlanSpace.not_enum_ops: + op_dist_attr = OperatorDistributedAttribute() + op_dist_attr.process_mesh = op_process_mesh + for var_name in op.input_arg_names: + if var_name in PlanSpace.special_vars: + op_dist_attr.set_input_dims_mapping(var_name, []) + else: + dims_mapping = [-1 for i in vars[var_name].shape] + op_dist_attr.set_input_dims_mapping(var_name, + dims_mapping) + + for var_name in op.output_arg_names: + if var_name in PlanSpace.special_vars: + op_dist_attr.set_output_dims_mapping(var_name, []) + else: + dims_mapping = [-1 for i in vars[var_name].shape] + op_dist_attr.set_output_dims_mapping(var_name, + dims_mapping) + op_valid_dist_attrs = [op_dist_attr] + pipeline_stage = 0 if pipeline_stage != -1 else pipeline_stage + else: + op_valid_dist_attrs = PlanSpace._enum_valid_dist_attr_for_op( + program, op, op_process_mesh) + + assert op_valid_dist_attrs is not None, "Enumerate {} valid distributed attribute failed.".format( + op) + valid_dist_attr_dict[op.desc.id( + )] = [op_valid_dist_attrs, pipeline_stage] + + return valid_dist_attr_dict, pipeline_process_meshes, global_process_mesh diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index 92918c834a5..c9cb4200e36 100755 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -13,6 +13,7 @@ # limitations under the License import os +import copy import paddle import threading import numpy as np @@ -1005,11 +1006,11 @@ def set_grad_var_shape(program, dist_context): need_set_shape_list = [ "reshape2_grad", "softmax_with_cross_entropy_grad", "transpose2_grad", "softmax_grad", "cross_entropy_grad2", - "dropout_grad" + "dropout_grad", "unsqueeze2_grad" ] forward_list = [ "reshape2", "softmax_with_cross_entropy", "transpose2", - "softmax", "cross_entropy2", "dropout" + "softmax", "cross_entropy2", "dropout", "unsqueeze2" ] if op.type in need_set_shape_list: for forward_op in block.ops: @@ -1172,3 +1173,88 @@ def update_op_dims_mapping_by_elementwise_like_dist_impl(dist_op): changed = True return changed + + +def get_all_distributed_main_program(serial_program_info, dist_context): + "Get all distributed main programs by dist_context." + from .dist_context import DistributedOperatorContext + cluster = serial_program_info.cluster + all_dist_main_program = [] + ranks = paddle.distributed.get_world_size() if cluster is None else len( + cluster.get_all_devices("GPU")) + for rank_id in range(ranks): + used_dist_context = copy.deepcopy(dist_context) + used_dist_context._dist_op_context = DistributedOperatorContext() + dist_main_program, dist_startup_program = get_specified_distributed_main_program( + serial_program_info, used_dist_context, rank_id) + all_dist_main_program.append(dist_main_program) + + return all_dist_main_program + + +def get_specified_distributed_main_program(serial_program_info, dist_context, + rank_id): + "Get distributed main program by the given dist_context and rank_id." + from .partitioner import Partitioner + from .reshard import reshard, HAS_SENT, HAS_RECV, HAS_ALLGATHER + from .process_group import _g_process_group_map, ProcessGroup + + dist_strategy = paddle.distributed.fleet.DistributedStrategy() + train_program = serial_program_info.train_program + startup_program = serial_program_info.startup_program + loss = serial_program_info.loss + optimizer = serial_program_info.optimizer + + partitioner = Partitioner(dist_strategy, dist_context, rank_id) + dist_main_program, dist_startup_program = partitioner.transpile_forward( + train_program, startup_program) + dist_params_grads = partitioner.apply_backward( + loss, train_program, startup_program, dist_main_program, + dist_startup_program) + opt_ops = partitioner.apply_optimize( + copy.deepcopy(optimizer), dist_params_grads, dist_main_program, + dist_startup_program) + set_grad_var_shape(dist_main_program, dist_context) + make_data_unshard(dist_main_program, dist_startup_program, dist_context) + reshard(dist_main_program, dist_startup_program, rank_id, dist_context) + HAS_SENT.clear() + HAS_RECV.clear() + HAS_ALLGATHER.clear() + + _g_process_group_map.clear() + _g_process_group_map[0] = ProcessGroup(0, []) + return dist_main_program, dist_startup_program + + +class SerialProgramInfo: + def __init__(self, + train_program, + satrtup_program, + loss, + optimizer, + cluster=None): + self._train_program = train_program + self._startup_program = satrtup_program + self._loss = loss + self._optimizer = optimizer + self._cluster = cluster + + @property + def train_program(self): + return self._train_program + + @property + def startup_program(self): + return self._startup_program + + @property + def loss(self): + return self._loss + + @property + def optimizer(self): + return self._optimizer + + @property + def cluster(self): + return self._cluster diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_cluster.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_cluster.py index d3942716f56..55b36654437 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_cluster.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_cluster.py @@ -209,6 +209,10 @@ class TestAutoParallelCluster(unittest.TestCase): cluster = Cluster() cluster.build_from_file("./auto_parallel_cluster.json") os.remove("./auto_parallel_cluster.json") + + self.assertEqual(len(cluster.get_all_devices("GPU")), 4) + self.assertEqual(len(cluster.get_all_devices("CPU")), 2) + self.assertEqual(len(cluster.get_all_devices("NIC")), 2) self.assertEqual(len(cluster.machines), 2) # machine0 diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_searcher.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_searcher.py index 665a16c862c..92d11801902 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_searcher.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_searcher.py @@ -14,9 +14,9 @@ from __future__ import print_function -# import os -# import copy -# import json +import os +import copy +import json import unittest import paddle @@ -24,13 +24,13 @@ import paddle.nn as nn import paddle.static as static import paddle.nn.functional as F import paddle.utils as utils -# from paddle.distributed import fleet +from paddle.distributed import fleet import paddle.distributed.auto_parallel as auto -# from paddle.distributed.auto_parallel.cluster import Cluster -# from paddle.distributed.auto_parallel.utils import SerialProgramInfo -# from paddle.distributed.auto_parallel.searcher import Checker, Enumerater +from paddle.distributed.auto_parallel.cluster import Cluster +from paddle.distributed.auto_parallel.utils import SerialProgramInfo +from paddle.distributed.auto_parallel.planner import PlanSpace, PlanFilter from paddle.distributed.auto_parallel.dist_context import DistributedContext -# from paddle.distributed.auto_parallel.utils import get_all_distributed_main_program +from paddle.distributed.auto_parallel.utils import get_all_distributed_main_program from paddle.distributed.auto_parallel.dist_attribute import TensorDistributedAttribute from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute from paddle.distributed.auto_parallel.utils import update_op_dims_mapping_by_default_dist_impl @@ -118,6 +118,30 @@ def set_default_dist_attr(program, dist_context, process_mesh): dist_context.add_process_mesh(process_mesh) +def check_process_meshes(processes): + result = PlanSpace.enum_process_mesh_topology(processes) + if result: + return True + return False + + +def check_pipeline_enumerater(program, process_mesh_topology): + valid_dist_attr_dict, pipeline_process_meshes, global_process_mesh = PlanSpace.enum_valid_dist_attr_for_program( + program, process_mesh_topology, True) + if valid_dist_attr_dict and len( + pipeline_process_meshes) > 1 and not global_process_mesh: + return True + return False + + +def check_nonpipeline_enumerater(program, process_mesh_topology): + valid_dist_attr_dict, pipeline_process_meshes, global_process_mesh = PlanSpace.enum_valid_dist_attr_for_program( + program, process_mesh_topology, False) + if valid_dist_attr_dict and not pipeline_process_meshes and global_process_mesh: + return True + return False + + class TestMLPSearcher(unittest.TestCase): def test_update(self): train_program = paddle.static.Program() @@ -174,6 +198,55 @@ class TestMLPSearcher(unittest.TestCase): continue self.assertTrue(changed) + def test_enumerater_and_checker(self): + processes = 4 + self.assertTrue(check_process_meshes(processes)) + + train_program = paddle.static.Program() + startup_program = paddle.static.Program() + _, train_program, startup_program = mlp_forward(train_program, + startup_program) + process_mesh_topology = [4] + self.assertTrue( + check_pipeline_enumerater(train_program, process_mesh_topology)) + self.assertTrue( + check_nonpipeline_enumerater(train_program, process_mesh_topology)) + + def test_get_dist_programs(self): + train_program = paddle.static.Program() + startup_program = paddle.static.Program() + loss, train_program, startup_program = mlp_forward(train_program, + startup_program) + process_mesh_topology = [4] + optimizer = paddle.optimizer.Adam( + learning_rate=0.00001, + beta1=0.9, + beta2=0.999, + epsilon=1e-08, + grad_clip=None) + valid_dist_attr_dict, pipeline_process_meshes, global_process_mesh = PlanSpace.enum_valid_dist_attr_for_program( + train_program, process_mesh_topology, False) + from test_auto_parallel_cluster import cluster_json + cluster_json_file = "" + cluster_json_object = json.loads(cluster_json) + with open("./auto_parallel_cluster.json", "w") as cluster_json_file: + json.dump(cluster_json_object, cluster_json_file) + cluster = Cluster() + cluster.build_from_file("./auto_parallel_cluster.json") + os.remove("./auto_parallel_cluster.json") + + ops = train_program.global_block().ops + vars = train_program.global_block().vars + new_dist_context = DistributedContext() + set_default_dist_attr(train_program, new_dist_context, + global_process_mesh) + + serial_program_info = SerialProgramInfo(train_program, startup_program, + loss, optimizer, cluster) + result = get_all_distributed_main_program(serial_program_info, + new_dist_context) + self.assertEqual(len(result), 4) + if __name__ == "__main__": unittest.main() -- GitLab