diff --git a/python/paddle/distributed/auto_parallel/constants.py b/python/paddle/distributed/auto_parallel/constants.py index ee42d4df4202be74e5eef1c946281526ef502be0..3f7d2e7a4948d631647848fa7761fff081fcfd9c 100644 --- a/python/paddle/distributed/auto_parallel/constants.py +++ b/python/paddle/distributed/auto_parallel/constants.py @@ -136,3 +136,10 @@ set_field_default_config(TUNING, "debug", False) DATASET = "dataset" set_field_default_config(DATASET, "enable", False) set_field_default_config(DATASET, "num_shards", 1) + +######################################### +# fused passes configuration +######################################### +FUSED_PASSES = "fused_passes" +set_field_default_config(FUSED_PASSES, "enable", False) +set_field_default_config(FUSED_PASSES, "fused_passes_list", []) diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index ead912764f089a41764f75d204caf8d79fc2642c..a270d754597fa1200ffb7d6f4e8df761e2c46e6e 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -702,7 +702,9 @@ class Engine: # For now, the completer has to be passed to the planner, # because we may use it to complete the annotation of the backwarkward and update. parallelizer = Parallelizer( - mode, self._planners[mode].completer, self._dist_contexts[mode] + mode, + self._planners[mode].completer, + self._dist_contexts[mode], ) if not all_ranks: parallelizer.parallel(self._cur_rank) diff --git a/python/paddle/distributed/auto_parallel/parallelizer_v2.py b/python/paddle/distributed/auto_parallel/parallelizer_v2.py index 48831f3ff2c502586f34a223ab02941ad0a31b26..175a1d263498671a879c2dabc7d9898f7232a491 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer_v2.py +++ b/python/paddle/distributed/auto_parallel/parallelizer_v2.py @@ -16,7 +16,7 @@ import copy import logging import time -from paddle.distributed.passes import new_pass +from paddle.distributed.passes import PassManager, new_pass from paddle.static import append_backward, program_guard from paddle.utils import unique_name @@ -338,3 +338,11 @@ class Parallelizer: auto_parallel_gradient_merge_pass.apply( [main_program], [startup_program], self._pass_context ) + + if self._mode == "train" and self._strategy.fused_passes.enable: + if len(self._strategy.fused_passes.fused_passes_list) > 0: + new_pass_list = [] + for op in self._strategy.fused_passes.fused_passes_list: + new_pass_list.append(new_pass(op)) + pass_manager = PassManager(new_pass_list) + pass_manager.apply([main_program], [startup_program]) diff --git a/python/paddle/distributed/auto_parallel/strategy.py b/python/paddle/distributed/auto_parallel/strategy.py index 7e6b98665a8d011bed4e051a45bba701eda9ce03..eb2e09b3a264316be91545c68a1056003510fbc6 100644 --- a/python/paddle/distributed/auto_parallel/strategy.py +++ b/python/paddle/distributed/auto_parallel/strategy.py @@ -120,6 +120,12 @@ class DatasetConfig(BaseConfig): super().__init__(category, config_dict) +class FusedPassesConfig(BaseConfig): + def __init__(self, config_dict=None): + category = constants.FUSED_PASSES + super().__init__(category, config_dict) + + class Strategy(BaseConfig): """ The `Strategy` object is used to configure the paralleization and optimization beheviors. @@ -188,3 +194,6 @@ class Strategy(BaseConfig): config_dict = self._config_dict.get(constants.DATASET, None) self.dataset = DatasetConfig(config_dict) + + config_dict = self._config_dict.get(constants.FUSED_PASSES, None) + self.fused_passes = FusedPassesConfig(config_dict) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt index 249387a0781d1ffa1502289434169935835bd2a4..02076545b014dcda6ed9d08ed98e2f89fa3e8be2 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt @@ -79,6 +79,8 @@ if(WITH_DISTRIBUTE AND WITH_GPU) set_tests_properties(test_pass_quantization PROPERTIES TIMEOUT 60) py_test_modules(test_tuning_recompute MODULES test_tuning_recompute) set_tests_properties(test_tuning_recompute PROPERTIES TIMEOUT 240) + py_test_modules(test_fused_linear_pass MODULES test_fused_linear_pass) + set_tests_properties(test_fused_linear_pass PROPERTIES TIMEOUT 20) py_test_modules(test_while_op_completion MODULES test_while_op_completion ENVS ${dist_ENVS}) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_fused_linear_pass.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_fused_linear_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..8aa75551c513e89cb4a7823b5787b7e0b1e9f160 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_fused_linear_pass.py @@ -0,0 +1,99 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import sys +import unittest + +import numpy as np +from get_gpt_model import FakeDataset, generate_model + +import paddle +from paddle.distributed.fleet import auto +from paddle.fluid.dygraph.parallel import ParallelEnv + +sys.path.append("..") +from test_sparse_addmm_op import get_cuda_version + + +def apply_pass(use_fused_passes=False, fused_passes_list=[]): + strategy = auto.Strategy() + strategy.auto_mode = "semi" + strategy.reinit = True + fused_passes = strategy.fused_passes + fused_passes.enable = use_fused_passes + fused_passes.fused_passes_list = fused_passes_list + return strategy + + +def reset_prog(): + paddle.fluid.framework.switch_main_program(paddle.static.Program()) + paddle.fluid.framework.switch_startup_program(paddle.static.Program()) + + +class TestAMPPass(unittest.TestCase): + def setUp(self): + self.rtol = 1e-5 + self.atol = 1e-8 + self.batch_size = 1 + self.batch_num = 1 + self.clip_norm = 0.2 + self.dataset = FakeDataset(self.batch_size * self.batch_num) + + def init(self, engine): + paddle.seed(2021) + np.random.seed(2021) + random.seed(2021) + place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id) + engine._executor = paddle.static.Executor(place) + + def get_engine(self, use_fused_passes=False, fused_passes_list=[]): + reset_prog() + + strategy = apply_pass(use_fused_passes, fused_passes_list) + clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm) + opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) + model, loss = generate_model("serial") + + engine = auto.Engine(model, loss, opt, strategy=strategy) + self.init(engine) + return engine + + def check_results(self, ref_losses, check_losses, rtol=None, atol=None): + np.testing.assert_allclose( + ref_losses, + check_losses, + rtol=rtol or self.rtol, + atol=atol or self.atol, + err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format( + __class__, ref_losses, check_losses, ref_losses - check_losses + ), + ) + + def test_passes(self): + losses = [] + if get_cuda_version() >= 11060: + for use_fused_passes in [True, False]: + engine = self.get_engine( + use_fused_passes, ["fuse_gemm_epilogue"] + ) + history = engine.fit( + self.dataset, 3, batch_size=self.batch_size + ) + losses.append(np.array(history.history["loss"])) + self.check_results(losses[0], losses[1]) + + +if __name__ == "__main__": + unittest.main()