diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index f163da4fb999b3b6708ddd846e7e19c2e0c291d1..83a0ff8099cba31c76167d9fe33c28a9ef5ca9f8 100755 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -37,9 +37,45 @@ from ..meta_optimizers import HybridParallelOptimizer, HeterParallelOptimizer from paddle import _C_ops from paddle.fluid import core from paddle.fluid.dygraph import to_variable +from paddle.distributed.fleet.utils.recompute import RecomputeFunction +from paddle.fluid.dygraph.varbase_patch_methods import _grad_scalar __all__ = [] +_grad_scalar = None + + +class _RecomputeModelWrapper(paddle.nn.Layer): + def __init__(self, model, segments=2, preserve_rng_state=True): + super(_RecomputeModelWrapper, self).__init__() + assert isinstance(model, paddle.nn.Sequential), ( + "The model passed to RecomputeModelWrapper must be of type " + "paddle.nn.Sequential.") + self._model = model + self._segments = segments + self._preserve_rng_state = preserve_rng_state + self._layers = list(model.children()) + self._segment_size = len(self._layers) // segments + + def _run_func(self, begin, end): + def do_run(input): + for i in range(begin, end): + input = self._layers[i](input) + return input + + return do_run + + def _checkpoint(self, func, *args, **kwargs): + return RecomputeFunction.apply(func, self._preserve_rng_state, *args) + + def forward(self, input): + end = 0 + for begin in range(0, self._segment_size * (self._segments - 1), + self._segment_size): + end = begin + self._segment_size + input = self._checkpoint(self._run_func(begin, end), input) + return self._run_func(end, len(self._layers))(input) + def apply_ir_passes(main_program, startup_program, config): build_strategy = config._user_defined_strategy.build_strategy._copy() @@ -952,6 +988,41 @@ class Fleet(object): if self.worker_num() <= 1: return model + amp_enable = False + recompute_enable = False + strategy = self._user_defined_strategy + if strategy.amp == True: + amp_enable = True + amp_level = "O2" if strategy.amp_configs['use_pure_fp16'] else "O1" + if amp_level.upper() == "O2": + model = paddle.amp.decorate( + models=model, + optimizers=None, + level="O2", + master_weight=None, + save_dtype=None) + init_loss_scaling = strategy.amp_configs['init_loss_scaling'] + incr_ratio = strategy.amp_configs['incr_ratio'] + decr_ratio = strategy.amp_configs['decr_ratio'] + incr_every_n_steps = strategy.amp_configs['incr_every_n_steps'] + decr_every_n_nan_or_inf = strategy.amp_configs[ + 'decr_every_n_nan_or_inf'] + use_dynamic_loss_scaling = strategy.amp_configs[ + 'use_dynamic_loss_scaling'] + + global _grad_scalar + _grad_scalar = paddle.amp.GradScaler( + init_loss_scaling=init_loss_scaling, + incr_ratio=incr_ratio, + decr_ratio=decr_ratio, + incr_every_n_steps=incr_every_n_steps, + decr_every_n_nan_or_inf=decr_every_n_nan_or_inf, + use_dynamic_loss_scaling=use_dynamic_loss_scaling) + + if strategy.recompute == True: + recompute_enable = True + model = _RecomputeModelWrapper(model) + if self._user_defined_strategy.heter_ccl_mode == True: distributed_model = paddle.DataParallel( model, @@ -964,7 +1035,7 @@ class Fleet(object): return distributed_model if self._hcg.get_parallel_mode() == ParallelMode.SHARDING_PARALLEL: - distributed_model = ShardingParallel( + model = ShardingParallel( model, self._hcg, strategy=self._user_defined_strategy) elif self._hcg.get_parallel_mode() == ParallelMode.DATA_PARALLEL: @@ -975,22 +1046,23 @@ class Fleet(object): assert self.sharding_degree == self._hcg.get_sharding_parallel_world_size( ) broadcast_sharding_parameters(model, self._hcg) - distributed_model = paddle.DataParallel( + model = paddle.DataParallel( model, comm_buffer_size=self._user_defined_strategy. fuse_grad_size_in_MB, last_comm_buffer_size=self._user_defined_strategy. last_comm_group_size_MB, find_unused_parameters=self._user_defined_strategy. - find_unused_parameters) + find_unused_parameters, + static_graph=True if recompute_enable else False) elif self._hcg.get_parallel_mode() == ParallelMode.TENSOR_PARALLEL: - distributed_model = TensorParallel( + model = TensorParallel( model, self._hcg, strategy=self._user_defined_strategy) elif self._hcg.get_parallel_mode() == ParallelMode.PIPELINE_PARALLEL: - distributed_model = PipelineParallel( + model = PipelineParallel( model, self._hcg, strategy=self._user_defined_strategy) - return distributed_model + return model @dygraph_only def state_dict(self): diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index 878fc1c68e4c193e7056a65fc2c45ac121474125..b8a2d958a7311ea8b81a05727838f9aa2d59e6f9 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -31,6 +31,8 @@ import paddle.utils.deprecated as deprecated import paddle.profiler as profiler from paddle import _C_ops +_grad_scalar = None + class TensorHookRemoveHelper(object): """ @@ -265,6 +267,9 @@ def monkey_patch_varbase(): grad_tensor = [] else: grad_tensor = [grad_tensor] + if _grad_scalar: + # When using amp with Fleet DistributedStrategy, we do loss scaling implicitly. + self = _grad_scalar.scale(self) if paddle.is_compiled_with_xpu() or paddle.is_compiled_with_npu(): # TODO(liuyuhui): Currently only for xpu. Will be removed in the future. scaled_loss = scale_loss(self) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 2acf530eea3fbd2d7fbc9cf04d2c792b7175035c..f1a90553283c3e85f29d9842bae6951b02f576f4 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -939,6 +939,7 @@ if (WITH_DISTRIBUTE) set_tests_properties(test_dist_fleet_infer PROPERTIES TIMEOUT 200) set_tests_properties(test_dist_fleet_raw_program_optimizer PROPERTIES TIMEOUT 120) set_tests_properties(test_dist_fleet_raw_program_optimizer_fuse_allreduce PROPERTIES TIMEOUT 60) + set_tests_properties(test_dist_dygraph_apis PROPERTIES TIMEOUT 120) endif() if (WITH_DISTRIBUTE AND NOT APPLE) diff --git a/python/paddle/fluid/tests/unittests/dygraph_fleet_api.py b/python/paddle/fluid/tests/unittests/dygraph_fleet_api.py new file mode 100644 index 0000000000000000000000000000000000000000..2a9d74e4afd4b7c62f57b5fd39856a18fe799619 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_fleet_api.py @@ -0,0 +1,60 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import random +import numpy as np +import os +import shutil + +import paddle +import paddle.nn as nn +from paddle.fluid import core +import datetime +from datetime import timedelta +import paddle.fluid.core as core +from paddle.fluid.framework import _test_eager_guard +from paddle.fluid.dygraph.parallel import ParallelEnv + + +class TestDygraphFleetAPI(unittest.TestCase): + def setUp(self): + paddle.seed(2022) + random.seed(2022) + np.random.seed(2022) + self.config() + + def config(self): + self.dtype = "float32" + self.shape = (2, 10, 5) + + def test_dygraph_fleet_api(self): + import paddle.distributed.fleet as fleet + import paddle.distributed as dist + strategy = fleet.DistributedStrategy() + strategy.amp = True + strategy.recompute = True + fleet.init(is_collective=True, strategy=strategy) + net = paddle.nn.Sequential( + paddle.nn.Linear(10, 1), paddle.nn.Linear(1, 2)) + net = dist.fleet.distributed_model(net) + data = np.random.uniform(-1, 1, [30, 10]).astype('float32') + data = paddle.to_tensor(data) + net(data) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_dist_dygraph_apis.py b/python/paddle/fluid/tests/unittests/test_dist_dygraph_apis.py new file mode 100644 index 0000000000000000000000000000000000000000..8e6fb99ae9355eefdb6de4f3a1bd0b2712535b83 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dist_dygraph_apis.py @@ -0,0 +1,27 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +from test_parallel_dygraph_dataparallel import TestMultipleGpus + + +class TestDygraphFleetApi(TestMultipleGpus): + def test_dygraph_fleet_api(self): + self.run_mnist_2gpu('dygraph_fleet_api.py') + + +if __name__ == "__main__": + unittest.main()