# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import paddle import os import numpy as np from .base import topology as tp from .base.topology import ParallelMode from .meta_parallel import TensorParallel, model_parallel_random_seed from .meta_parallel import PipelineParallel, ShardingParallel from paddle.fluid import core from paddle.distributed.fleet.utils.recompute import LegacyRecomputeFunction from paddle.fluid.dygraph.varbase_patch_methods import _grad_scalar from paddle.distributed import fleet class _RecomputeModelWrapper(paddle.nn.Layer): def __init__(self, model, segments=2, preserve_rng_state=True): super(_RecomputeModelWrapper, self).__init__() assert isinstance(model, paddle.nn.Sequential), ( "The model passed to RecomputeModelWrapper must be of type " "paddle.nn.Sequential.") self._model = model self._segments = segments self._preserve_rng_state = preserve_rng_state self._layers = list(model.children()) self._segment_size = len(self._layers) // segments def _run_func(self, begin, end): def do_run(input): for i in range(begin, end): input = self._layers[i](input) return input return do_run def _checkpoint(self, func, *args, **kwargs): return LegacyRecomputeFunction.apply(func, self._preserve_rng_state, *args) def forward(self, input): end = 0 for begin in range(0, self._segment_size * (self._segments - 1), self._segment_size): end = begin + self._segment_size input = self._checkpoint(self._run_func(begin, end), input) return self._run_func(end, len(self._layers))(input) _grad_scalar = None def distributed_model(model): """ Return distributed data parallel model (Only work in dygraph mode) Args: model (Layer): the user-defind model which inherits Layer. Returns: distributed data parallel model which inherits Layer. Examples: .. code-block:: python import paddle import paddle.nn as nn from paddle.distributed import fleet class LinearNet(nn.Layer): def __init__(self): super(LinearNet, self).__init__() self._linear1 = nn.Linear(10, 10) self._linear2 = nn.Linear(10, 1) def forward(self, x): return self._linear2(self._linear1(x)) # 1. initialize fleet environment fleet.init(is_collective=True) # 2. create layer & optimizer layer = LinearNet() loss_fn = nn.MSELoss() adam = paddle.optimizer.Adam( learning_rate=0.001, parameters=layer.parameters()) # 3. get data_parallel model using fleet adam = fleet.distributed_optimizer(adam) dp_layer = fleet.distributed_model(layer) # 4. run layer inputs = paddle.randn([10, 10], 'float32') outputs = dp_layer(inputs) labels = paddle.randn([10, 1], 'float32') loss = loss_fn(outputs, labels) print("loss:", loss.numpy()) loss.backward() adam.step() adam.clear_grad() """ fleet_env = fleet.fleet assert model is not None, "model should not be None" if fleet_env.worker_num() <= 1: return model amp_enable = False recompute_enable = False strategy = fleet_env._user_defined_strategy if strategy.amp == True: amp_enable = True amp_level = "O2" if strategy.amp_configs['use_pure_fp16'] else "O1" if amp_level.upper() == "O2": model = paddle.amp.decorate(models=model, optimizers=None, level="O2", master_weight=None, save_dtype=None) init_loss_scaling = strategy.amp_configs['init_loss_scaling'] incr_ratio = strategy.amp_configs['incr_ratio'] decr_ratio = strategy.amp_configs['decr_ratio'] incr_every_n_steps = strategy.amp_configs['incr_every_n_steps'] decr_every_n_nan_or_inf = strategy.amp_configs[ 'decr_every_n_nan_or_inf'] use_dynamic_loss_scaling = strategy.amp_configs[ 'use_dynamic_loss_scaling'] global _grad_scalar _grad_scalar = paddle.amp.GradScaler( init_loss_scaling=init_loss_scaling, incr_ratio=incr_ratio, decr_ratio=decr_ratio, incr_every_n_steps=incr_every_n_steps, decr_every_n_nan_or_inf=decr_every_n_nan_or_inf, use_dynamic_loss_scaling=use_dynamic_loss_scaling) if strategy.recompute == True: recompute_enable = True model = _RecomputeModelWrapper(model) if strategy.heter_ccl_mode == True: distributed_model = paddle.DataParallel( model, comm_buffer_size=strategy.fuse_grad_size_in_MB, last_comm_buffer_size=strategy.last_comm_group_size_MB, find_unused_parameters=strategy.find_unused_parameters) return distributed_model if fleet_env._hcg.get_parallel_mode() == ParallelMode.SHARDING_PARALLEL: model = ShardingParallel(model, fleet_env._hcg, strategy=strategy) elif fleet_env._hcg.get_parallel_mode() == ParallelMode.DATA_PARALLEL: # NOTE (JZ-LIANG) init parameters broadcast within sharding group # normally it should be done inside DataParallel if fleet_env.sharding_degree > 1: from paddle.distributed.fleet.utils.hybrid_parallel_util import broadcast_mp_parameters, broadcast_sharding_parameters assert fleet_env.sharding_degree == fleet_env._hcg.get_sharding_parallel_world_size( ) broadcast_sharding_parameters(model, fleet_env._hcg) model = paddle.DataParallel( model, comm_buffer_size=strategy.fuse_grad_size_in_MB, last_comm_buffer_size=strategy.last_comm_group_size_MB, find_unused_parameters=strategy.find_unused_parameters) elif fleet_env._hcg.get_parallel_mode() == ParallelMode.TENSOR_PARALLEL: model = TensorParallel(model, fleet_env._hcg, strategy=strategy) elif fleet_env._hcg.get_parallel_mode() == ParallelMode.PIPELINE_PARALLEL: model = PipelineParallel(model, fleet_env._hcg, strategy=strategy) return model