From dfed4a63f2334d6e4034b583e17085fa9b0fe35d Mon Sep 17 00:00:00 2001 From: Haohongxiang <86215757+haohongxiang@users.noreply.github.com> Date: Thu, 9 Dec 2021 18:49:31 +0800 Subject: [PATCH] support offload in sharding stage2 (#37904) * merge latest develop branch * fix bugs * update * fix bugs for unittest * modify for less use of gpu mem * fix bugs of using _reset_grad_inplace_version * update * update * modify for CI-Coverage * retrick all CIs --- .../sharding_optimizer_stage2.py | 82 ++++++++-- .../meta_parallel/sharding/sharding_stage2.py | 68 ++++++-- .../meta_parallel/sharding/sharding_utils.py | 151 ++++++++++++++++++ .../fleet/utils/internal_storage.py | 93 +++++++++-- .../unittests/dygraph_sharding_stage2.py | 19 +-- .../dygraph_sharding_stage2_offload.py | 115 +++++++++++++ .../unittests/test_dygraph_sharding_stage2.py | 3 + 7 files changed, 482 insertions(+), 49 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/dygraph_sharding_stage2_offload.py diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py index ffd24add50a..dc313c33ee3 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py @@ -27,11 +27,13 @@ from collections import OrderedDict import paddle import paddle.fluid as fluid from paddle import framework +from paddle.fluid import core import paddle.distributed as dist from paddle.optimizer import Optimizer +from paddle.fluid.clip import ClipGradByGlobalNorm from ...utils.internal_storage import ParamStorage -from ...meta_parallel.sharding.sharding_utils import Type +from ...meta_parallel.sharding.sharding_utils import Type, device_guard, ShardingClipGrad # CUDA alignment 256 bytes alignment = {"gpu": 256, } @@ -99,16 +101,41 @@ class ShardingOptimizerStage2(Optimizer): self.broadcast_fp16 = broadcast_fp16 self.param_storages = {} # {dtype: {rank: InternalStorage}} + + if isinstance(self._optim._grad_clip, ClipGradByGlobalNorm): + logging.warning( + "While using ClipGradByGlobalNorm in ShardingOptimizer, the grad clip of original optimizer will be changed." + ) + self._optim._grad_clip = ShardingClipGrad(self._optim._grad_clip, + group, + paddle.get_device()) + + if offload: + assert self._pfp16, "Only support offload strategy while using \'Adam\', \'AdamW\' and \'Momentum\' optimizer with AMP/Pure FP16" + self.offload = offload # Using for offload + self.offload_device = "cpu" + + self._master_params = {} # Update optimizer parameters and adjust parameter storage and use according to rank. self.update_opt_status() def _generate_master_params(self, trainable_params): - for param in trainable_params: - if param.dtype == Type.fp16.value: - self._optim._master_weights[param.name] = paddle.cast( - param, Type.fp32.value) + if self.offload: + for param in trainable_params: + if param.name not in self._master_params.keys(): + self._master_params[param.name] = core.VarBase( + name=param.name, + value=param.cast(dtype=Type.fp32.value).numpy(), + place=core.CPUPlace(), + stop_gradient=param.stop_gradient) + self._optim._master_weights = self._master_params + else: + for param in trainable_params: + if param.dtype == Type.fp16.value: + self._optim._master_weights[param.name] = paddle.cast( + param, Type.fp32.value) def update_opt_status(self): """Update optimizer status and parameter storage information, and special functions to be developed. @@ -243,22 +270,43 @@ class ShardingOptimizerStage2(Optimizer): A wrapper for Optimizer's step function to finish the update operation of the optimizer. """ - # Synchronize optimizer parameters for the current rank - if len(self.dtype_rank_params.keys( - )) == 1 and Type.fp32.value in self.dtype_rank_params.keys(): - self._optim._parameter_list = self.dtype_rank_params[ - Type.fp32.value][self.rank] - elif len(self.dtype_rank_params.keys( - )) == 1 and Type.fp16.value in self.dtype_rank_params.keys(): - self._optim._parameter_list = self.dtype_rank_params[ - Type.fp16.value][self.rank] + if self.offload: + self._optim._parameter_list = [ + param for name, param in self._master_params.items() + ] else: - self._optim._parameter_list = self.dtype_rank_params[ - Type.fp16.value][self.rank] + self.dtype_rank_params[ + # Synchronize optimizer parameters for the current rank + if len(self.dtype_rank_params.keys( + )) == 1 and Type.fp32.value in self.dtype_rank_params.keys(): + self._optim._parameter_list = self.dtype_rank_params[ Type.fp32.value][self.rank] + elif len(self.dtype_rank_params.keys( + )) == 1 and Type.fp16.value in self.dtype_rank_params.keys(): + self._optim._parameter_list = self.dtype_rank_params[ + Type.fp16.value][self.rank] + else: + self._optim._parameter_list = self.dtype_rank_params[ + Type.fp16.value][self.rank] + self.dtype_rank_params[ + Type.fp32.value][self.rank] # Run the optimizer of the current rank step - self._optim.step() + if self.offload: + with device_guard(self.rank, self.offload_device): + self._optim.step() + + for param in self._optim._parameter_list: + self._master_params[param.name].set_value(param) + + dev_id = 0 if paddle.get_device() == "cpu" else int( + paddle.get_device().split(":")[1]) + + for param in self._local_params: + if param.name in self._master_params.keys(): + param.set_value(self._master_params[param.name].cuda(dev_id) + .cast(dtype=param.dtype)) + self._master_params[param.name].clear_gradient(False) + else: + self._optim.step() # Synchronize all the updated shards in between the ranks self._broadcast_params() diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py index 37b85751149..fd49c2a7d65 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py @@ -112,6 +112,18 @@ class ShardingStage2(nn.Layer): self._has_grad_storage = [] self._grad_storage_list = [] + # offload + # TODO(haohongxiang): Now it's not supported for multi-optimizers using Offload strategy + self._offload_optims = list( + filter(lambda optim: optim.offload, self._sharding_optimizers)) + if len(self._offload_optims) > 0: + assert len( + self._sharding_optimizers + ) == 1, "Only support offload strategy for single optimizer" + + self._offload = self._sharding_optimizers[0].offload + self._offload_device = "cpu" + # Set backward pass hooks self._bw_hooks = [] @@ -156,7 +168,8 @@ class ShardingStage2(nn.Layer): # Release grad storages for dtype in self._grad_storages.keys(): if self._rank in self._grad_storages[dtype].keys(): - self._grad_storages[dtype][self._rank].buffer.zero_() + if not self._offload: + self._grad_storages[dtype][self._rank].buffer.zero_() # Release params for param in self._trainable_params: @@ -167,17 +180,24 @@ class ShardingStage2(nn.Layer): """ Before the gradient accumulation, scale the gradient. """ - # Scale grad storages - for dtype in self._grad_storages.keys(): - if self._rank in self._grad_storages[dtype].keys(): - self._grad_storages[dtype][self._rank].buffer.scale_( - scale=self._world_size_scaling) - - # Scale params - for param in self._trainable_params: - if param.name in self._param_grads and param.grad is not None: - param.grad.scale_(scale=self._world_size_scaling) - param._reset_grad_inplace_version(True) + if self._offload: + for param in self._trainable_params: + if param.name in self._sharding_optimizers[ + 0]._master_params.keys(): + self._sharding_optimizers[0]._master_params[ + param.name].grad.scale_(scale=self._world_size_scaling) + else: + # Scale grad storages + for dtype in self._grad_storages.keys(): + if self._rank in self._grad_storages[dtype].keys(): + self._grad_storages[dtype][self._rank].buffer.scale_( + scale=self._world_size_scaling) + + # Scale params + for param in self._trainable_params: + if param.name in self._param_grads and param.grad is not None: + param.grad.scale_(scale=self._world_size_scaling) + param._reset_grad_inplace_version(True) def _init_internal_storage(self, needs_fresh): """ @@ -195,8 +215,14 @@ class ShardingStage2(nn.Layer): """ Synchronously or asynchronously convert the data type of the layer, the device is not supported now. """ + assert isinstance(device, str), "Device must be type str" assert device == self._default_device, "New devices are not supported, because of the optimizer state is not sync" + self._layer.to(device=device, dtype=dtype, blocking=blocking) + + # Re-build the buckets, hooks, etc.. + self._fresh_trainable() + def _fresh_trainable(self): """ Whether to update training parameters. """ @@ -283,12 +309,17 @@ class ShardingStage2(nn.Layer): self._grad_reduced[index] = False if not self._accumulate_grads: param.grad.scale_(scale=self._world_size_scaling) - param._reset_grad_inplace_version(True) + param._reset_grad_inplace_version(True) # Clear the gradient that does not belong to the current rank through the callback function def cleanup(): if dst_rank != self._rank: param.clear_gradient(False) + elif self._offload: + self._sharding_optimizers[0]._master_params[ + param.name]._copy_gradient_from(param.grad.cpu( + ).cast(dtype=Type.fp32.value)) + param.clear_gradient(False) # Synchronize the reduce parameter gradient self._tasks_flow.append( @@ -339,6 +370,15 @@ class ShardingStage2(nn.Layer): grad_storage.buffer.value().get_tensor()._clear( ) + elif self._offload: + grad_storage.to(device=self._offload_device) + for param in grad_storage._params: + self._sharding_optimizers[0]._master_params[ + param.name]._copy_gradient_from( + param.grad.cast( + dtype=Type.fp32.value)) + grad_storage.buffer.value().get_tensor()._clear( + ) # Reduce the bucket grad_storage.sent = True @@ -478,7 +518,7 @@ class ShardingStage2(nn.Layer): # Rebuild fp16/fp32 grad storages for dtype in self._grad_storages.keys(): for dst_rank, grad_storage in self._grad_storages[dtype].items(): - if dst_rank != self._rank: + if self._offload or dst_rank != self._rank: grad_storage.manumal_relase() grad_storage.rebuild() diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py index d4c443e385f..651bed82396 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py @@ -17,10 +17,17 @@ import contextlib from collections import abc from enum import Enum from math import inf +import numpy as np +from types import MethodType import paddle import paddle.distributed as dist +from paddle import _C_ops from paddle.fluid import core +from paddle.fluid import layers +from paddle.fluid.dygraph import to_variable +from paddle.fluid.framework import dygraph_only +from paddle.fluid.dygraph import base as imperative_base class Taskflow: @@ -41,6 +48,88 @@ class Type(Enum): fp32 = paddle.float32 +class ShardingClipGrad: + def __init__(self, clip, group, device): + self._clip = clip + self._group = group + self._device = device + + @imperative_base.no_grad + def _dygraph_clip(self, params_grads): + params_and_grads = [] + + sum_square_fp16 = [] + sum_square_fp32 = [] + + for p, g in params_grads: + if g is None or getattr(p, 'need_clip', True) is False: + continue + + merge_grad = g + if g.type == core.VarDesc.VarType.SELECTED_ROWS: + merge_grad = layers.get_tensor_from_selected_rows( + layers.merge_selected_rows(g)) + square = layers.square(merge_grad) + sum_square = layers.reduce_sum(square) + + if p.dtype == paddle.float16: + sum_square_fp16.append(sum_square) + elif p.dtype == paddle.float32: + sum_square_fp32.append(sum_square) + + # global norm of non-distributed FP16 params_and_grads + if len(sum_square_fp16) == 0: + global_norm_fp16 = paddle.to_tensor([0.], dtype=paddle.float32) + else: + global_norm_fp16 = layers.concat(sum_square_fp16) + global_norm_fp16 = layers.reduce_sum(global_norm_fp16) + global_norm_fp16 = paddle.cast( + global_norm_fp16, dtype=paddle.float32) + + # global norm of non-distributed FP32 params_and_grads + global_norm_fp32 = layers.concat(sum_square_fp32) if len( + sum_square_fp32) != 0 else paddle.to_tensor( + [0.], dtype=paddle.float32) + global_norm_fp32 = layers.reduce_sum(global_norm_fp32) + + global_norm_var = global_norm_fp16 + global_norm_fp32 + + # add all reduce to get global norm of distributed params_and_grads + dev_id = int(self._device.split(":")[1]) + with device_guard(dev_id, "gpu"): + paddle.distributed.all_reduce(global_norm_var, group=self._group) + + global_norm_var = layers.sqrt(global_norm_var) + max_global_norm = layers.fill_constant( + shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm) + + clip_var = layers.elementwise_div( + x=max_global_norm, + y=layers.elementwise_max( + x=global_norm_var, y=max_global_norm)) + clip_var_fp16 = paddle.cast(clip_var, paddle.float16) + + for p, g in params_grads: + if g is None: + continue + if getattr(p, 'need_clip', True) is False: + params_and_grads.append((p, g)) + continue + if p.dtype == paddle.float16: + new_grad = layers.elementwise_mul(x=g, y=clip_var_fp16) + else: + new_grad = layers.elementwise_mul(x=g, y=clip_var) + params_and_grads.append((p, new_grad)) + + return params_and_grads + + def __getattr__(self, item): + return getattr(self._clip, item) + + def __call__(self, params_grads): + return self._dygraph_clip(params_grads) + + @contextlib.contextmanager def device_guard(dev_id, device="cpu"): origin_device = paddle.device.get_device() @@ -52,3 +141,65 @@ def device_guard(dev_id, device="cpu"): yield finally: paddle.set_device(origin_device) + + +@dygraph_only +def ShardingScaler(scaler, sharding_group): + def unscale_method(self, optimizer): + if not self._enable: + return + param_grads = [] + param_grads_fp16 = [] + param_grads_fp32 = [] + + if getattr(optimizer, '_param_groups', None) and isinstance( + optimizer._param_groups[0], dict): + + for group in optimizer._param_groups: + for param in group['params']: + if param._grad_ivar() is not None: + param_grads.append(param._grad_ivar()) + if param._grad_ivar( + ).dtype == core.VarDesc.VarType.FP16: + param_grads_fp16.append(param._grad_ivar()) + else: + param_grads_fp32.append(param._grad_ivar()) + else: + param_grads = [ + param._grad_ivar() for param in optimizer._parameter_list + if param._grad_ivar() is not None + ] + param_grads_fp16 = [ + param._grad_ivar() for param in optimizer._parameter_list + if (param._grad_ivar() is not None + ) and (param._grad_ivar().dtype == core.VarDesc.VarType.FP16 + ) + ] + param_grads_fp32 = [ + param._grad_ivar() for param in optimizer._parameter_list + if (param._grad_ivar() is not None + ) and (param._grad_ivar().dtype == core.VarDesc.VarType.FP32 + ) + ] + temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool)) + temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool)) + if len(param_grads_fp16): + _C_ops.check_finite_and_unscale(param_grads_fp16, self._scale, + param_grads_fp16, + temp_found_inf_fp16) + if len(param_grads_fp32): + _C_ops.check_finite_and_unscale(param_grads_fp32, self._scale, + param_grads_fp32, + temp_found_inf_fp32) + + self._found_inf = 1 if temp_found_inf_fp16 or temp_found_inf_fp32 else 0 + is_found_inf = paddle.to_tensor([self._found_inf], dtype="int32") + + paddle.distributed.all_reduce( + is_found_inf, + op=paddle.distributed.ReduceOp.MAX, + group=sharding_group) + self._found_inf = is_found_inf.numpy()[0] + + scaler._unscale = MethodType(unscale_method, scaler) + return scaler diff --git a/python/paddle/distributed/fleet/utils/internal_storage.py b/python/paddle/distributed/fleet/utils/internal_storage.py index ff41ca217e4..f44b57ede46 100644 --- a/python/paddle/distributed/fleet/utils/internal_storage.py +++ b/python/paddle/distributed/fleet/utils/internal_storage.py @@ -50,6 +50,29 @@ class InternalStorage: else: self.buffer = paddle.zeros(size, dtype=dtype) + def to(self, device, dtype=None, keep_alignment=True): + """ + Move the underlying buffer + """ + assert self.buffer is not None, "Cannot move a collapsed bucket, please rebuild it" + assert (dtype == Type.fp32.value or + Type.fp16.value), "Conversion type is not supported now" + + dev_id = 0 if paddle.get_device() == "cpu" else int(paddle.get_device() + .split(":")[1]) + + if self._device != device: + tmp_buffer = self.buffer.cuda( + dev_id) if device == "gpu" else self.buffer.cpu() + for param in self._params: + param.clear_gradient(False) + param._gradient_set_empty(False) + self.buffer.value().get_tensor()._clear() + self.buffer = tmp_buffer + + if dtype is not None: + self.buffer = self.buffer.cast(dtype=dtype) + class ParamStorage(InternalStorage): """ @@ -60,6 +83,16 @@ class ParamStorage(InternalStorage): super().__init__(size, dtype, device, convert_cpu=True) self.param2align = None + def to(self, device, dtype=None, keep_alignment=True): + """ + Move the underlying buffer + """ + + super().to(device, dtype) + + if keep_alignment: + self._array_params() + @fluid.dygraph.no_grad def add_rank_params(self, trainable_params, param2align): """ @@ -78,7 +111,7 @@ class ParamStorage(InternalStorage): p_shape = self._add_param_as_view(param, param2align[param.name]) cpu_param_shape.append(p_shape) - # buffer covert from cpu to cuda + # buffer convert from cpu to cuda dev_id = int(paddle.get_device().split(":")[1]) self.buffer = self.buffer.cuda(dev_id) self._fill = 0 @@ -109,7 +142,8 @@ class ParamStorage(InternalStorage): param.stop_gradient = origin_state # Copy the current param value - dev_id = int(paddle.get_device().split(":")[1]) + dev_id = 0 if paddle.get_device() == "cpu" else int(paddle.get_device() + .split(":")[1]) with device_guard(dev_id, "cpu"): tmp_var = core.VarBase(tensor=self.buffer._slice(self._fill, var_end)) @@ -134,6 +168,18 @@ class ParamStorage(InternalStorage): self._fill = offset + @fluid.dygraph.no_grad + def _array_params(self): + """ + Given the parameters which have been registered previously, rebuild the whole InternalStorage. + """ + assert len(self._params) > 0 + assert self.param2align is not None + + self._fill = 0 + for p in self._params: + self._convert_buffer(p, p.shape, self.param2align[p.name]) # modify + class GradStorage(InternalStorage): """ @@ -171,6 +217,18 @@ class GradStorage(InternalStorage): param.shape) + align <= self._max_size and id( param) not in self._param_ids + def to(self, device, dtype=None, keep_alignment=True): + """ + Move the underlying buffer + """ + if self._release: + self.rebuild() + + super().to(device, dtype) + + if keep_alignment: + self._array_grads() + @fluid.dygraph.no_grad def add_grad(self, param, align): """ @@ -206,17 +264,25 @@ class GradStorage(InternalStorage): """ Given the parameter gradients which have been registered previously, rebuild the whole InternalStorage. """ - assert len(self._params) > 0 if self._release: - self.buffer = paddle.zeros( - [self._max_size], dtype=self._params[0].dtype) + self.buffer = paddle.zeros([self._max_size], dtype=self._dtype) for p in self._params: self._add_grad_as_view(p, self._parm2align[p.name]) self._release = False + @fluid.dygraph.no_grad + def _array_grads(self): + """ + Given the parameters gradients which have been registered previously, rebuild the whole InternalStorage. + """ + if len(self._params) > 0: + self._fill = 0 + for p in self._params: + self._add_grad_as_view(p, self._parm2align[p.name]) + @fluid.dygraph.no_grad def _add_grad_as_view(self, param, align): assert np.prod( @@ -229,8 +295,17 @@ class GradStorage(InternalStorage): assert offset <= np.prod(self.buffer.shape) # Copy the current grad value to InternalStorage - assert self._device == "gpu" - tmp_var = core.VarBase(self.buffer._slice(self._fill, grad_end)) - param._copy_gradient_from(tmp_var) - tmp_var.value().get_tensor()._clear() + dev_id = 0 if paddle.get_device() == "cpu" else int(paddle.get_device() + .split(":")[1]) + if self._device == "cpu": + with device_guard(dev_id, self._device): + tmp_var = core.VarBase(self.buffer._slice(self._fill, grad_end)) + param._copy_gradient_from(tmp_var) + tmp_var.value().get_tensor()._clear() + + elif self._device == "gpu": + tmp_var = core.VarBase(self.buffer._slice(self._fill, grad_end)) + param._copy_gradient_from(tmp_var) + tmp_var.value().get_tensor()._clear() + self._fill = offset diff --git a/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py index 05008a3bc12..2b4002ab9c9 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py +++ b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py @@ -30,6 +30,7 @@ from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import Shar seed = 2021 epoch = 2 batch_size = 32 +linear_size = 10000 strategy = fleet.DistributedStrategy() strategy.hybrid_configs = { @@ -45,12 +46,12 @@ paddle.seed(seed) class MLP(fluid.Layer): - def __init__(self, param_attr=None, bias_attr=None): + def __init__(self, linear_size=10000, param_attr=None, bias_attr=None): super(MLP, self).__init__() - self._linear1 = Linear(10000, 10000) - self._linear2 = Linear(10000, 10000) - self._linear3 = Linear(10000, 10) + self._linear1 = Linear(linear_size, linear_size) + self._linear2 = Linear(linear_size, linear_size) + self._linear3 = Linear(linear_size, 10) def forward(self, inputs): y = self._linear1(inputs) @@ -59,10 +60,10 @@ class MLP(fluid.Layer): return y -def reader_decorator(): +def reader_decorator(linear_size=10000): def __reader__(): for _ in range(100): - img = np.random.rand(10000).astype('float32') + img = np.random.rand(linear_size).astype('float32') label = np.ones(1).astype('int64') yield img, label @@ -120,6 +121,9 @@ def train_mlp(model, use_multiprocess=True) train_loader.set_sample_list_generator(train_reader) + if sharding_stage == 2: + model.to(device="gpu") + for eop in range(epoch): model.train() @@ -153,9 +157,6 @@ def train_mlp(model, if all_test and batch_id == 2: return model.parameters() - if sharding_stage == 2: - model.to(device="gpu") - return model.parameters() diff --git a/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2_offload.py b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2_offload.py new file mode 100644 index 00000000000..8adcda9d24e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2_offload.py @@ -0,0 +1,115 @@ +# -*- coding: UTF-8 -*- + +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import argparse +import ast +import time +import paddle +import paddle.fluid as fluid +from paddle.fluid.dygraph.nn import Linear +from paddle.distributed import fleet +from paddle.fluid.dygraph import nn + +from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2 import ShardingOptimizerStage2 +from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import ShardingStage2 +from paddle.distributed.fleet.meta_parallel.sharding.sharding_utils import ShardingScaler + +from dygraph_sharding_stage2 import MLP, reader_decorator, optimizer_setting + +seed = 2021 +epoch = 2 +batch_size = 32 +linear_size = 8000 + +np.random.seed(seed) +paddle.seed(seed) + + +def train_mlp(model, offload=False): + group = paddle.distributed.new_group([0, 1]) + optimizer = optimizer_setting(model=model, use_pure_fp16=True) + + model = paddle.amp.decorate(models=model, level='O2', save_dtype='float32') + scaler = paddle.amp.GradScaler(init_loss_scaling=32768) + scaler = ShardingScaler(scaler, group) + + optimizer = ShardingOptimizerStage2( + params=model.parameters(), + optim=optimizer, + group=group, + offload=offload) + model = ShardingStage2(model, optimizer, group=group, accumulate_grads=True) + + train_reader = paddle.batch( + reader_decorator(linear_size), batch_size=batch_size, drop_last=True) + + train_loader = paddle.io.DataLoader.from_generator( + capacity=32, + use_double_buffer=True, + iterable=True, + return_list=True, + use_multiprocess=True) + train_loader.set_sample_list_generator(train_reader) + + for eop in range(epoch): + model.train() + + for batch_id, data in enumerate(train_loader()): + img, label = data + label.stop_gradient = True + img.stop_gradient = True + + with paddle.amp.auto_cast(True, level='O2'): + out = model(img) + loss = paddle.nn.functional.cross_entropy( + input=out, label=label) + + avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32)) + scaler.scale(avg_loss).backward() + + model.grad_scale() + scaler.step(optimizer) + scaler.update() + model.clear_gradients() + + for dtype in optimizer.param_storages: + for dst_rank, param_storage in optimizer.param_storages[dtype].items(): + param_storage.to(device="gpu", dtype=dtype) + + return model.parameters() + + +def test_sharding_stage2_offload(): + mlp = MLP(linear_size) + mlp_offload = MLP(linear_size) + mlp_offload.set_state_dict(mlp.state_dict()) + + mlp_params = train_mlp(mlp, offload=False) + mlp_offload_params = train_mlp(mlp_offload, offload=True) + + for i in range(len(mlp_params)): + for j in range(len(mlp_offload_params)): + if mlp_params[i].name == mlp_offload_params[j].name: + np.testing.assert_allclose( + mlp_params[i].numpy(), + mlp_offload_params[j].numpy(), + rtol=1e-6) + return + + +if __name__ == '__main__': + test_sharding_stage2_offload() diff --git a/python/paddle/fluid/tests/unittests/test_dygraph_sharding_stage2.py b/python/paddle/fluid/tests/unittests/test_dygraph_sharding_stage2.py index c5cf8c5d5ed..f76dcb5687c 100644 --- a/python/paddle/fluid/tests/unittests/test_dygraph_sharding_stage2.py +++ b/python/paddle/fluid/tests/unittests/test_dygraph_sharding_stage2.py @@ -26,6 +26,9 @@ class TestDygraphShardingStage2(TestMultipleGpus): def test_dygraph_sharding_optimizer_stage2(self): self.run_mnist_2gpu('dygraph_sharding_stage2.py') + def test_dygraph_sharding_optimizer_stage2_offload(self): + self.run_mnist_2gpu('dygraph_sharding_stage2_offload.py') + if __name__ == "__main__": unittest.main() -- GitLab