未验证 提交 9d8cfc1b 编写于 作者: L lilong12 提交者: GitHub

Wrap dist api for dygraph mode (#40408)

上级 bff9e28e
......@@ -37,9 +37,45 @@ from ..meta_optimizers import HybridParallelOptimizer, HeterParallelOptimizer
from paddle import _C_ops
from paddle.fluid import core
from paddle.fluid.dygraph import to_variable
from paddle.distributed.fleet.utils.recompute import RecomputeFunction
from paddle.fluid.dygraph.varbase_patch_methods import _grad_scalar
__all__ = []
_grad_scalar = None
class _RecomputeModelWrapper(paddle.nn.Layer):
def __init__(self, model, segments=2, preserve_rng_state=True):
super(_RecomputeModelWrapper, self).__init__()
assert isinstance(model, paddle.nn.Sequential), (
"The model passed to RecomputeModelWrapper must be of type "
"paddle.nn.Sequential.")
self._model = model
self._segments = segments
self._preserve_rng_state = preserve_rng_state
self._layers = list(model.children())
self._segment_size = len(self._layers) // segments
def _run_func(self, begin, end):
def do_run(input):
for i in range(begin, end):
input = self._layers[i](input)
return input
return do_run
def _checkpoint(self, func, *args, **kwargs):
return RecomputeFunction.apply(func, self._preserve_rng_state, *args)
def forward(self, input):
end = 0
for begin in range(0, self._segment_size * (self._segments - 1),
self._segment_size):
end = begin + self._segment_size
input = self._checkpoint(self._run_func(begin, end), input)
return self._run_func(end, len(self._layers))(input)
def apply_ir_passes(main_program, startup_program, config):
build_strategy = config._user_defined_strategy.build_strategy._copy()
......@@ -952,6 +988,41 @@ class Fleet(object):
if self.worker_num() <= 1:
return model
amp_enable = False
recompute_enable = False
strategy = self._user_defined_strategy
if strategy.amp == True:
amp_enable = True
amp_level = "O2" if strategy.amp_configs['use_pure_fp16'] else "O1"
if amp_level.upper() == "O2":
model = paddle.amp.decorate(
models=model,
optimizers=None,
level="O2",
master_weight=None,
save_dtype=None)
init_loss_scaling = strategy.amp_configs['init_loss_scaling']
incr_ratio = strategy.amp_configs['incr_ratio']
decr_ratio = strategy.amp_configs['decr_ratio']
incr_every_n_steps = strategy.amp_configs['incr_every_n_steps']
decr_every_n_nan_or_inf = strategy.amp_configs[
'decr_every_n_nan_or_inf']
use_dynamic_loss_scaling = strategy.amp_configs[
'use_dynamic_loss_scaling']
global _grad_scalar
_grad_scalar = paddle.amp.GradScaler(
init_loss_scaling=init_loss_scaling,
incr_ratio=incr_ratio,
decr_ratio=decr_ratio,
incr_every_n_steps=incr_every_n_steps,
decr_every_n_nan_or_inf=decr_every_n_nan_or_inf,
use_dynamic_loss_scaling=use_dynamic_loss_scaling)
if strategy.recompute == True:
recompute_enable = True
model = _RecomputeModelWrapper(model)
if self._user_defined_strategy.heter_ccl_mode == True:
distributed_model = paddle.DataParallel(
model,
......@@ -964,7 +1035,7 @@ class Fleet(object):
return distributed_model
if self._hcg.get_parallel_mode() == ParallelMode.SHARDING_PARALLEL:
distributed_model = ShardingParallel(
model = ShardingParallel(
model, self._hcg, strategy=self._user_defined_strategy)
elif self._hcg.get_parallel_mode() == ParallelMode.DATA_PARALLEL:
......@@ -975,22 +1046,23 @@ class Fleet(object):
assert self.sharding_degree == self._hcg.get_sharding_parallel_world_size(
)
broadcast_sharding_parameters(model, self._hcg)
distributed_model = paddle.DataParallel(
model = paddle.DataParallel(
model,
comm_buffer_size=self._user_defined_strategy.
fuse_grad_size_in_MB,
last_comm_buffer_size=self._user_defined_strategy.
last_comm_group_size_MB,
find_unused_parameters=self._user_defined_strategy.
find_unused_parameters)
find_unused_parameters,
static_graph=True if recompute_enable else False)
elif self._hcg.get_parallel_mode() == ParallelMode.TENSOR_PARALLEL:
distributed_model = TensorParallel(
model = TensorParallel(
model, self._hcg, strategy=self._user_defined_strategy)
elif self._hcg.get_parallel_mode() == ParallelMode.PIPELINE_PARALLEL:
distributed_model = PipelineParallel(
model = PipelineParallel(
model, self._hcg, strategy=self._user_defined_strategy)
return distributed_model
return model
@dygraph_only
def state_dict(self):
......
......@@ -31,6 +31,8 @@ import paddle.utils.deprecated as deprecated
import paddle.profiler as profiler
from paddle import _C_ops
_grad_scalar = None
class TensorHookRemoveHelper(object):
"""
......@@ -265,6 +267,9 @@ def monkey_patch_varbase():
grad_tensor = []
else:
grad_tensor = [grad_tensor]
if _grad_scalar:
# When using amp with Fleet DistributedStrategy, we do loss scaling implicitly.
self = _grad_scalar.scale(self)
if paddle.is_compiled_with_xpu() or paddle.is_compiled_with_npu():
# TODO(liuyuhui): Currently only for xpu. Will be removed in the future.
scaled_loss = scale_loss(self)
......
......@@ -939,6 +939,7 @@ if (WITH_DISTRIBUTE)
set_tests_properties(test_dist_fleet_infer PROPERTIES TIMEOUT 200)
set_tests_properties(test_dist_fleet_raw_program_optimizer PROPERTIES TIMEOUT 120)
set_tests_properties(test_dist_fleet_raw_program_optimizer_fuse_allreduce PROPERTIES TIMEOUT 60)
set_tests_properties(test_dist_dygraph_apis PROPERTIES TIMEOUT 120)
endif()
if (WITH_DISTRIBUTE AND NOT APPLE)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import random
import numpy as np
import os
import shutil
import paddle
import paddle.nn as nn
from paddle.fluid import core
import datetime
from datetime import timedelta
import paddle.fluid.core as core
from paddle.fluid.framework import _test_eager_guard
from paddle.fluid.dygraph.parallel import ParallelEnv
class TestDygraphFleetAPI(unittest.TestCase):
def setUp(self):
paddle.seed(2022)
random.seed(2022)
np.random.seed(2022)
self.config()
def config(self):
self.dtype = "float32"
self.shape = (2, 10, 5)
def test_dygraph_fleet_api(self):
import paddle.distributed.fleet as fleet
import paddle.distributed as dist
strategy = fleet.DistributedStrategy()
strategy.amp = True
strategy.recompute = True
fleet.init(is_collective=True, strategy=strategy)
net = paddle.nn.Sequential(
paddle.nn.Linear(10, 1), paddle.nn.Linear(1, 2))
net = dist.fleet.distributed_model(net)
data = np.random.uniform(-1, 1, [30, 10]).astype('float32')
data = paddle.to_tensor(data)
net(data)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
from test_parallel_dygraph_dataparallel import TestMultipleGpus
class TestDygraphFleetApi(TestMultipleGpus):
def test_dygraph_fleet_api(self):
self.run_mnist_2gpu('dygraph_fleet_api.py')
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册