diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py index e5d04aac1551e64f63625722b08088eb3d8552b6..41c6f92230ab3e0e8de9aec0abdf920fad1ef232 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py @@ -33,7 +33,7 @@ from paddle.fluid.framework import ParamBase from paddle.fluid.clip import ClipGradByGlobalNorm from paddle.distributed.collective import _get_global_group -from .sharding_utils import Type, ShardingClipGrad +from .sharding_utils import Type, ShardingClipGrad, device_guard from ..pp_utils.utils import _all_gather # CUDA alignment 256 bytes @@ -56,6 +56,13 @@ class ShardingStage3(nn.Layer): .. ZeRO: https://arxiv.org/pdf/1910.02054.pdf. """ + # TODO (Baibaifan) + # Feature Notes:: + # 1. The model supports the segmentation of parameters by global ranks in layers. + # 2. Support communication flow and computing flow. + # 3. Support offload function. + # 4. Support the establishment of independent communication groups. + def __init__(self, layer, optimizer, @@ -77,6 +84,15 @@ class ShardingStage3(nn.Layer): self._offload = offload self._sync_comm = sync_comm + global DEV + DEV = "cpu" if paddle.get_device() == "cpu" else paddle.get_device( + ).split(":")[0] + global DEV_ID + DEV_ID = 0 if paddle.get_device() == "cpu" else int(paddle.get_device() + .split(":")[1]) + global param2dtype + param2dtype = dict() + # Communication group establishment self._group = dist.new_group(_get_global_group() .ranks) if group is None else group @@ -85,6 +101,9 @@ class ShardingStage3(nn.Layer): self._rank = self._group.rank self._global_root_rank = 0 # picking rank 0 as the reference self._global_ranks = self._group.ranks + + # Parameter segmentation for global ranks + # After flatten -> self._param2buffer_size, self._param2buffer, self._trainable_params self._param2buffer_size = dict() # {param.name: size} self._param2buffer = dict( ) # {param.name: [(start0, end0),(start1, end1), ...]} @@ -116,12 +135,16 @@ class ShardingStage3(nn.Layer): self._order_tracer = OrderedDict() self._order_tracer["order"] = 0 self._order_tracer["layer"] = [] + # Register task flow self._task_flow = TaskFlow() + # Register forward hooks self._register_forward_hooks(self._layer) + # Register backward parameter hooks self._register_backward_hooks() + # Redefine optimizer step and clear function self._redefine_opt_step() self._redefine_opt_clear() @@ -152,7 +175,6 @@ class ShardingStage3(nn.Layer): param, "fw_storage" ), "Find {} don't have fw_storage attribute.".format(param.name) - # param.bw_storage.zero_() param.fw_storage.clear_gradient(False) param.fw_storage._gradient_set_empty(False) param.bw_storage._clear() @@ -192,6 +214,9 @@ class ShardingStage3(nn.Layer): return fw def _segment_rank_params(self, layer, name="last_layer"): + """ + Flatten parameters according to layer. + """ current_layer_params = _current_layer_params(layer) if current_layer_params: CHECK_LAYER[id(layer)] = name @@ -201,6 +226,10 @@ class ShardingStage3(nn.Layer): self._segment_rank_params(sub_layer, name) def _flatten_layer_params(self, layer, current_layer_params): + """ + Parameter segmentation and memory integration. + """ + def _add_manage_info(trainable_param): return _PartitionParam(trainable_param) @@ -238,8 +267,13 @@ class ShardingStage3(nn.Layer): # 3.Flatten layer params and release other rank buffer self._param_storage(param, buffer_size) + # Record param's dtype + param2dtype[param.name] = param.dtype def _param_storage(self, param, buffer_size): + """ + This is a function to simplify the handling of parameter InternalStorages. + """ assert isinstance(buffer_size, int) value = np.zeros( buffer_size, @@ -264,16 +298,31 @@ class ShardingStage3(nn.Layer): param._clear() # Current rank param_storage - param.fw_storage = core.VarBase( - buffer._slice(start, end), "slice@" + param.name) + if self._offload: + param.fw_storage = core.VarBase( + buffer._slice(start, end), + core.CPUPlace(), "slice@" + param.name) + else: + param.fw_storage = core.VarBase( + buffer._slice(start, end), "slice@" + param.name) param.status = "part" # Updata optimizer master weights - if param.dtype == Type.fp16.value: + if param.dtype == Type.fp16.value and not self._offload: self._optim._master_weights[param.fw_storage.name] = paddle.cast( param.fw_storage, Type.fp32.value) def _register_forward_hooks(self, layer): + """ + Register pylayer to manage memory slices. + There are four stages: + FW + 1. Before the forward layers, synchronize the full parameters. + 2. After the forward layers, release the full parameter and keep the parameter slice. + BW + 3. Before the backward layers, synchronize the full parameters and create param's grad. + 4. After the gradient accumulation, release the full parameter and keep the parameter slice. + """ current_layer_params = _current_layer_params(layer) if current_layer_params: self._register_forward_all_hooks(layer, self._task_flow) @@ -286,13 +335,13 @@ class ShardingStage3(nn.Layer): return ForwardPreHooks(layer, self._order_tracer, self._trainable_params, self._param2buffer, self._rank, self._group, self._sync_comm, - task_flow) + self._offload, task_flow) def _forward_post_hook(layer, inputs, outputs): return ForwardPostHooks.apply( outputs, layer, self._order_tracer, self._trainable_params, self._param2buffer, self._param2buffer_size, self._rank, - self._group, self._sync_comm, task_flow) + self._group, self._sync_comm, self._offload, task_flow) # register previous forward hooks sub_layer.register_forward_pre_hook(_forward_pre_hook) @@ -302,6 +351,10 @@ class ShardingStage3(nn.Layer): @paddle.no_grad() def _sync_buffers(self): + """ + Sync all the param buffers from all ranks (exp: batch norm statistics). + """ + for buffer in self._layer.buffers(include_sublayers=True): dist.broadcast( buffer, @@ -319,6 +372,9 @@ class ShardingStage3(nn.Layer): return getattr(self._layer, name) def _update_params(self): + """ + Update parameters to optimizer memory slice. + """ update_list = [] assert len(self._trainable_params.keys()) > 0 current_layer_params = self._layer.parameters(include_sublayers=True) @@ -331,36 +387,35 @@ class ShardingStage3(nn.Layer): param.name) if self._accumulate_grads: - param.bw_storage.scale_(scale=self._world_size_scaling) + if self._offload: + with device_guard(device="cpu"): + param.bw_storage.scale_(scale=self._world_size_scaling) + else: + param.bw_storage.scale_(scale=self._world_size_scaling) param.fw_storage = _VarBaseWrapper(param) param.fw_storage._copy_gradient_from(param.bw_storage) update_list.append(param) return update_list - def get_all_parameters(self): + def get_all_parameters(self, convert2cpu=False): + """ + Get the full parameters and return the corresponding task flows. + """ assert len(self._trainable_params.keys()) > 0 current_layer_params = self._layer.parameters(include_sublayers=True) trainable_params = list( filter(lambda x: x.trainable, current_layer_params)) - for param in trainable_params: - if param.use_count > 0: - continue - assert hasattr( - param, - "fw_storage"), "Find {} don't have fw_storage attribute".format( - param.name) - - full_param = _all_gather( - param.fw_storage, self._group, use_calc_stream=True) - dist.wait( - tensor=full_param, group=self._group, use_calc_stream=True) - core.VarBase(full_param._slice(0, param._numel()))._share_buffer_to( - param) - param.value().get_tensor()._set_dims(param.shape) - param.fw_storage._clear() - param.fw_storage = None - param.status = "all" - param.use_count += 1 + t_flow = _allgather_buffer( + trainable_params, + self._group, + use_calc_stream=True, + task_flow=TaskFlow(), + sync_wait=True, + offload=self._offload, + convert2cpu=convert2cpu) + if convert2cpu: + for param in current_layer_params: + t_flow.full_param[param.name]._share_buffer_to(param) self._optim._parameter_list = self._ori_parameter_list self._optim._param_groups = self._ori_param_groups @@ -393,13 +448,28 @@ class ShardingStage3(nn.Layer): use_calc_stream=True) start, end = self._param2buffer[param.name][self._rank] - if not self._accumulate_grads or param.bw_storage is None: + if not self._accumulate_grads or param.bw_storage is None or not param.bw_storage.value( + ).get_tensor()._is_initialized(): param.bw_storage = core.VarBase( full_grad._slice(start, end)).detach().clone() + if self._offload: + param.bw_storage = _device2cpu(param.bw_storage, + True) else: - param.bw_storage.add_( - core.VarBase(full_grad._slice(start, end)).detach() - .clone()) + if self._offload: + cpu_grad = _device2cpu( + core.VarBase(full_grad._slice(start, end)) + .detach().clone(), True) + param.bw_storage = paddle.add(param.bw_storage, + cpu_grad) + else: + # param.bw_storage.add_( + # core.VarBase(full_grad._slice(start, end)) + # .detach().clone()) + param.bw_storage = paddle.add( + param.bw_storage, + core.VarBase(full_grad._slice( + start, end)).detach().clone()) param.clear_gradient(False) param._gradient_set_empty(False) tmp_var = self._task_flow.full_grad.pop(param.name) @@ -410,15 +480,16 @@ class ShardingStage3(nn.Layer): param.use_count = 0 param._clear() start, end = self._param2buffer[param.name][self._rank] - with paddle.amp.auto_cast(enable=False): - param.fw_storage = core.VarBase( - self._task_flow.full_param[param.name]._slice(start, - end), - param.name + "@slice").detach().clone() + param.fw_storage = core.VarBase( + self._task_flow.full_param[param.name]._slice( + start, end), param.name + "@slice").detach().clone() param.status = "part" tmp_var = self._task_flow.full_param.pop(param.name) tmp_var._clear() + if self._offload: + param.fw_storage = _device2cpu(param.fw_storage, True) + return reduce def _redefine_opt_step(self): @@ -429,7 +500,11 @@ class ShardingStage3(nn.Layer): def _opt_step(self): if not update_scaler: params_slice_func() - opt_step() + if self.offload: + with device_guard(device="cpu"): + opt_step() + else: + opt_step() self._optim.step = MethodType(_opt_step, self._optim) @@ -443,7 +518,7 @@ class ShardingStage3(nn.Layer): def ForwardPreHooks(layer, order_tracer, trainable_params, param2buffer, rank, - group, sync_comm, task_flow): + group, sync_comm, offload, task_flow): # Record layer's id layer_id = id(layer) @@ -451,21 +526,28 @@ def ForwardPreHooks(layer, order_tracer, trainable_params, param2buffer, rank, if layer_id not in order_tracer.keys() or sync_comm: use_calc, sync_wait = True, True + + # Whether to use calc stream task_flow.use_calc[layer_id] = use_calc else: + # Whether to use calc stream task_flow.use_calc[layer_id] = use_calc - _wait_layer(trainable_params, layer_id, task_flow, group, use_calc) + # wait current layer params + _wait_layer(trainable_params[layer_id], task_flow, group, use_calc, + offload) if layer_id == order_tracer["layer"][-1]: return order_ = order_tracer[layer_id] layer_id = order_tracer["layer"][order_ + 1] + _allgather_buffer( - layer_id, - trainable_params, + trainable_params[layer_id], group, use_calc_stream=use_calc, task_flow=task_flow, - sync_wait=sync_wait) + sync_wait=sync_wait, + offload=offload) + return @@ -473,15 +555,20 @@ class ForwardPostHooks(PyLayer): @staticmethod def forward(ctx, inputs, layer, order_tracer, trainable_params, param2buffer, param2buffer_size, rank, group, sync_comm, - task_flow): - _release_param(layer, trainable_params, param2buffer, rank, task_flow) + offload, task_flow): layer_id = id(layer) + # release current layer full params + _release_param(trainable_params[layer_id], param2buffer, rank, + task_flow, offload) + if layer_id not in order_tracer.keys(): order_ = order_tracer["order"] order_tracer[layer_id] = order_ order_tracer["order"] += 1 order_tracer["layer"].append(layer_id) + + #Record bw info ctx.order_tracer = order_tracer ctx.task_flow = task_flow ctx.group = group @@ -489,6 +576,7 @@ class ForwardPostHooks(PyLayer): ctx.sync_comm = sync_comm ctx.trainable_params = trainable_params ctx.param2buffer_size = param2buffer_size + ctx.offload = offload return inputs @@ -502,31 +590,39 @@ class ForwardPostHooks(PyLayer): trainable_params = ctx.trainable_params param2buffer_size = ctx.param2buffer_size sync_comm = ctx.sync_comm + offload = ctx.offload layer_id = id(layer) use_calc, sync_wait = False, False + + # Allgather params synchronization if sync_comm: use_calc, sync_wait = True, True _allgather_buffer( - layer_id, - trainable_params, + trainable_params[layer_id], group, use_calc_stream=use_calc, task_flow=task_flow, - sync_wait=sync_wait) + sync_wait=sync_wait, + offload=offload) else: - _wait_layer(trainable_params, layer_id, task_flow, group, use_calc) - _create_params_grad(layer, trainable_params, param2buffer_size, + _wait_layer(trainable_params[layer_id], task_flow, group, use_calc, + offload) + + # Create params's grad + _create_params_grad(trainable_params[layer_id], param2buffer_size, task_flow) + + # Whether to use calc stream task_flow.use_calc[layer_id] = use_calc if layer_id != order_tracer["layer"][0] and not sync_comm: layer_next_id = order_tracer["layer"][order_tracer[layer_id] - 1] _allgather_buffer( - layer_next_id, - trainable_params, + trainable_params[layer_next_id], group, use_calc_stream=use_calc, task_flow=task_flow, - sync_wait=sync_wait) + sync_wait=sync_wait, + offload=offload) return args @@ -547,8 +643,12 @@ class TaskFlow: self.callback = callback -def _release_param(layer, trainable_params, param2buffer, rank, task_flow): - for param in trainable_params[id(layer)]: +def _release_param(trainable_params, + param2buffer, + rank, + task_flow, + offload=False): + for param in trainable_params: # async communicate share weight not clear param.use_count -= 1 if param.use_count == 0: @@ -562,11 +662,18 @@ def _release_param(layer, trainable_params, param2buffer, rank, task_flow): param.status = "part" tmp_var = task_flow.full_param.pop(param.name) tmp_var._clear() + + if offload: + param.fw_storage = _device2cpu(param.fw_storage) return -def _wait_layer(trainable_params, layer_id, task_flow, group, use_calc_stream): - for param in trainable_params[layer_id]: +def _wait_layer(trainable_params, + task_flow, + group, + use_calc_stream, + offload=False): + for param in trainable_params: if param.status == "all": param.use_count += 1 continue @@ -576,36 +683,43 @@ def _wait_layer(trainable_params, layer_id, task_flow, group, use_calc_stream): paddle.device.cuda.synchronize() core.VarBase(full_param._slice(0, param._numel()))._share_buffer_to( param) - param.value().get_tensor()._set_dims(param.shape) param.fw_storage._clear() param.fw_storage = None param.status = "all" param.use_count += 1 else: _allgather_buffer( - layer_id, trainable_params, group, - use_calc_stream, - task_flow, - sync_wait=True) + use_calc_stream=True, + task_flow=task_flow, + sync_wait=True, + offload=offload) break return task_flow -def _allgather_buffer(layer_id, - trainable_params, +def _allgather_buffer(trainable_params, group, use_calc_stream, task_flow, - sync_wait=False): - for param in trainable_params[layer_id]: + sync_wait=False, + offload=False, + convert2cpu=False): + + for param in trainable_params: if param.status == "all": param.use_count += 1 continue + + if offload: + param.fw_storage = _cpu2device(param) + with paddle.amp.auto_cast(enable=False): full_param = _all_gather( param.fw_storage, group, use_calc_stream=use_calc_stream) + + # Allgather current layer in the 1st step if sync_wait: with paddle.amp.auto_cast(enable=False): dist.wait( @@ -614,18 +728,26 @@ def _allgather_buffer(layer_id, use_calc_stream=use_calc_stream) core.VarBase(full_param._slice(0, param._numel()))._share_buffer_to( param) - param.value().get_tensor()._set_dims(param.shape) param.fw_storage._clear() param.fw_storage = None param.status = "all" param.use_count += 1 task_flow.full_param[param.name] = full_param + + # parameter converts to cpu + if convert2cpu: + p_name = param.name + param = _device2cpu(param) + tmp_var = task_flow.full_param.pop(p_name) + tmp_var._clear() + task_flow.full_param[p_name] = param + return task_flow @paddle.no_grad() -def _create_params_grad(layer, trainable_params, param2buffer_size, task_flow): - for param in trainable_params[id(layer)]: +def _create_params_grad(trainable_params, param2buffer_size, task_flow): + for param in trainable_params: if param.name in task_flow.full_grad.keys(): continue assert isinstance(param2buffer_size[param.name], int) @@ -668,6 +790,23 @@ def _OptimizerWrapper(optimizer, offload, group, update_params_slice): return optimizer +def _device2cpu(trans_param, convert_dtype=False): + if convert_dtype: + trans_param = paddle.cast(trans_param, Type.fp32.value) + tmp_p = trans_param.cpu() + trans_param._clear() + return tmp_p + + +def _cpu2device(param): + tmp_p = param.fw_storage.cuda(DEV_ID) + param.fw_storage._clear() + if tmp_p.dtype == Type.fp32.value and param2dtype[ + param.name] == Type.fp16.value: + tmp_p = paddle.cast(tmp_p, Type.fp16.value) + return tmp_p + + def _current_layer_params(layer): return layer.parameters( include_sublayers=False) + list(layer.extra_parameters) if hasattr( diff --git a/python/paddle/fluid/tests/unittests/dygraph_sharding_stage3.py b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage3.py index 5b0bec9c454b0fdfaea4d96ac821bfe8f859eff5..ddd31bc057f2e3f6eeeae571615f5e2991e6a8a2 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_sharding_stage3.py +++ b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage3.py @@ -30,7 +30,6 @@ from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage3 import Shar from paddle.distributed.fleet.meta_parallel.sharding.sharding_utils import ShardingScaler epoch = 10 -batch_size = 32 paddle.seed(2021) np.random.seed(2021) base_lr = 0.1 @@ -66,10 +65,10 @@ def reader_decorator(linear_size=1000): def optimizer_setting(model, use_pure_fp16, opt_group=False): clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0) - optimizer = paddle.optimizer.AdamW( + optimizer = paddle.optimizer.Momentum( parameters=[{ - "params": model.parameters() - }] if opt_group else model.parameters(), + "params": list(model.parameters()) + }] if opt_group else list(model.parameters()), learning_rate=0.001, weight_decay=0.00001, grad_clip=clip, @@ -82,6 +81,7 @@ def train_mlp(model, sharding_stage, use_pure_fp16=False, accumulate_grad=False, + batch_size=100, opt_group=False, recompute=False): group = paddle.distributed.new_group([0, 1]) @@ -104,10 +104,14 @@ def train_mlp(model, optimizer, group=group, buffer_max_size=2**21, - accumulate_grads=accumulate_grad) + accumulate_grads=batch_size == 20) elif sharding_stage == 3: model = ShardingStage3( - model, optimizer=optimizer, group=group, sync_comm=recompute) + model, + optimizer=optimizer, + group=group, + accumulate_grads=batch_size == 20, + sync_comm=recompute) train_reader = paddle.batch( reader_decorator(), batch_size=batch_size, drop_last=True) @@ -131,21 +135,22 @@ def train_mlp(model, loss = paddle.nn.functional.cross_entropy( input=out, label=label) avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32)) + if not use_pure_fp16: + avg_loss.backward() + else: + scaler.scale(avg_loss).backward() + if not accumulate_grad: if not use_pure_fp16: - avg_loss.backward() optimizer.step() else: - scaler.scale(avg_loss).backward() scaler.step(optimizer) scaler.update() optimizer.clear_grad() if accumulate_grad: if not use_pure_fp16: - avg_loss.backward() optimizer.step() else: - scaler.scale(avg_loss).backward() scaler.step(optimizer) scaler.update() optimizer.clear_grad() @@ -168,48 +173,50 @@ def test_stage2_stage3(): mlp8.set_state_dict(state_dict) # fp32 stage2_params = train_mlp( - mlp1, sharding_stage=2, use_pure_fp16=False, opt_group=True) + mlp1, sharding_stage=2, use_pure_fp16=False, opt_group=False) stage3_params = train_mlp( - mlp2, sharding_stage=3, use_pure_fp16=False, opt_group=True) + mlp2, sharding_stage=3, use_pure_fp16=False, opt_group=False) + for i in range(len(stage2_params)): - for j in range(len(stage3_params)): - if stage2_params[i].name == stage3_params[j].name: - np.testing.assert_allclose( - stage2_params[i].numpy(), - stage3_params[j].numpy(), - rtol=1e-6) + np.testing.assert_allclose( + stage2_params[i].numpy(), + stage3_params[i].numpy(), + rtol=1e-6, + atol=1e-6) + # fp32 accumulate grad - stage2_params = train_mlp( + stage3_params = train_mlp( mlp3, - sharding_stage=2, + sharding_stage=3, use_pure_fp16=False, accumulate_grad=True, opt_group=True) - stage3_params = train_mlp( + stage3_params_add = train_mlp( mlp4, sharding_stage=3, use_pure_fp16=False, accumulate_grad=True, + batch_size=20, opt_group=True) - for i in range(len(stage2_params)): - for j in range(len(stage3_params)): - if stage2_params[i].name == stage3_params[j].name: - np.testing.assert_allclose( - stage2_params[i].numpy(), - stage3_params[j].numpy(), - rtol=1e-6) + for i in range(len(stage3_params)): + np.testing.assert_allclose( + stage3_params[i].numpy(), + stage3_params_add[i].numpy(), + rtol=1e-6, + atol=1e-6) + # fp16 stage2_params = train_mlp( mlp5, sharding_stage=2, use_pure_fp16=True, opt_group=False) stage3_params = train_mlp( mlp6, sharding_stage=3, use_pure_fp16=True, opt_group=False) for i in range(len(stage2_params)): - for j in range(len(stage3_params)): - if stage2_params[i].name == stage3_params[j].name: - np.testing.assert_allclose( - stage2_params[i].numpy(), - stage3_params[j].numpy(), - rtol=1e-6) + np.testing.assert_allclose( + stage2_params[i].numpy(), + stage3_params[i].numpy(), + rtol=1e-4, + atol=1e-4) + # fp16 recompute stage3_params = train_mlp( mlp7, sharding_stage=3, use_pure_fp16=True, opt_group=False) @@ -220,12 +227,8 @@ def test_stage2_stage3(): opt_group=False, recompute=True) for i in range(len(stage3_params)): - for j in range(len(stage3_params_re)): - if stage3_params[i].name == stage3_params_re[j].name: - np.testing.assert_allclose( - stage3_params[i].numpy(), - stage3_params_re[j].numpy(), - rtol=1e-6) + np.testing.assert_allclose( + stage3_params[i].numpy(), stage3_params_re[i].numpy(), rtol=1e-6) return diff --git a/python/paddle/fluid/tests/unittests/dygraph_sharding_stage3_offload.py b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage3_offload.py new file mode 100644 index 0000000000000000000000000000000000000000..4d4b4c02068aa71c51f02dd8df74fac14e65fafc --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage3_offload.py @@ -0,0 +1,192 @@ +# -*- coding: UTF-8 -*- + +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import argparse +import ast +import time +import paddle +import paddle.fluid as fluid +from paddle.fluid.dygraph.nn import Linear +from paddle.distributed import fleet +from paddle.fluid.dygraph import nn + +from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage3 import ShardingStage3 +from paddle.distributed.fleet.meta_parallel.sharding.sharding_utils import ShardingScaler + +epoch = 10 +batch_size = 32 +paddle.seed(2022) +np.random.seed(2022) +base_lr = 0.1 +momentum_rate = 0.9 +l2_decay = 1e-4 +fleet.init(is_collective=True) + + +class MLP(fluid.Layer): + def __init__(self, linear_size=1000, param_attr=None, bias_attr=None): + super(MLP, self).__init__() + + self._linear1 = Linear(linear_size, linear_size) + self._linear2 = Linear(linear_size, linear_size) + self._linear3 = Linear(linear_size, 10) + + def forward(self, inputs): + y = self._linear1(inputs) + y = self._linear2(y) + y = self._linear3(y) + return y + + +def reader_decorator(linear_size=1000): + def __reader__(): + for _ in range(100): + img = np.random.rand(linear_size).astype('float32') + label = np.ones(1).astype('int64') + yield img, label + + return __reader__ + + +def optimizer_setting(model, use_pure_fp16, opt_group=False): + clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0) + optimizer = paddle.optimizer.AdamW( + parameters=[{ + "params": model.parameters() + }] if opt_group else model.parameters(), + learning_rate=0.001, + weight_decay=0.00001, + grad_clip=clip, + multi_precision=use_pure_fp16) + + return optimizer + + +def train_mlp(model, + use_pure_fp16=False, + accumulate_grad=False, + offload=False, + convert2cpu=False): + group = paddle.distributed.new_group([0, 1]) + optimizer = optimizer_setting(model=model, use_pure_fp16=use_pure_fp16) + + if use_pure_fp16: + model = paddle.amp.decorate( + models=model, level='O2', save_dtype='float32') + scaler = paddle.amp.GradScaler(init_loss_scaling=32768) + scaler = ShardingScaler(scaler) + + model = ShardingStage3( + model, optimizer=optimizer, group=group, offload=offload) + + train_reader = paddle.batch( + reader_decorator(), batch_size=batch_size, drop_last=True) + + train_loader = paddle.io.DataLoader.from_generator( + capacity=32, + use_double_buffer=True, + iterable=True, + return_list=True, + use_multiprocess=True) + train_loader.set_sample_list_generator(train_reader) + + for eop in range(epoch): + model.train() + for batch_id, data in enumerate(train_loader()): + img, label = data + label.stop_gradient = True + img.stop_gradient = True + with paddle.amp.auto_cast(True, level='O2'): + out = model(img) + loss = paddle.nn.functional.cross_entropy( + input=out, label=label) + avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32)) + if not use_pure_fp16: + avg_loss.backward() + else: + scaler.scale(avg_loss).backward() + if not accumulate_grad: + if not use_pure_fp16: + optimizer.step() + else: + scaler.step(optimizer) + scaler.update() + optimizer.clear_grad() + if accumulate_grad: + if not use_pure_fp16: + optimizer.step() + else: + scaler.step(optimizer) + scaler.update() + optimizer.clear_grad() + if not convert2cpu: + model.get_all_parameters() + else: + model.get_all_parameters(convert2cpu) + return model.parameters() + + +def test_stage3_offload(): + mlp, mlp1, mlp2, mlp3, mlp4, mlp5, mlp6 = MLP(), MLP(), MLP(), MLP(), MLP( + ), MLP(), MLP() + state_dict = mlp.state_dict() + mlp1.set_state_dict(state_dict) + mlp2.set_state_dict(state_dict) + mlp3.set_state_dict(state_dict) + mlp4.set_state_dict(state_dict) + mlp5.set_state_dict(state_dict) + mlp6.set_state_dict(state_dict) + + # fp32 offload + stage3_params = train_mlp(mlp1, use_pure_fp16=False) + stage3_params_offload = train_mlp(mlp2, use_pure_fp16=False, offload=True) + for i in range(len(stage3_params)): + np.testing.assert_allclose( + stage3_params[i].numpy(), + stage3_params_offload[i].numpy(), + rtol=1e-6, + atol=1e-8) + + # fp16 offload + stage3_params = train_mlp(mlp3, use_pure_fp16=True) + stage3_params_offload = train_mlp(mlp4, use_pure_fp16=True, offload=True) + for i in range(len(stage3_params)): + np.testing.assert_allclose( + stage3_params[i].numpy(), + stage3_params_offload[i].numpy(), + rtol=1e-2, + atol=1e-2) + + # fp32 accumulate grad offload + stage3_params = train_mlp(mlp5, use_pure_fp16=False, accumulate_grad=True) + stage3_params_offload = train_mlp( + mlp6, + use_pure_fp16=False, + accumulate_grad=True, + offload=True, + convert2cpu=True) + for i in range(len(stage3_params)): + np.testing.assert_allclose( + stage3_params[i].numpy(), + stage3_params_offload[i].numpy(), + rtol=1e-6, + atol=1e-8) + return + + +if __name__ == '__main__': + test_stage3_offload() diff --git a/python/paddle/fluid/tests/unittests/test_dygraph_sharding_stage2.py b/python/paddle/fluid/tests/unittests/test_dygraph_sharding_stage2.py index f76dcb5687c2ab77e411e7ef3c4de64200d99c66..669ab7d8f7f342653bef7cb6b48abf75ee6b2d11 100644 --- a/python/paddle/fluid/tests/unittests/test_dygraph_sharding_stage2.py +++ b/python/paddle/fluid/tests/unittests/test_dygraph_sharding_stage2.py @@ -23,10 +23,10 @@ from test_parallel_dygraph_dataparallel import TestMultipleGpus class TestDygraphShardingStage2(TestMultipleGpus): # check sharding logic as well as the accuracy with single mode - def test_dygraph_sharding_optimizer_stage2(self): + def test_dygraph_sharding_stage2(self): self.run_mnist_2gpu('dygraph_sharding_stage2.py') - def test_dygraph_sharding_optimizer_stage2_offload(self): + def test_dygraph_sharding_stage2_offload(self): self.run_mnist_2gpu('dygraph_sharding_stage2_offload.py') diff --git a/python/paddle/fluid/tests/unittests/test_dygraph_sharding_stage3.py b/python/paddle/fluid/tests/unittests/test_dygraph_sharding_stage3.py index 89d5f2e8c7b292592369651887fc72bcabcb77ea..c7da5d1e941b43c6ae28b2a5a84a59bbea311a24 100644 --- a/python/paddle/fluid/tests/unittests/test_dygraph_sharding_stage3.py +++ b/python/paddle/fluid/tests/unittests/test_dygraph_sharding_stage3.py @@ -23,9 +23,12 @@ from test_parallel_dygraph_dataparallel import TestMultipleGpus class TestDygraphShardingStage3(TestMultipleGpus): # check sharding logic as well as the accuracy with single mode - def test_dygraph_sharding_optimizer_stage3(self): + def test_dygraph_sharding_stage3(self): self.run_mnist_2gpu('dygraph_sharding_stage3.py') + def test_dygraph_sharding_stage3_offload(self): + self.run_mnist_2gpu('dygraph_sharding_stage3_offload.py') + if __name__ == "__main__": unittest.main()