diff --git a/paddle/fluid/operators/controlflow/fetch_v2_op.cc b/paddle/fluid/operators/controlflow/fetch_v2_op.cc index 9105fe88862966018b749321549e5a9aaeb5cad8..bd0df3ba04e648680c20dd35db74e8f6fdf47761 100644 --- a/paddle/fluid/operators/controlflow/fetch_v2_op.cc +++ b/paddle/fluid/operators/controlflow/fetch_v2_op.cc @@ -140,6 +140,8 @@ class FetchV2Kernel { "operator 'Fetch') of current fetching variable to be " "no less than 0. But received column index = %d.", col)); + VLOG(3) << "Fetch variable " << fetch_var_name << "'s " << col + << " column."; auto *fetch_list = out_var->GetMutable(); diff --git a/python/paddle/distributed/auto_parallel/static/engine.py b/python/paddle/distributed/auto_parallel/static/engine.py index 70df712d3ba221f7705f5c3746e02d8b00bbd0bb..96c289235906dee228df7bdcb8abc6337ffbdb47 100644 --- a/python/paddle/distributed/auto_parallel/static/engine.py +++ b/python/paddle/distributed/auto_parallel/static/engine.py @@ -400,8 +400,11 @@ class Engine: dist_main_block._sync_with_cpp() self._has_prepared_reader[self._mode] = True - # Insert read op to forward TaskNode if 1F1B pass is setted - if self.main_program._pipeline_opt: + # Insert read op to forward TaskNode for fleet executor if 1F1B pass is setted + if ( + self.main_program._pipeline_opt + and not auto_utils.use_new_executor() + ): assert "tasks" in self.main_program._pipeline_opt["fleet_opt"] fleet_opt = self.main_program._pipeline_opt["fleet_opt"] fwd_task = None @@ -471,8 +474,6 @@ class Engine: if var_name not in fetch_names: fetch_names.append(var_name) group_indices.append(fetch_names.index(var_name)) - if not group_indices: - fetch_names.append([]) fetch_indices.append(group_indices) dist_context = self._dist_contexts[mode] diff --git a/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py b/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py index fd5da8193f3abbb478012ffbd4046f3e8a2eaa2e..2cbed1ee39819d43aa25857dcfa8e09cb438fa67 100644 --- a/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py +++ b/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py @@ -25,7 +25,7 @@ from ..random import init_auto_parallel_rng from .partitioner import Partitioner from .process_group import get_world_process_group from .reshard import Resharder -from .utils import set_grad_var_shape +from .utils import set_grad_var_shape, use_new_executor class Parallelizer: @@ -38,6 +38,14 @@ class Parallelizer: self._strategy = self._dist_context.strategy self._logger = get_logger(logging.INFO) + @property + def is_train(self): + return self._mode == "train" + + @property + def is_test(self): + return self._mode in ["eval", "predict"] + def parallel_all(self): world_process_group = get_world_process_group() all_ranks = world_process_group.ranks @@ -50,7 +58,7 @@ class Parallelizer: serial_main_program = self._dist_context.serial_main_program serial_startup_program = self._dist_context.serial_startup_program serial_optimizer = self._dist_context.serial_optimizer - if self._mode == "train" and serial_optimizer: + if self.is_train and serial_optimizer: # Generate backward serial_loss = self._dist_context.serial_loss params_grads = self._generate_backward( @@ -191,8 +199,9 @@ class Parallelizer: time.time() - time0, self._mode ) ) + # Clone program for test - if self._mode != 'train': + if self.is_test: pipeline_opt = dist_main_prog._pipeline_opt dist_main_prog = dist_main_prog.clone(for_test=True) dist_startup_prog = dist_startup_prog.clone(for_test=True) @@ -263,7 +272,7 @@ class Parallelizer: # apply quantization pass # The pass can be applied when mode must be 'train' - if self._mode == 'train' and self._strategy.qat.enable: + if self.is_train and self._strategy.qat.enable: config = copy.deepcopy(self._strategy.qat.to_dict()) config["dist_context"] = self._dist_context config["params_grads"] = params_grads @@ -282,7 +291,7 @@ class Parallelizer: # apply recompute pass # recompute is then train-only optimization - if self._mode == "train" and self._strategy.recompute.enable: + if self.is_train and self._strategy.recompute.enable: config = copy.deepcopy(self._strategy.recompute.to_dict()) config["dist_context"] = self._dist_context config["no_grad_set"] = None @@ -326,7 +335,7 @@ class Parallelizer: ) params_grads = self._pass_context.get_attr("params_grads") - if self._mode == "train": + if self.is_train: # GradClip is train-only optimization config = copy.deepcopy(self._strategy.sharding.to_dict()) config["dist_context"] = self._dist_context @@ -349,7 +358,7 @@ class Parallelizer: [main_program], [startup_program], self._pass_context ) - if self._strategy.pipeline.enable: + if self.is_train and self._strategy.pipeline.enable: self._strategy.gradient_merge.enable = True self._strategy.gradient_merge.k_steps = ( self._strategy.pipeline.accumulate_steps @@ -357,7 +366,7 @@ class Parallelizer: self._strategy.gradient_merge.avg = True # gradient_merge is then train-only optimization - if self._mode == "train" and self._strategy.gradient_merge.enable: + if self.is_train and self._strategy.gradient_merge.enable: config = copy.deepcopy(self._strategy.gradient_merge.to_dict()) config["dist_context"] = self._dist_context config["params_grads"] = params_grads @@ -368,7 +377,7 @@ class Parallelizer: [main_program], [startup_program], self._pass_context ) - if self._strategy.pipeline.enable: + if self._strategy.pipeline.enable and not use_new_executor(): config = copy.deepcopy(self._strategy.pipeline.to_dict()) config["dist_context"] = self._dist_context auto_parallel_pipeline_pass = new_pass( @@ -378,10 +387,17 @@ class Parallelizer: [main_program], [startup_program], self._pass_context ) - if self._mode == "train" and self._strategy.fused_passes.enable: + if self.is_train and self._strategy.fused_passes.enable: if len(self._strategy.fused_passes.fused_passes_list) > 0: new_pass_list = [] for op in self._strategy.fused_passes.fused_passes_list: new_pass_list.append(new_pass(op)) pass_manager = PassManager(new_pass_list) pass_manager.apply([main_program], [startup_program]) + + if self._strategy.pipeline.enable and use_new_executor(): + main_program._pipeline_opt = {} + main_program._pipeline_opt["standalone_opt"] = { + "schedule_mode": self._strategy.pipeline.schedule_mode, + "num_micro_batches": self._strategy.pipeline.accumulate_steps, + } diff --git a/python/paddle/distributed/auto_parallel/static/utils.py b/python/paddle/distributed/auto_parallel/static/utils.py index 937fc89ee441c6985fbed177db51d48ac8b0d786..cfd5e9b844c16b94394a0b7bea115a7bd219f146 100644 --- a/python/paddle/distributed/auto_parallel/static/utils.py +++ b/python/paddle/distributed/auto_parallel/static/utils.py @@ -2367,3 +2367,16 @@ def _dygraph_guard_(func): dygraph_guard = wrap_decorator(_dygraph_guard_) + + +def use_new_executor(): + new_executor_micro_batching = os.environ.get( + 'FLAGS_new_executor_micro_batching', None + ) + return new_executor_micro_batching in [ + 1, + '1', + True, + 'True', + 'true', + ] diff --git a/python/paddle/distributed/passes/auto_parallel_gradient_merge.py b/python/paddle/distributed/passes/auto_parallel_gradient_merge.py index 8a87ac7f599d2f20c8c127723cf31c1e9c056948..21816c34ee4450e3f61e0bb19c547a4dcdee7795 100644 --- a/python/paddle/distributed/passes/auto_parallel_gradient_merge.py +++ b/python/paddle/distributed/passes/auto_parallel_gradient_merge.py @@ -94,8 +94,13 @@ def _get_gm_cond_var(main_program, k_steps, dist_context): ) set_var_dist_attr(dist_context, step_var, [-1], world_process_group.ranks) - cond_var = main_block.create_var( - name="gradient_merge_cond", shape=[1], dtype='bool' + cond_var = paddle.static.create_global_var( + name="gradient_merge_cond", + shape=[1], + value=bool(0), + dtype='bool', + persistable=True, + force_cpu=True, ) set_var_dist_attr(dist_context, cond_var, [-1], world_process_group.ranks) diff --git a/python/paddle/distributed/passes/pipeline_scheduler_pass.py b/python/paddle/distributed/passes/pipeline_scheduler_pass.py index 8ff5f2b35e7b584dabbcecf5a0f61aa590375305..3d63c14dde65cdec2dfebf441e0818153a474b30 100644 --- a/python/paddle/distributed/passes/pipeline_scheduler_pass.py +++ b/python/paddle/distributed/passes/pipeline_scheduler_pass.py @@ -22,7 +22,7 @@ from paddle.distributed.fleet.meta_optimizers.common import OpRole from paddle.fluid import core from paddle.fluid.framework import Parameter, Program -from .pass_base import PassBase, register_pass +from .pass_base import PassBase, PassContext, new_pass, register_pass __not_shape_var_type__ = [ core.VarDesc.VarType.READER, @@ -257,7 +257,7 @@ def _program_for_fthenb_and_1f1b(program): } -@register_pass("pipeline_fthenb_scheduler") +@register_pass("pipeline_scheduler_FThenB") class PipelineFThenBPass(PassBase): def __init__(self): super().__init__() @@ -272,12 +272,12 @@ class PipelineFThenBPass(PassBase): job_list = [] lr_job = core.Job("lr") job_list.append(lr_job) - for i in range(self._micro_batch_size): + for i in range(self._num_micro_batches): forward_job = core.Job("forward") forward_job.set_micro_batch_id(i) job_list.append(forward_job) - for i in range(self._micro_batch_size): + for i in range(self._num_micro_batches): backward_job = core.Job("backward") backward_job.set_micro_batch_id(i) job_list.append(backward_job) @@ -287,7 +287,7 @@ class PipelineFThenBPass(PassBase): return job_list def _apply_single_impl(self, main_program, startup_program, context): - self._micro_batch_size = self.get_attr("micro_batch_size") + self._num_micro_batches = self.get_attr("num_micro_batches") self._program = main_program _insert_sync_for_fthenb_1f1b(self._program) @@ -296,3 +296,16 @@ class PipelineFThenBPass(PassBase): plan = core.Plan(job_list, type_to_program) context.set_attr("plan", plan) + + +def apply_pass(main_program, startup_program, pass_name, pass_attr={}): + assert pass_name in [ + "FThenB" + ], "pipeline scheduler only support FThenB, but recieve {}".format( + pass_name + ) + pipeline_pass = new_pass("pipeline_scheduler_" + pass_name, pass_attr) + pass_context = PassContext() + pipeline_pass.apply([main_program], [startup_program], pass_context) + plan = pass_context.get_attr("plan") + return plan diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index d1445e227226d385c3506da352d3157109fd71f0..5123c9b4ed102d46bfe5d892d7c7605fd02dd2ea 100755 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -414,6 +414,36 @@ def _add_feed_fetch_ops( return tmp_program +def _set_micro_batch_fetch(plan): + if plan.micro_batch_num() <= 1: + return + + valid_fetch_types = ["fetch", "fetch_v2"] + for job in plan.job_list(): + idx_to_col_attr = {} + prog = plan.program(job.type()) + for i in range(prog.block(0).op_size()): + op = prog.block(0).op(i) + if op.type() in valid_fetch_types: + idx_to_col_attr[i] = op.attr('col') + + for idx, col in idx_to_col_attr.items(): + job.set_col_attr_for_fetch_op( + idx, col * plan.micro_batch_num() + job.micro_batch_id() + ) + + +def _merge_tensors(tensor, micro_batch_num): + if micro_batch_num <= 1: + return tensor + assert len(tensor) % micro_batch_num == 0 + chunk_tensor = [ + tensor[i : i + micro_batch_num] + for i in range(0, len(tensor), micro_batch_num) + ] + return [np.array(chunk) for chunk in chunk_tensor] + + def _apply_inplace_addto_pass( program, enable_inplace, enable_addto, skip_var_names ): @@ -653,8 +683,13 @@ class _StandaloneExecutor: """ tensors = self._new_exe.run(feed_names)._move_to_list() if return_numpy: - return as_numpy(tensors, copy=True) + tensors = as_numpy(tensors, copy=True) + return _merge_tensors(tensors, self._plan.micro_batch_num()) else: + if self._plan.micro_batch_num() > 1: + raise RuntimeError( + "`merge_tensor` does not support when return_numpy is False." + ) return tensors def _create_new_executor(self): @@ -831,12 +866,30 @@ class _ExecutorCache: _apply_inplace_addto_pass( program, enable_inplace, enable_addto, skip_var_names ) + new_program = program.clone() - new_exe = _StandaloneExecutor( - place, - core.Plan([core.Job("default")], {"default": new_program.desc}), - scope, - ) + if ( + new_program._pipeline_opt + and "standalone_opt" in new_program._pipeline_opt + ): + from paddle.distributed.passes.pipeline_scheduler_pass import ( + apply_pass, + ) + + standalone_opt = new_program._pipeline_opt["standalone_opt"] + pass_name = standalone_opt["schedule_mode"] + pass_attr = { + "num_micro_batches": standalone_opt["num_micro_batches"] + } + plan = apply_pass(new_program, new_program, pass_name, pass_attr) + else: + default_job = core.Job("default") + type_to_program = {"default": new_program.desc} + plan = core.Plan([default_job], type_to_program) + + _set_micro_batch_fetch(plan) + + new_exe = _StandaloneExecutor(place, plan, scope) return new_program, new_exe @@ -1408,7 +1461,15 @@ class Executor: fetch_list = self._check_fetch_list(fetch_list) - if isinstance(program, Program) and program._pipeline_opt: + from paddle.distributed.auto_parallel.static.utils import ( + use_new_executor, + ) + + if ( + isinstance(program, Program) + and program._pipeline_opt + and not use_new_executor() + ): if "fleet_opt" in program._pipeline_opt: # Move prepare here for port conflict with nccl in startup program if self._fleet_executor is None: diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index b774afd8d5666ece92ff33cafdc920b8813886e1..a9ba6f91a1e259f89ed12497f313eb5aa6936259 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -5921,6 +5921,8 @@ class Program: p._appending_grad_times = self._appending_grad_times if hasattr(self, 'lr_scheduler'): p.lr_scheduler = self.lr_scheduler + if hasattr(self, '_pipeline_opt'): + p._pipeline_opt = self._pipeline_opt # NOTE(zhiqiu): we sync the cloned program, to update its program by # its desc. diff --git a/test/auto_parallel/CMakeLists.txt b/test/auto_parallel/CMakeLists.txt index 43c62bd9b1e083dee0feab6253b9992c331e26af..08d0c4fd34a39fb756a292bd0c3bba16006cb2bf 100644 --- a/test/auto_parallel/CMakeLists.txt +++ b/test/auto_parallel/CMakeLists.txt @@ -68,6 +68,10 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_auto_tuner MODULES test_auto_tuner) set_tests_properties(test_auto_tuner PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100) + py_test_modules(test_pipeline_scheduler_FThenB MODULES + test_pipeline_scheduler_FThenB) + set_tests_properties(test_pipeline_scheduler_FThenB + PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) py_test_modules(test_auto_tuner_compare MODULES test_auto_tuner_compare) set_tests_properties(test_auto_tuner_compare PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100) diff --git a/test/auto_parallel/pipeline_scheduler_FThenB.py b/test/auto_parallel/pipeline_scheduler_FThenB.py new file mode 100644 index 0000000000000000000000000000000000000000..b02ed3a4739d03219cff30d40ca4334364264db3 --- /dev/null +++ b/test/auto_parallel/pipeline_scheduler_FThenB.py @@ -0,0 +1,118 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import random +import unittest + +import numpy as np +from get_gpt_model import FakeDataset, generate_model + +import paddle +from paddle.distributed import ParallelEnv +from paddle.distributed.fleet import auto + +paddle.enable_static() + + +def apply_pass(use_standalone_exe=False): + strategy = auto.Strategy() + strategy.auto_mode = "semi" + strategy.reinit = True + + pipeline = strategy.pipeline + pipeline.enable = True + pipeline.schedule_mode = "1F1B" if not use_standalone_exe else "FThenB" + pipeline.accumulate_steps = 2 + + return strategy + + +def reset_prog(): + paddle.fluid.framework.switch_main_program(paddle.static.Program()) + paddle.fluid.framework.switch_startup_program(paddle.static.Program()) + + +class Test1F1BPass(unittest.TestCase): + def setUp(self): + self.rtol = 1e-5 + self.atol = 1e-8 + self.batch_size = 2 + self.batch_num = 10 + self.clip_norm = 0.2 + self.dataset = FakeDataset(self.batch_size * self.batch_num) + + def init(self, engine): + paddle.seed(2021) + np.random.seed(2021) + random.seed(2021) + paddle.distributed.fleet.init(is_collective=True) + place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id) + engine._executor = paddle.static.Executor(place) + + def get_engine(self, use_standalone_exe=False): + reset_prog() + + strategy = apply_pass(use_standalone_exe) + clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm) + opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) + model, loss = generate_model("pp") + + engine = auto.Engine(model, loss, opt, strategy=strategy) + self.init(engine) + return engine + + def check_results(self, ref_losses, check_losses): + np.testing.assert_allclose( + ref_losses, + check_losses, + rtol=self.rtol, + atol=self.atol, + err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format( + __class__, ref_losses, check_losses, ref_losses - check_losses + ), + ) + + def test_pp_pass(self): + # pp2 1f1b training with fleet executor + os.environ['FLAGS_new_executor_micro_batching'] = 'False' + engine_1f1b = self.get_engine(use_standalone_exe=False) + history_1f1b = engine_1f1b.fit( + self.dataset, 3, batch_size=self.batch_size, log_freq=1 + ) + assert engine_1f1b._strategy.pipeline.schedule_mode == "1F1B" + assert os.environ.get('FLAGS_new_executor_micro_batching') == "False" + + # pp2 fthenb training with standalone executor + os.environ['FLAGS_new_executor_micro_batching'] = 'True' + engine_fthenb = self.get_engine(use_standalone_exe=True) + history_fthenb = engine_fthenb.fit( + self.dataset, 3, batch_size=self.batch_size, log_freq=1 + ) + assert engine_fthenb._strategy.pipeline.schedule_mode == "FThenB" + assert os.environ.get('FLAGS_new_executor_micro_batching') == "True" + + # NOTE: every sample data from dataset is all the same + if paddle.distributed.get_rank() == 1: + losses_1f1b = np.array(history_1f1b.history["loss"]) + losses_fthenb = np.array(history_fthenb.history["loss"]) + # accumulate_steps is 2 + assert losses_fthenb[0].shape[0] == 2 + # losses_1f1b is the last loss of accumulate_steps + # losses_fthenb is all the losses of accumulate_steps + self.check_results(losses_1f1b[0], losses_fthenb[0][-1]) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/auto_parallel/test_pipeline_scheduler_FThenB.py b/test/auto_parallel/test_pipeline_scheduler_FThenB.py new file mode 100644 index 0000000000000000000000000000000000000000..fab5ed0863d93b2d9c5c83daa2708279539b9d0c --- /dev/null +++ b/test/auto_parallel/test_pipeline_scheduler_FThenB.py @@ -0,0 +1,57 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import subprocess +import sys +import tempfile +import unittest + + +class TestFThenBPass(unittest.TestCase): + def test_pp2(self): + file_dir = os.path.dirname(os.path.abspath(__file__)) + launch_model_path = os.path.join( + file_dir, "pipeline_scheduler_FThenB.py" + ) + + if os.environ.get("WITH_COVERAGE", "OFF") == "ON": + coverage_args = ["-m", "coverage", "run", "--branch", "-p"] + else: + coverage_args = [] + + tmp_dir = tempfile.TemporaryDirectory() + cmd = ( + [sys.executable, "-u"] + + coverage_args + + [ + "-m", + "paddle.distributed.launch", + "--devices", + "0,1", + "--log_dir", + tmp_dir.name, + launch_model_path, + ] + ) + + process = subprocess.Popen(cmd) + process.wait() + self.assertEqual(process.returncode, 0) + + tmp_dir.cleanup() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/standalone_executor/test_standalone_executor_fthenb_plan.py b/test/standalone_executor/test_standalone_executor_fthenb_plan.py index 6912dc0e609f8e84b3b462a0ef5966a287be9104..76557231b83e4e3e0005062865d68ddc876e7618 100644 --- a/test/standalone_executor/test_standalone_executor_fthenb_plan.py +++ b/test/standalone_executor/test_standalone_executor_fthenb_plan.py @@ -21,13 +21,13 @@ from paddle.distributed.passes import PassContext, new_pass class TestStandaloneExecutorFThenBPlan(unittest.TestCase): def test_standalone_executor_fthenb_plan(self): config = {} - config["micro_batch_size"] = 4 + config["num_micro_batches"] = 4 pass_context = PassContext() startup_program = static.Program() main_program = static.Program() - pipeline_fthenb_pass = new_pass("pipeline_fthenb_scheduler", config) + pipeline_fthenb_pass = new_pass("pipeline_scheduler_FThenB", config) pipeline_fthenb_pass.apply( [main_program], [startup_program], pass_context )