From 72f2ed43756218cb125d9cb3ba3b949c94e636a0 Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Fri, 29 Jul 2022 14:56:53 +0800 Subject: [PATCH] [Auto parallel] Optimization Tuning (#43782) * fixed bug for pass & engine * fixed bug for benchmark GPT-3 * add tuner & profiler * add algorithms & config --- .../framework/distributed_strategy.proto | 2 + .../distributed/auto_parallel/engine.py | 141 ++++- .../distributed/auto_parallel/parallelizer.py | 1 - .../auto_parallel/process_group.py | 12 +- .../auto_parallel/tuner/__init__.py | 4 + .../auto_parallel/tuner/algorithms.py | 159 +++++ .../distributed/auto_parallel/tuner/config.py | 135 +++++ .../auto_parallel/tuner/optimization_tuner.py | 547 ++++++++++++++++++ .../auto_parallel/tuner/profiler.py | 287 +++++++++ .../distributed/auto_parallel/tuner/trial.py | 49 ++ .../paddle/distributed/auto_parallel/utils.py | 8 + .../distributed/passes/auto_parallel_amp.py | 1 - .../passes/auto_parallel_sharding.py | 8 +- .../unittests/auto_parallel/CMakeLists.txt | 5 + .../auto_parallel/optimization_tuner_api.py | 157 +++++ .../test_optimization_tuner_api.py | 50 ++ 16 files changed, 1539 insertions(+), 27 deletions(-) create mode 100644 python/paddle/distributed/auto_parallel/tuner/algorithms.py create mode 100644 python/paddle/distributed/auto_parallel/tuner/config.py create mode 100644 python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py create mode 100644 python/paddle/distributed/auto_parallel/tuner/profiler.py create mode 100644 python/paddle/fluid/tests/unittests/auto_parallel/optimization_tuner_api.py create mode 100644 python/paddle/fluid/tests/unittests/auto_parallel/test_optimization_tuner_api.py diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 832d91d131a..6e2bab8c5b3 100755 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -27,6 +27,7 @@ message RecomputeConfig { repeated string checkpoints = 1; optional bool enable_offload = 2 [ default = false ]; repeated int32 checkpoint_shape = 3; + optional bool enable_tuning = 4 [ default = false ]; // incubate for auto parallel } message ShardingConfig { @@ -46,6 +47,7 @@ message ShardingConfig { // Optimizer sharding. Temporary plans and may be deprecated optional bool _dp_as_optimizer_sharding = 13 [ default = false ]; optional int32 stage = 14 [ default = 1 ]; + optional bool enable_tuning = 15 [ default = false ]; // incubate for auto parallel } message HybridConfig { diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index b9ff116d244..58778042b13 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -58,7 +58,8 @@ class Engine: inputs_spec=None, labels_spec=None, cluster=None, - strategy=None): + strategy=None, + user_tuning_config=None): self.model = model self.inputs_spec = self._validate_spec(inputs_spec) self.labels_spec = self._validate_spec(labels_spec) @@ -68,6 +69,7 @@ class Engine: self.strategy = strategy if self.strategy is None: self.strategy = fleet.DistributedStrategy() + self._user_tuning_config = user_tuning_config self._executor = None self._cur_rank = paddle.distributed.get_rank() @@ -127,19 +129,21 @@ class Engine: self._prepare_single_mode("train") def _prepare_single_mode(self, mode): - self._modes = [mode] - self._build(self._modes[0]) - # Do auto parallel process - for mode in self._modes: - # Do the planning process - self._plan(mode) - for mode in self._modes: - # Do the parallel process - self._parallel(mode, self._all_ranks) - - # Init comm and startup program - self._initialize(mode) - self._mode_init_states[mode] = True + + self._build(mode) + # Do the planning process + self._plan(mode) + + # Do the Optimization tuning + if self._user_tuning_config and mode == "train": + self._optimization_tuning(mode) + + # Do the parallel process + self._parallel(mode, self._all_ranks) + + # Init comm and startup program + self._initialize(mode) + self._mode_init_states[mode] = True def _build(self, mode): if _non_static_mode() or self._dygraph_mode: @@ -174,6 +178,7 @@ class Engine: metrics = [] serial_main_prog = self._orig_main_prog.clone() serial_startup_prog = self._orig_startup_prog.clone() + # FIXME to support grad clip with static.program_guard(serial_main_prog, serial_startup_prog), \ utils.unique_name.guard(): inputs_spec = self.inputs_spec @@ -204,12 +209,41 @@ class Engine: "metrics": metrics } + self._set_recompute_ckpts() self._dist_contexts[mode] = DistributedContext( serial_main_prog, serial_startup_prog, self._optimizer, losses, feed_vars, fetch_vars, self.cluster, self.strategy) self._dist_contexts[mode].gradient_scale = self._gradient_scale self._dist_contexts[mode]._dygraph_mode = self._dygraph_mode + def _optimization_tuning(self, mode): + + self.mode = mode + assert "batch_size" in self._user_tuning_config, "Optimization Tuning should provide with batch size." + assert "dataset" in self._user_tuning_config, "Optimization Tuning should provide with dataset." + batch_size = self._user_tuning_config["batch_size"] + dataset = self._user_tuning_config["dataset"] + dataset.dp_world_size = self._dp_world_size + dataset.dp_rank = self._dp_rank + + from .tuner.optimization_tuner import OptimizationTuner + self._optimization_tuner = OptimizationTuner(self._user_tuning_config, + self._dist_contexts[mode], + dataset, + self.inputs_spec, + self.labels_spec, + batch_size=batch_size, + rank=self._cur_rank) + + self._optimization_tuner.tune() + + if self._user_tuning_config["run_after_tuning"]: + # update the strategy + self._dist_contexts[ + mode]._strategy = self._optimization_tuner.get_best_config() + else: + return + def _plan(self, mode): if self._planned_mode is None: self._planned_mode = mode @@ -219,6 +253,18 @@ class Engine: self._planners[mode] = Planner(mode, self._dist_contexts[mode]) self._planners[mode].plan() + # infer data parallel info + inputs_var = self._dist_contexts[mode].serial_feed_vars["inputs"] + labels_var = self._dist_contexts[mode].serial_feed_vars["labels"] + block = self._dist_contexts[mode].serial_main_program.global_block() + feed_list = [] + for var in inputs_var + labels_var: + if var.name in block.vars: + feed_list.append(block.vars[var.name]) + + self._dp_world_size, self._dp_rank = self._get_data_parallel_info( + feed_list[0], self._dist_contexts[mode]) + def _parallel(self, mode, all_ranks): # Parallelize program based on the planner's results # For now, the completer has to be passed to the planner, @@ -317,6 +363,40 @@ class Engine: prune_startup_prog = dist_startup_prog._prune(uninitialized) self._executor.run(prune_startup_prog) + if self.strategy.amp and self.strategy.amp_configs['use_pure_fp16']: + # from paddle.fluid.contrib.mixed_precision.fp16_utils import cast_parameters_to_fp16 + def cast_parameters_to_fp16(place, + program, + scope=None, + to_fp16_var_names=None): + """ + Traverse all parameters in the whole model and set them to the FP16 data type. + Whereas, this function will keep parameters of batchnorms in FP32. + Args: + place(fluid.CPUPlace|fluid.CUDAPlace): `place` is used to restore the FP16 weight tensors. + program (Program): The used program. + scope(fluid.Scope, optional): `scope` is used to get the FP32 weight tensor values. + Default is None. + to_fp16_var_names(set|list, optional): The data types of vars in `to_fp16_var_names` + will be set to FP16. Usually, it is the returned + value of `cast_model_to_fp16` API. + """ + from paddle.framework import core + import numpy as np + all_parameters = [] + for block in program.blocks: + all_parameters.extend(block.all_parameters()) + + var_scope = scope if scope else paddle.static.global_scope() + for param in all_parameters: + if param.dtype == core.VarDesc.VarType.FP16: + param_t = var_scope.find_var( + param.name).get_tensor() + data = np.array(param_t) + param_t.set(np.float16(data), place) + + cast_parameters_to_fp16(self._place, prune_startup_prog) + def fit(self, train_data, batch_size=1, @@ -342,7 +422,6 @@ class Engine: usr_fetch = self._validate_fetches(fetches) fetch_loss = self._validate_fetches(self.fetch_vars["loss"]) fetch_list, fetch_map = self._fetch_map(fetch_loss, usr_fetch) - for epoch in range(epochs): train_logs = {"epoch": epoch} for step, _ in enumerate(train_dataloader): @@ -457,8 +536,6 @@ class Engine: for var in inputs_var + labels_var: if var.name in dist_main_block.vars: feed_list.append(dist_main_block.vars[var.name]) - dp_world_size, dp_rank = self._get_data_parallel_info( - feed_list[0], dist_context) # remove the first three ops if multi run fit/evaluate/predict op_size = len(dist_main_block.ops) @@ -477,8 +554,8 @@ class Engine: batch_size, epochs, steps_per_epoch, - data_parallel_world_size=dp_world_size, - data_parallel_rank=dp_rank) + data_parallel_world_size=self._dp_world_size, + data_parallel_rank=self._dp_rank) # move read op from the end of program to the start of program new_op_size = len(dist_main_block.ops) @@ -561,6 +638,32 @@ class Engine: return None, None + def _set_recompute_ckpts(self): + # NOTE hack to enable recompute in engine api for GPT-3 + # TODO support more PaddleNLP/CV models here + + config = self.strategy.recompute_configs + + # extract ckpts by specific model + self.model + if isinstance(self.model, paddle.nn.Layer): + if hasattr( + self.model, "model" + ) and self.model.model.__class__.__name__ == 'GPTForPretraining': + exact_ckpts = self.model.model.gpt.checkpoints + else: + exact_ckpts = config["checkpoints"] + + # modify strategy + if self.strategy.recompute: + config["checkpoints"] = exact_ckpts[:] + self.strategy.recompute_configs = config + logs = { + 'Model Class': self.model.model.__class__.__name__, + 'Applied Recompute ckpts': exact_ckpts + } + self._logger.info(logs) + def save(self, path, training=True, mode=None): if not mode: mode = self.mode diff --git a/python/paddle/distributed/auto_parallel/parallelizer.py b/python/paddle/distributed/auto_parallel/parallelizer.py index 1ad85598101..4b538431bb0 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer.py +++ b/python/paddle/distributed/auto_parallel/parallelizer.py @@ -48,7 +48,6 @@ from .mapper import mapping from .dist_op import DistributedOperator from .dist_tensor import DistributedTensor from .planner import Planner -from paddle.distributed.passes import new_pass, PassContext _logger = get_logger(logging.INFO) diff --git a/python/paddle/distributed/auto_parallel/process_group.py b/python/paddle/distributed/auto_parallel/process_group.py index 245c5c955e8..17f960381aa 100644 --- a/python/paddle/distributed/auto_parallel/process_group.py +++ b/python/paddle/distributed/auto_parallel/process_group.py @@ -42,7 +42,13 @@ def get_world_process_group(): return _g_process_group_map[0] -def new_process_group(ranks): +def clear_all_process_groups(): + global _g_process_group_map + _g_process_group_map = {} + _g_process_group_map[0] = ProcessGroup(0, []) + + +def new_process_group(ranks, group_id=None): global _g_process_group_map # A key constructed from ranks is used for avoiding duplication new_key = ''.join(map(str, sorted(ranks))) @@ -54,7 +60,9 @@ def new_process_group(ranks): num_groups = len(_g_process_group_map) # Note: our process group may interfere with the original implementation # so the created group id should start from the original _new_ring_id() - group_id = _new_ring_id() + num_groups + 1 + if group_id == None: + group_id = _new_ring_id() + num_groups + 1 + new_pg = ProcessGroup(group_id, ranks) _g_process_group_map[group_id] = new_pg return new_pg diff --git a/python/paddle/distributed/auto_parallel/tuner/__init__.py b/python/paddle/distributed/auto_parallel/tuner/__init__.py index 513558501a0..23559cd2ad0 100644 --- a/python/paddle/distributed/auto_parallel/tuner/__init__.py +++ b/python/paddle/distributed/auto_parallel/tuner/__init__.py @@ -11,3 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .profiler import profiler + +__all__ = [] diff --git a/python/paddle/distributed/auto_parallel/tuner/algorithms.py b/python/paddle/distributed/auto_parallel/tuner/algorithms.py new file mode 100644 index 00000000000..8440ab91a81 --- /dev/null +++ b/python/paddle/distributed/auto_parallel/tuner/algorithms.py @@ -0,0 +1,159 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from abc import ABC, abstractmethod +import logging + +from paddle.distributed.utils import get_logger +from .trial import TrialStatus +from .trial import OptimizationTunerTrial as Trial + + +class AlgorithmBase(ABC): + """ + An Tuning alogrithm is a class to find out an optimal configuration + given the selected tuning optimization pass(es) and the arguments to be tuned. + Different optimization pass(es) will correspond to a different algorithm, + where different search space **pruning rules** will applied. + + In another word, the key "algorithm" for this class is the + search space pruning rules specific for the given optimization scenario. + """ + _REGISTERED_ALGORITHMS = {} + + name = None + + @staticmethod + def _register(algo_name, algo_class): + assert issubclass(algo_class, AlgorithmBase) + AlgorithmBase._REGISTERED_ALGORITHMS[algo_name] = algo_class + + def __init__(self, config): + self._config = config + self._init_spaces() + self._logger = get_logger(logging.INFO) + self._changed_configs = [] + + @property + def changed_configs(self): + return self._changed_configs[:] + + def collect_model_info(self, main_prog, startup_prog): + """ + Collect the model static info (from programs) that could be used to + pruning candidate trials and saving tuning time.For instance, + model info like number of model parameters and activation memory could be + used to prune candidated trial and decide the next trial. + """ + pass + + @abstractmethod + def _init_spaces(self): + pass + + @abstractmethod + def next_trial(self): + pass + + @abstractmethod + def update(self, results): + """ + Update the algorthim with the results of last trial. Using this information is used to + pruning the search space of the future trial. + """ + pass + + def get_config_from_trial(self, trial): + """ + Return a new fleet.DistributedStrategy with the configurations in trial. + """ + assert len(self._changed_configs) > 0 + new_strategy = copy.deepcopy(self._config.dist_strategy) + for name in self._changed_configs: + config = getattr(trial.space, name) + setattr(new_strategy, name, config) + return new_strategy + + +def register_algor(name): + + def impl(cls): + AlgorithmBase._register(name, cls) + cls.name = name + return cls + + return impl + + +def new_algorithm(name, config): + algor_class = AlgorithmBase._REGISTERED_ALGORITHMS.get(name) + assert algor_class is not None, "Algorithm {} is not defined.".format(name) + algor_obj = algor_class(config) + return algor_obj + + +@register_algor("sharding") +class ShardingStageAlgorithm(AlgorithmBase): + + # TODO import trial class & copy strategy + def __init__(self, config): + super().__init__(config) + self._changed_configs = ["sharding_configs"] + + def _init_spaces(self): + self._max_stage = 3 + self._trial_idx = 0 + + stage_range = self._config.sharding_configs.get("stage_range", None) + if stage_range: + assert set(stage_range).issubset( + set([0, 1, 2, 3]) + ), "Sharding Stage should belong into range within 0 - 3 but got {}.".format( + stage_range) + stage_range.sort(reverse=True) + else: + stage_range = list(range(self._max_stage + 1)).sort(reverse=True) + + self._stage_range = stage_range[:] + self._total_num_trial = len(self._stage_range) + + def next_trial(self): + + if self._trial_idx < self._total_num_trial: + + stage = self._stage_range[self._trial_idx] + + new_strategy = copy.deepcopy(self._config.dist_strategy) + config_dict = new_strategy.sharding_configs + config_dict["stage"] = stage + new_strategy.sharding_configs = config_dict + + name = "trial-sharding-stage{}".format(stage) + trial = Trial(new_strategy, name, self.changed_configs) + + return trial + else: + return Trial(None, None, None, status=TrialStatus.STOPPED) + + def update(self, results): + + et = results.get("ErrorType", None) + if et and et == "ResourceExhaustedError": + self._trial_idx = self._total_num_trial + self._logger.info( + "Last trial is failed with OOM, all remaining trials are pruned to save time !" + ) + else: + self._trial_idx += 1 diff --git a/python/paddle/distributed/auto_parallel/tuner/config.py b/python/paddle/distributed/auto_parallel/tuner/config.py new file mode 100644 index 00000000000..19818a3a655 --- /dev/null +++ b/python/paddle/distributed/auto_parallel/tuner/config.py @@ -0,0 +1,135 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import copy +import pathlib + +import paddle +from paddle.distributed import fleet + +_tuning_supported_passes = ["sharding", "recompute"] +_strategy_config_suffiex = "_configs" + + +def _get_pass_config(strategy, pass_name): + config_name = pass_name + _strategy_config_suffiex + config = getattr(strategy, config_name) + return config + + +class TuningConfig(object): + """ + A uniform config wrap: + distributed strategy: the user defined configuration for optimization pass + tuning config: configuration for the tuning process: mode (profile or cost model), log dir, extra tuning config for optimization like search range for specific + """ + + def __init__(self, user_config, strategy): + + if not isinstance(strategy, fleet.DistributedStrategy): + raise TypeError( + "'strategy' must be object of class `fleet.DistributedStrategy`." + ) + + if not user_config: + user_config = {} + + self._tuning_passes_name = set() + self._dist_strategy = copy.deepcopy(strategy) + self._mode = None + self._profile_start_step = None + self._profile_end_step = None + self._project_dir = None + self._max_num_trial = None + self._early_stop = None + self._verbose = None + + self._initialize(user_config) + + @property + def mode(self): + return self._mode + + @property + def profile_start_step(self): + return self._profile_start_step + + @property + def profile_end_step(self): + return self._profile_end_step + + @property + def project_dir(self): + return self._project_dir + + @property + def tuning_passes_name(self): + return self._tuning_passes_name + + @property + def max_num_trial(self): + return self._max_num_trial + + @property + def early_stop(self): + return self._early_stop + + @property + def verbose(self): + return self._verbose + + @property + def dist_strategy(self): + return self._dist_strategy + + # initialize config with user define value or default value + def _initialize(self, user_config): + + self._mode = user_config.get("mode", "PROFILE") + + self._profile_start_step = user_config.get("profile_start_step", 10) + + self._profile_end_step = user_config.get("profile_end_step", 30) + + self._max_num_trial = user_config.get("max_num_trial", 50) + + self._early_stop = user_config.get("early_stop", None) + + self._verbose = user_config.get("verbose", False) + + project_dir = user_config.get("project_dir", None) + if not project_dir: + project_dir = os.path.join(os.getcwd(), "OptimizationTuning") + self._project_dir = project_dir + + for p in _tuning_supported_passes: + if getattr(self._dist_strategy, p) and _get_pass_config( + self._dist_strategy, p)["enable_tuning"]: + # TODO distinguish different args of each passes + self._tuning_passes_name.add(p) + + config_name = p + _strategy_config_suffiex + p_dict = getattr(self._dist_strategy, config_name) + self.__dict__[config_name] = p_dict + + # TODO verify the user defined configs + user_config_for_pass = user_config.get(p, None) + if user_config_for_pass: + for k, v in user_config_for_pass.items(): + self.__dict__[config_name][k] = v + + # (NOTE)tuning config ONLY wraps dist strategy for pass config which is to be tuned + def __getattr__(self, item): + return getattr(self._dist_strategy, item) diff --git a/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py b/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py new file mode 100644 index 00000000000..89b6d22e32a --- /dev/null +++ b/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py @@ -0,0 +1,547 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import copy +import shlex +import pathlib +import time +import shutil +import pickle +import json +import logging +import subprocess +import traceback + +import paddle +from paddle.fluid import program_guard +from paddle.fluid.backward import append_backward +from paddle.distributed.passes import new_pass, PassContext +from paddle.distributed.utils import get_logger + +from paddle.distributed.auto_parallel.dist_context import DistributedContext, get_default_distributed_context +from paddle.distributed.auto_parallel.completion import Completer +from paddle.distributed.auto_parallel.reshard import Resharder +from paddle.distributed.auto_parallel.partitioner import Partitioner +from paddle.distributed.auto_parallel.process_group import clear_all_process_groups, get_all_process_groups +from paddle.distributed.auto_parallel.utils import debug_program +from paddle.distributed.auto_parallel.utils import make_data_unshard, set_grad_var_shape + +from .config import TuningConfig +from .algorithms import new_algorithm +from .trial import TrialStatus + + +def _get_new_params_grads(target_program, ref_program, ref_params_grads): + ref_block = ref_program.global_block() + target_block = target_program.global_block() + target_params_grads = [] + + for p, g in ref_params_grads: + # NOTE grad var might not be generated + assert ref_block.has_var(p.name) + assert target_block.has_var(p.name) + new_p = target_block.var(p.name) + if g: + new_g = target_block.var(g.name) + else: + new_g = None + + target_params_grads.append((new_p, new_g)) + + return target_params_grads + + +def _get_new_loss(target_program, ref_program, loss): + ref_block = ref_program.global_block() + target_block = target_program.global_block() + assert ref_block.has_var(loss.name) + + return target_block.var(loss.name) + + +def parse_process_groups(): + group_map = {} + all_process_groups = get_all_process_groups() + for process_group in all_process_groups: + group_map[process_group.id] = process_group.ranks + return group_map + + +def get_metric(results): + assert isinstance( + results, + dict), "results should be type of dictionary, but got {}.".format( + type(results)) + if 'Throughtput' in results and isinstance(results['Throughtput'], float): + return float(results['Throughtput']) + else: + return -1.0 + + +def parse_results(results): + if results['Throughtput'] > 0: + return "Throughtput: {} step / s.".format(results['Throughtput']) + et = results.get("ErrorType", None) + if et == "ResourceExhaustedError": + return "Fail with OOM" + else: + return "Fail with UNKWON ERROR" + + +# TODO only dependent on dist context +# all env need to be start a new pass are member of dist context +def _copy_context(ref_dist_context): + + clear_all_process_groups() + + new_dist_context = DistributedContext() + new_dist_context._serial_main_program = ref_dist_context.serial_main_program.clone( + for_test=False) + new_dist_context._serial_startup_program = ref_dist_context.serial_startup_program.clone( + for_test=False) + + # mapping variable into new dist context + if getattr(ref_dist_context, '_params_grads', None): + new_dist_context._params_grads = _get_new_params_grads( + new_dist_context.serial_main_program, + ref_dist_context.serial_main_program, + ref_dist_context._params_grads) + new_dist_context._serial_loss = _get_new_loss( + new_dist_context.serial_main_program, + ref_dist_context.serial_main_program, ref_dist_context.serial_loss) + + for key, var_list in ref_dist_context._serial_feed_vars.items(): + new_var_list = [] + for var in var_list: + block_idx = var.block.idx + var_name = var.name + var = new_dist_context._serial_main_program.blocks[ + block_idx]._var_recursive(var_name) + new_var_list.append(var) + new_dist_context._serial_feed_vars[key] = new_var_list + + for key, var_list in ref_dist_context._serial_fetch_vars.items(): + new_var_list = [] + for var in var_list: + block_idx = var.block.idx + var_name = var.name + var = new_dist_context._serial_main_program.blocks[ + block_idx]._var_recursive(var_name) + new_var_list.append(var) + new_dist_context._serial_fetch_vars[key] = new_var_list + + # copy information in forward and backward + new_dist_context._serial_optimizer = copy.deepcopy( + ref_dist_context.serial_optimizer) + new_dist_context._dist_tensors_for_program = copy.deepcopy( + ref_dist_context._dist_tensors_for_program) + new_dist_context._dist_ops_for_program = copy.deepcopy( + ref_dist_context._dist_ops_for_program) + for pm in ref_dist_context.process_meshes: + new_dist_context.add_process_mesh(pm) + new_dist_context._dist_op_context = copy.deepcopy( + ref_dist_context._dist_op_context) + new_dist_context._block_state = copy.deepcopy(ref_dist_context.block_state) + + return new_dist_context + + +class OptimizationTuner: + """ + OptimizationTuner is used to manage the tuning procedure of hyper-parameters (configs) + of Optimization Pass in AutoParallel. + """ + + def __init__( + self, + user_configs, + dist_context, + dataset, + inputs_spec, + labels_spec, + batch_size, + rank, + ): + + self._config = TuningConfig(user_configs, dist_context._strategy) + # should not modify dist context from calling function + self._baseline_dist_context = _copy_context(dist_context) + self._baseline_completer = Completer(self._baseline_dist_context) + + self._rank = rank + self._inputs_spec = inputs_spec + self._labels_spec = labels_spec + self._dataset = dataset + self._batch_size = batch_size + + self._finished_trials = [] + self._best_metric = None + self._best_iter = float("-inf") + + self._logger = get_logger(logging.INFO) + + self._build_programs_without_optimization() + self._select_tuning_algorithm() + + @property + def project_dir(self): + dirname = self._config.project_dir + if not os.path.exists(dirname): + if self.rank == 0: + pathlib.Path(dirname).mkdir(parents=True, exist_ok=True) + return dirname + + @property + def rank(self): + return self._rank + + @property + def device_id(self): + return paddle.distributed.ParallelEnv().device_id + + # TODO Generate compelet program with all parts like forward, backward, update + # as well as parallelism transformation. + def _build_programs_without_optimization(self): + + serial_main_program = self._baseline_dist_context.serial_main_program + serial_startup_program = self._baseline_dist_context.serial_startup_program + serial_loss = self._baseline_dist_context.serial_loss + + with program_guard(serial_main_program, serial_startup_program): + params_grads = append_backward( + serial_loss, + distop_context=self._baseline_dist_context.dist_op_context) + + self._baseline_completer.complete_backward_annotation( + serial_main_program) + self._baseline_dist_context.block_state.parse_backward_blocks( + serial_main_program) + self._baseline_dist_context._params_grads = params_grads + + if self._config.verbose: + baseline_dir = os.path.join(self.project_dir, "baseline") + if not os.path.exists(baseline_dir): + pathlib.Path(baseline_dir).mkdir(parents=True, exist_ok=True) + debug_program(self._baseline_dist_context._serial_main_program, + baseline_dir, "main") + debug_program(self._baseline_dist_context._serial_startup_program, + baseline_dir, "startup") + + def _select_tuning_algorithm(self): + + selected_passes_set = self._config.tuning_passes_name + algorithm_name = "_".join(sorted(selected_passes_set)) + self._algorithm = new_algorithm(algorithm_name, self._config) + + def _apply_optimization(self, trial): + new_strategy = trial.space + dist_context = _copy_context(self._baseline_dist_context) + pass_context = PassContext() + completer = Completer(dist_context) + + main_program = dist_context.serial_main_program + startup_program = dist_context.serial_startup_program + + # applying optimization pass + if new_strategy.amp: + config = copy.deepcopy(new_strategy.amp_configs) + config["dist_context"] = dist_context + config["params_grads"] = dist_context._params_grads + + # TODO AMP Pass should not use loss var + config["loss"] = dist_context.serial_loss + config["input_data"] = self._baseline_dist_context.serial_feed_vars["inputs"] \ + + self._baseline_dist_context.serial_feed_vars["labels"] + if config["use_pure_fp16"]: + config["base_opt"] = dist_context.optimizer + auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config) + auto_parallel_fp16_pass.apply([main_program], [startup_program], + pass_context) + else: + auto_parallel_amp_pass = new_pass("auto_parallel_amp", config) + auto_parallel_amp_pass.apply([main_program], [startup_program], + pass_context) + + if new_strategy.recompute: + config = copy.deepcopy(new_strategy.recompute_configs) + config["dist_context"] = dist_context + config["no_grad_set"] = None + config["loss"] = dist_context.serial_loss + auto_parallel_recompute_pass = new_pass("auto_parallel_recompute", + config) + auto_parallel_recompute_pass.apply([main_program], + [startup_program], pass_context) + + # Do logical partition + partitioner = Partitioner(dist_context, self.rank) + dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition( + main_program, startup_program, dist_context._params_grads) + + # Generate optimizer + # FIXME should be remove from apply pass after pass support optimizers + with program_guard(dist_main_prog, dist_startup_prog): + optimizer_ops = dist_context.serial_optimizer.apply_gradients( + dist_params_grads) + completer.complete_update_annotation(dist_main_prog) + + # Do reshard process + set_grad_var_shape(dist_main_prog, dist_context) + resharder = Resharder(dist_main_prog, dist_startup_prog, self.rank, + dist_context, dist_params_grads) + resharder.reshard() + + if new_strategy.sharding: + config = copy.deepcopy(new_strategy.sharding_configs) + config["dist_context"] = dist_context + config["params_grads"] = dist_params_grads + config["global_rank"] = self.rank + auto_parallel_sharding_pass = new_pass("auto_parallel_sharding", + config) + auto_parallel_sharding_pass.apply([dist_main_prog], + [dist_startup_prog], pass_context) + + if new_strategy.gradient_merge: + config = copy.deepcopy(new_strategy.gradient_merge_configs) + config["dist_context"] = dist_context + config["params_grads"] = dist_params_grads + auto_parallel_gradient_merge_pass = new_pass( + "auto_parallel_gradient_merge_pass", config) + auto_parallel_gradient_merge_pass.apply([dist_main_prog], + [dist_startup_prog], + pass_context) + trial.main_program, trial.startup_program = dist_main_prog, dist_startup_prog + return trial + + def _get_profile_context(self, trial, result_path): + + profile_ctx = {} + + profile_ctx['distributed_env'] = copy.deepcopy( + paddle.distributed.ParallelEnv()) + profile_ctx['group_map'] = parse_process_groups() + profile_ctx[ + "loss_var_name"] = self._baseline_dist_context.serial_loss.name + profile_ctx[ + "main_program_decs"] = trial.main_program.desc.serialize_to_string( + ) + profile_ctx[ + "startup_program_decs"] = trial.startup_program.desc.serialize_to_string( + ) + self._dataset.batch_size = self._batch_size + self._dataset.input_names = self._get_input_names() + + profile_ctx["dataset"] = self._dataset + profile_ctx["result_filename"] = result_path + + return profile_ctx + + def _get_input_names(self): + input_names = [] + for input_spec in self._inputs_spec[:] + self._labels_spec[:]: + input_names.append(input_spec.name) + return input_names + + def _launch_profile(self, ctx_path, trial_dir): + + if os.environ.get("WITH_COVERAGE", "OFF") == "ON": + coverage_args = ["-m", "coverage", "run", "--branch", "-p"] + else: + coverage_args = [] + + profile_args = " ".join([ + "--rank", + str(self.rank), + "--device_id", + str(self.device_id), + "--ctx_filename", + ctx_path, + ]) + cmd_args = "-m paddle.distributed.auto_parallel.tuner.profiler" + " " + profile_args + cmd = [sys.executable, "-u"] + coverage_args + shlex.split(cmd_args) + + parent_env = copy.copy(os.environ.copy()) + # env flags need for profile + new_env = { + "FLAGS_USE_STANDALONE_EXECUTOR": "False", + } + new_env.update(parent_env) + + # TODO if any rank hang or fail, kill all processes + self._logger.debug("Executing cmd:\n{} .".format(" ".join(cmd))) + # new_process = subprocess.Popen(cmd, env=new_env) + with open(os.path.join(trial_dir, "stdout.log" + str(self.rank)), + "wb") as out, open( + os.path.join(trial_dir, "stderr.log" + str(self.rank)), + "wb") as err: + result = subprocess.Popen(cmd, stdout=out, stderr=err, env=new_env) + result.wait() + out.flush() + err.flush() + os.fsync(out) + os.fsync(err) + + def _profile_trial(self, trial): + # Making working directory + trial_dir = self._get_trial_dir(trial) + if not os.path.exists(trial_dir): + if self.rank == 0: + pathlib.Path(trial_dir).mkdir(parents=True, exist_ok=True) + else: + while not os.path.exists(trial_dir): + pass + ctx_filename = "profile_ctx." + str(self.rank) + ctx_path = os.path.join(trial_dir, ctx_filename) + result_path = os.path.join(trial_dir, "result.json") + + # Prepare Profile Context + profile_ctx = self._get_profile_context(trial, result_path) + with open(ctx_path, 'wb') as f: + pickle.dump(profile_ctx, f, protocol=4) + + if self._config.verbose: + debug_program(trial.main_program, trial_dir, "main_program") + debug_program(trial.startup_program, trial_dir, "startup_program") + + # Run + self._launch_profile(ctx_path, trial_dir) + + # Load results + try: + with open(result_path, 'r') as fp: + results = json.load(fp) + return results + except FileNotFoundError: + Error_results = {"Throughtput": -1, "ErrorType": 'FatalError'} + return Error_results + + def _evaluate_trial(self, trial): + + self._logger.info("Trial {} evaluation start.".format(trial.name)) + self._apply_optimization(trial) + + if self._config.mode == "PROFILE": + results = self._profile_trial(trial) + + elif self._config.mode == "COSTMODEL": + raise NotImplementedError( + "COSTMODEL mode for optimization tuning is not supported yet!") + else: + raise NotImplementedError("invalid evaluation mode: {}".format( + self._config.mode)) + + self._logger.info("Trial {} evaluation finish with {}.".format( + trial.name, parse_results(results))) + return results + + def _update(self, i, trial, results): + self._finished_trials.append(trial) + + cur_mertic = get_metric(results) + if self._best_metric == None or cur_mertic > self._best_metric: + self._best_metric = cur_mertic + self._best_iter = i + + def _get_trial_dir(self, trial): + return os.path.join(self.project_dir, trial.name) + + def get_best_config(self): + """ + Return the best optimization configuration found in the tuning. + + Returns: + A object of fleet.DistributedStrategy with best configuration. + """ + assert self._best_iter >= 0, "The best configuration is not found yet !" + best_trial = self._finished_trials[self._best_iter] + return self._algorithm.get_config_from_trial(best_trial) + + def summary(self): + """ + Display tuning result summary. + """ + # TODO summary with the trial_name with metric_of_trial + best_trial = self._finished_trials[self._best_iter] + summary_ = """ +Tuning Result Summary +Run total {} trials with {} min. +The best trial is: [{}], whose configuration is following: + """.format(len(self._finished_trials), + (time.time() - self._tuning_start_time) / 60, + best_trial.name) + summary_ += "\n" + best_trial.summary() + "\n"\ + + self._logger.info(summary_) + with open(os.path.join(self.project_dir, "summary.txt"), "w+") as fw: + for line in summary_.split("\n"): + fw.write(line + "\n") + + full_strategy = self.get_best_config() + full_strategy.save_to_prototxt( + os.path.join(self.project_dir, "tuned_dist_strategy.prototxt")) + + def clear(self): + """ + Clear the temporary file generated in tuning procedure. + """ + # TODO clear up zombie process created by tuning + if not self._config.verbose: + for trial in self._finished_trials: + trial_dir = self._get_trial_dir(trial) + shutil.rmtree(trial_dir, ignore_errors=True) + + def tune(self): + """ + Performs the search for best hyperparameter configuations + for the selected optimization pass(es). + """ + + # step1: collect model info which might be used for + # pruning the search space of the algorithm + self._tuning_start_time = time.time() + self._algorithm.collect_model_info( + self._baseline_dist_context.serial_main_program, + self._baseline_dist_context.serial_startup_program) + + # main search loop + i = 0 + while i < self._config.max_num_trial: + # step2: create a new trial + trial = self._algorithm.next_trial() + + if trial.status == TrialStatus.STOPPED: + break + + # step3: evaluate the trial + results = self._evaluate_trial(trial) + + # step4: update the algorithm with last result, + # which could be used by algorithm to pruning the + # remaining search space. + self._algorithm.update(results) + self._update(i, trial, results) + + # early stop + i += 1 + if self._config.early_stop and self._config.early_stop <= i - self._best_iter: + self._logger.info( + "Early stop the Tuning since there is no better trial found within [{}] trials" + .format(self._config.early_stop)) + break + + # step5: summary the best config and return + self.summary() + + self.clear() diff --git a/python/paddle/distributed/auto_parallel/tuner/profiler.py b/python/paddle/distributed/auto_parallel/tuner/profiler.py new file mode 100644 index 00000000000..a894554c2fa --- /dev/null +++ b/python/paddle/distributed/auto_parallel/tuner/profiler.py @@ -0,0 +1,287 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import argparse +import traceback +import pickle +import json +import time +import numpy as np +from functools import partial + +import paddle +from paddle.fluid.framework import Program, _current_expected_place +from paddle.fluid.framework import Operator, Parameter +from paddle.distributed.auto_parallel.process_group import clear_all_process_groups, get_all_process_groups, new_process_group +from paddle.distributed.auto_parallel.dist_loader import NonIterableGeneratorLoader +from paddle.distributed.collective import _get_global_env + +paddle.enable_static() + + +def str2bool(v): + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Unsupported value encountered.') + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--profile_start_step", + default=10, + type=int, + help="integer indicates the warmup step before starting profile.") + parser.add_argument("--profile_end_step", + default=30, + type=int, + help="integer indicates at the end step of profile.") + parser.add_argument("--rank", + type=int, + required=True, + help="the rank id of the this process.") + parser.add_argument("--device_id", + type=int, + required=True, + help="the device id of the this process.") + parser.add_argument( + "--ctx_filename", + type=str, + required=True, + help= + "the filename to the profile context file saved by optimizaiton tuner") + + args = parser.parse_args() + + return args + + +def init_process_groups(group_map, rank): + for group_id, ranks in group_map.items(): + if group_id == 0: + continue + new_process_group(ranks=ranks, group_id=group_id) + + # TODO should instantiate global group first + all_process_groups = get_all_process_groups() + for process_group in all_process_groups: + if process_group.id == 0 or rank not in process_group.ranks: + continue + print(process_group) + process_group.instantiate() + + +def get_cpp_error_type(error): + + msg = str(error).splitlines() + cpp_error_types = [ + 'InvalidArgumentError', + 'NotFoundError', + 'OutOfRangeError', + 'AlreadyExistsError', + 'ResourceExhaustedError', + 'PreconditionNotMetError', + 'PermissionDeniedError', + 'ExecutionTimeoutError', + 'UnimplementedError', + 'UnavailableError', + 'FatalError', + 'ExternalError', + ] + error_type = 'FatalError' + for et in cpp_error_types: + for line in msg: + if et in line: + return et + return error_type + + +def create_dataloader(main_program, + startup_program, + profile_ctx, + epochs=1, + steps_per_epoch=None): + + dataset = profile_ctx["dataset"] + main_block = main_program.global_block() + feed_list = [] + for name in dataset.input_names: + if name in main_block.vars: + feed_list.append(main_block.vars[name]) + + # remove the first three ops if multi run fit/evaluate/predict + op_size = len(main_block.ops) + if main_block.ops[0].type == 'create_py_reader': + op_size -= 3 + for _ in range(3): + main_block._remove_op(0, sync=False) + + # insert read op at the end of program + places = paddle.static.cuda_places() + with paddle.static.program_guard(main_program, startup_program): + dataloader = NonIterableGeneratorLoader( + dataset, + feed_list, + places, + dataset.batch_size, + epochs, + steps_per_epoch, + data_parallel_world_size=dataset.dp_world_size, + data_parallel_rank=dataset.dp_rank) + + # move read op from the end of program to the start of program + new_op_size = len(main_block.ops) + for _ in range(new_op_size - 1, op_size - 1, -1): + op = main_block.ops[new_op_size - 1] + new_op_desc = main_block.desc._prepend_op() + new_op_desc.copy_from(op.desc) + new_op = Operator(main_block, new_op_desc, type=new_op_desc.type()) + main_block.ops.insert(0, new_op) + for _ in range(new_op_size - op_size): + main_block._remove_op(new_op_size, sync=False) + main_block._sync_with_cpp() + return dataloader + + +def init_comm(profile_ctx): + # override the env for current process + dist_env = profile_ctx['distributed_env'] + genv = _get_global_env() + genv = dist_env + print("current process rank: {}, device_id: {}, ip: {}.", genv.rank, + genv.device_id, genv.current_endpoint) + + # init nccl comm + group_map = profile_ctx['group_map'] + init_process_groups(group_map, args.rank) + + +def load_programs(profile_ctx): + main_program_desc_str = profile_ctx['main_program_decs'] + main_program = Program.parse_from_string(main_program_desc_str) + + startup_program_decs_str = profile_ctx['startup_program_decs'] + startup_program = Program.parse_from_string(startup_program_decs_str) + + loss_var_name = profile_ctx["loss_var_name"] + assert main_program.global_block().has_var(loss_var_name) + loss_var = main_program.global_block().var(loss_var_name) + + return main_program, startup_program, loss_var + + +def get_executor(): + place_type = _current_expected_place() + if not isinstance(place_type, paddle.CUDAPlace): + raise RuntimeError("OptimizationTuner only support CUDA GPU right now.") + + genv = _get_global_env() + place = paddle.CUDAPlace(genv.device_id) + exe = paddle.static.Executor(place) + return exe + + +def profiler(args): + """ + main function to profile experiment for each pass hyper-parameter. + """ + # load ctx + if not os.path.isfile(args.ctx_filename): + raise ValueError("There is no profile context named {}.".format( + args.ctx_filename)) + with open(args.ctx_filename, 'rb') as f: + profile_ctx = pickle.load(f, encoding='latin1') + + init_comm(profile_ctx) + + main_program, startup_program, loss_var = load_programs(profile_ctx) + + data_loader = create_dataloader(main_program, startup_program, profile_ctx) + + result_path = profile_ctx["result_filename"] + + exe = get_executor() + + exe.run(startup_program) + + # profile main + duration = 0 + eval_step = 0 + data_loader._inner_dataloader.start() + try: + while eval_step < args.profile_end_step: + start_time = time.time() + + loss = exe.run( + main_program, + fetch_list=[loss_var], + use_program_cache=True, + ) + + end_time = time.time() + + if eval_step >= args.profile_start_step: + duration += end_time - start_time + + print("step: %d, loss_print: %f" % (eval_step, loss[0])) + eval_step += 1 + + avg_tput = 1.0 * (args.profile_end_step - + args.profile_start_step) / duration + + result_dict = { + "Throughtput": avg_tput, + "ErrorType": None, + } + + if paddle.distributed.get_rank() == 0: + with open(result_path, 'w') as fp: + json.dump(result_dict, fp) + + print("profile done! avg speed : {} step / s.".format((avg_tput))) + + except paddle.framework.core.EOFException: + data_loader._inner_dataloader.reset() + + except Exception as e: + + error_type = get_cpp_error_type(e) + result_dict = { + "Throughtput": -1, + "ErrorType": error_type, + } + if not os.path.isfile(result_path): + with open(result_path, 'w') as fp: + json.dump(result_dict, fp) + + print("profile failed with error: [{}]".format(error_type)) + print(e) + print(traceback.format_exc()) + + data_loader._inner_dataloader.reset() + del data_loader._inner_dataloader + exit(1) + + data_loader._inner_dataloader.reset() + del data_loader._inner_dataloader + + +if __name__ == "__main__": + args = parse_args() + profiler(args) diff --git a/python/paddle/distributed/auto_parallel/tuner/trial.py b/python/paddle/distributed/auto_parallel/tuner/trial.py index 78139cbd58b..3937ca98651 100644 --- a/python/paddle/distributed/auto_parallel/tuner/trial.py +++ b/python/paddle/distributed/auto_parallel/tuner/trial.py @@ -115,6 +115,55 @@ class Trial(Storable): return trial +class OptimizationTunerTrial(Trial): + + def __init__(self, + config, + name, + changed_configs, + trial_id=None, + status=TrialStatus.RUNNING): + super(OptimizationTunerTrial, self).__init__(config, trial_id, status) + self._name = name + self._changed_configs = changed_configs + + @property + def name(self): + return self._name + + def summary(self): + + spacing = 2 + max_k = 38 + max_v = 38 + + length = max_k + max_v + spacing + + h1_format = " " + "|{{:^{}s}}|\n".format(length) + h2_format = " " + "|{{:>{}s}}{}{{:^{}s}}|\n".format( + max_k, " " * spacing, max_v) + + border = " +" + "".join(["="] * length) + "+" + line = " +" + "".join(["-"] * length) + "+" + + draws = border + "\n" + draws += h1_format.format("") + draws += h1_format.format("Tuned Configuartions Overview") + draws += h1_format.format("") + + for name in self._changed_configs: + draws += border + "\n" + draws += h1_format.format("{} auto=True <-> {}".format(name, name)) + draws += line + "\n" + my_configs = getattr(self.space, name) + keys = my_configs.keys() + for key in keys: + draws += h2_format.format(key, str(my_configs.get(key, None))) + + result_res = draws + border + return result_res + + def _generate_trial_id(): s = str(time.time()) + str(random.randint(1, int(1e7))) return hashlib.sha256(s.encode("utf-8")).hexdigest()[:32] diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index 5d9499d9286..b0d4963140e 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -1473,3 +1473,11 @@ def to_list(value): if isinstance(value, (list, tuple)): return list(value) return [value] + + +def debug_program(program, path, name): + + filename = os.path.join( + path, name + '_program' + ".%d" % (paddle.distributed.get_rank())) + with open(filename, 'w') as f: + f.write(str(program)) diff --git a/python/paddle/distributed/passes/auto_parallel_amp.py b/python/paddle/distributed/passes/auto_parallel_amp.py index 7afba8c0f13..d97209f7fe5 100644 --- a/python/paddle/distributed/passes/auto_parallel_amp.py +++ b/python/paddle/distributed/passes/auto_parallel_amp.py @@ -142,7 +142,6 @@ class AMPState(object): modified from paddle.fluid.contrib.mixed_precision """ num_cast_ops = 0 - var_name_dict = {} for in_name in op.input_names: if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input( diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index 3c1f0443e03..c6a8f11574f 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -245,10 +245,10 @@ class ShardingPass(PassBase): }) dist_attr = self._dist_context.get_tensor_dist_attr_for_program( main_block.var(sum_op_output)) - assert dist_attr is not None - naive_set_dist_op_attr_for_program_by_mesh_and_mapping( - new_op, dist_attr.process_mesh, dist_attr.dims_mapping, - self._dist_context) + # assert dist_attr is not None + # naive_set_dist_op_attr_for_program_by_mesh_and_mapping( + # new_op, dist_attr.process_mesh, dist_attr.dims_mapping, + # self._dist_context) break main_block._sync_with_cpp() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt index 85b13b38a43..500cb91094f 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt @@ -25,6 +25,11 @@ if(WITH_DISTRIBUTE AND WITH_GPU) set_tests_properties(test_engine_api_dp PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 80) + py_test_modules(test_optimization_tuner_api MODULES + test_optimization_tuner_api ENVS ${dist_ENVS}) + set_tests_properties(test_optimization_tuner_api + PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 80) + py_test_modules(test_converter MODULES test_converter ENVS ${dist_ENVS}) set_tests_properties(test_converter PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/optimization_tuner_api.py b/python/paddle/fluid/tests/unittests/auto_parallel/optimization_tuner_api.py new file mode 100644 index 00000000000..8e058d16b87 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/optimization_tuner_api.py @@ -0,0 +1,157 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import time +import tempfile +import copy +import os +import numpy as np +import subprocess +import paddle +import paddle.nn as nn +import paddle.fluid as fluid +import paddle.static as static +import paddle.nn.functional as F +import paddle.utils as utils +from paddle.fluid import layers +from paddle.io import Dataset, IterableDataset, DataLoader +from paddle.static import InputSpec +from paddle.distributed import fleet +import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.engine import Engine +from engine_api_dp import MyDataset + +paddle.enable_static() +batch_size = 16 +batch_num = 5 +hidden_size = 1024 +sequence_len = 512 +image_size = hidden_size +class_num = 10 + +paddle.seed(44) + +# class MyDataset(Dataset): + +# def __init__(self, num_samples): +# super(MyDataset, self).__init__() +# self.num_samples = num_samples + +# def __getitem__(self, index): +# input = np.random.uniform(size=image_size).astype("float32") +# label = np.random.randint(0, class_num - 1, dtype="int64") +# return input, label + +# def __len__(self): +# return self.num_samples + + +class MLPLayer(nn.Layer): + + def __init__(self, + hidden_size=1024, + intermediate_size=4 * 1024, + dropout_ratio=0.1, + initializer_range=0.02): + super(MLPLayer, self).__init__() + d_model = hidden_size + dim_feedforward = intermediate_size + weight_attr = paddle.ParamAttr( + initializer=nn.initializer.Normal(mean=0.0, std=initializer_range)) + bias_attr = None + + self.linear0 = nn.Linear(d_model, + dim_feedforward, + weight_attr, + bias_attr=bias_attr) + self.linear1 = nn.Linear(dim_feedforward, + d_model, + weight_attr, + bias_attr=bias_attr) + self.linear2 = nn.Linear(d_model, 1, weight_attr, bias_attr=bias_attr) + self.norm = nn.LayerNorm(d_model, epsilon=1e-5) + self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train") + + def forward(self, input): + out = self.norm(input) + out = self.linear0(out) + out = F.gelu(out, approximate=True) + out = self.linear1(out) + out = self.dropout(out) + out = self.linear2(out) + self.out = out + return out + + +def train(fetch): + mlp = MLPLayer(hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + dropout_ratio=0.1, + initializer_range=0.02) + loss = paddle.nn.CrossEntropyLoss() + optimizer = paddle.fluid.optimizer.AdamOptimizer(learning_rate=0.00001, + beta1=0.9, + beta2=0.999, + epsilon=1e-08, + grad_clip=None) + + inputs_spec = InputSpec([batch_size, hidden_size], 'float32', 'x') + labels_spec = InputSpec([batch_size], 'int64', 'label') + + dist_strategy = fleet.DistributedStrategy() + dist_strategy.amp = False + dist_strategy.pipeline = False + dist_strategy.recompute = False + # init parallel optimizer + dist_strategy.semi_auto = True + dist_strategy.sharding = True + dist_strategy.sharding_configs = { + "sharding_degree": 2, + "stage": 3, + "enable_tuning": True, + } + fleet.init(is_collective=True, strategy=dist_strategy) + + # init engine + import tempfile + tmp_dir = tempfile.TemporaryDirectory() + dataset = MyDataset(batch_num * batch_size) + + # Tuning configuration + tuning_config = { + "batch_size": batch_size, + "dataset": dataset, + "profile_start_step": 1, + "profile_end_step": 5, + "run_after_tuning": True, + "sharding": { + "stage_range": [0, 1, 2, 3] + }, + "verbose": True, + } + engine = Engine(mlp, + inputs_spec=inputs_spec, + labels_spec=labels_spec, + strategy=dist_strategy, + user_tuning_config=tuning_config) + engine.prepare(optimizer, loss, metrics=paddle.metric.Accuracy()) + + # check tuned + assert (engine._dist_contexts['train'].strategy.sharding_configs['stage'] != + 3) + + +if __name__ == "__main__": + train(True) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_optimization_tuner_api.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_optimization_tuner_api.py new file mode 100644 index 00000000000..c88d0810f15 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_optimization_tuner_api.py @@ -0,0 +1,50 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest +import os +import sys +import shutil +import subprocess +from paddle.distributed.fleet.launch_utils import run_with_coverage + + +class TestOptimizationTunerAPI(unittest.TestCase): + + def test_engine_api(self): + file_dir = os.path.dirname(os.path.abspath(__file__)) + launch_model_path = os.path.join(file_dir, "optimization_tuner_api.py") + + if os.environ.get("WITH_COVERAGE", "OFF") == "ON": + coverage_args = ["-m", "coverage", "run", "--branch", "-p"] + else: + coverage_args = [] + + tmp_dir = tempfile.TemporaryDirectory() + cmd = [sys.executable, "-u"] + coverage_args + [ + "-m", "launch", "--gpus", "0,1", "--log_dir", tmp_dir.name, + launch_model_path + ] + + process = subprocess.Popen(cmd) + process.wait() + self.assertEqual(process.returncode, 0) + + tmp_dir.cleanup() + shutil.rmtree('./OptimizationTuning', ignore_errors=True) + + +if __name__ == "__main__": + unittest.main() -- GitLab