From 3108ba11f15782d00e571d90bbe9b53cc278b33d Mon Sep 17 00:00:00 2001 From: caozhou <48191911+Caozhou1995@users.noreply.github.com> Date: Tue, 18 Oct 2022 20:25:27 +0800 Subject: [PATCH] [Auto Parallel]Add parallel tuner (#46189) * add parallel tuner * add unittest * fix unittest * set timeout of unittest * set unittest timeout * fix auto_mode setting * update unittest * sync from develop and update unittest * remove unused import * update unittest * update cmakelist * add unittests --- .../auto_parallel/operators/dist_matmul.py | 2 + .../distributed/auto_parallel/planner_v2.py | 19 +- .../auto_parallel/tuner/parallel_tuner.py | 968 ++++++++++++++++++ .../auto_parallel/tuner/tunable_space.py | 8 + .../auto_parallel/tuner/tunable_variable.py | 12 +- .../unittests/auto_parallel/CMakeLists.txt | 12 + .../auto_parallel/test_parallel_tuner.py | 141 +++ .../auto_parallel/test_parallel_tuner_full.py | 147 +++ .../test_parallel_tuner_predict.py | 144 +++ .../auto_parallel/test_tunable_space.py | 10 + .../unittests/auto_parallel_gpt_model.py | 41 +- 11 files changed, 1469 insertions(+), 35 deletions(-) create mode 100644 python/paddle/distributed/auto_parallel/tuner/parallel_tuner.py create mode 100644 python/paddle/fluid/tests/unittests/auto_parallel/test_parallel_tuner.py create mode 100644 python/paddle/fluid/tests/unittests/auto_parallel/test_parallel_tuner_full.py create mode 100644 python/paddle/fluid/tests/unittests/auto_parallel/test_parallel_tuner_predict.py diff --git a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py index afd6123a0cb..fd7852e4699 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py @@ -1305,6 +1305,8 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): process_mesh = dist_attr.process_mesh processes = process_mesh.processes # col parallel: matmul + allreduce + if backward_op.attr("trans_y"): + Y_var_dim_mapping.reverse() assert Y_var_dim_mapping[0] < 0 parallel_axis = Y_var_dim_mapping[1] diff --git a/python/paddle/distributed/auto_parallel/planner_v2.py b/python/paddle/distributed/auto_parallel/planner_v2.py index 3fb41239e7e..8e2c0c4617b 100755 --- a/python/paddle/distributed/auto_parallel/planner_v2.py +++ b/python/paddle/distributed/auto_parallel/planner_v2.py @@ -14,8 +14,7 @@ from .completion import Completer from .dist_context import get_default_distributed_context - -# from .tuner.parallel_tuner import ParallelTuner +from .tuner.parallel_tuner import ParallelTuner class Planner: @@ -38,20 +37,20 @@ class Planner: self._completer = Completer(self._dist_context) self._strategy = dist_context.strategy - # if self._strategy.auto_search: - # self._parallel_tuner = ParallelTuner( - # self._dist_context, mode=self._mode) + # set parallel tuner for auto search + if self._strategy.auto_mode == "full": + self._parallel_tuner = ParallelTuner(self._dist_context, + mode=self._mode) @property def completer(self): return self._completer def plan(self): - self._completer.complete_forward_annotation() - # if self._strategy.auto_search: - # self._parallel_tuner.tune() - # else: - # self._completer.complete_forward_annotation() + if self._strategy.auto_mode == "full": + self._parallel_tuner.tune() + else: + self._completer.complete_forward_annotation() # parse forward sub block self._dist_context.block_state.parse_forward_blocks( self._dist_context.serial_main_program) diff --git a/python/paddle/distributed/auto_parallel/tuner/parallel_tuner.py b/python/paddle/distributed/auto_parallel/tuner/parallel_tuner.py new file mode 100644 index 00000000000..24ee382f7f7 --- /dev/null +++ b/python/paddle/distributed/auto_parallel/tuner/parallel_tuner.py @@ -0,0 +1,968 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import math +import copy +import hashlib +import itertools +from collections import defaultdict +import numpy as np +from ..process_mesh import ProcessMesh +from ..completion import Completer +from ..parallelizer_v2 import Parallelizer +from ..dist_context import _node_id +from ..dist_op import DistributedOperator +from ..operators.common import find_compatible_distributed_operator_impls +from .trial import Trial, TrialStatus +from .tunable_space import TunableSpace +from .tunable_variable import Boolean, IntRange +from ..cost import CostEstimator +from .tunable_variable import Boolean, IntRange + + +class ParallelTuner: + + def __init__(self, + dist_context, + mode="train", + max_trials=25, + tuner_id=None, + seed=None, + logger=None, + loop_count=10): + self._loop_count = loop_count + self._estimator = None + self._dist_context = dist_context + assert self._dist_context._is_initialized + self._mode = mode + self._cluster = self._dist_context.cluster + self._num_machines = self._cluster.get_num_machines() + self._num_devices_per_machine = self._cluster.get_num_devices_per_machine( + ) + self._space = TunableSpace() + self._objective = "time" + self._direction = "min" + self._max_trials = max_trials + self._tuner_id = tuner_id + self._seed = seed if seed is not None else 9999 + + print("seed", + self._seed, + "mode", + self._mode, + "num_machies", + self._num_machines, + "num_devices_per_machine", + self._num_devices_per_machine, + flush=True) + self._seed_state = self._seed + self._logger = logger + self._max_collisions = 3 + self._tried_values = set() + self._num_trials = 0 + self._rng = np.random.default_rng(self._seed) + + # Search the op types in the include_op_types, + # and will search all op types if it is empty. + # Exclude the op types in the exclude_op_types + # from the search list. + self._exclude_op_types = [] + self._include_op_types = [] + # The final dist ops will be searched after considering + # the include_op_types and exclude_op_types. + self._concerned_dist_ops = {} + + self._op_id_to_dist_attr_candidates = defaultdict(list) + self._cached_dims_mapping_candidates = {} + self._cached_candidates_info = defaultdict(list) + + self._special_ops = [ + "create_py_reader", "create_double_buffer_reader", "read", "while", + "read_from_array", "write_to_array" + ] + + # Each parallel strategy has two elements. The First one is for distributed tensors, + # the second element is for distributed tensors, the third element is for process meshes. + self._init_parallel_strategy = [None, None, None] + self._best_parallel_strategy = [None, None, None] + + self._completer = Completer(self._dist_context) + + self._parallelizer = Parallelizer(self._mode, self._completer, + self._dist_context) + + def _generate_combination(self, + elements, + target, + idx, + partial_candidate, + candidates, + num_candidates=None): + if target == 0: + candidates.append(copy.deepcopy(partial_candidate)) + return + + if target < 0 or idx == len(elements) \ + or len(candidates) > num_candidates: + return + + # Use + partial_candidate.append(elements[idx]) + self._generate_combination(elements, target - elements[idx], idx, + partial_candidate, candidates, + num_candidates) + # Not use + partial_candidate.pop() + self._generate_combination(elements, target, idx + 1, partial_candidate, + candidates, num_candidates) + + def _permute_combination(self, + combination, + target, + check, + partial_candidate, + candidates, + num_candidates=None, + skip_prob=None): + if num_candidates is not None \ + and len(candidates) == num_candidates: + return + + if len(partial_candidate) == len(combination): + candidates.append(partial_candidate) + return + + for i in range(len(combination)): + if check[i] == 1: + continue + if self._rng.choice([True, False], p=[skip_prob, 1 - skip_prob]): + continue + if i > 0 and combination[i] == combination[i - 1] \ + and check[i -1] == 0: + continue + check[i] = 1 + self._permute_combination(combination, target, check, + partial_candidate + [combination[i]], + candidates, num_candidates, skip_prob) + check[i] = 0 + + def _partition_number(self, target): + log2_target = int(math.log2(target)) + elements = [pow(2, i) for i in range(log2_target)] + if pow(2, log2_target) == target: + elements.append(target) + seed_candidates = [] + num_seed_candidates = 1000 + partial_results = [] + self._generate_combination(elements, target, 0, partial_results, + seed_candidates, num_seed_candidates) + + candidates = [] + for seed_candidate in seed_candidates: + cur_candidates = [] + num_cur_candidates = 16 + seed_candidate.sort() + check = [0 for i in range(len(seed_candidate))] + if target <= 8: + skip_prob = 0.0 + else: + skip_prob = (len(seed_candidate) / target) + self._permute_combination(seed_candidate, target, check, [], + cur_candidates, num_cur_candidates, + skip_prob) + candidates.extend(cur_candidates) + return candidates + + def _partition_devices(self, num_machines, num_devices_per_machine): + inter_node_partitions = self._partition_number(num_machines) + intra_node_partitions = self._partition_number(num_devices_per_machine) + return inter_node_partitions, intra_node_partitions + + def _generate_process_mesh_list(self, inter_node_partition, + intra_node_partition): + process_mesh_list = [] + start_row = 0 + start_col = 0 + for m in inter_node_partition: + start_col = 0 + for n in intra_node_partition: + process_mesh = [] + for p in range(m): + start = (start_row + + p) * self._num_devices_per_machine + start_col + tmp = [] + for q in range(n): + tmp.append(start + q) + process_mesh.append(tmp) + process_mesh_list.append(copy.deepcopy(process_mesh)) + start_col += n + start_row += m + return process_mesh_list + + def _generate_dims_mapping_candidates_helper(self, dims_mapping, dims_list, + start, visited, candidates): + if start == len(dims_mapping) or all(visited): + candidates.append(copy.deepcopy(dims_mapping)) + return + + for idx, dim in enumerate(dims_list): + if visited[idx] == False: + dims_mapping[start] = dim + visited[idx] = True + self._generate_dims_mapping_candidates_helper( + dims_mapping, dims_list, start + 1, visited, candidates) + visited[idx] = False + dims_mapping[start] = -1 + self._generate_dims_mapping_candidates_helper(dims_mapping, dims_list, + start + 1, visited, + candidates) + + def _generate_dims_mapping_candidates(self, dims_mapping_len, + process_mesh_len): + assert dims_mapping_len >= 1 and process_mesh_len >= 1 + key = (dims_mapping_len, process_mesh_len) + if key in self._cached_dims_mapping_candidates: + return self._cached_dims_mapping_candidates[key] + candidates = [] + dims_mapping = [-1 for i in range(dims_mapping_len)] + dims_list = [i for i in range(process_mesh_len)] + visited = [False for i in range(process_mesh_len)] + self._generate_dims_mapping_candidates_helper(dims_mapping, dims_list, + 0, visited, candidates) + self._cached_dims_mapping_candidates[key] = candidates + return candidates + + def _generate_dist_attr_candidates(self, op_id, dist_op): + # For now, only allow the process meshes have two dimensions + process_mesh_len = 2 + serial_op = dist_op.serial_op + op_dist_attr = dist_op.dist_attr + if serial_op.type in self._special_ops: + return [copy.deepcopy(op_dist_attr)] + key = [] + key.append(serial_op.type) + for input_name in serial_op.input_names: + key.append(input_name) + for input_arg_name in serial_op.input(input_name): + key.append( + len(op_dist_attr.get_input_dims_mapping(input_arg_name))) + for output_name in serial_op.output_names: + key.append(output_name) + for output_arg_name in serial_op.output(output_name): + key.append( + len(op_dist_attr.get_output_dims_mapping(output_arg_name))) + key = tuple(key) + + if key in self._cached_candidates_info: + cached_dist_attr_candidates = [] + cached_input_arg_names = self._cached_candidates_info[key][0] + cached_output_arg_names = self._cached_candidates_info[key][1] + for cached_dist_attr in self._cached_candidates_info[key][2]: + new_op_dist_attr = copy.deepcopy(dist_op.dist_attr) + i = 0 + for input_name in serial_op.input_names: + for input_arg_name in serial_op.input(input_name): + cached_dims_mapping = cached_dist_attr.get_input_dims_mapping( + cached_input_arg_names[i]) + new_op_dist_attr.set_input_dims_mapping( + input_arg_name, cached_dims_mapping) + i += 1 + i = 0 + for output_name in serial_op.output_names: + for output_arg_name in serial_op.output(output_name): + cached_dims_mapping = cached_dist_attr.get_output_dims_mapping( + cached_output_arg_names[i]) + new_op_dist_attr.set_output_dims_mapping( + output_arg_name, cached_dims_mapping) + i += 1 + cached_dist_attr_candidates.append(new_op_dist_attr) + return cached_dist_attr_candidates + + # cached_candidates_info = [] + input_arg_names = [] + for input_name in serial_op.input_names: + for input_arg_name in serial_op.input(input_name): + input_arg_names.append(input_arg_name) + self._cached_candidates_info[key].append(input_arg_names) + # cached_candidates_info.append(input_arg_names) + output_arg_names = [] + for output_name in serial_op.output_names: + for output_arg_name in serial_op.output(output_name): + output_arg_names.append(output_arg_name) + self._cached_candidates_info[key].append(output_arg_names) + # cached_candidates_info.append(output_arg_names) + + new_op_dist_attr = copy.deepcopy(dist_op.dist_attr) + # Find valid dims_mapping candidates for inputs + input_names = [] + dims_mapping_generated = [] + inputs_dist_attrs = op_dist_attr.inputs_dist_attrs + for tensor_name, tensor_dist_attr in inputs_dist_attrs.items(): + original_dims_mapping = tensor_dist_attr.dims_mapping + dims_mapping_len = len(original_dims_mapping) + input_names.append(tensor_name) + if dims_mapping_len < 1: + dims_mapping_generated.append( + [copy.deepcopy(original_dims_mapping)]) + else: + dims_mapping_generated.append( + self._generate_dims_mapping_candidates( + dims_mapping_len, process_mesh_len)) + input_dims_mapping_candidates = [] + for dims_mapping_list in itertools.product(*dims_mapping_generated): + dims_mapping_list = list(dims_mapping_list) + assert len(dims_mapping_list) == len(input_names) + for i, dims_mapping in enumerate(dims_mapping_list): + new_op_dist_attr.set_input_dims_mapping(input_names[i], + dims_mapping) + new_dist_op = DistributedOperator(dist_op.serial_op, + new_op_dist_attr) + dist_op_impls = find_compatible_distributed_operator_impls( + new_dist_op, fwd=True) + if dist_op_impls is not None: + input_dims_mapping_candidates.append(dims_mapping_list) + + # Find valid dims_mapping candidates for outputs + output_names = [] + dims_mapping_generated = [] + outputs_dist_attrs = op_dist_attr.outputs_dist_attrs + for tensor_name, tensor_dist_attr in outputs_dist_attrs.items(): + original_dims_mapping = tensor_dist_attr.dims_mapping + dims_mapping_len = len(original_dims_mapping) + output_names.append(tensor_name) + if dims_mapping_len < 1: + dims_mapping_generated.append( + [copy.deepcopy(original_dims_mapping)]) + else: + dims_mapping_generated.append( + self._generate_dims_mapping_candidates( + dims_mapping_len, process_mesh_len)) + output_dims_mapping_candidates = [] + for dims_mapping_list in itertools.product(*dims_mapping_generated): + dims_mapping_list = list(dims_mapping_list) + assert len(dims_mapping_list) == len(output_names) + for i, dims_mapping in enumerate(dims_mapping_list): + new_op_dist_attr.set_output_dims_mapping( + output_names[i], dims_mapping) + new_dist_op = DistributedOperator(dist_op.serial_op, + new_op_dist_attr) + dist_op_impls = find_compatible_distributed_operator_impls( + new_dist_op, fwd=False) + if dist_op_impls is not None: + output_dims_mapping_candidates.append(dims_mapping_list) + + if not input_dims_mapping_candidates and output_dims_mapping_candidates: + inout_dims_mapping_generated = [[[[-2]]], + output_dims_mapping_candidates] + elif input_dims_mapping_candidates and not output_dims_mapping_candidates: + inout_dims_mapping_generated = [ + input_dims_mapping_candidates, [[[-2]]] + ] + elif not input_dims_mapping_candidates and not output_dims_mapping_candidates: + inout_dims_mapping_generated = [[[[-2]]], [[[-2]]]] + else: + inout_dims_mapping_generated = [ + input_dims_mapping_candidates, output_dims_mapping_candidates + ] + # Find valid dims_mapping generated for both inputs and outputs + cached_dist_attr_candidates = [] + for inout_dims_mapping_list in itertools.product( + *inout_dims_mapping_generated): + assert len(inout_dims_mapping_list) == 2 + if input_dims_mapping_candidates: + assert len(inout_dims_mapping_list[0]) == len(input_names) + if output_dims_mapping_candidates: + assert len(inout_dims_mapping_list[1]) == len(output_names) + # set the dims_mappings for inputs + for i, dims_mapping in enumerate(inout_dims_mapping_list[0]): + if dims_mapping != [-2]: + new_op_dist_attr.set_input_dims_mapping( + input_names[i], dims_mapping) + # set the dims_mappings for outputs + for i, dims_mapping in enumerate(inout_dims_mapping_list[1]): + if dims_mapping != [-2]: + new_op_dist_attr.set_output_dims_mapping( + output_names[i], dims_mapping) + new_dist_op = DistributedOperator(dist_op.serial_op, + new_op_dist_attr) + dist_op_impls = find_compatible_distributed_operator_impls( + new_dist_op, partial=False) + if dist_op_impls is None: + continue + for dist_op_impl in dist_op_impls: + new_op_dist_attr.impl_type = dist_op_impl.type + new_op_dist_attr.impl_idx = dist_op_impl.idx + cached_dist_attr_candidates.append( + copy.deepcopy(new_op_dist_attr)) + self._cached_candidates_info[key].append(cached_dist_attr_candidates) + return self._cached_candidates_info[key][2] + + def construct_space(self): + inter_node_partitions, intra_node_partitions = self._partition_devices( + self._num_machines, self._num_devices_per_machine) + self._space.choice("inter_node_partitions", + inter_node_partitions, + default=inter_node_partitions[0]) + self._space.choice("intra_node_partitions", + intra_node_partitions, + default=intra_node_partitions[0]) + + dist_ops = self._dist_context._dist_ops_for_program + for op_id, dist_op in dist_ops.items(): + op_type = dist_op.serial_op.type + if self._include_op_types: + if op_type in self._include_op_types: + self._concerned_dist_ops[op_id] = dist_op + else: + self._concerned_dist_ops[op_id] = dist_op + + for op_id, dist_op in self._concerned_dist_ops.items(): + op_type = dist_op.serial_op.type + if op_type in self._exclude_op_types: + del self._concerned_dist_ops[op_id] + + print("Number of the concered dist ops", + len(self._concerned_dist_ops), + flush=True) + search_space = 1 + for op_id, dist_op in self._concerned_dist_ops.items(): + op_dist_attr_candidates = self._generate_dist_attr_candidates( + op_id, dist_op) + search_space *= len(op_dist_attr_candidates) + self._space.choice(str(op_id), + op_dist_attr_candidates, + default=op_dist_attr_candidates[0]) + + def _compute_values_hash(self, values): + keys = sorted(values.keys()) + s = "".join(str(k) + "=" + str(values[k]) for k in keys) + return hashlib.sha256(s.encode("utf-8")).hexdigest()[:32] + + def _random_values(self): + space = TunableSpace() + collisions = 0 + while True: + for v in self._space.variables.values(): + space._register(v) + space.values[v.name] = v.random(self._seed_state) + self._seed_state += 1 + values = space.values + values_hash = self._compute_values_hash(values) + if values_hash in self._tried_values: + collisions += 1 + if collisions > self._max_collisions: + return None + continue + self._tried_values.add(values_hash) + break + return values + + def _populate_space(self): + values = self._random_values() + if values is None: + return {"status": TrialStatus.STOPPED, "values": None} + return {"status": TrialStatus.RUNNING, "values": values} + + def _create_trial(self): + trial_id = "{{:0{}d}}".format(len(str(self._max_trials))) + trial_id = trial_id.format(self._num_trials) + + if self._max_trials and self._num_trials >= self._max_trials: + status = TrialStatus.STOPPED + values = None + else: + results = self._populate_space() + status = results["status"] + values = results["values"] + + space = TunableSpace() + space.variables = self._space.variables + space.values = values + trial = Trial(tunable_space=space, trial_id=trial_id, status=status) + self._num_trials += 1 + return trial + + def _generate_pipeline_starts(self, process_mesh_list): + total_ops = len(self._dist_context._dist_ops_for_program) + total_stages = len(process_mesh_list) + ops_per_stage = total_ops // total_stages + if ops_per_stage == 0: + return None + # Compute the initial pipeline starts + pipeline_starts = [] + start = 0 + pipeline_starts.append(0) + # The pipeline_starts have total_stages+1 items, and + # at least have 2 items. + for _ in process_mesh_list: + start += ops_per_stage + pipeline_starts.append(start) + pipeline_starts[-1] = total_ops + # Adjust the pipeline starts by random selection + directions = [] + sizes = [] + half_ops_per_stage = ops_per_stage // 2 + if half_ops_per_stage > 0 and total_stages > 1: + new_pipeline_starts = [] + # Don't change the first start + new_pipeline_starts.append(0) + # Consider the starts except the first and the last one + for _ in pipeline_starts[1:-1]: + directions.append(Boolean("direction")) + sizes.append( + IntRange("size", + start=0, + stop=half_ops_per_stage, + endpoint=True)) + for i, start in enumerate(pipeline_starts[1:-1]): + direction = directions[i].random(self._seed) + size = sizes[i].random(self._seed) + if direction: + # Substract 1 from size to avoid the overlapping of new starts + new_start = start - (size - 1) + else: + new_start = start + size + new_pipeline_starts.append(new_start) + # Don't change the last start + new_pipeline_starts.append(pipeline_starts[-1]) + # Validate the new starts + print("Adjusted pipeline starts", + new_pipeline_starts, + half_ops_per_stage, + pipeline_starts, + flush=True) + for i, new_start in enumerate(new_pipeline_starts[1:]): + assert new_start > new_pipeline_starts[i] + return new_pipeline_starts + else: + print("Non-adjusted pipeline starts", + pipeline_starts, + half_ops_per_stage, + flush=True) + return pipeline_starts + + def _apply_pipeline_partition(self, process_mesh_list): + op_id_to_process_mesh = {} + total_ops = len(self._dist_context._dist_ops_for_program) + total_stages = len(process_mesh_list) + ops_per_stage = total_ops // total_stages + if ops_per_stage == 0: + return None + pipeline_starts = self._generate_pipeline_starts(process_mesh_list) + start_idx = 1 + sorted_op_ids = sorted(self._dist_context._dist_ops_for_program.keys()) + for idx, op_id in enumerate(sorted_op_ids): + if idx < pipeline_starts[start_idx]: + op_id_to_process_mesh[op_id] = process_mesh_list[start_idx - 1] + else: + start_idx += 1 + op_id_to_process_mesh[op_id] = process_mesh_list[start_idx - 1] + return op_id_to_process_mesh + + def _amend_dist_attr(self): + # 1) Reshape the process mesh of [1, x] to [x] or [x, 1] to [x], + # and amend the corresponding dims_mapping. + # 2) Set the dim_mapping to -1 when the shape cannot be divided + # by the corresponding processes. + for dist_op in self._dist_context._dist_ops_for_program.values(): + dist_attr = dist_op.dist_attr + process_mesh = dist_attr.process_mesh + if process_mesh is None: + continue + assert process_mesh.ndim == 2 + dim_of_one = None + dim_of_other = None + if process_mesh.topology[0] == 1: + dim_of_one = 0 + dim_of_other = 1 + elif process_mesh.topology[1] == 1: + dim_of_one = 1 + dim_of_other = 0 + + if dim_of_one is not None: + dist_attr.process_mesh = ProcessMesh(process_mesh.processes) + self._dist_context.add_process_mesh(dist_attr.process_mesh) + + for arg_name in dist_attr.inputs_dist_attrs.keys(): + new_dims_mapping = [] + dims_mapping = dist_attr.get_input_dims_mapping(arg_name) + for dim_mapping in dims_mapping: + if dim_mapping == dim_of_one: + new_dims_mapping.append(-1) + elif dim_mapping == dim_of_other: + new_dims_mapping.append(0) + else: + new_dims_mapping.append(dim_mapping) + dist_attr.set_input_dims_mapping(arg_name, new_dims_mapping) + + dims_mapping = dist_attr.get_input_dims_mapping(arg_name) + # dynamic_dims = dist_attr.get_input_dynamic_dims(arg_name) + process_mesh = dist_attr.process_mesh + process_shape = process_mesh.topology + tensor = dist_op.get_serial_input(arg_name) + if dims_mapping: + tensor_shape = tensor.shape + else: + continue + for i, dim_mapping in enumerate(dims_mapping): + # if dim_mapping != -1 \ + # and (tensor_shape[i] % process_shape[dim_mapping] != 0 \ + # or dynamic_dims[i] == 1): + if dim_mapping != -1 \ + and (tensor_shape[i] % process_shape[dim_mapping] != 0): + dims_mapping[i] = -1 + # it is a fix-bug + if dim_mapping != -1 \ + and process_shape[dim_mapping] == 1: + dims_mapping[i] = -1 + + for arg_name in dist_attr.outputs_dist_attrs.keys(): + new_dims_mapping = [] + dims_mapping = dist_attr.get_output_dims_mapping(arg_name) + for dim_mapping in dims_mapping: + if dim_mapping == dim_of_one: + new_dims_mapping.append(-1) + elif dim_mapping == dim_of_other: + new_dims_mapping.append(0) + else: + new_dims_mapping.append(dim_mapping) + dist_attr.set_output_dims_mapping(arg_name, new_dims_mapping) + + dims_mapping = dist_attr.get_output_dims_mapping(arg_name) + # dynamic_dims = dist_attr.get_output_dynamic_dims(arg_name) + process_mesh = dist_attr.process_mesh + process_shape = process_mesh.topology + + tensor = dist_op.get_serial_output(arg_name) + if dims_mapping: + tensor_shape = tensor.shape + else: + continue + for i, dim_mapping in enumerate(dims_mapping): + if dim_mapping != -1 \ + and (tensor_shape[i] % process_shape[dim_mapping] != 0): + dims_mapping[i] = -1 + # it is a fix-bug + if dim_mapping != -1 \ + and process_shape[dim_mapping] == 1: + dims_mapping[i] = -1 + dist_op_impls = find_compatible_distributed_operator_impls( + dist_op, partial=False) + serial_op_type = dist_op.serial_op.type + + if dist_op_impls is not None and ( + serial_op_type != "fused_softmax_mask_upper_triangle" + or self._check_fused_softmax_mask_upper_triangle(dist_op)): + dist_op.dist_attr.impl_type = dist_op_impls[0].type + dist_op.dist_attr.impl_idx = dist_op_impls[0].idx + else: + # Use the default dist op impl + for arg_name in dist_attr.inputs_dist_attrs.keys(): + dims_mapping = dist_attr.get_input_dims_mapping(arg_name) + for i, _ in enumerate(dims_mapping): + dims_mapping[i] = -1 + for arg_name in dist_attr.outputs_dist_attrs.keys(): + dims_mapping = dist_attr.get_output_dims_mapping(arg_name) + for i, _ in enumerate(dims_mapping): + dims_mapping[i] = -1 + dist_op.dist_attr.impl_type = "default" + dist_op.dist_attr.impl_idx = 0 + + def _check_fused_softmax_mask_upper_triangle(self, dist_op): + """The last_but_one dim shoule be equal to last dim.""" + input_name = dist_op.serial_op.input_arg_names[0] + input_dims_mapping = dist_op.dist_attr.get_input_dims_mapping( + input_name) + topology = dist_op.dist_attr.process_mesh.topology + input_tensor = dist_op.get_serial_input(input_name) + last_but_one_dim = input_tensor.shape[-2] // topology[ + input_dims_mapping[-2]] if input_dims_mapping[ + -2] != -1 else input_tensor.shape[-2] + last_dim = input_tensor.shape[-1] // topology[input_dims_mapping[ + -1]] if input_dims_mapping[-1] != -1 else input_tensor.shape[-1] + if last_but_one_dim == last_dim: + return True + return False + + def _eval_trial(self, trial): + if self._num_trials == 0: + num_prev_trials = 0 + else: + num_prev_trials = self._num_trials - 1 + + results = None + + start_time = time.time() + + inter_node_partition = trial.space.values["inter_node_partitions"] + intra_node_partition = trial.space.values["intra_node_partitions"] + process_mesh_list = self._generate_process_mesh_list( + inter_node_partition, intra_node_partition) + print("\tprocess_mesh list", process_mesh_list, flush=True) + op_id_to_process_mesh = self._apply_pipeline_partition( + process_mesh_list) + if op_id_to_process_mesh is None: + print("Operators are less than pipeline stages", flush=True) + return results + + op_id_to_dist_attr = {} + for name, value in trial.space.values.items(): + if name != "inter_node_partitions" \ + and name !="intra_node_partitions": + op_id_to_dist_attr[int(name)] = value + + end_time = time.time() + cur_sample_time = end_time - start_time + self._sample_time = (num_prev_trials * self._sample_time + + cur_sample_time) / self._num_trials + print("\tsample_time", + num_prev_trials, + self._num_trials, + self._sample_time, + cur_sample_time, + flush=True) + + assert len(op_id_to_process_mesh) == len(op_id_to_dist_attr) + + start_time = time.time() + for op_id, process_mesh in op_id_to_process_mesh.items(): + dist_op = self._dist_context._dist_ops_for_program[op_id] + dist_op.dist_attr = copy.deepcopy(op_id_to_dist_attr[op_id]) + assert dist_op.dist_attr.impl_type == op_id_to_dist_attr[ + op_id].impl_type + assert dist_op.dist_attr.impl_idx == op_id_to_dist_attr[ + op_id].impl_idx + dist_op.dist_attr.process_mesh = process_mesh + self._amend_dist_attr() + + self._completer._complete_tensor_dist_attr_by_op() + + self._dist_context.block_state.parse_forward_blocks( + self._dist_context.serial_main_program) + + end_time = time.time() + cur_complete_time = end_time - start_time + self._complete_time = (num_prev_trials * self._complete_time + + cur_complete_time) / self._num_trials + print("\tcomplete_time", + num_prev_trials, + self._num_trials, + self._complete_time, + cur_complete_time, + flush=True) + + start_time = time.time() + estimate_time = self._estimate_trial() + end_time = time.time() + cur_estimate_time = end_time - start_time + self._estimate_time = (num_prev_trials * self._estimate_time + + cur_estimate_time) / self._num_trials + print("\testimate_time", + num_prev_trials, + self._num_trials, + self._estimate_time, + cur_estimate_time, + estimate_time, + flush=True) + + results = {"estimate_time": estimate_time} + return results + + def _update_trail(self, trial, metrics, step=0): + for metric_name, metric_value in metrics.items(): + trial.recorder.update(metric_name, metric_value, step=step) + return trial.status + + def _estimate_trial(self): + assert self._cluster is not None + if self._mode == "eval": + self._estimator = CostEstimator( + self._dist_context.serial_main_program, + self._cluster, + loop_count=self._loop_count) + elif self._mode == "predict": + self._estimator = CostEstimator( + self._dist_context.serial_main_program, + self._cluster, + loop_count=self._loop_count) + elif self._mode == "train": + # get serial main program with backward + serial_main_program = self._dist_context.serial_main_program + serial_startup_program = self._dist_context.serial_startup_program + serial_optimizer = self._dist_context.serial_optimizer + + # Generate backward + serial_loss = self._dist_context.serial_fetch_vars["loss"][0] + params_grads = self._parallelizer._generate_backward( + serial_main_program, serial_startup_program, serial_loss) + + # Generate optimizer + optimizer_ops = self._parallelizer._generate_optimizer( + serial_main_program, serial_startup_program, serial_optimizer, + params_grads) + self._estimator = CostEstimator(serial_main_program, + self._cluster, + loop_count=self._loop_count) + + max_memory = self._estimator._estimate_max_memory_by_dist_op( + self._dist_context) + print("\tmax_memory", "{:,}".format(max_memory), flush=True) + # The max memory must be less than 80% 32GB (hard code) + if max_memory > 32 * 0.8 * 1024 * 1024 * 1024: + return math.inf + else: + global_cost = self._estimator.estimate(self._dist_context) + return global_cost.time + + def _store_init_parallel_strategy(self): + # If there is no annotation information, use the dp as the initial parallel strategy. + # TODO: we should need a better way to set up the initial parallel strategy. + if not self._dist_context.has_annotation \ + or not self._dist_context.process_meshes: + ranks = self._num_machines * self._num_devices_per_machine + tensor_node = self._dist_context._serial_ordered_tensor_nodes[0] + tensor_node_id = _node_id(tensor_node) + tensor = self._dist_context._dist_tensors_for_graph[ + tensor_node_id].serial_tensor + tensor_dist_attr = self._dist_context._dist_tensors_for_graph[ + tensor_node_id].dist_attr + tensor_dist_attr.process_mesh = ProcessMesh(list(range(ranks))) + self._dist_context._process_meshes.append( + tensor_dist_attr.process_mesh) + tensor_dist_attr.dims_mapping = [0] + [ + -1 for _ in range(len(tensor.shape) - 1) + ] + tensor_dist_attr.mark_annotated("process_mesh") + tensor_dist_attr.mark_annotated("dims_mapping") + print("Use dp as the init parallel strategy!", flush=True) + + # Do the sharding propagation + self._completer.complete_forward_annotation() + self._dist_context.block_state.parse_forward_blocks( + self._dist_context.serial_main_program) + + # Backup the intital parallel strategy + self._init_parallel_strategy[0] = copy.deepcopy( + self._dist_context._dist_tensors_for_program) + self._init_parallel_strategy[1] = copy.deepcopy( + self._dist_context._dist_ops_for_program) + self._init_parallel_strategy[2] = copy.deepcopy( + self._dist_context.process_meshes) + + # Initialize the best parallel strategy to the initial one + self._best_parallel_strategy[0] = copy.deepcopy( + self._dist_context._dist_tensors_for_program) + self._best_parallel_strategy[1] = copy.deepcopy( + self._dist_context._dist_ops_for_program) + self._best_parallel_strategy[2] = copy.deepcopy( + self._dist_context._process_meshes) + + def _store_best_parallel_strategy(self): + # Swap the best and the current parallel strategy + tmp = [None, None, None] + tmp[0] = self._best_parallel_strategy[0] + tmp[1] = self._best_parallel_strategy[1] + tmp[2] = self._best_parallel_strategy[2] + self._best_parallel_strategy[ + 0] = self._dist_context._dist_tensors_for_program + self._best_parallel_strategy[ + 1] = self._dist_context._dist_ops_for_program + self._best_parallel_strategy[2] = self._dist_context._process_meshes + self._dist_context._dist_tensors_for_program = tmp[0] + self._dist_context._dist_ops_for_program = tmp[1] + self._dist_context._process_meshes = tmp[2] + + def tune(self): + global_start_time = time.time() + self._dist_context._backup(serial=True, dist=True) + # This store statement must follow the above backup statement + self._store_init_parallel_strategy() + init_time = self._estimate_trial() # estimate_trial when init + # print_program_with_dist_attr(self._dist_context.serial_main_program, self._dist_context) + # We have to restore the distributed context, because the estimation of one trail need to + # generate the backward and update parts. Since we will do the tuning process, + # here we only need to reset all distributed information to the default one. + self._dist_context._restore(serial=True, + serial_mode="to_backup", + dist=True, + dist_mode="to_default") + + best_time = init_time + start_time = time.time() + self.construct_space() + end_time = time.time() + print("construct_space time", + self._num_trials, + end_time - start_time, + flush=True) + create_trial_time = 0.0 + eval_trial_time = 0.0 + self._sample_time = 0.0 + self._complete_time = 0.0 + self._estimate_time = 0.0 + while True: + start_time = time.time() + trial = self._create_trial() + if self._num_trials == 0: + num_prev_trials = 0 + else: + num_prev_trials = self._num_trials - 1 + end_time = time.time() + cur_create_trial_time = end_time - start_time + create_trial_time = (num_prev_trials * create_trial_time + + cur_create_trial_time) / self._num_trials + print("create_trial time", + num_prev_trials, + self._num_trials, + create_trial_time, + cur_create_trial_time, + flush=True) + if trial.status == TrialStatus.STOPPED: + break + # We need to backup the distributed context, because the evaluation of one trail will + # generate the backward and update parts which may change the context. + # However, the distributed information of the context aren't backup since a new one is used. + self._dist_context._backup(serial=True, dist=False) + + start_time = time.time() + results = self._eval_trial(trial) + end_time = time.time() + cur_eval_trial_time = end_time - start_time + eval_trial_time = (num_prev_trials * eval_trial_time + + cur_eval_trial_time) / self._num_trials + print("eval_trial time", + num_prev_trials, + self._num_trials, + eval_trial_time, + cur_eval_trial_time, + "\n", + flush=True) + + cur_time = results["estimate_time"] + if cur_time < best_time: + self._update_trail(trial, results) + self._store_best_parallel_strategy() + best_time = cur_time + # We need to restore the distributed context and reset the distributed information to the default. + self._dist_context._restore(serial=True, + serial_mode="to_backup", + dist=True, + dist_mode="to_default") + # Select the best parallel strategy + self._dist_context._dist_tensors_for_program = self._best_parallel_strategy[ + 0] + self._dist_context._dist_ops_for_program = self._best_parallel_strategy[ + 1] + self._dist_context._process_meshes = self._best_parallel_strategy[2] diff --git a/python/paddle/distributed/auto_parallel/tuner/tunable_space.py b/python/paddle/distributed/auto_parallel/tuner/tunable_space.py index 01212563e80..2009f1d911c 100644 --- a/python/paddle/distributed/auto_parallel/tuner/tunable_space.py +++ b/python/paddle/distributed/auto_parallel/tuner/tunable_space.py @@ -37,10 +37,18 @@ class TunableSpace(object): def variables(self): return self._variables + @variables.setter + def variables(self, variables): + self._variables = variables + @property def values(self): return self._values + @values.setter + def values(self, values): + self._values = values + def get_value(self, name): if name in self.values: return self.values[name] diff --git a/python/paddle/distributed/auto_parallel/tuner/tunable_variable.py b/python/paddle/distributed/auto_parallel/tuner/tunable_variable.py index 424b6b74bb1..31dd07aad37 100644 --- a/python/paddle/distributed/auto_parallel/tuner/tunable_variable.py +++ b/python/paddle/distributed/auto_parallel/tuner/tunable_variable.py @@ -90,6 +90,7 @@ class Choice(TunableVariable): raise TypeError( "Choice can contain only one type of value, but found values: {} with types: {}." .format(str(values), str(types))) + self._is_unknown_type = False if isinstance(values[0], str): values = [str(v) for v in values] @@ -108,9 +109,8 @@ class Choice(TunableVariable): if default is not None: default = bool(default) else: - raise TypeError( - "Choice can only contain str, int, float, or boll, but found: {} " - .format(str(values))) + self._is_unknown_type = True + self._indices = [i for i in range(len(values))] self.values = values if default is not None and default not in values: @@ -129,7 +129,11 @@ class Choice(TunableVariable): def random(self, seed=None): rng = np.random.default_rng(seed) - return rng.choice(self.values) + if self._is_unknown_type: + indice = rng.choice(self._indices) + return self.values[indice] + else: + return rng.choice(self.values) def get_state(self): state = super(Choice, self).get_state() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt index c70028051f1..3d34ed4fcdb 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt @@ -99,8 +99,20 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_interface MODULES test_interface) py_test_modules(test_strategy MODULES test_strategy) py_test_modules(test_pass_quantization MODULES test_pass_quantization) + py_test_modules(test_dist_shape MODULES test_dist_shape) py_test_modules(test_dist_assign MODULES test_dist_assign) py_test_modules(test_conditional_block_reshard MODULES test_conditional_block_reshard) + + py_test_modules(test_parallel_tuner MODULES test_parallel_tuner ENVS + ${dist_ENVS}) + set_tests_properties(test_parallel_tuner PROPERTIES TIMEOUT 120) + py_test_modules(test_parallel_tuner_full MODULES test_parallel_tuner_full + ENVS ${dist_ENVS}) + set_tests_properties(test_parallel_tuner_full PROPERTIES TIMEOUT 120) + py_test_modules(test_parallel_tuner_predict MODULES + test_parallel_tuner_predict ENVS ${dist_ENVS}) + set_tests_properties(test_parallel_tuner_predict PROPERTIES TIMEOUT 120) + endif() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_parallel_tuner.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_parallel_tuner.py new file mode 100644 index 00000000000..ab48e2838f9 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_parallel_tuner.py @@ -0,0 +1,141 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle +import paddle.static as static + +from paddle.distributed import fleet +from paddle.distributed.auto_parallel.cluster import Cluster +from paddle.distributed.auto_parallel.dist_context import DistributedContext, set_default_distributed_context +from paddle.distributed.auto_parallel.tuner.parallel_tuner import ParallelTuner +from paddle.distributed.auto_parallel.process_mesh import ProcessMesh +import sys + +sys.path.append("..") +import auto_parallel_gpt_model as modeling +from auto_parallel_gpt_model import GPTModel, GPTForPretraining, GPTPretrainingCriterion + +paddle.enable_static() + +batch_size = 4 +epoch_num = 10 +hidden_size = 1024 +sequence_len = 512 +_g_process_mesh = [ + ProcessMesh([0, 1], dim_names=["x"]), + ProcessMesh([2, 3], dim_names=["x"]) +] + + +def get_program_v3(): + dist_strategy = fleet.DistributedStrategy() + dist_strategy.semi_auto = True + # fleet.init(is_collective=True, strategy=dist_strategy) + place = paddle.set_device("gpu") + gpus = [0, 1] + batch_size = 8 + sequence_len = 512 + vocab_size = 1000 + + train_program = static.Program() + start_program = static.Program() + modeling.init_global() + modeling._global_parallel_strategy = None + # modeling.DPMPPP_MESH_LIST = [ + # ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"]), + # ProcessMesh([[4, 5], [6, 7]], dim_names=["x", "y"]) + # ] + with static.program_guard(train_program, start_program): + tokens = paddle.static.data(name="tokens", + shape=[batch_size, sequence_len], + dtype='int64') + position_ids = paddle.static.data(name="position_ids", + shape=[batch_size, sequence_len], + dtype='int64') + attention_mask = paddle.static.data( + name="attention_mask", + shape=[batch_size, 1, sequence_len, sequence_len], + dtype='float32') + labels = paddle.static.data(name="labels", + shape=[batch_size, sequence_len], + dtype='int64') + loss_mask = paddle.static.data(name="loss_mask", + shape=[batch_size, sequence_len], + dtype='float32') + data_holder = [tokens, position_ids, attention_mask, labels, loss_mask] + + gpt = GPTModel(vocab_size=1000, + hidden_size=1024, + num_hidden_layers=2, + num_attention_heads=16, + intermediate_size=4 * 1024, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + max_position_embeddings=1024, + type_vocab_size=1, + initializer_range=0.02, + pad_token_id=0, + eos_token_id=7, + bos_token_id=0, + eol_token_id=3, + pp_degree=1) + + model = GPTForPretraining(gpt, + vocab_size=1000, + hidden_size=64, + initializer_range=0.02) + preds = model(tokens, position_ids, attention_mask) + criterion = GPTPretrainingCriterion() + loss = criterion(preds, labels, loss_mask) + + optimizer = paddle.fluid.optimizer.AdamOptimizer(learning_rate=0.00001, + beta1=0.9, + beta2=0.999, + epsilon=1e-08, + grad_clip=None) + + feed_vars = { + "inputs": [tokens, position_ids, attention_mask, loss_mask], + "labels": [labels] + } + fetch_vars = {"loss": [loss]} + + return train_program, start_program, None, loss, optimizer, feed_vars, fetch_vars + + +class TestParallelTunerTrain(unittest.TestCase): + + def test_tune_with_train(self): + flag = False + set_default_distributed_context(DistributedContext()) + train_program, start_program, dataloader, loss, optimizer, feed_vars, fetch_vars = get_program_v3( + ) + cluster = Cluster() + cluster.gen_default_config_cluster(node_count=1, device_count=8) + dist_context = DistributedContext(train_program, start_program, + optimizer, loss, feed_vars, + fetch_vars, cluster) + dist_context.initialize() + parallel_tuner = ParallelTuner(dist_context, max_trials=3, mode="train") + parallel_tuner.tune() + parallel_tuner._store_best_parallel_strategy() + flag = True + self.assertTrue(flag) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_parallel_tuner_full.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_parallel_tuner_full.py new file mode 100644 index 00000000000..27833a6a185 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_parallel_tuner_full.py @@ -0,0 +1,147 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle +import paddle.static as static + +from paddle.distributed import fleet +from paddle.distributed.auto_parallel.cluster import Cluster +from paddle.distributed.auto_parallel.dist_context import DistributedContext, set_default_distributed_context +from paddle.distributed.auto_parallel.tuner.parallel_tuner import ParallelTuner +from paddle.distributed.auto_parallel.process_mesh import ProcessMesh +from paddle.distributed.auto_parallel.planner_v2 import Planner +from paddle.distributed.auto_parallel.strategy import Strategy +import sys + +sys.path.append("..") +import auto_parallel_gpt_model as modeling +from auto_parallel_gpt_model import GPTModel, GPTForPretraining, GPTPretrainingCriterion + +paddle.enable_static() + +batch_size = 4 +epoch_num = 10 +hidden_size = 1024 +sequence_len = 512 +_g_process_mesh = [ + ProcessMesh([0, 1], dim_names=["x"]), + ProcessMesh([2, 3], dim_names=["x"]) +] + + +def get_program_v3(): + dist_strategy = fleet.DistributedStrategy() + dist_strategy.semi_auto = True + # fleet.init(is_collective=True, strategy=dist_strategy) + place = paddle.set_device("gpu") + gpus = [0, 1] + batch_size = 8 + sequence_len = 512 + vocab_size = 1000 + + train_program = static.Program() + start_program = static.Program() + modeling.init_global() + modeling._global_parallel_strategy = "dp_mp_pp" + modeling.DPMPPP_MESH_LIST = [ + ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"]), + ProcessMesh([[4, 5], [6, 7]], dim_names=["x", "y"]) + ] + with static.program_guard(train_program, start_program): + tokens = paddle.static.data(name="tokens", + shape=[batch_size, sequence_len], + dtype='int64') + position_ids = paddle.static.data(name="position_ids", + shape=[batch_size, sequence_len], + dtype='int64') + attention_mask = paddle.static.data( + name="attention_mask", + shape=[batch_size, 1, sequence_len, sequence_len], + dtype='float32') + labels = paddle.static.data(name="labels", + shape=[batch_size, sequence_len], + dtype='int64') + loss_mask = paddle.static.data(name="loss_mask", + shape=[batch_size, sequence_len], + dtype='float32') + data_holder = [tokens, position_ids, attention_mask, labels, loss_mask] + + gpt = GPTModel(vocab_size=1000, + hidden_size=1024, + num_hidden_layers=2, + num_attention_heads=16, + intermediate_size=4 * 1024, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + max_position_embeddings=1024, + type_vocab_size=1, + initializer_range=0.02, + pad_token_id=0, + eos_token_id=7, + bos_token_id=0, + eol_token_id=3, + pp_degree=len(modeling.DPMPPP_MESH_LIST)) + + model = GPTForPretraining(gpt, + vocab_size=1000, + hidden_size=64, + initializer_range=0.02) + preds = model(tokens, position_ids, attention_mask) + criterion = GPTPretrainingCriterion() + loss = criterion(preds, labels, loss_mask) + + optimizer = paddle.fluid.optimizer.AdamOptimizer(learning_rate=0.00001, + beta1=0.9, + beta2=0.999, + epsilon=1e-08, + grad_clip=None) + + feed_vars = { + "inputs": [tokens, position_ids, attention_mask, loss_mask], + "labels": [labels] + } + fetch_vars = {"loss": [loss]} + + return train_program, start_program, None, loss, optimizer, feed_vars, fetch_vars + + +class TestParallelTunerFull(unittest.TestCase): + + def test_tune_with_planner(self): + flag = False + set_default_distributed_context(DistributedContext()) + train_program, start_program, dataloader, loss, optimizer, feed_vars, fetch_vars = get_program_v3( + ) + cluster = Cluster() + cluster.gen_default_config_cluster(node_count=1, device_count=8) + strategy = Strategy() + strategy.auto_mode = "full" + dist_context = DistributedContext(train_program, start_program, + optimizer, loss, feed_vars, + fetch_vars, cluster, strategy) + dist_context.initialize() + planner = Planner("train", dist_context) + planner._parallel_tuner = ParallelTuner(planner._dist_context, + mode=planner._mode, + max_trials=3) + planner.plan() + flag = True + self.assertTrue(flag) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_parallel_tuner_predict.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_parallel_tuner_predict.py new file mode 100644 index 00000000000..2d7a2c10579 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_parallel_tuner_predict.py @@ -0,0 +1,144 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle +import paddle.static as static + +from paddle.distributed import fleet +from paddle.distributed.auto_parallel.cluster import Cluster +from paddle.distributed.auto_parallel.dist_context import DistributedContext, set_default_distributed_context +from paddle.distributed.auto_parallel.tuner.parallel_tuner import ParallelTuner +from paddle.distributed.auto_parallel.process_mesh import ProcessMesh +import sys + +sys.path.append("..") +import auto_parallel_gpt_model as modeling +from auto_parallel_gpt_model import GPTModel, GPTForPretraining, GPTPretrainingCriterion + +paddle.enable_static() + +batch_size = 4 +epoch_num = 10 +hidden_size = 1024 +sequence_len = 512 +_g_process_mesh = [ + ProcessMesh([0, 1], dim_names=["x"]), + ProcessMesh([2, 3], dim_names=["x"]) +] + + +def get_program_v3(): + dist_strategy = fleet.DistributedStrategy() + dist_strategy.semi_auto = True + # fleet.init(is_collective=True, strategy=dist_strategy) + place = paddle.set_device("gpu") + gpus = [0, 1] + batch_size = 8 + sequence_len = 512 + vocab_size = 1000 + + train_program = static.Program() + start_program = static.Program() + modeling.init_global() + modeling._global_parallel_strategy = "dp_mp_pp" + modeling.DPMPPP_MESH_LIST = [ + ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"]), + ProcessMesh([[4, 5], [6, 7]], dim_names=["x", "y"]) + ] + with static.program_guard(train_program, start_program): + tokens = paddle.static.data(name="tokens", + shape=[batch_size, sequence_len], + dtype='int64') + position_ids = paddle.static.data(name="position_ids", + shape=[batch_size, sequence_len], + dtype='int64') + attention_mask = paddle.static.data( + name="attention_mask", + shape=[batch_size, 1, sequence_len, sequence_len], + dtype='float32') + labels = paddle.static.data(name="labels", + shape=[batch_size, sequence_len], + dtype='int64') + loss_mask = paddle.static.data(name="loss_mask", + shape=[batch_size, sequence_len], + dtype='float32') + data_holder = [tokens, position_ids, attention_mask, labels, loss_mask] + + gpt = GPTModel(vocab_size=1000, + hidden_size=1024, + num_hidden_layers=2, + num_attention_heads=16, + intermediate_size=4 * 1024, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + max_position_embeddings=1024, + type_vocab_size=1, + initializer_range=0.02, + pad_token_id=0, + eos_token_id=7, + bos_token_id=0, + eol_token_id=3, + pp_degree=len(modeling.DPMPPP_MESH_LIST)) + + model = GPTForPretraining(gpt, + vocab_size=1000, + hidden_size=64, + initializer_range=0.02) + preds = model(tokens, position_ids, attention_mask) + criterion = GPTPretrainingCriterion() + loss = criterion(preds, labels, loss_mask) + + optimizer = paddle.fluid.optimizer.AdamOptimizer(learning_rate=0.00001, + beta1=0.9, + beta2=0.999, + epsilon=1e-08, + grad_clip=None) + + feed_vars = { + "inputs": [tokens, position_ids, attention_mask, loss_mask], + "labels": [labels] + } + fetch_vars = {"loss": [loss]} + + return train_program, start_program, None, loss, optimizer, feed_vars, fetch_vars + + +class TestParallelTunerPredict(unittest.TestCase): + + def test_tune_predict(self): + flag = False + set_default_distributed_context(DistributedContext()) + train_program, start_program, dataloader, loss, optimizer, feed_vars, fetch_vars = get_program_v3( + ) + cluster = Cluster() + cluster.gen_default_config_cluster(node_count=1, device_count=8) + dist_context = DistributedContext(train_program, start_program, + optimizer, loss, feed_vars, + fetch_vars, cluster) + dist_context.initialize() + + parallel_tuner = ParallelTuner(dist_context, + max_trials=3, + mode="predict") + parallel_tuner.tune() + flag = True + + self.assertTrue(flag) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_tunable_space.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_tunable_space.py index f0c6a0b7cdf..58ff36aba09 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_tunable_space.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_tunable_space.py @@ -136,6 +136,16 @@ class TestTunableSpace(unittest.TestCase): self.assertEqual(new_space.variables["int_range"].step, 1) self.assertEqual(new_space.variables["int_range"].endpoint, False) + def test_expection(self): + space = ts.TunableSpace() + flag = True + try: + val = space.get_value("test") + flag = False + except: + pass + self.assertTrue(flag) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py b/python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py index 027f4a8e0d7..3940aa0170b 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py @@ -298,14 +298,14 @@ class TransformerDecoder(nn.Layer): auto.shard_tensor(output, PP_MESH_LIST[0], [None for i in range(len(output.shape))]) if _global_parallel_strategy == "dp_pp": - auto.shard_tensor(output, DPPP_MESH_LIST[0], ["x"].extends( - [None for i in range(len(output.shape) - 1)])) + auto.shard_tensor(output, DPPP_MESH_LIST[0], ["x"] + + [None for i in range(len(output.shape) - 1)]) if _global_parallel_strategy == "mp_pp": auto.shard_tensor(output, MPPP_MESH_LIST[0], [None for i in range(len(output.shape))]) if _global_parallel_strategy == "dp_mp_pp": - auto.shard_tensor(output, DPMPPP_MESH_LIST[0], ["x"].extends( - [None for i in range(len(output.shape) - 1)])) + auto.shard_tensor(output, DPMPPP_MESH_LIST[0], ["x"] + + [None for i in range(len(output.shape) - 1)]) for i, mod in enumerate(self.layers): if cache is None: if use_cache: @@ -323,8 +323,8 @@ class TransformerDecoder(nn.Layer): tgt_mask, use_cache, cache) auto.shard_tensor( - output, DPPP_MESH_LIST[mod.mesh_idx], ["x"].extends( - [None for i in range(len(output.shape) - 1)])) + output, DPPP_MESH_LIST[mod.mesh_idx], ["x"] + + [None for i in range(len(output.shape) - 1)]) elif _global_parallel_strategy == "mp_pp": output, new_cache = auto.shard_op( mod, MPPP_MESH_LIST[mod.mesh_idx])(output, memory, @@ -362,8 +362,8 @@ class TransformerDecoder(nn.Layer): tgt_mask, use_cache, cache) auto.shard_tensor( - output, DPPP_MESH_LIST[mod.mesh_idx], ["x"].extends( - [None for i in range(len(output.shape) - 1)])) + output, DPPP_MESH_LIST[mod.mesh_idx], ["x"] + + [None for i in range(len(output.shape) - 1)]) elif _global_parallel_strategy == "mp_pp": output = auto.shard_op( mod, MPPP_MESH_LIST[mod.mesh_idx])(output, memory, @@ -378,9 +378,8 @@ class TransformerDecoder(nn.Layer): output, memory, tgt_mask, use_cache, cache) auto.shard_tensor( - output, DPMPPP_MESH_LIST[mod.mesh_idx], - ["x"].extends( - [None for i in range(len(output.shape) - 1)])) + output, DPMPPP_MESH_LIST[mod.mesh_idx], ["x"] + + [None for i in range(len(output.shape) - 1)]) else: output = mod(output, memory, @@ -400,9 +399,9 @@ class TransformerDecoder(nn.Layer): mod, DPPP_MESH_LIST[mod.mesh_idx])(output, memory, tgt_mask, use_cache, cache) - auto.shard_tensor(output, DPPP_MESH_LIST[mod.mesh_idx], [ - "x" - ].extends([None for i in range(len(output.shape) - 1)])) + auto.shard_tensor( + output, DPPP_MESH_LIST[mod.mesh_idx], + ["x"] + [None for i in range(len(output.shape) - 1)]) elif _global_parallel_strategy == "mp_pp": output, new_cache = auto.shard_op( mod, @@ -415,9 +414,9 @@ class TransformerDecoder(nn.Layer): mod, DPMPPP_MESH_LIST[mod.mesh_idx])(output, memory, tgt_mask, use_cache, cache) - auto.shard_tensor(output, DPMPPP_MESH_LIST[mod.mesh_idx], [ - "x" - ].extends([None for i in range(len(output.shape) - 1)])) + auto.shard_tensor( + output, DPMPPP_MESH_LIST[mod.mesh_idx], + ["x"] + [None for i in range(len(output.shape) - 1)]) else: output, new_cache = mod(output, memory, @@ -682,11 +681,11 @@ class GPTModel(nn.Layer): auto.shard_tensor(input_ids, PP_MESH_LIST[0], [None for i in range(len(input_ids.shape))]) if _global_parallel_strategy == "dp_pp": - auto.shard_tensor(input_ids, DPPP_MESH_LIST[0], ["x"].extends( - [None for i in range(len(input_ids.shape) - 1)])) + auto.shard_tensor(input_ids, DPPP_MESH_LIST[0], ["x"] + + [None for i in range(len(input_ids.shape) - 1)]) if _global_parallel_strategy == "dp_mp_pp": - auto.shard_tensor(input_ids, DPMPPP_MESH_LIST[0], ["x"].extends( - [None for i in range(len(input_ids.shape) - 1)])) + auto.shard_tensor(input_ids, DPMPPP_MESH_LIST[0], ["x"] + + [None for i in range(len(input_ids.shape) - 1)]) encoder_outputs = self.decoder(embedding_output, memory=None, tgt_mask=attention_mask, -- GitLab