From c47853f6a29d044a650e936666bfac27877a03e7 Mon Sep 17 00:00:00 2001 From: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com> Date: Wed, 19 Apr 2023 11:02:04 +0800 Subject: [PATCH] [AutoParallel] add gradient_merge master_grad & 1F1B pass (#52647) --- .../fleet_executor/fleet_executor.cc | 42 +- .../distributed/auto_parallel/constants.py | 13 +- .../distributed/auto_parallel/dist_loader.py | 160 +++--- .../distributed/auto_parallel/dist_saver.py | 51 +- .../distributed/auto_parallel/engine.py | 63 ++- .../auto_parallel/parallelizer_v2.py | 16 +- .../distributed/auto_parallel/partitioner.py | 373 ++++++++------ .../distributed/auto_parallel/reshard.py | 82 +++- .../distributed/auto_parallel/strategy.py | 9 + .../paddle/distributed/auto_parallel/utils.py | 7 + ...uto_parallel_data_parallel_optimization.py | 279 +++++++---- .../passes/auto_parallel_gradient_merge.py | 462 ++++++++++++------ .../passes/auto_parallel_pipeline.py | 413 ++++++++++++---- .../auto_parallel/1F1B_pass_unittest.py | 126 +++++ .../unittests/auto_parallel/CMakeLists.txt | 3 + .../unittests/auto_parallel/get_gpt_model.py | 7 + .../gradient_merge_pass_unittest.py | 17 +- .../unittests/auto_parallel/test_pass_1F1B.py | 57 +++ 18 files changed, 1532 insertions(+), 648 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/auto_parallel/1F1B_pass_unittest.py create mode 100644 python/paddle/fluid/tests/unittests/auto_parallel/test_pass_1F1B.py diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc index 05f75ad79ce..915b1f82804 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc @@ -110,12 +110,15 @@ void PreventVarsDelete( std::vector GetUnusedVarsAfterWhile( const framework::ProgramDesc& program_desc, + TaskNode* cond_task, const std::vector& vars_not_gc) { // NOTE: Since while op won't appear in task node, in order to analyze // the vars which should be free after calling while op, we rebuild the // whole program and get the unused vars after calling while op. - // vars in parent block should not be free until the while op is finished. - // The local vars will be free while running op in sub block. + // The vars in while block should not be free until the while op is finished. + // In a word, the vars need to be free after while op is: + // 1. Vars in parent block and being used in while block. + // 2. Local vars only defined in while block. // The unused vars above will be free in cond interceptor. std::vector while_block_vars; std::vector> ops; @@ -129,29 +132,14 @@ std::vector GetUnusedVarsAfterWhile( for (const auto& var_name : pair.second) { while_block_vars.emplace_back(var_name); } + for (auto& var : program_desc.Block(1).AllVars()) { + while_block_vars.emplace_back(var->Name()); + } } } return while_block_vars; } -std::unordered_map> -GetSubUnusedVars(const framework::ProgramDesc& program_desc, - const std::set& sub_block_tasks, - const std::vector& vars_not_gc) { - std::vector> ops; - for (auto* task_node : sub_block_tasks) { - for (const auto& op : task_node->ops()) { - ops.emplace_back(std::unique_ptr(op)); - } - } - auto unused_vars = framework::GetUnusedVars(program_desc.Block(1), ops, {}); - for (auto& unique_op : ops) { - unique_op.release(); - } - PreventVarsDelete(&unused_vars, vars_not_gc); - return unused_vars; -} - } // namespace void FleetExecutor::Init( @@ -174,13 +162,8 @@ void FleetExecutor::Init( for (const auto& task_node : task_nodes) { if (task_node->type() == "Cond") { GetSubBlockTask(task_nodes, task_node, &sub_block_tasks); - while_block_vars = - GetUnusedVarsAfterWhile(program_desc, inference_root_scope_vars); - for (auto* task_node : sub_block_tasks) { - for (auto iter : task_node->vars_to_dtype()) { - while_block_vars.emplace_back(iter.first); - } - } + while_block_vars = GetUnusedVarsAfterWhile( + program_desc, task_node, inference_root_scope_vars); VLOG(3) << "Vars will be gced after while op"; for (auto var : while_block_vars) { VLOG(3) << var; @@ -210,9 +193,6 @@ void FleetExecutor::Init( unique_op.release(); } - auto sub_unused_vars = - GetSubUnusedVars(program_desc, sub_block_tasks, while_block_vars); - // NOTE: For inference, the vars in inference_root_scope_vars // shouldn't be deleted during inf, for that they may be the result of the // inf. If they are GCed, it will cause error during ZeroCopy the result. @@ -223,8 +203,6 @@ void FleetExecutor::Init( for (auto task_node : task_nodes) { if (sub_block_tasks.find(task_node) == sub_block_tasks.end()) { task_node->SetUnusedVars(global_unused_vars); - } else { - task_node->SetUnusedVars(sub_unused_vars); } int64_t interceptor_id = task_node->task_id(); interceptor_id_to_task.emplace(interceptor_id, task_node); diff --git a/python/paddle/distributed/auto_parallel/constants.py b/python/paddle/distributed/auto_parallel/constants.py index 19d444248fa..d9313845664 100644 --- a/python/paddle/distributed/auto_parallel/constants.py +++ b/python/paddle/distributed/auto_parallel/constants.py @@ -117,9 +117,9 @@ set_field_default_config(QAT, "activation_bits", 8) set_field_default_config(QAT, "not_quant_pattern", ['skip_quant']) set_field_default_config(QAT, "algo", None) -# ######################################### +######################################### # auto tuning configuration -# ######################################### +######################################### TUNING = "tuning" set_field_default_config(TUNING, "enable", False) set_field_default_config(TUNING, "batch_size", 1) @@ -135,3 +135,12 @@ set_field_default_config(TUNING, "verbose", True) DATASET = "dataset" set_field_default_config(DATASET, "enable", False) set_field_default_config(DATASET, "num_shards", 1) + +######################################### +# data parallel configuration +######################################### +DP_OPTIMIZATION = "dp_optimization" +set_field_default_config(DP_OPTIMIZATION, "enable", False) +set_field_default_config(DP_OPTIMIZATION, "fuse_all_reduce_ops", True) +set_field_default_config(DP_OPTIMIZATION, "fuse_grad_size_in_MB", 32) +set_field_default_config(DP_OPTIMIZATION, "overlap_comm_cacl", True) diff --git a/python/paddle/distributed/auto_parallel/dist_loader.py b/python/paddle/distributed/auto_parallel/dist_loader.py index 38b537799e5..f1e10fb5184 100644 --- a/python/paddle/distributed/auto_parallel/dist_loader.py +++ b/python/paddle/distributed/auto_parallel/dist_loader.py @@ -17,12 +17,18 @@ import numpy as np import paddle from paddle.io import BatchSampler, IterableDataset -from paddle.fluid.dataloader.batch_sampler import _InfiniteIterableSampler, DistributedBatchSampler -from paddle.fluid.dataloader.dataloader_iter import _DatasetKind, default_collate_fn, default_convert_fn +from paddle.fluid.dataloader.batch_sampler import ( + _InfiniteIterableSampler, + DistributedBatchSampler, +) +from paddle.fluid.dataloader.dataloader_iter import ( + _DatasetKind, + default_collate_fn, + default_convert_fn, +) class DistributedDataLoaderBase(metaclass=abc.ABCMeta): - @abc.abstractmethod def __iter__(self): raise NotImplementedError @@ -43,24 +49,26 @@ class DistributedDataLoaderBase(metaclass=abc.ABCMeta): class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase): - - def __init__(self, - dataset, - feed_list=None, - capacity=None, - use_double_buffer=True, - iterable=True, - return_list=False, - use_multiprocess=False, - drop_last=True, - places=None, - batch_size=1, - epochs=1, - steps_per_epoch=None, - collate_fn=None, - split_data=True, - data_parallel_world_size=[], - data_parallel_rank=[]): + def __init__( + self, + dataset, + feed_list=None, + capacity=None, + use_double_buffer=True, + iterable=True, + return_list=False, + use_multiprocess=False, + drop_last=True, + places=None, + batch_size=1, + epochs=1, + steps_per_epoch=None, + collate_fn=None, + split_data=True, + data_parallel_world_size=[], + data_parallel_rank=[], + acc_steps=1, + ): self.dataset = dataset self.feed_list = feed_list self.capacity = capacity @@ -79,6 +87,7 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase): assert len(data_parallel_rank) == len(feed_list) self.dp_world_sizes = data_parallel_world_size self.dp_ranks = data_parallel_rank + self.acc_steps = acc_steps if isinstance(dataset, IterableDataset): self.dataset_kind = _DatasetKind.ITER @@ -90,12 +99,15 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase): else: if isinstance(dataset, IterableDataset): self.batch_sampler = _InfiniteIterableSampler( - dataset, batch_size) + dataset, batch_size + ) else: - self.batch_sampler = BatchSampler(dataset, - batch_size=batch_size, - shuffle=False, - drop_last=drop_last) + self.batch_sampler = BatchSampler( + dataset, + batch_size=batch_size, + shuffle=False, + drop_last=drop_last, + ) self.auto_collate_batch = self.batch_sampler is not None self.sampler_iter = iter(self.index_sampler) @@ -106,8 +118,12 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase): self.collate_fn = collate_fn or default_convert_fn self.dataset_fetcher = _DatasetKind.create_fetcher( - self.dataset_kind, self.dataset, self.auto_collate_batch, - self.collate_fn, self.drop_last) + self.dataset_kind, + self.dataset, + self.auto_collate_batch, + self.collate_fn, + self.drop_last, + ) self._steps = self._infer_steps() self._inner_dataloader = self._create_inner_dataloader() @@ -136,9 +152,11 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase): if isinstance(self.dataset, IterableDataset): steps_per_epoch = None elif self.batch_size is None: - steps_per_epoch = len(self.dataset) + steps_per_epoch = len(self.dataset) // self.acc_steps else: - steps_per_epoch = len(self.dataset) // self.batch_size + steps_per_epoch = ( + len(self.dataset) // self.batch_size // self.acc_steps + ) except: raise ValueError( "Pleace set `steps_per_epoch` or implement `__len__` methond in dataset class." @@ -156,18 +174,21 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase): return _InfiniteIterableSampler(self.dataset, 1) def _create_inner_dataloader(self): - def data_generator(): while True: try: indices = next(self.sampler_iter) batch = self.dataset_fetcher.fetch(indices) - if batch is None: break + if batch is None: + break except StopIteration: self.dataset_fetcher = _DatasetKind.create_fetcher( - self.dataset_kind, self.dataset, - self.auto_collate_batch, self.collate_fn, - self.drop_last) + self.dataset_kind, + self.dataset, + self.auto_collate_batch, + self.collate_fn, + self.drop_last, + ) break partial_data = [] @@ -178,11 +199,16 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase): continue batch_size = array.shape[0] - assert batch_size % self.dp_world_sizes[i] == 0, \ - "batch_size [{}] is not divisible by dp_world_size [{}]".format(str(batch_size), str(self.dp_world_sizes[i])) + assert ( + batch_size % self.dp_world_sizes[i] == 0 + ), "batch_size [{}] is not divisible by dp_world_size [{}]".format( + str(batch_size), str(self.dp_world_sizes[i]) + ) partial_data.append( - np.split(array, - self.dp_world_sizes[i])[self.dp_ranks[i]]) + np.split(array, self.dp_world_sizes[i])[ + self.dp_ranks[i] + ] + ) yield partial_data @@ -194,33 +220,35 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase): iterable=False, return_list=self.return_list, use_multiprocess=self.use_multiprocess, - drop_last=self.drop_last) + drop_last=self.drop_last, + ) dataloader.set_batch_generator(data_generator, self.places) return dataloader class DistributedDataLoader(DistributedDataLoaderBase): - - def __init__(self, - dataset, - feed_list=None, - places=None, - return_list=True, - batch_size=1, - shuffle=False, - drop_last=False, - collate_fn=None, - num_workers=0, - use_buffer_reader=True, - use_shared_memory=True, - timeout=0, - worker_init_fn=None, - epochs=1, - steps_per_epoch=None, - split_data=True, - data_parallel_world_size=[], - data_parallel_rank=[]): + def __init__( + self, + dataset, + feed_list=None, + places=None, + return_list=True, + batch_size=1, + shuffle=False, + drop_last=False, + collate_fn=None, + num_workers=0, + use_buffer_reader=True, + use_shared_memory=True, + timeout=0, + worker_init_fn=None, + epochs=1, + steps_per_epoch=None, + split_data=True, + data_parallel_world_size=[], + data_parallel_rank=[], + ): self.dataset = dataset self.feed_list = feed_list self.return_list = return_list @@ -241,8 +269,13 @@ class DistributedDataLoader(DistributedDataLoaderBase): self.split_data = split_data # TODO: rank info self.batch_sampler = DistributedBatchSampler( - self.dataset, self.batch_size, self.dp_world_sizes[0], - self.dp_ranks[0], self.shuffle, self.drop_last) + self.dataset, + self.batch_size, + self.dp_world_sizes[0], + self.dp_ranks[0], + self.shuffle, + self.drop_last, + ) self._inner_dataloader = self._create_inner_dataloader() def __iter__(self): @@ -263,7 +296,8 @@ class DistributedDataLoader(DistributedDataLoaderBase): use_buffer_reader=self.use_buffer_reader, use_shared_memory=self.use_shared_memory, timeout=self.timeout, - worker_init_fn=self.worker_init_fn) + worker_init_fn=self.worker_init_fn, + ) self.data = (x for x in dataloader) return dataloader diff --git a/python/paddle/distributed/auto_parallel/dist_saver.py b/python/paddle/distributed/auto_parallel/dist_saver.py index 350e5ac44e7..2bb35a7a3d9 100644 --- a/python/paddle/distributed/auto_parallel/dist_saver.py +++ b/python/paddle/distributed/auto_parallel/dist_saver.py @@ -18,6 +18,7 @@ import errno import pickle import warnings import logging +import collections import numpy as np import paddle @@ -53,16 +54,13 @@ def _process_path(path): class DistributedSaver: - def __init__(self): self._logger = get_logger(logging.INFO) def save(self, path, serial_program, dist_main_program, dist_context): - def _save_state(program, path, mode="param"): state = { - k: np.array(v) - for k, v in program.state_dict(mode).items() + k: np.array(v) for k, v in program.state_dict(mode).items() } with open(path, "wb") as f: pickle.dump(state, f) @@ -108,8 +106,9 @@ class DistributedSaver: def _load_file(filename, dirname, suffix="pdparams"): file_list = [] for file in os.listdir(dirname): - if check_filename('{}(.*)_dist(.*).{}'.format(filename, suffix), - file): + if check_filename( + '{}(.*)_dist(.*).{}'.format(filename, suffix), file + ): file_list.append(os.path.join(dirname, file)) file_list.sort() return file_list @@ -137,14 +136,16 @@ class DistributedSaver: # load path.pdparam and path.pdopt param_state_dict = _load_state(filename, dirname) - opt_state_dict = _load_state(filename, dirname, - "pdopt") if load_optimizer else {} + opt_state_dict = ( + _load_state(filename, dirname, "pdopt") if load_optimizer else {} + ) state_dict = dict(param_state_dict, **opt_state_dict) # load path.pdattr dist_attr_file_list = _load_file(filename, dirname, "pdattr") self._logger.info( - "Load distributed attribute file: {}".format(dist_attr_file_list)) + "Load distributed attribute file: {}".format(dist_attr_file_list) + ) dist_attr = {} for dist_attr_file in dist_attr_file_list: with open(dist_attr_file, 'rb') as f: @@ -196,12 +197,24 @@ class DistributedSaver: used_inputs += op.input_arg_names used_outputs += op.output_arg_names - dist_feed_vars_names = list(set(feed_vars_names) & set(used_inputs)) - dist_fetch_vars_names = list(set(fetch_vars_names) & set(used_outputs)) + # delete duplicated elements and keep order + feed_vars_names = list({}.fromkeys(feed_vars_names).keys()) + used_inputs = list({}.fromkeys(used_inputs).keys()) + fetch_vars_names = list({}.fromkeys(fetch_vars_names).keys()) + used_outputs = list({}.fromkeys(used_outputs).keys()) - dist_feed_vars = [ - global_block.vars[name] for name in dist_feed_vars_names + dist_feed_vars_names = [ + var_name for var_name in feed_vars_names if var_name in used_inputs ] + dist_fetch_vars_names = [ + var_name + for var_name in fetch_vars_names + if var_name in used_outputs + ] + + dist_feed_vars = list( + reversed([global_block.vars[name] for name in dist_feed_vars_names]) + ) dist_fetch_vars = [ global_block.vars[name] for name in dist_fetch_vars_names ] @@ -209,11 +222,13 @@ class DistributedSaver: # NOTE: `paddle.static.save_inference_model` does not support subblock. dist_filename = filename + "_dist" + str(rank_id) dist_path = os.path.join(dirname, dist_filename) - paddle.static.save_inference_model(dist_path, - dist_feed_vars, - dist_fetch_vars, - exe, - program=dist_main_prog) + paddle.static.save_inference_model( + dist_path, + dist_feed_vars, + dist_fetch_vars, + exe, + program=dist_main_prog, + ) def _save_rank_mapping(self, dirname): path = os.path.join(dirname, 'rank_mapping.csv') diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index 7d988c6c95e..26c22297d9d 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -225,6 +225,11 @@ class Engine: self._planned_mode = None self._dygraph_mode = False self._tuning = self._strategy.tuning + self._acc_steps = 1 + if self._strategy.gradient_merge.enable: + self._acc_steps = self._strategy.gradient_merge.k_steps + elif self._strategy.pipeline.enable: + self._acc_steps = self._strategy.pipeline.accumulate_steps self.history = None @@ -388,7 +393,12 @@ class Engine: if self.main_program._pipeline_opt: assert "tasks" in self.main_program._pipeline_opt["fleet_opt"] fleet_opt = self.main_program._pipeline_opt["fleet_opt"] - fwd_task = fleet_opt["tasks"][0] + fwd_task = None + if self._strategy.pipeline.schedule_mode == "1F1B": + fwd_task = fleet_opt["tasks"][1] + elif self._strategy.pipeline.schedule_mode == "stream": + fwd_task = fleet_opt["tasks"][0] + assert fwd_task is not None fwd_prog = fwd_task.get_program() fwd_block = fwd_prog.global_block() @@ -438,8 +448,6 @@ class Engine: ), "user_fetches must be a list, but receive {}".format( type(user_fetches).__name__ ) - else: - user_fetches = [] fetch_names = [] fetch_indices = [] @@ -466,7 +474,7 @@ class Engine: _process_fetch_group("metrics_" + str(i), var_list) if mode == "predict": _process_fetch_group("outputs", fetch_vars["outputs"]) - for usr_fetch in user_fetches: + for usr_fetch in user_fetches or []: var_name = _to_name_str(usr_fetch) fetch(var_name) user_fetches_collection = [ @@ -903,6 +911,7 @@ class Engine: self._inputs_spec, self._labels_spec = self._prepare_data_spec( train_data, train_sample_split, batch_size ) + batch_size = self._validate_batch_size(batch_size) if not self._has_prepared[self._mode]: self._prepare_program(self._mode) else: @@ -931,7 +940,7 @@ class Engine: save_dir=save_dir, verbose=verbose, metrics=self._metrics_name(), - acc_step=self._k_steps, + acc_step=self._acc_steps, ) cbks.on_begin('train') @@ -965,7 +974,7 @@ class Engine: val_logs = self.evaluate( valid_data, valid_sample_split, - batch_size, + batch_size * self._acc_steps, valid_steps, log_freq, collate_fn, @@ -1046,6 +1055,7 @@ class Engine: self._inputs_spec, self._labels_spec = self._prepare_data_spec( valid_data, valid_sample_split, batch_size ) + batch_size = self._validate_batch_size(batch_size) if not self._has_prepared[self._mode]: self._prepare_program(self._mode) else: @@ -1152,6 +1162,7 @@ class Engine: self._inputs_spec, self._labels_spec = self._prepare_data_spec( test_data, test_sample_split, batch_size ) + batch_size = self._validate_batch_size(batch_size) if not self._has_prepared[self._mode]: self._prepare_program(self._mode) else: @@ -1214,6 +1225,7 @@ class Engine: self._inputs_spec, self._labels_spec = self._prepare_data_spec( dataset, sample_split, batch_size ) + batch_size = self._validate_batch_size(batch_size) if not self._has_prepared[self._mode]: self._prepare_program(self._mode) else: @@ -1256,6 +1268,7 @@ class Engine: self._inputs_spec, self._labels_spec = self._prepare_data_spec( dataset, sample_split, batch_size ) + batch_size = self._validate_batch_size(batch_size) if not self._has_prepared[self._mode]: self._prepare_program(self._mode) else: @@ -1371,14 +1384,6 @@ class Engine: steps_per_epoch=None, ): - if self._strategy.gradient_merge and batch_size is not None: - assert ( - batch_size % self._k_steps == 0 - ), "Requires batch_size:[{}] to be divisible by k_steps:[{}].".format( - batch_size, self._k_steps - ) - batch_size //= self._k_steps - dist_context = self._dist_contexts[self._mode] dist_main_prog = dist_context.dist_main_programs[self._cur_rank] dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank] @@ -1440,14 +1445,6 @@ class Engine: collate_fn=None, ): - if self._strategy.gradient_merge and batch_size is not None: - assert ( - batch_size % self._k_steps == 0 - ), "Requires batch_size:[{}] to be divisible by k_steps:[{}].".format( - batch_size, self._k_steps - ) - batch_size //= self._k_steps - dist_context = self._dist_contexts[self._mode] dist_main_prog = dist_context.dist_main_programs[self._cur_rank] dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank] @@ -1487,6 +1484,9 @@ class Engine: split_data=self._strategy.split_data, data_parallel_world_size=self._dp_world_sizes, data_parallel_rank=self._dp_ranks, + acc_steps=1 + if not self._strategy.pipeline.enable + else self._acc_steps, ) self._prepare_reader(feed_list) return dataloader @@ -1498,9 +1498,18 @@ class Engine: ) self._optimization_tuning(self._mode, tune_data, batch_size) + def _validate_batch_size(self, batch_size): + if batch_size is None: + return None + assert ( + batch_size % self._acc_steps == 0 + ), "Requires batch_size:[{}] to be divisible by acc_steps:[{}].".format( + batch_size, self._acc_steps + ) + return batch_size // self._acc_steps + def _validate_spec(self, specs): specs = to_list(specs) - self._k_steps = self._strategy.gradient_merge.k_steps if specs is not None: for i, spec in enumerate(specs): if not isinstance(spec, InputSpec): @@ -1513,14 +1522,14 @@ class Engine: i, spec ) ) - if self._k_steps > 1: + if self._acc_steps > 1: shape = list(spec.shape) assert ( - shape[0] % self._k_steps == 0 + shape[0] % self._acc_steps == 0 ), "Requires batch_size[{}] to be divisible by k_steps[{}].".format( - spec.shape[0], self._k_steps + spec.shape[0], self._acc_steps ) - shape[0] //= self._k_steps + shape[0] //= self._acc_steps spec.shape = shape return specs or [] diff --git a/python/paddle/distributed/auto_parallel/parallelizer_v2.py b/python/paddle/distributed/auto_parallel/parallelizer_v2.py index 09f5f6464bc..0594f539f0e 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer_v2.py +++ b/python/paddle/distributed/auto_parallel/parallelizer_v2.py @@ -297,13 +297,15 @@ class Parallelizer: if self._strategy is None: return - # data parallel optimization - config = {} - config["dist_context"] = self._dist_context - config["global_rank"] = rank - config["use_sharding"] = self._strategy.sharding.enable - dp_pass = new_pass("auto_parallel_data_parallel_optimization", config) - dp_pass.apply([main_program], [startup_program], self._pass_context) + if self._strategy.dp_optimization.enable: + config = copy.deepcopy(self._strategy.dp_optimization.to_dict()) + config["dist_context"] = self._dist_context + config["global_rank"] = rank + config["use_sharding"] = self._strategy.sharding.enable + dp_pass = new_pass( + "auto_parallel_data_parallel_optimization", config + ) + dp_pass.apply([main_program], [startup_program], self._pass_context) if self._strategy.sharding.enable: config = copy.deepcopy(self._strategy.sharding.to_dict()) diff --git a/python/paddle/distributed/auto_parallel/partitioner.py b/python/paddle/distributed/auto_parallel/partitioner.py index e12a111dd2a..b7dd197ea65 100644 --- a/python/paddle/distributed/auto_parallel/partitioner.py +++ b/python/paddle/distributed/auto_parallel/partitioner.py @@ -13,24 +13,25 @@ # limitations under the License import copy -import numpy as np -import paddle -import paddle.fluid as fluid -from paddle.fluid import core -from paddle.fluid import framework as framework -from paddle.fluid import core, unique_name -from paddle.fluid.framework import Program, Parameter, Variable, program_guard -from paddle.distributed.auto_parallel.operators.common import get_distributed_operator_impl_container -from paddle.distributed.auto_parallel.dist_context import DistributedContext, DistributedOperatorContext +from paddle.fluid.framework import Program, Parameter, core +from paddle.distributed.auto_parallel.operators.common import ( + get_distributed_operator_impl_container, +) +from paddle.distributed.auto_parallel.dist_context import DistributedContext from .dist_attribute import OperatorDistributedAttribute -from .process_group import new_process_group -from .utils import set_dist_op_desc_original_id -from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op, is_loss_op, is_optimize_op +from .utils import ( + is_forward_op, + is_backward_op, + is_loss_op, + is_optimize_op, + is_fillconst_op_for_micro_batch, +) from .operators.common import BACKWARD_ONLY_DIST_OPS __varname_not_in_block__ = ["lod_tensor_blocking_queue"] __not_shape_var_type__ = [ - core.VarDesc.VarType.READER, core.VarDesc.VarType.STEP_SCOPES + core.VarDesc.VarType.READER, + core.VarDesc.VarType.STEP_SCOPES, ] @@ -39,7 +40,7 @@ class Partitioner(object): warning:: Partitioner is experimental and subject to change. Partitioner convert a program into another program. - Given a serial program which has been auto completed with shard annotation, the Partitioner + Given a serial program which has been auto completed with shard annotation, the Partitioner convert the serial program into a "distributed" program. The Partitioner will modify the serial program in following two ways, which is also the major difference between serial and distributed program: 1. partition op: replace a serial op into its corresponding dist op infered from the shard annotation @@ -56,25 +57,29 @@ class Partitioner(object): """ if not isinstance(dist_context, DistributedContext): raise TypeError( - "dist_context be paddle.fluid.DistributedContext, got %s here" % - type(dist_context)) + "dist_context be paddle.fluid.DistributedContext, got %s here" + % type(dist_context) + ) self._dist_context = dist_context self._rank_id = rank_id self._serial2dist_varname_mapping = {} self._dist_varname_suffix = "" - def partition(self, serial_main_program, serial_startup_program, - params_grads): + def partition( + self, serial_main_program, serial_startup_program, params_grads + ): if not isinstance(serial_main_program, (Program)): raise TypeError( - "main_program be paddle.fluid.framework.program, got %s here" % - type(serial_main_program)) + "main_program be paddle.fluid.framework.program, got %s here" + % type(serial_main_program) + ) # check if shard annotated serial program valid if not self._is_valid_annotated_program(serial_main_program): raise RuntimeError( - "Not all vars or ops are annotated in main program !") + "Not all vars or ops are annotated in main program !" + ) # init distop helper dist_op_context = self._dist_context.dist_op_context @@ -86,24 +91,33 @@ class Partitioner(object): partitioned_startup_prog = None else: partitioned_startup_prog = self.partition_startup_program( - serial_main_program, serial_startup_program) + serial_main_program, serial_startup_program + ) dist_op_context.dst_startup_program = partitioned_startup_prog # partition main program - partitioned_main_prog, partitioned_params_grads = self.partition_main_program( - serial_main_program, params_grads) + ( + partitioned_main_prog, + partitioned_params_grads, + ) = self.partition_main_program(serial_main_program, params_grads) - return partitioned_main_prog, partitioned_startup_prog, partitioned_params_grads + return ( + partitioned_main_prog, + partitioned_startup_prog, + partitioned_params_grads, + ) - def partition_startup_program(self, serial_main_program, - serial_startup_program): + def partition_startup_program( + self, serial_main_program, serial_startup_program + ): if not isinstance(serial_startup_program, (Program)): raise TypeError( - "dist_context be paddle.fluid.framework.program, got %s here" % - type(serial_startup_program)) + "dist_context be paddle.fluid.framework.program, got %s here" + % type(serial_startup_program) + ) - partitioned_startup_prog = fluid.Program() + partitioned_startup_prog = Program() ref_block = serial_main_program.global_block() target_block = partitioned_startup_prog.global_block() var2shape = {} @@ -114,27 +128,33 @@ class Partitioner(object): assert var.persistable new_name = var.name + self._dist_varname_suffix temp_varname_map[var.name] = new_name - target_shape = _partition_var(self._dist_context, ref_block, - target_block, var.name, new_name) + target_shape = _partition_var( + self._dist_context, ref_block, target_block, var.name, new_name + ) var2shape[new_name] = target_shape # ops for op in serial_startup_program.global_block().ops: # TODO if var not belong to this rank, should be filtered output_vars = op.desc.output_arg_names() - assert len( - output_vars - ) == 1, "initializer should output only ONE variable, but got [{}]".format( - str(op.desc)) - assert temp_varname_map[output_vars[ - 0]] in var2shape, "try to initialize [{}] which is not a persistable var".format( - output_vars[0]) + assert ( + len(output_vars) == 1 + ), "initializer should output only ONE variable, but got [{}]".format( + str(op.desc) + ) + assert ( + temp_varname_map[output_vars[0]] in var2shape + ), "try to initialize [{}] which is not a persistable var".format( + output_vars[0] + ) new_op_desc = target_block.desc.append_op() new_op_desc.copy_from(op.desc) - new_op_desc._rename_output(output_vars[0], - temp_varname_map[output_vars[0]]) - new_op_desc._set_attr("shape", - var2shape[temp_varname_map[output_vars[0]]]) + new_op_desc._rename_output( + output_vars[0], temp_varname_map[output_vars[0]] + ) + new_op_desc._set_attr( + "shape", var2shape[temp_varname_map[output_vars[0]]] + ) target_block._sync_with_cpp() # set distribute atrribute @@ -142,14 +162,17 @@ class Partitioner(object): assert new_op.type == new_op_desc.type() assert new_op.desc == new_op_desc output_var = target_block.var(output_vars[0]) - output_var_attr = self._dist_context.get_tensor_dist_attr_for_program( - output_var) + output_var_attr = ( + self._dist_context.get_tensor_dist_attr_for_program(output_var) + ) op_attr = OperatorDistributedAttribute() op_attr.process_mesh = output_var_attr.process_mesh - op_attr.set_output_dims_mapping(output_var.name, - output_var_attr.dims_mapping) - op_attr.set_input_dims_mapping(output_var.name, - output_var_attr.dims_mapping) + op_attr.set_output_dims_mapping( + output_var.name, output_var_attr.dims_mapping + ) + op_attr.set_input_dims_mapping( + output_var.name, output_var_attr.dims_mapping + ) self._dist_context.set_op_dist_attr_for_program(new_op, op_attr) return partitioned_startup_prog @@ -160,7 +183,7 @@ class Partitioner(object): 2. replace local op with corresponding dist op """ - partitioned_main_prog = fluid.Program() + partitioned_main_prog = Program() dist_op_context = self._dist_context.dist_op_context dist_op_context.dst_main_program = partitioned_main_prog @@ -171,7 +194,8 @@ class Partitioner(object): target_block = partitioned_main_prog.blocks[0] else: target_block = partitioned_main_prog._create_block( - parent_idx=ref_block.parent_idx) + parent_idx=ref_block.parent_idx + ) assert ref_block.idx == target_block.idx target_block._set_forward_block_idx(ref_block.forward_block_idx) dist_op_context.work_block = target_block @@ -186,8 +210,9 @@ class Partitioner(object): for attr_name in op.all_attrs(): if op.attr_type(attr_name) == core.AttrType.BLOCK: relative_id = op._block_attr_id(attr_name) - op._set_attr(attr_name, - partitioned_main_prog.block(relative_id)) + op._set_attr( + attr_name, partitioned_main_prog.block(relative_id) + ) partitioned_params_and_grads = [] for p, g in params_and_grads: @@ -198,7 +223,8 @@ class Partitioner(object): else: assert g.name in self._serial2dist_varname_mapping dist_g = self._get_dist_var_by_serial_var( - g, partitioned_main_prog) + g, partitioned_main_prog + ) partitioned_params_and_grads.append((dist_p, dist_g)) return partitioned_main_prog, partitioned_params_and_grads @@ -222,71 +248,116 @@ class Partitioner(object): for idx in range(len(serial_ops)): if idx <= last_fwd_op_idx: forward_op_id2forward_op[ - serial_ops[idx].desc.original_id()] = serial_ops[idx] + serial_ops[idx].desc.original_id() + ] = serial_ops[idx] # partiiton appended_grad_times = 0 for idx, op in enumerate(serial_ops): op_dist_attr = self._dist_context.get_op_dist_attr_for_program(op) - if is_backward_op(op) and (is_forward_op(serial_ops[idx - 1]) - or is_loss_op(serial_ops[idx - 1])): + if is_backward_op(op) and ( + is_forward_op(serial_ops[idx - 1]) + or is_loss_op(serial_ops[idx - 1]) + ): if not op_dist_attr.is_recompute: appended_grad_times += 1 # partititon input variables for serial_input_varname in op.desc.input_arg_names(): - if serial_input_varname not in self._serial2dist_varname_mapping: - new_varname = serial_input_varname + self._dist_varname_suffix + if ( + serial_input_varname + not in self._serial2dist_varname_mapping + ): + new_varname = ( + serial_input_varname + self._dist_varname_suffix + ) if ref_block.has_var(serial_input_varname): - _partition_var(self._dist_context, ref_block, - target_block, serial_input_varname, - new_varname) + _partition_var( + self._dist_context, + ref_block, + target_block, + serial_input_varname, + new_varname, + ) else: for varname_not_in_block in __varname_not_in_block__: - assert varname_not_in_block in serial_input_varname, \ - "{} is not found".format(serial_input_varname) + assert ( + varname_not_in_block in serial_input_varname + ), "{} is not found".format(serial_input_varname) self._serial2dist_varname_mapping[ - serial_input_varname] = new_varname + serial_input_varname + ] = new_varname # partition output vars for serial_output_varname in op.desc.output_arg_names(): - if serial_output_varname not in self._serial2dist_varname_mapping: - new_varname = serial_output_varname + self._dist_varname_suffix - _partition_var(self._dist_context, ref_block, target_block, - serial_output_varname, new_varname) + if ( + serial_output_varname + not in self._serial2dist_varname_mapping + ): + new_varname = ( + serial_output_varname + self._dist_varname_suffix + ) + _partition_var( + self._dist_context, + ref_block, + target_block, + serial_output_varname, + new_varname, + ) self._serial2dist_varname_mapping[ - serial_output_varname] = new_varname + serial_output_varname + ] = new_varname # partition op - if is_forward_op(op) or op_dist_attr.is_recompute: + if ( + is_forward_op(op) + or op_dist_attr.is_recompute + or is_fillconst_op_for_micro_batch(op) + ): kinputs, koutputs = dist_op_context.prepare_context(op) dist_op_forward_impl = _get_dist_op_forward_implement( - op, self._dist_context) - dist_op_forward_impl.forward(self._dist_context, **kinputs, - **koutputs) + op, self._dist_context + ) + dist_op_forward_impl.forward( + self._dist_context, **kinputs, **koutputs + ) elif is_backward_op(op): kinputs, koutputs = dist_op_context.prepare_context(op) dist_op_backward_impl = _get_dist_op_backward_implement( - op, self._dist_context, forward_op_id2forward_op) - grad_var_to_var = self._dist_context.dist_op_context.grad_var_to_var[ - appended_grad_times] + op, self._dist_context, forward_op_id2forward_op + ) + grad_var_to_var = ( + self._dist_context.dist_op_context.grad_var_to_var[ + appended_grad_times + ] + ) dist_op_backward_impl.backward( - self._dist_context, **kinputs, **koutputs, - **{"grad_var_to_var": grad_var_to_var}) + self._dist_context, + **kinputs, + **koutputs, + **{"grad_var_to_var": grad_var_to_var} + ) elif is_optimize_op(op): - # NOTE: BACKWARD_ONLY_DIST_OPS's op_role must 2 because of 1F1B PASS + # NOTE: BACKWARD_ONLY_DIST_OPS's op_role must be 2 because of 1F1B PASS kinputs, koutputs = dist_op_context.prepare_context(op) dist_op_opt_impl = _get_dist_op_backward_implement( - op, self._dist_context, forward_op_id2forward_op) - dist_op_opt_impl.backward(self._dist_context, **kinputs, - **koutputs, **{"grad_var_to_var": {}}) + op, self._dist_context, forward_op_id2forward_op + ) + dist_op_opt_impl.backward( + self._dist_context, + **kinputs, + **koutputs, + **{"grad_var_to_var": {}} + ) else: raise NotImplementedError( - "partitioner only support forward and backward, optimize ops, but got {}" - .format(str(op))) + "partitioner only support forward and backward, optimize ops, but got {}".format( + str(op) + ) + ) def _is_valid_annotated_program(self, program): @@ -298,13 +369,16 @@ class Partitioner(object): ] var_dist_attrs = [ self._dist_context.get_tensor_dist_attr_for_program(var) - for var in vars_ if (var.type not in __not_shape_var_type__) + for var in vars_ + if (var.type not in __not_shape_var_type__) ] - all_ops_annotated = all(dist_attr is not None - for dist_attr in op_dist_attrs) - all_vars_annotated = all(dist_attr is not None - for dist_attr in var_dist_attrs) + all_ops_annotated = all( + dist_attr is not None for dist_attr in op_dist_attrs + ) + all_vars_annotated = all( + dist_attr is not None for dist_attr in var_dist_attrs + ) return all_ops_annotated and all_vars_annotated @@ -328,22 +402,26 @@ def _get_dist_shape(var, dist_attr): assert len(var_shape) == len( mapping ), "variable shape [{}] and dim_mapping [{}] is NOT match !".format( - var_shape, mapping) + var_shape, mapping + ) new_shape = [] for idx in range(len(var_shape)): if var_shape[idx] == -1 or mapping[idx] == -1: new_shape.append(var_shape[idx]) else: - assert var_shape[idx] % mesh[mapping[ - idx]] == 0, "un-event partition: var_shape[idx]=[{}], mesh[{}]".format( - var_shape[idx], mesh[mapping[idx]]) + assert ( + var_shape[idx] % mesh[mapping[idx]] == 0 + ), "un-event partition: var_shape[idx]=[{}], mesh[{}]".format( + var_shape[idx], mesh[mapping[idx]] + ) new_shape.append(var_shape[idx] // mesh[mapping[idx]]) return new_shape -def _partition_parameter(dist_context, src_var, dst_block, dst_varname, - dst_shape): +def _partition_parameter( + dist_context, src_var, dst_block, dst_varname, dst_shape +): # NOTE hack to copied Parameter # not initialized parameter, need to initialize it copied_kwargs = {} @@ -353,39 +431,45 @@ def _partition_parameter(dist_context, src_var, dst_block, dst_varname, copied_kwargs['do_model_average'] = src_var.do_model_average copied_kwargs['need_clip'] = src_var.need_clip - param = Parameter(block=dst_block, - type=src_var.type, - name=dst_varname, - shape=dst_shape, - dtype=src_var.dtype, - lod_level=src_var.lod_level, - error_clip=src_var.error_clip, - stop_gradient=src_var.stop_gradient, - is_data=src_var.is_data, - belong_to_optimizer=src_var.belong_to_optimizer, - **copied_kwargs) + param = Parameter( + block=dst_block, + type=src_var.type, + name=dst_varname, + shape=dst_shape, + dtype=src_var.dtype, + lod_level=src_var.lod_level, + error_clip=src_var.error_clip, + stop_gradient=src_var.stop_gradient, + is_data=src_var.is_data, + belong_to_optimizer=src_var.belong_to_optimizer, + **copied_kwargs + ) return param -def _partition_intermediate_var(dist_context, src_var, dst_block, dst_varname, - dst_shape): - var = dst_block.create_var(type=src_var.type, - name=dst_varname, - shape=dst_shape, - dtype=src_var.dtype, - lod_level=src_var.lod_level, - persistable=src_var.persistable, - error_clip=src_var.error_clip, - stop_gradient=src_var.stop_gradient, - is_data=src_var.is_data, - belong_to_optimizer=src_var.belong_to_optimizer) +def _partition_intermediate_var( + dist_context, src_var, dst_block, dst_varname, dst_shape +): + var = dst_block.create_var( + type=src_var.type, + name=dst_varname, + shape=dst_shape, + dtype=src_var.dtype, + lod_level=src_var.lod_level, + persistable=src_var.persistable, + error_clip=src_var.error_clip, + stop_gradient=src_var.stop_gradient, + is_data=src_var.is_data, + belong_to_optimizer=src_var.belong_to_optimizer, + ) return var -def _partition_var(dist_context, src_block, dst_block, src_varname, - dst_varname): +def _partition_var( + dist_context, src_block, dst_block, src_varname, dst_varname +): """ partition include: split + replicate """ @@ -393,44 +477,53 @@ def _partition_var(dist_context, src_block, dst_block, src_varname, if src_var.type in __not_shape_var_type__: persist = getattr(src_var, 'persistable', False) - new_var = dst_block.create_var(type=src_var.type, - name=dst_varname, - persistable=persist, - stop_gradient=True) + new_var = dst_block.create_var( + type=src_var.type, + name=dst_varname, + persistable=persist, + stop_gradient=True, + ) target_shape = None else: dist_attr = dist_context.get_tensor_dist_attr_for_program(src_var) target_shape = _get_dist_shape(src_var, dist_attr) if isinstance(src_var, Parameter): - new_var = _partition_parameter(dist_context, src_var, dst_block, - dst_varname, target_shape) + new_var = _partition_parameter( + dist_context, src_var, dst_block, dst_varname, target_shape + ) else: - new_var = _partition_intermediate_var(dist_context, src_var, - dst_block, dst_varname, - target_shape) + new_var = _partition_intermediate_var( + dist_context, src_var, dst_block, dst_varname, target_shape + ) dist_attr = copy.deepcopy( - dist_context.get_tensor_dist_attr_for_program(src_var)) + dist_context.get_tensor_dist_attr_for_program(src_var) + ) assert dist_attr is not None dist_context.set_tensor_dist_attr_for_program(new_var, dist_attr) return target_shape -def _get_dist_op_backward_implement(backward_op, dist_context, - forward_op_id2forward_op): +def _get_dist_op_backward_implement( + backward_op, dist_context, forward_op_id2forward_op +): dist_op_context = dist_context.dist_op_context if backward_op.desc.original_id() in dist_op_context.grad_op_id_to_op_id: forward_op_id = dist_op_context.grad_op_id_to_op_id[ - backward_op.desc.original_id()] + backward_op.desc.original_id() + ] forward_op = forward_op_id2forward_op[forward_op_id] forward_op_dist_attr = dist_context.get_op_dist_attr_for_program( - forward_op) + forward_op + ) dist_op_impl_container = get_distributed_operator_impl_container( - forward_op_dist_attr.impl_type) + forward_op_dist_attr.impl_type + ) dist_op_impl = dist_op_impl_container.get_impl( - forward_op_dist_attr.impl_idx) + forward_op_dist_attr.impl_idx + ) return dist_op_impl # # NOTE trick for dist ops that only have backward implement @@ -438,7 +531,8 @@ def _get_dist_op_backward_implement(backward_op, dist_context, op_dist_attr = dist_context.get_op_dist_attr_for_program(backward_op) assert op_dist_attr.impl_idx >= 0 dist_op_impl = get_distributed_operator_impl_container( - op_dist_attr.impl_type).get_impl(op_dist_attr.impl_idx) + op_dist_attr.impl_type + ).get_impl(op_dist_attr.impl_idx) return dist_op_impl dist_op = get_distributed_operator_impl_container("default") @@ -448,6 +542,7 @@ def _get_dist_op_backward_implement(backward_op, dist_context, def _get_dist_op_forward_implement(forward_op, dist_context): dist_attr = dist_context.get_op_dist_attr_for_program(forward_op) dist_op_impl_container = get_distributed_operator_impl_container( - dist_attr.impl_type) + dist_attr.impl_type + ) dist_op_impl = dist_op_impl_container.get_impl(dist_attr.impl_idx) return dist_op_impl diff --git a/python/paddle/distributed/auto_parallel/reshard.py b/python/paddle/distributed/auto_parallel/reshard.py index 4e5d5b0bf32..b5846b4db9d 100644 --- a/python/paddle/distributed/auto_parallel/reshard.py +++ b/python/paddle/distributed/auto_parallel/reshard.py @@ -422,11 +422,11 @@ class Inserter: ) inputs = {'X': [tensor]} outputs = {"Out": [out]} - attrs = {"in_place": False} - slice_op = block._insert_op( + attrs = {"in_place": False, "op_role": op_role} + assign_op = block._insert_op( idx, type="assign", inputs=inputs, outputs=outputs, attrs=attrs ) - slice_op._set_attr('op_namescope', "/auto_parallel/reshard") + assign_op._set_attr('op_namescope', "/auto_parallel/reshard") return out # use split once @@ -1217,6 +1217,8 @@ class Resharder: shape_x[0] <= shape_y[0] < shape_x[1] ): overlapped = True + if shape_x == [0, 0] and shape_y == [0, 0]: + overlapped = True return overlapped def is_unshard(self, dims_mapping): @@ -1304,6 +1306,14 @@ class Resharder: # judge whether need reshard by process_mesh if tensor_process_mesh != op_process_mesh: is_reshard = True + # not reshard data in send/recv scene + if ( + tensor_process_mesh != op_process_mesh + and len(tensor_process_mesh.process_ids) + == len(op_process_mesh.process_ids) + and dist_tensor.serial_tensor.is_data + ): + is_reshard = False else: op_output_dims_mapping = dist_attr[1] if all( @@ -1585,10 +1595,10 @@ class Resharder: if i == 0: all_partition_index_list.append(process_index[j][1]) for process in group: - # append slice op desc - slice_starts = [] - slice_ends = [] - slices_axes = [] + min_comm_group = copy.deepcopy(group) + all_partition_index_list_copied = copy.deepcopy( + all_partition_index_list + ) target_partition_index = Resharder.compute_partition_index( process, complete_shape, @@ -1596,12 +1606,56 @@ class Resharder: target_process_shape, target_process_group, ) - for idx, item in enumerate(target_partition_index): - slice_starts.append(item[0]) - slice_ends.append(item[1]) + for _process in group: + source_partition_index = ( + Resharder.compute_partition_index( + _process, + complete_shape, + source_dims_mapping, + source_process_shape, + source_process_group, + ) + ) + if not all( + _ + for _ in list( + map( + self.is_overlapped, + source_partition_index, + target_partition_index, + ) + ) + ): + min_comm_group.remove(_process) + all_partition_index_list_copied.remove( + source_partition_index + ) + + concatenated_partition_index_list = [] + for partition_index in all_partition_index_list_copied: + Resharder.concat_partitions( + concatenated_partition_index_list, partition_index + ) + + concatenated_partition_index = ( + concatenated_partition_index_list[0] + ) + + slice_starts = [] + slice_ends = [] + slices_axes = [] + to_slice_tensor_shape = [] + + for idx, item in enumerate(concatenated_partition_index): + slice_starts.append( + target_partition_index[idx][0] - item[0] + ) + slice_ends.append( + target_partition_index[idx][1] - item[0] + ) slices_axes.append(idx) + to_slice_tensor_shape.append(item[1] - item[0]) - to_slice_tensor_shape = dist_tensor.global_sizes() slice_op_desc = SliceOpDesc( starts=slice_starts, ends=slice_ends, @@ -1616,16 +1670,16 @@ class Resharder: op_desc_seq[process] = ( [ AllGatherOpDesc( - group=group, + group=min_comm_group, shape=allgather_shape, is_bool=(source_tensor.dtype == paddle.bool), ), ConcatOpDesc( - partition_index_list=all_partition_index_list + partition_index_list=all_partition_index_list_copied ), slice_op_desc, ] - if len(group) > 1 + if len(min_comm_group) > 1 else [slice_op_desc] ) diff --git a/python/paddle/distributed/auto_parallel/strategy.py b/python/paddle/distributed/auto_parallel/strategy.py index f7dd7e6697b..224c189d55b 100644 --- a/python/paddle/distributed/auto_parallel/strategy.py +++ b/python/paddle/distributed/auto_parallel/strategy.py @@ -123,6 +123,12 @@ class DatasetConfig(BaseConfig): super(DatasetConfig, self).__init__(category, config_dict) +class DPOptimizationConfig(BaseConfig): + def __init__(self, config_dict=None): + category = constants.DP_OPTIMIZATION + super(DPOptimizationConfig, self).__init__(category, config_dict) + + class Strategy(BaseConfig): """ The `Strategy` object is used to configure the paralleization and optimization beheviors. @@ -194,3 +200,6 @@ class Strategy(BaseConfig): config_dict = self._config_dict.get(constants.DATASET, None) self.dataset = DatasetConfig(config_dict) + + config_dict = self._config_dict.get(constants.DP_OPTIMIZATION, None) + self.dp_optimization = DPOptimizationConfig(config_dict) diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index a08a17288a4..33b30068653 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -1252,6 +1252,7 @@ def set_grad_var_shape(program, dist_context): "fused_softmax_mask_upper_triangle_grad", "flatten_contiguous_range_grad", "relu_grad", + "exp_grad", ] forward_list = [ "reshape2", @@ -1270,6 +1271,7 @@ def set_grad_var_shape(program, dist_context): "fused_softmax_mask_upper_triangle", "flatten_contiguous_range", "relu", + "exp", ] if op.type in need_set_shape_list: for forward_op in block.ops: @@ -1320,6 +1322,11 @@ def is_forward_op(op): ) +def is_fillconst_op_for_micro_batch(op): + op_role = int(op.attr('op_role')) + return OP_ROLE_KEY in op.attr_names and (op_role == int(OpRole.LRSched)) + + def is_backward_op(op): return OP_ROLE_KEY in op.attr_names and int( op.all_attrs()[OP_ROLE_KEY] diff --git a/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py b/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py index 70592e8b380..c93feb36ed0 100644 --- a/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py +++ b/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py @@ -18,15 +18,31 @@ import numpy as np import paddle from paddle.fluid import core, unique_name from paddle.fluid.framework import default_main_program -from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY -from paddle.distributed.auto_parallel.operators.common import is_data_parallel_scale_op, is_data_parallel_reduce_op -from paddle.distributed.auto_parallel.utils import is_loss_grad_op, is_optimize_op, is_backward_op, ring_id_to_process_group, find_higher_order_backward_op +from paddle.distributed.fleet.meta_optimizers.common import ( + OpRole, + OP_ROLE_KEY, + OP_ROLE_VAR_KEY, +) +from paddle.distributed.auto_parallel.operators.common import ( + is_data_parallel_scale_op, + is_data_parallel_reduce_op, +) +from paddle.distributed.auto_parallel.utils import ( + is_loss_grad_op, + is_optimize_op, + is_backward_op, + ring_id_to_process_group, + find_higher_order_backward_op, +) from .pass_base import PassBase, PassType, register_pass # add new optimizers supporting rescale_grad here __rescale_grad_supported_opts__ = [ - 'lars_momentum', 'sparse_momentum', 'dgc_momentum', 'momentum', - 'merge_momentum' + 'lars_momentum', + 'sparse_momentum', + 'dgc_momentum', + 'momentum', + 'merge_momentum', ] # a heuristic number @@ -41,7 +57,7 @@ def numel(var): class DataParallelOptimizationPass(PassBase): """ Apply Optimizations that specialized for data parallelism in Auto Parallel. - 1. prune grad scaling + 1. prune grad scaling 2. overlap comm and calc 3. fuse allreduce """ @@ -52,6 +68,9 @@ class DataParallelOptimizationPass(PassBase): self.set_attr("dist_context", None) self.set_attr("global_rank", -1) self.set_attr("use_sharding", False) + self.set_attr("fuse_all_reduce_ops", False) + self.set_attr("fuse_grad_size_in_MB", 32) + self.set_attr("overlap_comm_cacl", False) # {grad1: group1, grad2: group1, grad3: group2} # record the order for fuse grad data memory self._grad_name_to_group_map = OrderedDict() @@ -62,8 +81,9 @@ class DataParallelOptimizationPass(PassBase): def _check_self(self): if self.get_attr("dist_context") is None: return False - if (not isinstance(self.get_attr("global_rank"), - int)) or self.get_attr("global_rank") < 0: + if (not isinstance(self.get_attr("global_rank"), int)) or self.get_attr( + "global_rank" + ) < 0: return False return True @@ -80,13 +100,18 @@ class DataParallelOptimizationPass(PassBase): self.global_rank = int(self.get_attr("global_rank")) self.use_sharding = self.get_attr("use_sharding") + overlap_comm_cacl = self.get_attr("overlap_comm_cacl") + fuse_all_reduce_ops = self.get_attr("fuse_all_reduce_ops") + with paddle.static.program_guard(main_program, startup_program): self._analyze_program() if self.is_data_parallel_applied(): - self._prune_grad_scaling() - self._calc_comm_overlap() - grad_group = self._fuse_allreduce() + if overlap_comm_cacl: + self._prune_grad_scaling() + self._calc_comm_overlap() + if fuse_all_reduce_ops: + grad_group = self._fuse_allreduce() # self.summary(grad_group) @@ -140,8 +165,11 @@ class DataParallelOptimizationPass(PassBase): ), "Unexception: comm op [{}] has NOT ring id.".format(str(op)) group = ring_id_to_process_group(op.attr("ring_id")) - assert group is not None, "Unexception: data parallel group of [{}] from op [{}] is None".format( - grad_name, str(op)) + assert ( + group is not None + ), "Unexception: data parallel group of [{}] from op [{}] is None".format( + grad_name, str(op) + ) self._grad_name_to_group_map[grad_name] = group @@ -156,18 +184,21 @@ class DataParallelOptimizationPass(PassBase): # TODO support multiple optimizers in on network in future. # here we assume that the optimizer is unique in network. - elif is_optimize_op( - op) and op.type in __rescale_grad_supported_opts__: + elif ( + is_optimize_op(op) + and op.type in __rescale_grad_supported_opts__ + ): self._support_rescale_grad = True not_synchronized_grads = [] for grad_name in scaled_grads: if grad_name not in self._grad_name_to_group_map: not_synchronized_grads.append(grad_name) - assert len( + assert ( + len(not_synchronized_grads) == 0 + ), "Unexception: gradients [{}] is scaled BUT NOT synchronized.".format( not_synchronized_grads - ) == 0, "Unexception: gradients [{}] is scaled BUT NOT synchronized.".format( - not_synchronized_grads) + ) def is_data_parallel_applied(self): return len(self._group_to_grad_name_map) > 0 @@ -175,14 +206,21 @@ class DataParallelOptimizationPass(PassBase): def _could_be_prune(self): return self.dist_context.gradient_scale and ( - self._support_rescale_grad or self._all_dp_groups_same_degree()) + self._support_rescale_grad or self._all_dp_groups_same_degree() + ) def _all_dp_groups_same_degree(self): - return len( - set([ - len(group.ranks) - for group in self._group_to_grad_name_map.keys() - ])) == 1 + return ( + len( + set( + [ + len(group.ranks) + for group in self._group_to_grad_name_map.keys() + ] + ) + ) + == 1 + ) def _scale_backward_initial_grad(self): @@ -191,9 +229,10 @@ class DataParallelOptimizationPass(PassBase): for idx, op in reversed(list(enumerate(block.ops))): if is_loss_grad_op(op): - assert op.type == 'fill_constant', \ - "loss_grad_op must be fill_constant op, " \ + assert op.type == 'fill_constant', ( + "loss_grad_op must be fill_constant op, " "but this op is {}".format(op.type) + ) assert op.has_attr('value') loss_scale = float(op.attr('value')) loss_scale = loss_scale / dp_degree @@ -215,28 +254,35 @@ class DataParallelOptimizationPass(PassBase): scaled_grads = set() for idx, op in reversed(list(enumerate(block.ops))): - if is_optimize_op( - op) and op.type in __rescale_grad_supported_opts__: + if ( + is_optimize_op(op) + and op.type in __rescale_grad_supported_opts__ + ): assert op.has_attr( 'rescale_grad' ), "Unexception: op [{}] is supported to have [rescale_grad] attribute.".format( - str(op)) - assert len( - op.input("Grad") - ) == 1, "Unexception: op [{}] is supported to have only one input grad var.".format( - str(op)) + str(op) + ) + assert ( + len(op.input("Grad")) == 1 + ), "Unexception: op [{}] is supported to have only one input grad var.".format( + str(op) + ) grad_name = op.input("Grad")[0] dp_degree = len( - list(self._grad_name_to_group_map[grad_name].ranks)) + list(self._grad_name_to_group_map[grad_name].ranks) + ) scaled_grads.add(grad_name) rescale_grad = float(op.attr('rescale_grad')) / dp_degree op._set_attr('rescale_grad', rescale_grad) - assert scaled_grads == set(self._grad_name_to_group_map.keys( - )), "Unexception: gradients [{}] are unscaled.".format( - set(self._grad_name_to_group_map.keys()) - scaled_grads) + assert scaled_grads == set( + self._grad_name_to_group_map.keys() + ), "Unexception: gradients [{}] are unscaled.".format( + set(self._grad_name_to_group_map.keys()) - scaled_grads + ) def _could_be_overlap(self): # NOTE current different nccl comm will use different cuda stream @@ -266,14 +312,13 @@ class DataParallelOptimizationPass(PassBase): op._set_attr('use_calc_stream', False) ring_id = op.attr("ring_id") - block._insert_op_without_sync(idx, - type='c_wait_compute', - inputs={'X': []}, - outputs={'Out': []}, - attrs={ - 'op_role': OpRole.Backward, - 'ring_id': ring_id - }) + block._insert_op_without_sync( + idx, + type='c_wait_compute', + inputs={'X': []}, + outputs={'Out': []}, + attrs={'op_role': OpRole.Backward, 'ring_id': ring_id}, + ) block._sync_with_cpp() @@ -307,8 +352,10 @@ class DataParallelOptimizationPass(PassBase): # other ops that might use communicating grad else: for input_var_name in op.input_arg_names: - for ring_id, unsync_grad_names in ring_id_to_un_sync_grad_map.items( - ): + for ( + ring_id, + unsync_grad_names, + ) in ring_id_to_un_sync_grad_map.items(): if input_var_name in unsync_grad_names: # need to sync before op_i if i in op_idx_to_sync_ring_id_map: @@ -328,14 +375,13 @@ class DataParallelOptimizationPass(PassBase): for i in sorted(indices, reverse=True): for ring_id in op_idx_to_sync_ring_id_map[i]: - block._insert_op_without_sync(i, - type='c_wait_comm', - inputs={'X': []}, - outputs={'Out': []}, - attrs={ - 'op_role': OpRole.Backward, - 'ring_id': ring_id - }) + block._insert_op_without_sync( + i, + type='c_wait_comm', + inputs={'X': []}, + outputs={'Out': []}, + attrs={'op_role': OpRole.Backward, 'ring_id': ring_id}, + ) def _could_be_fuse(self): # TODO support gradient fuse higher order gradient. @@ -350,9 +396,9 @@ class DataParallelOptimizationPass(PassBase): """ conditions for gradients to be grouped: 1. group size < max_fuse_numel - 2. same dp group + 2. same dp group 3. same dtype - 4. dependency: grad would NOT be used by other ops within group segment + 4. dependency: grad would NOT be used by other ops within group segment gradients inside same group would be fuse into one coalesce tensor """ @@ -423,36 +469,51 @@ class DataParallelOptimizationPass(PassBase): for i, group in enumerate(grad_groups[::-1]): # create coalecse tensor - group.coalesce_var = block.create_var(name=unique_name.generate( - 'coalecse_grad_{}'.format(i)), - dtype=group.dtype, - persistable=False, - stop_gradient=True) + group.coalesce_var = block.create_var( + name=unique_name.generate('coalecse_grad_{}'.format(i)), + dtype=group.dtype, + persistable=False, + stop_gradient=True, + ) # update allreduce & scale op if group.scale_op_idx != -1: scale_op = block.ops[group.scale_op_idx] - assert scale_op.type == 'scale', "should found scale op but found {}".format( - str(scale_op)) - scale_op._rename_input(scale_op.input_arg_names[0], - group.coalesce_var.name) - scale_op._rename_output(scale_op.output_arg_names[0], - group.coalesce_var.name) + assert ( + scale_op.type == 'scale' + ), "should found scale op but found {}".format(str(scale_op)) + scale_op._rename_input( + scale_op.input_arg_names[0], group.coalesce_var.name + ) + scale_op._rename_output( + scale_op.output_arg_names[0], group.coalesce_var.name + ) allreduce_op = block.ops[group.allreduce_op_idx] - assert allreduce_op.type == 'c_allreduce_sum', "should found c_allreduce_sum op but found {}".format( - str(allreduce_op)) - allreduce_op._rename_input(allreduce_op.input_arg_names[0], - group.coalesce_var.name) - allreduce_op._rename_output(allreduce_op.output_arg_names[0], - group.coalesce_var.name) + assert ( + allreduce_op.type == 'c_allreduce_sum' + ), "should found c_allreduce_sum op but found {}".format( + str(allreduce_op) + ) + allreduce_op._rename_input( + allreduce_op.input_arg_names[0], group.coalesce_var.name + ) + allreduce_op._rename_output( + allreduce_op.output_arg_names[0], group.coalesce_var.name + ) # remvoe un-used op - remove_op_indices = group.remove_wait_op_indices + group.remove_allreduce_op_indices + group.remove_scale_op_indices + remove_op_indices = ( + group.remove_wait_op_indices + + group.remove_allreduce_op_indices + + group.remove_scale_op_indices + ) for idx in sorted(remove_op_indices, reverse=True): - assert block.ops[ - idx].type in remove_op_types, "Unexception: try to remove op {}".format( - str(op)) + assert ( + block.ops[idx].type in remove_op_types + ), "Unexception: try to remove op {}".format( + str(block.ops[idx].type()) + ) block._remove_op(idx) # insert coalecse op @@ -464,22 +525,23 @@ class DataParallelOptimizationPass(PassBase): concated_ranks.append(len(shape)) grad_names = [grad.name for grad in group.gradients] - block._insert_op_without_sync(group.coalesce_op_idx, - type="coalesce_tensor", - inputs={"Input": grad_names}, - outputs={ - "Output": grad_names, - "FusedOutput": group.coalesce_var - }, - attrs={ - "copy_data": False, - "use_align": True, - "dtype": group.dtype, - "concated_shapes": - concated_shapes, - "concated_ranks": concated_ranks, - OP_ROLE_KEY: OpRole.Backward - }) + block._insert_op_without_sync( + group.coalesce_op_idx, + type="coalesce_tensor", + inputs={"Input": grad_names}, + outputs={ + "Output": grad_names, + "FusedOutput": group.coalesce_var, + }, + attrs={ + "copy_data": False, + "use_align": True, + "dtype": group.dtype, + "concated_shapes": concated_shapes, + "concated_ranks": concated_ranks, + OP_ROLE_KEY: OpRole.Backward, + }, + ) block._sync_with_cpp() # TODO update dist attr @@ -487,6 +549,7 @@ class DataParallelOptimizationPass(PassBase): def summary(self, grad_groups=[]): # TODO: add logger module import logging + self._logger = logging.getLogger() self._logger.propagate = False if not self._logger.handlers: @@ -500,26 +563,31 @@ class DataParallelOptimizationPass(PassBase): if len(grad_groups) > 0: self._logger.info( - "origin {} allreduce ops are fused into {} coalecse allreduce ops." - .format(len(self._grad_name_to_group_map.keys()), - len(grad_groups))) + "origin {} allreduce ops are fused into {} coalecse allreduce ops.".format( + len(self._grad_name_to_group_map.keys()), len(grad_groups) + ) + ) self._logger.info("gradient fusing group are following: ") fused_grads = set() for i, group in enumerate(grad_groups): self._logger.info( "coalecse gradient [{}] is composed by: {}".format( - i, [grad.name for grad in group.gradients])) + i, [grad.name for grad in group.gradients] + ) + ) fused_grads.update([grad.name for grad in group.gradients]) - individual_grads = set( - self._grad_name_to_group_map.keys()) - set(fused_grads) + individual_grads = set(self._grad_name_to_group_map.keys()) - set( + fused_grads + ) self._logger.info( "the following [{}] gradients are not fused: ".format( - len(individual_grads))) + len(individual_grads) + ) + ) self._logger.info("individual gradient {}".format(individual_grads)) class GradientsGroup(object): - def __init__(self, ops, max_group_size): self.max_group_size = max_group_size self.ops = ops @@ -575,8 +643,11 @@ class GradientsGroup(object): grad_op_idx -= 1 grad_op = self.ops[grad_op_idx] - assert grad_var.name in grad_op.output_arg_names, "grad [{}] should be output of {}".format( - grad_var.name, str(grad_op)) + assert ( + grad_var.name in grad_op.output_arg_names + ), "grad [{}] should be output of {}".format( + grad_var.name, str(grad_op) + ) self.coalesce_op_idx = grad_op_idx def finalize(self): diff --git a/python/paddle/distributed/passes/auto_parallel_gradient_merge.py b/python/paddle/distributed/passes/auto_parallel_gradient_merge.py index c61d944400d..6d9a5e38caa 100644 --- a/python/paddle/distributed/passes/auto_parallel_gradient_merge.py +++ b/python/paddle/distributed/passes/auto_parallel_gradient_merge.py @@ -12,23 +12,40 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np -from collections import OrderedDict -from typing import List, Tuple, Dict, Any - import paddle from paddle.framework import core from paddle.fluid import layers -from paddle.fluid.framework import program_guard, device_guard +from paddle.distributed.fleet.meta_optimizers.common import ( + OpRole, + OP_ROLE_KEY, + OP_ROLE_VAR_KEY, +) +from paddle.distributed.auto_parallel.utils import ( + set_var_dist_attr, + is_optimize_op, + is_backward_op, + naive_set_dist_op_attr_for_program_by_mesh_and_mapping, +) +from paddle.distributed.auto_parallel.process_group import ( + get_world_process_group, +) +from paddle.distributed.auto_parallel.operators.common import ( + is_data_parallel_reduce_op, + is_data_parallel_scale_op, +) + from .pass_base import PassBase, PassType, register_pass -from paddle.distributed.auto_parallel.utils import set_var_dist_attr, is_optimize_op, OpRole, OP_ROLE_KEY -from paddle.distributed.auto_parallel.utils import naive_set_dist_op_attr_for_program_by_mesh_and_mapping -from paddle.distributed.auto_parallel.process_group import get_world_process_group world_process_group = get_world_process_group() -def _remove_and_get_optimizer_op(main_program, dist_context): +def is_gradient_clip_op(op_desc): + return op_desc.has_attr("op_namescope") and op_desc.attr( + "op_namescope" + ).startswith("/gradient_clip") + + +def _remove_and_get_ops(main_program, dist_context): # 1 create tmp block # 2 mv optimizer op from global program to tmp block # 3 del the op from dist_context @@ -36,101 +53,119 @@ def _remove_and_get_optimizer_op(main_program, dist_context): temp_block = main_program._create_block() removed_op_idx = [] optimize_ops_desc = [] + allreduce_sum_desc = [] for idx, op in enumerate(main_block.ops): + # append optimizer op to tmp block if is_optimize_op(op): - # append optimizer op to tmp block new_op_desc = temp_block.desc.append_op() new_op_desc.copy_from(op.desc) optimize_ops_desc.append(new_op_desc) removed_op_idx.append(idx) - - # del op from dist_context - if dist_context: + dist_context.del_dist_op_for_program(op) + + # append allreduce_op and scale_op to tmp block + if is_backward_op(op): + if is_data_parallel_reduce_op(op) or is_data_parallel_scale_op(op): + assert len(op.desc.output_arg_names()) == 1 + new_op_desc = temp_block.desc.append_op() + new_op_desc.copy_from(op.desc) + allreduce_sum_desc.append(new_op_desc) + removed_op_idx.append(idx) dist_context.del_dist_op_for_program(op) for idx in removed_op_idx[::-1]: main_block._remove_op(idx, sync=False) main_block._sync_with_cpp() - return optimize_ops_desc + return optimize_ops_desc, allreduce_sum_desc -def _get_gm_cond_var(main_program, k_steps, dist_context): +def _create_gm_cond_var(main_program, k_steps, dist_context): main_block = main_program.global_block() # Add const var - k_step_var = layers.create_global_var(name="gradient_merge_k", - shape=[1], - value=int(k_steps), - dtype='int32', - persistable=True, - force_cpu=True) + k_step_var = layers.create_global_var( + name="gradient_merge_k", + shape=[1], + value=int(k_steps), + dtype='int32', + persistable=True, + force_cpu=True, + ) set_var_dist_attr(dist_context, k_step_var, [-1], world_process_group.ranks) - zero_var = layers.create_global_var(name="gradient_merge_zero", - shape=[1], - value=int(0), - dtype='int32', - persistable=True, - force_cpu=True) + zero_var = layers.create_global_var( + name="gradient_merge_zero", + shape=[1], + value=int(0), + dtype='int32', + persistable=True, + force_cpu=True, + ) set_var_dist_attr(dist_context, zero_var, [-1], world_process_group.ranks) # Add step var & cond var - step_var = layers.create_global_var(name="gradient_merge_step", - shape=[1], - value=int(0), - dtype='int32', - persistable=True, - force_cpu=True) + step_var = layers.create_global_var( + name="gradient_merge_step", + shape=[1], + value=int(0), + dtype='int32', + persistable=True, + force_cpu=True, + ) set_var_dist_attr(dist_context, step_var, [-1], world_process_group.ranks) - cond_var = main_block.create_var(name="gradient_merge_cond", - shape=[1], - dtype='bool') + cond_var = main_block.create_var( + name="gradient_merge_cond", shape=[1], dtype='bool' + ) set_var_dist_attr(dist_context, cond_var, [-1], world_process_group.ranks) - with device_guard("cpu"): + with paddle.static.device_guard("cpu"): # step_var += 1 - increment_op = main_block.append_op(type='increment', - inputs={'X': [step_var]}, - outputs={'Out': [step_var]}, - attrs={ - 'step': float(1.0), - OP_ROLE_KEY: OpRole.Backward - }) + increment_op = main_block.append_op( + type='increment', + inputs={'X': [step_var]}, + outputs={'Out': [step_var]}, + attrs={'step': float(1.0), OP_ROLE_KEY: OpRole.Backward}, + ) naive_set_dist_op_attr_for_program_by_mesh_and_mapping( - increment_op, world_process_group.ranks, [-1], dist_context) + increment_op, world_process_group.ranks, [-1], dist_context + ) # step_var %= k_step - elementwise_mod_op = main_block.append_op(type='elementwise_mod', - inputs={ - 'X': step_var, - 'Y': k_step_var - }, - outputs={'Out': step_var}, - attrs={ - 'axis': -1, - 'use_mkldnn': False, - OP_ROLE_KEY: - OpRole.Backward - }) + elementwise_mod_op = main_block.append_op( + type='elementwise_mod', + inputs={'X': step_var, 'Y': k_step_var}, + outputs={'Out': step_var}, + attrs={ + 'axis': -1, + 'use_mkldnn': False, + OP_ROLE_KEY: OpRole.Backward, + }, + ) naive_set_dist_op_attr_for_program_by_mesh_and_mapping( - elementwise_mod_op, world_process_group.ranks, [-1], dist_context) + elementwise_mod_op, world_process_group.ranks, [-1], dist_context + ) # cond_var = (step_var == 0) - equal_op = main_block.append_op(type='equal', - inputs={ - 'X': step_var, - 'Y': zero_var - }, - outputs={'Out': cond_var}, - attrs={OP_ROLE_KEY: OpRole.Backward}) + equal_op = main_block.append_op( + type='equal', + inputs={'X': step_var, 'Y': zero_var}, + outputs={'Out': cond_var}, + attrs={OP_ROLE_KEY: OpRole.Backward}, + ) naive_set_dist_op_attr_for_program_by_mesh_and_mapping( - equal_op, world_process_group.ranks, [-1], dist_context) + equal_op, world_process_group.ranks, [-1], dist_context + ) return cond_var def _append_gradient_merge_backward_op( - main_program, startup_program, params_grads: List[Tuple[Any, Any]], - dist_context) -> Tuple[List[Tuple[Any, Any]], Dict[str, Any]]: + main_program, + startup_program, + params_grads, + master_grad, + dist_context, +): + main_block = main_program.global_block() startup_block = startup_program.global_block() @@ -148,149 +183,260 @@ def _append_gradient_merge_backward_op( for param, grad in params_grads: param_name = param.name param_var = main_block.var(param_name) - assert (param_var is not None) - ref_dist_attr = dist_context.get_tensor_dist_attr_for_program(param_var) - assert ref_dist_attr is not None - gradient_merge_var = main_block.create_var(name=param_name + - "@GRAD@GradientMerge", - shape=param_var.shape, - dtype=param_var.dtype, - persistable=True) - ref_process_mesh = ref_dist_attr.process_mesh - ref_dims_mapping = ref_dist_attr.dims_mapping + assert param_var is not None - set_var_dist_attr(dist_context, gradient_merge_var, ref_dims_mapping, - ref_process_mesh) + dst_dtype = ( + core.VarDesc.VarType.FP32 if master_grad else param_var.dtype + ) + # 2.1 crate param@GRAD@MERGE var in startup_block startup_gradient_merge_var = startup_block.create_var( - name=param_name + "@GRAD@GradientMerge", + name=param_name + "@GRAD@MERGED", + shape=param_var.shape, + dtype=dst_dtype, + persistable=True, + ) + startup_block.append_op( + type="fill_constant", + outputs={"Out": startup_gradient_merge_var}, + attrs={ + "shape": param_var.shape, + "dtype": dst_dtype, + "value": float(0), + }, + ) + + # 2.2 crate param@GRAD@MERGE var in main_block + ref_dist_attr = dist_context.get_tensor_dist_attr_for_program(param_var) + assert ref_dist_attr is not None + gradient_merge_var = main_block.create_var( + name=param_name + "@GRAD@MERGED", shape=param_var.shape, - dtype=param_var.dtype, - persistable=True) - startup_block.append_op(type="fill_constant", - outputs={"Out": startup_gradient_merge_var}, - attrs={ - "shape": param_var.shape, - "dtype": param_var.dtype, - "value": float(0), - }) - - # grad_merge += grad - new_grad_op = main_block.append_op(type="elementwise_add", - inputs={ - 'X': grad, - 'Y': gradient_merge_var - }, - outputs={'Out': gradient_merge_var}, - attrs={ - 'axis': -1, - 'use_mkldnn': False, - OP_ROLE_KEY: OpRole.Backward - }) + dtype=dst_dtype, + persistable=True, + ) + ref_process_mesh = ref_dist_attr.process_mesh + ref_dims_mapping = ref_dist_attr.dims_mapping + set_var_dist_attr( + dist_context, gradient_merge_var, ref_dims_mapping, ref_process_mesh + ) + + # 2.3 grad_merge += grad + grad_name = grad.name + if grad.dtype != dst_dtype: + cast_grad_name = grad_name + "@TMP" + cast_grad_var = main_block.create_var( + name=cast_grad_name, + shape=grad.shape, + dtype=dst_dtype, + persistable=False, + stop_gradient=grad.stop_gradient, + ) + set_var_dist_attr( + dist_context, cast_grad_var, ref_dims_mapping, ref_process_mesh + ) + cast_op = main_block.append_op( + type="cast", + inputs={"X": grad}, + outputs={"Out": cast_grad_var}, + attrs={ + "in_dtype": grad.dtype, + "out_dtype": dst_dtype, + OP_ROLE_KEY: OpRole.Backward, + }, + ) + naive_set_dist_op_attr_for_program_by_mesh_and_mapping( + cast_op, ref_process_mesh, ref_dims_mapping, dist_context + ) + grad = cast_grad_var + + new_grad_op = main_block.append_op( + type="elementwise_add", + inputs={'X': grad, 'Y': gradient_merge_var}, + outputs={'Out': gradient_merge_var}, + attrs={ + 'axis': -1, + 'use_mkldnn': False, + OP_ROLE_KEY: OpRole.Backward, + }, + ) new_params_to_grads.append([param, gradient_merge_var]) - grad_to_gradient_merge[grad.name] = gradient_merge_var.name + grad_to_gradient_merge[grad_name] = gradient_merge_var.name naive_set_dist_op_attr_for_program_by_mesh_and_mapping( - new_grad_op, ref_process_mesh, ref_dims_mapping, dist_context) + new_grad_op, ref_process_mesh, ref_dims_mapping, dist_context + ) + return new_params_to_grads, grad_to_gradient_merge -def _create_cond_block_and_update_optimizer( - main_program, cond_var, new_params_to_grads: List[Tuple[Any, Any]], - grad_to_gradient_merge: Dict[str, str], optimize_ops_desc: List[Any], - k_steps, avg): +def _rename_arg_names(op_desc, var_name_dict): + for input_name in op_desc.input_arg_names(): + if input_name in var_name_dict: + op_desc._rename_input(input_name, var_name_dict[input_name]) + for output_name in op_desc.output_arg_names(): + if output_name in var_name_dict: + op_desc._rename_output(output_name, var_name_dict[output_name]) + + +def _create_cond_block_and_update_optimizer( + main_program, + cond_var, + params_grads, + new_params_to_grads, + grad_to_gradient_merge, + optimize_ops_desc, + allreduce_sum_desc, + k_steps, + avg, + master_grad, +): def true_apply_gradient(): cur_block_idx = main_program.current_block_idx cur_block = main_program.current_block() # cur_block's forward_block & backward_block is itself cur_block._set_forward_block_idx(cur_block_idx) - op_maker = core.op_proto_and_checker_maker + + # record grads_name to insert c_allreduce_sum op + grads_name = [grad.name for _, grad in params_grads] + # append c_allreduce_sum ops and scale ops + for op_desc in allreduce_sum_desc: + outputs_name = op_desc.output_arg_names() + assert len(outputs_name) == 1 + if outputs_name[0] in grads_name: + new_op_desc = cur_block.desc.append_op() + new_op_desc.copy_from(op_desc) + _rename_arg_names(new_op_desc, grad_to_gradient_merge) + new_op_desc._set_attr(OP_ROLE_KEY, OpRole.Optimize) + cur_block._sync_with_cpp() + if avg: - for param, new_grad in new_params_to_grads: + for _, new_grad in new_params_to_grads: # grad /= k_steps - cur_block.append_op(type='scale', - inputs={'X': new_grad}, - outputs={'Out': new_grad}, - attrs={ - 'scale': 1.0 / k_steps, - 'bias': 0.0, - 'bias_after_scale': False - }) + cur_block.append_op( + type='scale', + inputs={'X': new_grad}, + outputs={'Out': new_grad}, + attrs={ + 'scale': 1.0 / k_steps, + 'bias': 0.0, + 'bias_after_scale': False, + }, + ) new_grad.op._set_attr(OP_ROLE_KEY, OpRole.Optimize) + cast_name_dict = {} # append optimizer ops for op_desc in optimize_ops_desc: + if master_grad and is_gradient_clip_op(op_desc): + if op_desc.type() == "cast": + if ( + op_desc.attr('out_dtype') in [4, 22] + and op_desc.attr('in_dtype') == 5 + ): + cast_name_dict[ + op_desc.output_arg_names()[0] + ] = op_desc.input_arg_names()[0] + elif ( + op_desc.attr('in_dtype') in [4, 22] + and op_desc.attr('out_dtype') == 5 + ): + cast_name_dict[ + op_desc.output_arg_names()[0] + ] = op_desc.input_arg_names()[0] + continue + + for out_name in op_desc.output_arg_names(): + out_var = cur_block._var_recursive(out_name) + out_var.desc.set_dtype(core.VarDesc.VarType.FP32) + + _rename_arg_names(op_desc, cast_name_dict) + new_op_desc = cur_block.desc.append_op() new_op_desc.copy_from(op_desc) - #update input/output - for input_name in new_op_desc.input_arg_names(): - if input_name in grad_to_gradient_merge: - new_op_desc._rename_input( - input_name, grad_to_gradient_merge[input_name]) - - for output_name in new_op_desc.output_arg_names(): - if output_name in grad_to_gradient_merge: - new_op_desc._rename_output( - output_name, grad_to_gradient_merge[output_name]) + # update input/output + _rename_arg_names(new_op_desc, grad_to_gradient_merge) # remove op_role_var - if new_op_desc.has_attr(op_maker.kOpRoleVarAttrName()): - new_op_desc.remove_attr(op_maker.kOpRoleVarAttrName()) + if new_op_desc.has_attr(OP_ROLE_VAR_KEY): + new_op_desc.remove_attr(OP_ROLE_VAR_KEY) # op's update Grad if core.grad_var_suffix() in new_op_desc.input_arg_names(): grad_value = new_op_desc.input("Grad")[0] # TODO FIXME(xym) support fp16 - grad_merge_value = grad_value + '@GradientMerge' + grad_merge_value = grad_value + '@MERGED' new_op_desc.set_input("Grad", [grad_merge_value]) - main_program.global_block()._sync_with_cpp() cur_block._sync_with_cpp() # clear gradient_merge_vars - for param, new_grad in new_params_to_grads: - layers.fill_constant(shape=new_grad.shape, - dtype=new_grad.dtype, - value=0.0, - out=new_grad) - new_grad.op._set_attr(OP_ROLE_KEY, op_maker.OpRole.Optimize) + for _, new_grad in new_params_to_grads: + layers.fill_constant( + shape=new_grad.shape, + dtype=new_grad.dtype, + value=0.0, + out=new_grad, + ) + new_grad.op._set_attr(OP_ROLE_KEY, OpRole.Optimize) layers.cond(cond_var, true_fn=true_apply_gradient, false_fn=None) cond_op = main_program.global_block().ops[-1] cond_op._set_attr(OP_ROLE_KEY, OpRole.Optimize) -def parse_program(main_program, startup_program, params_grads, k_steps, avg, - dist_context): - # 1 remove optimizer_op from main_program - optimize_ops_desc = _remove_and_get_optimizer_op(main_program, dist_context) +def parse_program( + main_program, + startup_program, + params_grads, + k_steps, + avg, + master_grad, + dist_context, +): + # 1 remove optimizer_op, allreduce_sum_op and scale_op from main_program + optimize_ops_desc, allreduce_sum_desc = _remove_and_get_ops( + main_program, dist_context + ) # back to block 0 main_program._rollback() # 2 append gradient merge backward op to main_program - new_params_to_grads, grad_to_gradient_merge = _append_gradient_merge_backward_op( - main_program, startup_program, params_grads, dist_context) + ( + new_params_to_grads, + grad_to_gradient_merge, + ) = _append_gradient_merge_backward_op( + main_program, startup_program, params_grads, master_grad, dist_context + ) # 3 create gradient_merge_cond - cond_var = _get_gm_cond_var(main_program, k_steps, dist_context) + cond_var = _create_gm_cond_var(main_program, k_steps, dist_context) # 4 create ConditionalBlock and append gradient merge optimizer ops - _create_cond_block_and_update_optimizer(main_program, cond_var, - new_params_to_grads, - grad_to_gradient_merge, - optimize_ops_desc, k_steps, avg) + _create_cond_block_and_update_optimizer( + main_program, + cond_var, + params_grads, + new_params_to_grads, + grad_to_gradient_merge, + optimize_ops_desc, + allreduce_sum_desc, + k_steps, + avg, + master_grad, + ) @register_pass("auto_parallel_gradient_merge_pass") class GradientMergePass(PassBase): - def __init__(self): super(GradientMergePass, self).__init__() self.set_attr("k_steps", -1) self.set_attr("avg", True) + self.set_attr("master_grad", False) def _check_self(self): if self.get_attr("k_steps") < 1: @@ -306,10 +452,20 @@ class GradientMergePass(PassBase): def _apply_single_impl(self, main_program, startup_program, context): k_steps = self.get_attr("k_steps", -1) avg = self.get_attr("avg", False) + master_grad = self.get_attr("master_grad", False) dist_context = self.get_attr("dist_context") params_grads = self.get_attr("params_grads") + # TODO(zyl): make master_grad configurable + master_grad = True with paddle.static.program_guard(main_program, startup_program): - parse_program(main_program, startup_program, params_grads, k_steps, - avg, dist_context) + parse_program( + main_program, + startup_program, + params_grads, + k_steps, + avg, + master_grad, + dist_context, + ) main_program._sync_with_cpp() diff --git a/python/paddle/distributed/passes/auto_parallel_pipeline.py b/python/paddle/distributed/passes/auto_parallel_pipeline.py index 982fd7c228a..318ea03f498 100644 --- a/python/paddle/distributed/passes/auto_parallel_pipeline.py +++ b/python/paddle/distributed/passes/auto_parallel_pipeline.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from logging import exception import os from paddle.fluid import core @@ -26,6 +25,7 @@ from paddle.distributed.auto_parallel.utils import ( is_backward_op, is_optimize_op, is_lr_sched_op, + is_fillconst_op_for_micro_batch, ) @@ -38,6 +38,12 @@ __not_shape_var_type__ = [ ] +def is_reshard_op(op): + return op.has_attr('op_namescope') and "/auto_parallel/reshard" in op.attr( + 'op_namescope' + ) + + @register_pass("auto_parallel_pipeline") class PipelinePass(PassBase): def __init__(self): @@ -59,8 +65,17 @@ class PipelinePass(PassBase): self._gen_bsz = self.get_attr("generation_batch_size") self._program = main_program + self._cur_rank = int(os.getenv("PADDLE_TRAINER_ID", 0)) + trainer_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS", "").split(',') + self._nrank = len(trainer_endpoints) + + # compute current pp stage + self._pp_stages = len(self._dist_context.process_meshes) + self._cur_pp_stage = self._get_pp_stage(self._cur_rank) + if self._mode == "1F1B": - raise NotImplementedError("1F1B has not been implemented") + self._insert_sync_ops_for_1f1b() + self._task_1f1b() elif self._mode == "F-Then-B": raise NotImplementedError("F-Then-B has not been implemented") elif self._mode == "stream": @@ -103,6 +118,93 @@ class PipelinePass(PassBase): block._sync_with_cpp() + def _insert_sync_ops_for_1f1b(self): + """ + This implementation refers to lots of Paddle/python/paddle/fluid/optimizer.py. + The difference between this function with 'PipelineOptimizer' is that + 'send_v2' op and 'recv_v2' op have been inserted in program by 'reshard'. + """ + + for block in self._program.blocks: + offset = 0 + first_optimize_index = None + for index, op in enumerate(list(block.ops)): + if is_optimize_op(op): + first_optimize_index = index + break + + # insert sync ops + for index, op in enumerate(list(block.ops)): + if op.type == 'send_v2': + # step1: set 'use_calc_stream' False + op._set_attr("use_calc_stream", False) + op_role = op.attr('op_role') + ring_id = op.attr('ring_id') + # step2: insert 'c_sync_calc_stream' op before 'send_v2' op + var_name = op.input_arg_names[0] + var = block.var(var_name) + block._insert_op_without_sync( + index=index + offset, + type="c_sync_calc_stream", + inputs={'X': [var]}, + outputs={'Out': [var]}, + attrs={'op_role': op_role}, + ) + offset += 1 + # step3: insert 'c_sync_comm_stream' op after 'send_v2' op or + # before the first optimize op + if int(op_role) == int(OpRole.Backward): + index = first_optimize_index + offset + new_op_role = OpRole.Optimize + else: + index = index + offset + 1 + new_op_role = OpRole.Backward + sync_comm_op = block._insert_op_without_sync( + index=index, + type="c_sync_comm_stream", + inputs={'X': [var]}, + outputs={'Out': [var]}, + attrs={ + 'op_role': new_op_role, + 'ring_id': ring_id, + }, + ) + # step4: If 'send_v2' op in forward parse, set 'pipeline_flag' to distinguish + # whether the 'c_sync_comm_stream' op is inserted for pipeline. + if int(op_role) == int(OpRole.Forward): + sync_comm_op._set_attr('pipeline_flag', '') + offset += 1 + block._sync_with_cpp() + + offset = 0 + backward_recv_index = None + for index, op in enumerate(block.ops): + if op.type == "recv_v2" and is_backward_op(op): + backward_recv_index = index + break + if backward_recv_index is None: + continue + + # replace 'c_sync_comm_stream' op with 'nop' op + for index, op in enumerate(list(block.ops)): + if index >= backward_recv_index: + break + if op.type == 'c_sync_comm_stream' and op.has_attr( + 'pipeline_flag' + ): + var_name = op.output_arg_names[0] + var = block.var(var_name) + block._remove_op(index + offset, sync=False) + offset -= 1 + block._insert_op_without_sync( + index=backward_recv_index, + type="nop", + inputs={'X': [var]}, + outputs={'Out': [var]}, + attrs={'op_role': OpRole.Backward}, + ) + block._sync_with_cpp() + def _create_param(self, dst_block, src_var): copied_kwargs = {} copied_kwargs['trainable'] = src_var.trainable @@ -190,16 +292,185 @@ class PipelinePass(PassBase): break return pp_idx - def _task_stream(self): - cur_rank = int(os.getenv("PADDLE_TRAINER_ID", 0)) - trainer_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS", "").split(',') - nrank = len(trainer_endpoints) - num_of_functionality = 5 + def _task_1f1b(self): + # create fwd, bwd, opt program with op_role + num_of_functionality = 4 + lr_prog = Program() + fwd_prog = Program() + bwd_prog = Program() + opt_prog = Program() + + for idx, src_block in enumerate(self._program.blocks): + if idx == 0: + lr_block = lr_prog.block(0) + fwd_block = fwd_prog.block(0) + bwd_block = bwd_prog.block(0) + opt_block = opt_prog.block(0) + else: + lr_block = lr_prog._create_block( + parent_idx=src_block.parent_idx + ) + fwd_block = fwd_prog._create_block( + parent_idx=src_block.parent_idx + ) + bwd_block = bwd_prog._create_block( + parent_idx=src_block.parent_idx + ) + opt_block = opt_prog._create_block( + parent_idx=src_block.parent_idx + ) + lr_block._set_forward_block_idx(src_block.forward_block_idx) + fwd_block._set_forward_block_idx(src_block.forward_block_idx) + bwd_block._set_forward_block_idx(src_block.forward_block_idx) + opt_block._set_forward_block_idx(src_block.forward_block_idx) + + # split the program based on the op_role + for op in src_block.ops: + if is_lr_sched_op(op): + self._create_program(src_block, lr_block, op) + if is_forward_op(op) or is_fillconst_op_for_micro_batch(op): + self._create_program(src_block, fwd_block, op) + elif is_backward_op(op): + self._create_program(src_block, bwd_block, op) + elif is_optimize_op(op): + self._create_program(src_block, opt_block, op) + else: + raise ValueError( + "The op role: " + + str(op.attr('op_role')) + + " isn't one of LRSched, Forward, Backward or Optimizer." + ) - # compute current pp stage - pp_stages = len(self._dist_context.process_meshes) - cur_pp_stage = self._get_pp_stage(cur_rank) + lr_prog._sync_with_cpp() + fwd_prog._sync_with_cpp() + bwd_prog._sync_with_cpp() + opt_prog._sync_with_cpp() + + lr_prog._rollback() + fwd_prog._rollback() + bwd_prog._rollback() + opt_prog._rollback() + + # Create task nodes. + lr_task_node = TaskNode( + rank=self._cur_rank, + max_run_times=self._acc_steps, + program=lr_prog, + task_id=int(self._cur_rank * num_of_functionality + 0), + node_type="Amplifier", + lazy_initialize=True, + ) + lr_task_node.set_run_pre_steps(self._acc_steps) + fwd_task_node = TaskNode( + rank=self._cur_rank, + max_run_times=self._acc_steps, + program=fwd_prog, + task_id=int(self._cur_rank * num_of_functionality + 1), + node_type="Compute", + lazy_initialize=True, + ) + bwd_task_node = TaskNode( + rank=self._cur_rank, + max_run_times=self._acc_steps, + program=bwd_prog, + task_id=int(self._cur_rank * num_of_functionality + 2), + node_type="Compute", + lazy_initialize=True, + ) + opt_task_node = TaskNode( + rank=self._cur_rank, + max_run_times=self._acc_steps, + program=opt_prog, + task_id=int(self._cur_rank * num_of_functionality + 3), + node_type="Amplifier", + lazy_initialize=True, + ) + opt_task_node.set_run_pre_steps(self._acc_steps) + opt_task_node.set_run_at_offset(self._acc_steps - 1) + task_nodes = { + "lr": lr_task_node, + "fwd": fwd_task_node, + "bwd": bwd_task_node, + "opt": opt_task_node, + } + + # get upstream ranks and downstream ranks of cur_rank + up_down_streams = self._dist_context.up_down_streams + pp_upstream = up_down_streams.ups(self._cur_rank) + pp_downstream = up_down_streams.downs(self._cur_rank) + + # set upstream/downstream for task_nodes of cur_rank + for i, (task_role, task_node) in enumerate(task_nodes.items()): + + cur_id = int(self._cur_rank * num_of_functionality + i) + ups = [] + downs = [] + + # set upstream/downstream and buffersize in pipeline stage + pp_buff_size = int(self._pp_stages - self._cur_pp_stage) + prev_id = cur_id - 1 + next_id = cur_id + 1 + if task_role != "lr": + buf_size = pp_buff_size if task_role == "bwd" else 2 + ups.append((prev_id, buf_size)) + if task_role != "opt": + buf_size = pp_buff_size if task_role == "fwd" else 2 + downs.append((next_id, buf_size)) + + # set upstream/downstream and buffersize cross pipeline stage + for upstream in pp_upstream: + upstream_id = int(upstream * num_of_functionality + i) + if task_role == "fwd": + if upstream != -1: + ups.append((upstream_id, 2)) + elif task_role == "bwd": + if upstream != -1: + downs.append((upstream_id, 2)) + for downstream in pp_downstream: + downstream_id = int(downstream * num_of_functionality + i) + if task_role == "fwd": + if downstream != -1: + downs.append((downstream_id, 2)) + elif task_role == "bwd": + if downstream != -1: + ups.append((downstream_id, 2)) + + for up in ups: + print( + "Task:", + cur_id, + "'s upstream includes:", + up[0], + ", buffer size is:", + up[1], + ) + task_node.add_upstream_task(up[0], up[1]) + for down in downs: + print( + "Task:", + cur_id, + "'s downstream includes:", + down[0], + ", buffer size is:", + down[1], + ) + task_node.add_downstream_task(down[0], down[1]) + + # record global message: task_id_to_rank + task_id_to_rank = {} + for i in range(self._nrank): + for j in range(num_of_functionality): + task_id_to_rank[int(i * num_of_functionality + j)] = i + self._program._pipeline_opt = {} + self._program._pipeline_opt['fleet_opt'] = { + "tasks": list(task_nodes.values()), + "task_id_to_rank": task_id_to_rank, + "num_micro_batches": self._acc_steps, + } + + def _task_stream(self): + num_of_functionality = 5 start_prog = Program() cond_prog = Program() end_prog = Program() @@ -207,6 +478,7 @@ class PipelinePass(PassBase): recv_prog = Program() cond_var_name = None + # record the varnames related to the while cond vars and communicate by nccl send_vars_name = set() recv_vars_name = dict() for ib, src_block in enumerate(self._program.blocks): @@ -231,38 +503,23 @@ class PipelinePass(PassBase): src_block, end_block, op, force_create=True ) elif ib == 1: + # NOTE: The while block will be split to two separate blocks. + # The send_block: + # include all ops about tansformer generation + # execlude the nccl op about the while cond var + # The recv_block: + # include all ops about the while cond var + # execlude the nccl op about the while cond var + # the nccl op about cond var: + # put these varnames in the task node and do communication by brpc send_block = send_prog.block(0) recv_block = recv_prog.block(0) is_after_send_op = False is_after_recv_op = False - for op in src_block.ops: + for i, op in enumerate(src_block.ops): if op.type == "send_v2" and not is_after_send_op: is_after_send_op = True - if cur_pp_stage == pp_stages - 1: - if op.type in ["c_sync_calc_stream", "nop"]: - continue - if ( - op.type not in ["recv_2", "assign"] - and op.has_attr('op_namescope') - and "/auto_parallel/reshard" - in op.attr('op_namescope') - ): - if ( - len(op.desc.input_arg_names()) > 0 - and "@RESHARD" - not in op.desc.input_arg_names()[0] - ): - send_vars_name.add( - op.desc.input_arg_names()[0] - ) - continue - if op.type == "send_v2": - continue - self._create_program( - src_block, send_block, op, force_create=True - ) - continue if ( is_after_send_op @@ -270,45 +527,21 @@ class PipelinePass(PassBase): and op.type == "recv_v2" ): is_after_recv_op = True - if op.has_attr( - 'op_namescope' - ) and "/auto_parallel/reshard" in op.attr( - 'op_namescope' - ): - var_name = op.desc.output_arg_names()[0] - index = var_name.find("@") - if index > 0: - old_var_name = var_name[:index] - else: - old_var_name = var_name - recv_vars_name[var_name] = old_var_name - if not src_block._find_var_recursive(old_var_name): - src_var = src_block._var_recursive(var_name) - recv_block.create_var( - type=src_var.type, - name=old_var_name, - shape=src_var.shape, - dtype=src_var.dtype, - lod_level=src_var.lod_level, - persistable=src_var.persistable, - error_clip=src_var.error_clip, - stop_gradient=src_var.stop_gradient, - is_data=src_var.is_data, - belong_to_optimizer=src_var.belong_to_optimizer, - ) - continue - - self._create_program( - src_block, recv_block, op, force_create=True - ) - continue if not is_after_send_op or not is_after_recv_op: - if cur_pp_stage == pp_stages - 1: - if op.type in ["c_sync_calc_stream", "nop"]: + if self._cur_pp_stage == self._pp_stages - 1: + # the c_sync_calc_stream about c_allgather cannot be removed + if ( + op.type == "c_sync_calc_stream" + and src_block.ops[i + 1].type == "send_v2" + ): + continue + if op.type == "nop": continue + # HACKCODE: the varname of send_v2 op, cast op should be recorded for brpc comm if ( - op.type not in ["recv_2", "assign"] + op.type + not in ["recv_2", "assign", "c_allgather"] and op.has_attr('op_namescope') and "/auto_parallel/reshard" in op.attr('op_namescope') @@ -327,13 +560,16 @@ class PipelinePass(PassBase): self._create_program( src_block, send_block, op, force_create=True ) + continue if is_after_send_op and is_after_recv_op: + # HACKCODE: the varname of recv_v2 op, assign op should be recorded for brpc comm if op.has_attr( 'op_namescope' ) and "/auto_parallel/reshard" in op.attr( 'op_namescope' ): + # remove the suffix of "@RESHARD" var_name = op.desc.output_arg_names()[0] index = var_name.find("@") if index > 0: @@ -365,6 +601,7 @@ class PipelinePass(PassBase): self._create_program( src_block, recv_block, op, force_create=True ) + continue else: raise Exception("Only support generation condition.") @@ -406,52 +643,52 @@ class PipelinePass(PassBase): vars_to_shape = recv_task_node_var_shape start_task_node = TaskNode( - rank=cur_rank, + rank=self._cur_rank, max_run_times=self._acc_steps, node_type="Start", - task_id=int(cur_rank * num_of_functionality + 0), + task_id=int(self._cur_rank * num_of_functionality + 0), program=start_prog, lazy_initialize=True, ) cond_task_node = TaskNode( - rank=cur_rank, + rank=self._cur_rank, max_run_times=self._acc_steps, node_type="Cond", - task_id=int(cur_rank * num_of_functionality + 1), + task_id=int(self._cur_rank * num_of_functionality + 1), program=cond_prog, cond_var_name=cond_var_name, lazy_initialize=True, ) send_task_node = TaskNode( - rank=cur_rank, + rank=self._cur_rank, max_run_times=self._acc_steps, node_type="Compute", - task_id=int(cur_rank * num_of_functionality + 2), + task_id=int(self._cur_rank * num_of_functionality + 2), program=send_prog, lazy_initialize=True, ) recv_task_node = TaskNode( - rank=cur_rank, + rank=self._cur_rank, max_run_times=self._acc_steps, node_type="Compute", - task_id=int(cur_rank * num_of_functionality + 3), + task_id=int(self._cur_rank * num_of_functionality + 3), program=recv_prog, lazy_initialize=True, vars_to_dtype=vars_to_dtype, vars_to_shape=vars_to_shape, ) end_task_node = TaskNode( - rank=cur_rank, + rank=self._cur_rank, max_run_times=self._acc_steps, node_type="Compute", - task_id=int(cur_rank * num_of_functionality + 4), + task_id=int(self._cur_rank * num_of_functionality + 4), program=end_prog, lazy_initialize=True, ) # add dependencies for task nodes intra stage inf = -1 - pp_buff_size = int(pp_stages - cur_pp_stage) + pp_buff_size = int(self._pp_stages - self._cur_pp_stage) start_task_node.add_downstream_task( cond_task_node.task_id(), self._gen_bsz ) @@ -560,12 +797,12 @@ class PipelinePass(PassBase): # add dependencies for task nodes inter stage # get upstream ranks and downstream ranks of cur_rank up_down_streams = self._dist_context.up_down_streams - pp_upstream_ranks = up_down_streams.ups(cur_rank) - pp_downstream_ranks = up_down_streams.downs(cur_rank) + pp_upstream = up_down_streams.ups(self._cur_rank) + pp_downstream = up_down_streams.downs(self._cur_rank) - for upstream_rank in pp_upstream_ranks: + for upstream_rank in pp_upstream: upstream_pp_stage = self._get_pp_stage(upstream_rank) - if upstream_pp_stage < pp_stages - 1: + if upstream_pp_stage < self._pp_stages - 1: upstream_task_id = int(upstream_rank * num_of_functionality + 2) send_task_node.add_upstream_task(upstream_task_id) print( @@ -587,8 +824,8 @@ class PipelinePass(PassBase): ", buffer size is:", 2, ) - for downstream_rank in pp_downstream_ranks: - if cur_pp_stage < pp_stages - 1: + for downstream_rank in pp_downstream: + if self._cur_pp_stage < self._pp_stages - 1: downstream_task_id = int( downstream_rank * num_of_functionality + 2 ) @@ -616,7 +853,7 @@ class PipelinePass(PassBase): ) task_id_to_rank = {} - for i in range(nrank): + for i in range(self._nrank): for j in range(num_of_functionality): task_id_to_rank[int(i * num_of_functionality + j)] = i self._program._pipeline_opt = { diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/1F1B_pass_unittest.py b/python/paddle/fluid/tests/unittests/auto_parallel/1F1B_pass_unittest.py new file mode 100644 index 00000000000..3261ed6418d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/1F1B_pass_unittest.py @@ -0,0 +1,126 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import random +import numpy as np +import paddle + +from paddle.distributed.fleet import auto +from paddle.fluid.dygraph.parallel import ParallelEnv +from get_gpt_model import generate_model, FakeDataset + +paddle.enable_static() + + +def apply_pass(use_1f1b=False): + strategy = auto.Strategy() + strategy.auto_mode = "semi" + strategy.reinit = True + + if use_1f1b: + pipeline = strategy.pipeline + pipeline.enable = True + pipeline.schedule_mode = "1F1B" + pipeline.accumulate_steps = 2 + else: + gradient_merge = strategy.gradient_merge + gradient_merge.enable = True + gradient_merge.k_steps = 2 + gradient_merge.avg = True + + amp = strategy.amp + amp.enable = True + amp.custom_white_list = ['softmax', 'layer_norm', 'gelu'] + amp.custom_black_list = [ + 'c_softmax_with_cross_entropy', + 'elementwise_div', + 'reduce_sum', + ] + amp.init_loss_scaling = 32768 + amp.use_fp16_guard = False + amp.use_pure_fp16 = True + + return strategy + + +def reset_prog(): + paddle.fluid.framework.switch_main_program(paddle.static.Program()) + paddle.fluid.framework.switch_startup_program(paddle.static.Program()) + + +class Test1F1BPass(unittest.TestCase): + def setUp(self): + self.rtol = 1e-5 + self.atol = 1e-8 + self.batch_size = 2 + self.batch_num = 10 + self.clip_norm = 0.2 + self.dataset = FakeDataset(self.batch_size * self.batch_num) + + def init(self, engine): + paddle.seed(2021) + np.random.seed(2021) + random.seed(2021) + paddle.distributed.fleet.init(is_collective=True) + place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id) + engine._executor = paddle.static.Executor(place) + + def get_engine(self, use_1f1b=False): + reset_prog() + + strategy = apply_pass(use_1f1b) + clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm) + opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) + model, loss = generate_model("pp") + + engine = auto.Engine(model, loss, opt, strategy=strategy) + self.init(engine) + return engine + + def check_results(self, ref_losses, check_losses): + np.testing.assert_allclose( + ref_losses, + check_losses, + rtol=self.rtol, + atol=self.atol, + err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format( + __class__, ref_losses, check_losses, ref_losses - check_losses + ), + ) + + def test_1f1b_pass(self): + # navie_pp+gradient_merge training + engine_pp = self.get_engine() + history = engine_pp.fit( + self.dataset, 3, batch_size=self.batch_size, log_freq=1 + ) + assert engine_pp._strategy.pipeline.enable == False + + # pp2 1f1b merge training + engine_1f1b = self.get_engine(True) + history = engine_1f1b.fit( + self.dataset, 3, batch_size=self.batch_size, log_freq=1 + ) + assert engine_1f1b._strategy.pipeline.enable == True + + # NOTE: every sample data from dataset is all the same + if paddle.distributed.get_rank() == 1: + losses_pp = np.array(history.history["loss"]) + losses_1f1b = np.array(history.history["loss"]) + self.check_results(losses_pp, losses_1f1b) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt index ee3f855b1c4..6902c556298 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt @@ -69,6 +69,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_engine_callbacks MODULES test_engine_callbacks) set_tests_properties(test_engine_callbacks PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) + py_test_modules(test_pass_1F1B MODULES test_pass_1F1B) + set_tests_properties(test_pass_1F1B PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" + TIMEOUT 50) py_test_modules(test_parallel_tuner MODULES test_parallel_tuner ENVS ${dist_ENVS}) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py b/python/paddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py index 71f16f97206..457820af93b 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py @@ -89,6 +89,12 @@ def generate_model(strategy, dropout_prob=0.0): modeling._global_parallel_strategy = "mp" elif strategy == "dp": modeling._global_parallel_strategy = "dp" + elif strategy == "pp": + modeling._global_parallel_strategy = "pp" + modeling.PP_MESH_LIST = [ + auto.ProcessMesh(mesh=[0]), + auto.ProcessMesh(mesh=[1]), + ] else: raise ValueError("Only support serial, mp2 and dp2.") @@ -108,6 +114,7 @@ def generate_model(strategy, dropout_prob=0.0): eos_token_id=7, bos_token_id=0, eol_token_id=3, + pp_degree=2 if strategy == "pp" else None, ) model = GPTForPretraining( gpt, vocab_size=1000, hidden_size=64, initializer_range=0.02 diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/gradient_merge_pass_unittest.py b/python/paddle/fluid/tests/unittests/auto_parallel/gradient_merge_pass_unittest.py index 58bc1143885..26fc3026f32 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/gradient_merge_pass_unittest.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/gradient_merge_pass_unittest.py @@ -19,7 +19,7 @@ import paddle from paddle.distributed.fleet import auto from paddle.fluid.dygraph.parallel import ParallelEnv -from get_gpt_model import generate_model, create_data_holder, FakeDataset +from get_gpt_model import generate_model, FakeDataset paddle.enable_static() @@ -28,12 +28,25 @@ def apply_pass(use_gradient_merge=False): strategy = auto.Strategy() strategy.auto_mode = "semi" strategy.reinit = True + if use_gradient_merge: gradient_merge = strategy.gradient_merge gradient_merge.enable = True gradient_merge.k_steps = 4 gradient_merge.avg = True + amp = strategy.amp + amp.enable = True + amp.custom_white_list = ['softmax', 'layer_norm', 'gelu'] + amp.custom_black_list = [ + 'c_softmax_with_cross_entropy', + 'elementwise_div', + 'reduce_sum', + ] + amp.init_loss_scaling = 32768 + amp.use_fp16_guard = False + amp.use_pure_fp16 = True + return strategy @@ -88,6 +101,7 @@ class TestGradientMergePass(unittest.TestCase): history = dp_engine.fit( self.dataset, 3, batch_size=self.batch_size, log_freq=1 ) + assert dp_engine._strategy.gradient_merge.enable == False dp_losses = np.array(history.history["loss"]) # dp2 gradient merge training @@ -95,6 +109,7 @@ class TestGradientMergePass(unittest.TestCase): history = gm_engine.fit( self.dataset, 3, batch_size=self.batch_size, log_freq=1 ) + assert gm_engine._strategy.gradient_merge.enable == True gm_losses = np.array(history.history["loss"]) # avg_loss = 0 diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_1F1B.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_1F1B.py new file mode 100644 index 00000000000..9a503af9bf4 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_1F1B.py @@ -0,0 +1,57 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest +import os +import sys +import shutil +import subprocess +from paddle.distributed.fleet.launch_utils import run_with_coverage + + +class Test1F1BPass(unittest.TestCase): + def test_pp2(self): + file_dir = os.path.dirname(os.path.abspath(__file__)) + launch_model_path = os.path.join(file_dir, "1F1B_pass_unittest.py") + + if os.environ.get("WITH_COVERAGE", "OFF") == "ON": + coverage_args = ["-m", "coverage", "run", "--branch", "-p"] + else: + coverage_args = [] + + tmp_dir = tempfile.TemporaryDirectory() + cmd = ( + [sys.executable, "-u"] + + coverage_args + + [ + "-m", + "paddle.distributed.launch", + "--devices", + "0,1", + "--log_dir", + tmp_dir.name, + launch_model_path, + ] + ) + + process = subprocess.Popen(cmd) + process.wait() + self.assertEqual(process.returncode, 0) + + tmp_dir.cleanup() + + +if __name__ == "__main__": + unittest.main() -- GitLab