未验证 提交 46823104 编写于 作者: B Baibaifan 提交者: GitHub

Add sharding stage3 offload (#38989)

上级 f4623876
......@@ -33,7 +33,7 @@ from paddle.fluid.framework import ParamBase
from paddle.fluid.clip import ClipGradByGlobalNorm
from paddle.distributed.collective import _get_global_group
from .sharding_utils import Type, ShardingClipGrad
from .sharding_utils import Type, ShardingClipGrad, device_guard
from ..pp_utils.utils import _all_gather
# CUDA alignment 256 bytes
......@@ -56,6 +56,13 @@ class ShardingStage3(nn.Layer):
.. ZeRO: https://arxiv.org/pdf/1910.02054.pdf.
"""
# TODO (Baibaifan)
# Feature Notes::
# 1. The model supports the segmentation of parameters by global ranks in layers.
# 2. Support communication flow and computing flow.
# 3. Support offload function.
# 4. Support the establishment of independent communication groups.
def __init__(self,
layer,
optimizer,
......@@ -77,6 +84,15 @@ class ShardingStage3(nn.Layer):
self._offload = offload
self._sync_comm = sync_comm
global DEV
DEV = "cpu" if paddle.get_device() == "cpu" else paddle.get_device(
).split(":")[0]
global DEV_ID
DEV_ID = 0 if paddle.get_device() == "cpu" else int(paddle.get_device()
.split(":")[1])
global param2dtype
param2dtype = dict()
# Communication group establishment
self._group = dist.new_group(_get_global_group()
.ranks) if group is None else group
......@@ -85,6 +101,9 @@ class ShardingStage3(nn.Layer):
self._rank = self._group.rank
self._global_root_rank = 0 # picking rank 0 as the reference
self._global_ranks = self._group.ranks
# Parameter segmentation for global ranks
# After flatten -> self._param2buffer_size, self._param2buffer, self._trainable_params
self._param2buffer_size = dict() # {param.name: size}
self._param2buffer = dict(
) # {param.name: [(start0, end0),(start1, end1), ...]}
......@@ -116,12 +135,16 @@ class ShardingStage3(nn.Layer):
self._order_tracer = OrderedDict()
self._order_tracer["order"] = 0
self._order_tracer["layer"] = []
# Register task flow
self._task_flow = TaskFlow()
# Register forward hooks
self._register_forward_hooks(self._layer)
# Register backward parameter hooks
self._register_backward_hooks()
# Redefine optimizer step and clear function
self._redefine_opt_step()
self._redefine_opt_clear()
......@@ -152,7 +175,6 @@ class ShardingStage3(nn.Layer):
param, "fw_storage"
), "Find {} don't have fw_storage attribute.".format(param.name)
# param.bw_storage.zero_()
param.fw_storage.clear_gradient(False)
param.fw_storage._gradient_set_empty(False)
param.bw_storage._clear()
......@@ -192,6 +214,9 @@ class ShardingStage3(nn.Layer):
return fw
def _segment_rank_params(self, layer, name="last_layer"):
"""
Flatten parameters according to layer.
"""
current_layer_params = _current_layer_params(layer)
if current_layer_params:
CHECK_LAYER[id(layer)] = name
......@@ -201,6 +226,10 @@ class ShardingStage3(nn.Layer):
self._segment_rank_params(sub_layer, name)
def _flatten_layer_params(self, layer, current_layer_params):
"""
Parameter segmentation and memory integration.
"""
def _add_manage_info(trainable_param):
return _PartitionParam(trainable_param)
......@@ -238,8 +267,13 @@ class ShardingStage3(nn.Layer):
# 3.Flatten layer params and release other rank buffer
self._param_storage(param, buffer_size)
# Record param's dtype
param2dtype[param.name] = param.dtype
def _param_storage(self, param, buffer_size):
"""
This is a function to simplify the handling of parameter InternalStorages.
"""
assert isinstance(buffer_size, int)
value = np.zeros(
buffer_size,
......@@ -264,16 +298,31 @@ class ShardingStage3(nn.Layer):
param._clear()
# Current rank param_storage
param.fw_storage = core.VarBase(
buffer._slice(start, end), "slice@" + param.name)
if self._offload:
param.fw_storage = core.VarBase(
buffer._slice(start, end),
core.CPUPlace(), "slice@" + param.name)
else:
param.fw_storage = core.VarBase(
buffer._slice(start, end), "slice@" + param.name)
param.status = "part"
# Updata optimizer master weights
if param.dtype == Type.fp16.value:
if param.dtype == Type.fp16.value and not self._offload:
self._optim._master_weights[param.fw_storage.name] = paddle.cast(
param.fw_storage, Type.fp32.value)
def _register_forward_hooks(self, layer):
"""
Register pylayer to manage memory slices.
There are four stages:
FW
1. Before the forward layers, synchronize the full parameters.
2. After the forward layers, release the full parameter and keep the parameter slice.
BW
3. Before the backward layers, synchronize the full parameters and create param's grad.
4. After the gradient accumulation, release the full parameter and keep the parameter slice.
"""
current_layer_params = _current_layer_params(layer)
if current_layer_params:
self._register_forward_all_hooks(layer, self._task_flow)
......@@ -286,13 +335,13 @@ class ShardingStage3(nn.Layer):
return ForwardPreHooks(layer, self._order_tracer,
self._trainable_params, self._param2buffer,
self._rank, self._group, self._sync_comm,
task_flow)
self._offload, task_flow)
def _forward_post_hook(layer, inputs, outputs):
return ForwardPostHooks.apply(
outputs, layer, self._order_tracer, self._trainable_params,
self._param2buffer, self._param2buffer_size, self._rank,
self._group, self._sync_comm, task_flow)
self._group, self._sync_comm, self._offload, task_flow)
# register previous forward hooks
sub_layer.register_forward_pre_hook(_forward_pre_hook)
......@@ -302,6 +351,10 @@ class ShardingStage3(nn.Layer):
@paddle.no_grad()
def _sync_buffers(self):
"""
Sync all the param buffers from all ranks (exp: batch norm statistics).
"""
for buffer in self._layer.buffers(include_sublayers=True):
dist.broadcast(
buffer,
......@@ -319,6 +372,9 @@ class ShardingStage3(nn.Layer):
return getattr(self._layer, name)
def _update_params(self):
"""
Update parameters to optimizer memory slice.
"""
update_list = []
assert len(self._trainable_params.keys()) > 0
current_layer_params = self._layer.parameters(include_sublayers=True)
......@@ -331,36 +387,35 @@ class ShardingStage3(nn.Layer):
param.name)
if self._accumulate_grads:
param.bw_storage.scale_(scale=self._world_size_scaling)
if self._offload:
with device_guard(device="cpu"):
param.bw_storage.scale_(scale=self._world_size_scaling)
else:
param.bw_storage.scale_(scale=self._world_size_scaling)
param.fw_storage = _VarBaseWrapper(param)
param.fw_storage._copy_gradient_from(param.bw_storage)
update_list.append(param)
return update_list
def get_all_parameters(self):
def get_all_parameters(self, convert2cpu=False):
"""
Get the full parameters and return the corresponding task flows.
"""
assert len(self._trainable_params.keys()) > 0
current_layer_params = self._layer.parameters(include_sublayers=True)
trainable_params = list(
filter(lambda x: x.trainable, current_layer_params))
for param in trainable_params:
if param.use_count > 0:
continue
assert hasattr(
param,
"fw_storage"), "Find {} don't have fw_storage attribute".format(
param.name)
full_param = _all_gather(
param.fw_storage, self._group, use_calc_stream=True)
dist.wait(
tensor=full_param, group=self._group, use_calc_stream=True)
core.VarBase(full_param._slice(0, param._numel()))._share_buffer_to(
param)
param.value().get_tensor()._set_dims(param.shape)
param.fw_storage._clear()
param.fw_storage = None
param.status = "all"
param.use_count += 1
t_flow = _allgather_buffer(
trainable_params,
self._group,
use_calc_stream=True,
task_flow=TaskFlow(),
sync_wait=True,
offload=self._offload,
convert2cpu=convert2cpu)
if convert2cpu:
for param in current_layer_params:
t_flow.full_param[param.name]._share_buffer_to(param)
self._optim._parameter_list = self._ori_parameter_list
self._optim._param_groups = self._ori_param_groups
......@@ -393,13 +448,28 @@ class ShardingStage3(nn.Layer):
use_calc_stream=True)
start, end = self._param2buffer[param.name][self._rank]
if not self._accumulate_grads or param.bw_storage is None:
if not self._accumulate_grads or param.bw_storage is None or not param.bw_storage.value(
).get_tensor()._is_initialized():
param.bw_storage = core.VarBase(
full_grad._slice(start, end)).detach().clone()
if self._offload:
param.bw_storage = _device2cpu(param.bw_storage,
True)
else:
param.bw_storage.add_(
core.VarBase(full_grad._slice(start, end)).detach()
.clone())
if self._offload:
cpu_grad = _device2cpu(
core.VarBase(full_grad._slice(start, end))
.detach().clone(), True)
param.bw_storage = paddle.add(param.bw_storage,
cpu_grad)
else:
# param.bw_storage.add_(
# core.VarBase(full_grad._slice(start, end))
# .detach().clone())
param.bw_storage = paddle.add(
param.bw_storage,
core.VarBase(full_grad._slice(
start, end)).detach().clone())
param.clear_gradient(False)
param._gradient_set_empty(False)
tmp_var = self._task_flow.full_grad.pop(param.name)
......@@ -410,15 +480,16 @@ class ShardingStage3(nn.Layer):
param.use_count = 0
param._clear()
start, end = self._param2buffer[param.name][self._rank]
with paddle.amp.auto_cast(enable=False):
param.fw_storage = core.VarBase(
self._task_flow.full_param[param.name]._slice(start,
end),
param.name + "@slice").detach().clone()
param.fw_storage = core.VarBase(
self._task_flow.full_param[param.name]._slice(
start, end), param.name + "@slice").detach().clone()
param.status = "part"
tmp_var = self._task_flow.full_param.pop(param.name)
tmp_var._clear()
if self._offload:
param.fw_storage = _device2cpu(param.fw_storage, True)
return reduce
def _redefine_opt_step(self):
......@@ -429,7 +500,11 @@ class ShardingStage3(nn.Layer):
def _opt_step(self):
if not update_scaler:
params_slice_func()
opt_step()
if self.offload:
with device_guard(device="cpu"):
opt_step()
else:
opt_step()
self._optim.step = MethodType(_opt_step, self._optim)
......@@ -443,7 +518,7 @@ class ShardingStage3(nn.Layer):
def ForwardPreHooks(layer, order_tracer, trainable_params, param2buffer, rank,
group, sync_comm, task_flow):
group, sync_comm, offload, task_flow):
# Record layer's id
layer_id = id(layer)
......@@ -451,21 +526,28 @@ def ForwardPreHooks(layer, order_tracer, trainable_params, param2buffer, rank,
if layer_id not in order_tracer.keys() or sync_comm:
use_calc, sync_wait = True, True
# Whether to use calc stream
task_flow.use_calc[layer_id] = use_calc
else:
# Whether to use calc stream
task_flow.use_calc[layer_id] = use_calc
_wait_layer(trainable_params, layer_id, task_flow, group, use_calc)
# wait current layer params
_wait_layer(trainable_params[layer_id], task_flow, group, use_calc,
offload)
if layer_id == order_tracer["layer"][-1]: return
order_ = order_tracer[layer_id]
layer_id = order_tracer["layer"][order_ + 1]
_allgather_buffer(
layer_id,
trainable_params,
trainable_params[layer_id],
group,
use_calc_stream=use_calc,
task_flow=task_flow,
sync_wait=sync_wait)
sync_wait=sync_wait,
offload=offload)
return
......@@ -473,15 +555,20 @@ class ForwardPostHooks(PyLayer):
@staticmethod
def forward(ctx, inputs, layer, order_tracer, trainable_params,
param2buffer, param2buffer_size, rank, group, sync_comm,
task_flow):
_release_param(layer, trainable_params, param2buffer, rank, task_flow)
offload, task_flow):
layer_id = id(layer)
# release current layer full params
_release_param(trainable_params[layer_id], param2buffer, rank,
task_flow, offload)
if layer_id not in order_tracer.keys():
order_ = order_tracer["order"]
order_tracer[layer_id] = order_
order_tracer["order"] += 1
order_tracer["layer"].append(layer_id)
#Record bw info
ctx.order_tracer = order_tracer
ctx.task_flow = task_flow
ctx.group = group
......@@ -489,6 +576,7 @@ class ForwardPostHooks(PyLayer):
ctx.sync_comm = sync_comm
ctx.trainable_params = trainable_params
ctx.param2buffer_size = param2buffer_size
ctx.offload = offload
return inputs
......@@ -502,31 +590,39 @@ class ForwardPostHooks(PyLayer):
trainable_params = ctx.trainable_params
param2buffer_size = ctx.param2buffer_size
sync_comm = ctx.sync_comm
offload = ctx.offload
layer_id = id(layer)
use_calc, sync_wait = False, False
# Allgather params synchronization
if sync_comm:
use_calc, sync_wait = True, True
_allgather_buffer(
layer_id,
trainable_params,
trainable_params[layer_id],
group,
use_calc_stream=use_calc,
task_flow=task_flow,
sync_wait=sync_wait)
sync_wait=sync_wait,
offload=offload)
else:
_wait_layer(trainable_params, layer_id, task_flow, group, use_calc)
_create_params_grad(layer, trainable_params, param2buffer_size,
_wait_layer(trainable_params[layer_id], task_flow, group, use_calc,
offload)
# Create params's grad
_create_params_grad(trainable_params[layer_id], param2buffer_size,
task_flow)
# Whether to use calc stream
task_flow.use_calc[layer_id] = use_calc
if layer_id != order_tracer["layer"][0] and not sync_comm:
layer_next_id = order_tracer["layer"][order_tracer[layer_id] - 1]
_allgather_buffer(
layer_next_id,
trainable_params,
trainable_params[layer_next_id],
group,
use_calc_stream=use_calc,
task_flow=task_flow,
sync_wait=sync_wait)
sync_wait=sync_wait,
offload=offload)
return args
......@@ -547,8 +643,12 @@ class TaskFlow:
self.callback = callback
def _release_param(layer, trainable_params, param2buffer, rank, task_flow):
for param in trainable_params[id(layer)]:
def _release_param(trainable_params,
param2buffer,
rank,
task_flow,
offload=False):
for param in trainable_params:
# async communicate share weight not clear
param.use_count -= 1
if param.use_count == 0:
......@@ -562,11 +662,18 @@ def _release_param(layer, trainable_params, param2buffer, rank, task_flow):
param.status = "part"
tmp_var = task_flow.full_param.pop(param.name)
tmp_var._clear()
if offload:
param.fw_storage = _device2cpu(param.fw_storage)
return
def _wait_layer(trainable_params, layer_id, task_flow, group, use_calc_stream):
for param in trainable_params[layer_id]:
def _wait_layer(trainable_params,
task_flow,
group,
use_calc_stream,
offload=False):
for param in trainable_params:
if param.status == "all":
param.use_count += 1
continue
......@@ -576,36 +683,43 @@ def _wait_layer(trainable_params, layer_id, task_flow, group, use_calc_stream):
paddle.device.cuda.synchronize()
core.VarBase(full_param._slice(0, param._numel()))._share_buffer_to(
param)
param.value().get_tensor()._set_dims(param.shape)
param.fw_storage._clear()
param.fw_storage = None
param.status = "all"
param.use_count += 1
else:
_allgather_buffer(
layer_id,
trainable_params,
group,
use_calc_stream,
task_flow,
sync_wait=True)
use_calc_stream=True,
task_flow=task_flow,
sync_wait=True,
offload=offload)
break
return task_flow
def _allgather_buffer(layer_id,
trainable_params,
def _allgather_buffer(trainable_params,
group,
use_calc_stream,
task_flow,
sync_wait=False):
for param in trainable_params[layer_id]:
sync_wait=False,
offload=False,
convert2cpu=False):
for param in trainable_params:
if param.status == "all":
param.use_count += 1
continue
if offload:
param.fw_storage = _cpu2device(param)
with paddle.amp.auto_cast(enable=False):
full_param = _all_gather(
param.fw_storage, group, use_calc_stream=use_calc_stream)
# Allgather current layer in the 1st step
if sync_wait:
with paddle.amp.auto_cast(enable=False):
dist.wait(
......@@ -614,18 +728,26 @@ def _allgather_buffer(layer_id,
use_calc_stream=use_calc_stream)
core.VarBase(full_param._slice(0, param._numel()))._share_buffer_to(
param)
param.value().get_tensor()._set_dims(param.shape)
param.fw_storage._clear()
param.fw_storage = None
param.status = "all"
param.use_count += 1
task_flow.full_param[param.name] = full_param
# parameter converts to cpu
if convert2cpu:
p_name = param.name
param = _device2cpu(param)
tmp_var = task_flow.full_param.pop(p_name)
tmp_var._clear()
task_flow.full_param[p_name] = param
return task_flow
@paddle.no_grad()
def _create_params_grad(layer, trainable_params, param2buffer_size, task_flow):
for param in trainable_params[id(layer)]:
def _create_params_grad(trainable_params, param2buffer_size, task_flow):
for param in trainable_params:
if param.name in task_flow.full_grad.keys():
continue
assert isinstance(param2buffer_size[param.name], int)
......@@ -668,6 +790,23 @@ def _OptimizerWrapper(optimizer, offload, group, update_params_slice):
return optimizer
def _device2cpu(trans_param, convert_dtype=False):
if convert_dtype:
trans_param = paddle.cast(trans_param, Type.fp32.value)
tmp_p = trans_param.cpu()
trans_param._clear()
return tmp_p
def _cpu2device(param):
tmp_p = param.fw_storage.cuda(DEV_ID)
param.fw_storage._clear()
if tmp_p.dtype == Type.fp32.value and param2dtype[
param.name] == Type.fp16.value:
tmp_p = paddle.cast(tmp_p, Type.fp16.value)
return tmp_p
def _current_layer_params(layer):
return layer.parameters(
include_sublayers=False) + list(layer.extra_parameters) if hasattr(
......
......@@ -30,7 +30,6 @@ from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage3 import Shar
from paddle.distributed.fleet.meta_parallel.sharding.sharding_utils import ShardingScaler
epoch = 10
batch_size = 32
paddle.seed(2021)
np.random.seed(2021)
base_lr = 0.1
......@@ -66,10 +65,10 @@ def reader_decorator(linear_size=1000):
def optimizer_setting(model, use_pure_fp16, opt_group=False):
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
optimizer = paddle.optimizer.AdamW(
optimizer = paddle.optimizer.Momentum(
parameters=[{
"params": model.parameters()
}] if opt_group else model.parameters(),
"params": list(model.parameters())
}] if opt_group else list(model.parameters()),
learning_rate=0.001,
weight_decay=0.00001,
grad_clip=clip,
......@@ -82,6 +81,7 @@ def train_mlp(model,
sharding_stage,
use_pure_fp16=False,
accumulate_grad=False,
batch_size=100,
opt_group=False,
recompute=False):
group = paddle.distributed.new_group([0, 1])
......@@ -104,10 +104,14 @@ def train_mlp(model,
optimizer,
group=group,
buffer_max_size=2**21,
accumulate_grads=accumulate_grad)
accumulate_grads=batch_size == 20)
elif sharding_stage == 3:
model = ShardingStage3(
model, optimizer=optimizer, group=group, sync_comm=recompute)
model,
optimizer=optimizer,
group=group,
accumulate_grads=batch_size == 20,
sync_comm=recompute)
train_reader = paddle.batch(
reader_decorator(), batch_size=batch_size, drop_last=True)
......@@ -131,21 +135,22 @@ def train_mlp(model,
loss = paddle.nn.functional.cross_entropy(
input=out, label=label)
avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32))
if not use_pure_fp16:
avg_loss.backward()
else:
scaler.scale(avg_loss).backward()
if not accumulate_grad:
if not use_pure_fp16:
avg_loss.backward()
optimizer.step()
else:
scaler.scale(avg_loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.clear_grad()
if accumulate_grad:
if not use_pure_fp16:
avg_loss.backward()
optimizer.step()
else:
scaler.scale(avg_loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.clear_grad()
......@@ -168,48 +173,50 @@ def test_stage2_stage3():
mlp8.set_state_dict(state_dict)
# fp32
stage2_params = train_mlp(
mlp1, sharding_stage=2, use_pure_fp16=False, opt_group=True)
mlp1, sharding_stage=2, use_pure_fp16=False, opt_group=False)
stage3_params = train_mlp(
mlp2, sharding_stage=3, use_pure_fp16=False, opt_group=True)
mlp2, sharding_stage=3, use_pure_fp16=False, opt_group=False)
for i in range(len(stage2_params)):
for j in range(len(stage3_params)):
if stage2_params[i].name == stage3_params[j].name:
np.testing.assert_allclose(
stage2_params[i].numpy(),
stage3_params[j].numpy(),
rtol=1e-6)
np.testing.assert_allclose(
stage2_params[i].numpy(),
stage3_params[i].numpy(),
rtol=1e-6,
atol=1e-6)
# fp32 accumulate grad
stage2_params = train_mlp(
stage3_params = train_mlp(
mlp3,
sharding_stage=2,
sharding_stage=3,
use_pure_fp16=False,
accumulate_grad=True,
opt_group=True)
stage3_params = train_mlp(
stage3_params_add = train_mlp(
mlp4,
sharding_stage=3,
use_pure_fp16=False,
accumulate_grad=True,
batch_size=20,
opt_group=True)
for i in range(len(stage2_params)):
for j in range(len(stage3_params)):
if stage2_params[i].name == stage3_params[j].name:
np.testing.assert_allclose(
stage2_params[i].numpy(),
stage3_params[j].numpy(),
rtol=1e-6)
for i in range(len(stage3_params)):
np.testing.assert_allclose(
stage3_params[i].numpy(),
stage3_params_add[i].numpy(),
rtol=1e-6,
atol=1e-6)
# fp16
stage2_params = train_mlp(
mlp5, sharding_stage=2, use_pure_fp16=True, opt_group=False)
stage3_params = train_mlp(
mlp6, sharding_stage=3, use_pure_fp16=True, opt_group=False)
for i in range(len(stage2_params)):
for j in range(len(stage3_params)):
if stage2_params[i].name == stage3_params[j].name:
np.testing.assert_allclose(
stage2_params[i].numpy(),
stage3_params[j].numpy(),
rtol=1e-6)
np.testing.assert_allclose(
stage2_params[i].numpy(),
stage3_params[i].numpy(),
rtol=1e-4,
atol=1e-4)
# fp16 recompute
stage3_params = train_mlp(
mlp7, sharding_stage=3, use_pure_fp16=True, opt_group=False)
......@@ -220,12 +227,8 @@ def test_stage2_stage3():
opt_group=False,
recompute=True)
for i in range(len(stage3_params)):
for j in range(len(stage3_params_re)):
if stage3_params[i].name == stage3_params_re[j].name:
np.testing.assert_allclose(
stage3_params[i].numpy(),
stage3_params_re[j].numpy(),
rtol=1e-6)
np.testing.assert_allclose(
stage3_params[i].numpy(), stage3_params_re[i].numpy(), rtol=1e-6)
return
......
# -*- coding: UTF-8 -*-
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import argparse
import ast
import time
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Linear
from paddle.distributed import fleet
from paddle.fluid.dygraph import nn
from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage3 import ShardingStage3
from paddle.distributed.fleet.meta_parallel.sharding.sharding_utils import ShardingScaler
epoch = 10
batch_size = 32
paddle.seed(2022)
np.random.seed(2022)
base_lr = 0.1
momentum_rate = 0.9
l2_decay = 1e-4
fleet.init(is_collective=True)
class MLP(fluid.Layer):
def __init__(self, linear_size=1000, param_attr=None, bias_attr=None):
super(MLP, self).__init__()
self._linear1 = Linear(linear_size, linear_size)
self._linear2 = Linear(linear_size, linear_size)
self._linear3 = Linear(linear_size, 10)
def forward(self, inputs):
y = self._linear1(inputs)
y = self._linear2(y)
y = self._linear3(y)
return y
def reader_decorator(linear_size=1000):
def __reader__():
for _ in range(100):
img = np.random.rand(linear_size).astype('float32')
label = np.ones(1).astype('int64')
yield img, label
return __reader__
def optimizer_setting(model, use_pure_fp16, opt_group=False):
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
optimizer = paddle.optimizer.AdamW(
parameters=[{
"params": model.parameters()
}] if opt_group else model.parameters(),
learning_rate=0.001,
weight_decay=0.00001,
grad_clip=clip,
multi_precision=use_pure_fp16)
return optimizer
def train_mlp(model,
use_pure_fp16=False,
accumulate_grad=False,
offload=False,
convert2cpu=False):
group = paddle.distributed.new_group([0, 1])
optimizer = optimizer_setting(model=model, use_pure_fp16=use_pure_fp16)
if use_pure_fp16:
model = paddle.amp.decorate(
models=model, level='O2', save_dtype='float32')
scaler = paddle.amp.GradScaler(init_loss_scaling=32768)
scaler = ShardingScaler(scaler)
model = ShardingStage3(
model, optimizer=optimizer, group=group, offload=offload)
train_reader = paddle.batch(
reader_decorator(), batch_size=batch_size, drop_last=True)
train_loader = paddle.io.DataLoader.from_generator(
capacity=32,
use_double_buffer=True,
iterable=True,
return_list=True,
use_multiprocess=True)
train_loader.set_sample_list_generator(train_reader)
for eop in range(epoch):
model.train()
for batch_id, data in enumerate(train_loader()):
img, label = data
label.stop_gradient = True
img.stop_gradient = True
with paddle.amp.auto_cast(True, level='O2'):
out = model(img)
loss = paddle.nn.functional.cross_entropy(
input=out, label=label)
avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32))
if not use_pure_fp16:
avg_loss.backward()
else:
scaler.scale(avg_loss).backward()
if not accumulate_grad:
if not use_pure_fp16:
optimizer.step()
else:
scaler.step(optimizer)
scaler.update()
optimizer.clear_grad()
if accumulate_grad:
if not use_pure_fp16:
optimizer.step()
else:
scaler.step(optimizer)
scaler.update()
optimizer.clear_grad()
if not convert2cpu:
model.get_all_parameters()
else:
model.get_all_parameters(convert2cpu)
return model.parameters()
def test_stage3_offload():
mlp, mlp1, mlp2, mlp3, mlp4, mlp5, mlp6 = MLP(), MLP(), MLP(), MLP(), MLP(
), MLP(), MLP()
state_dict = mlp.state_dict()
mlp1.set_state_dict(state_dict)
mlp2.set_state_dict(state_dict)
mlp3.set_state_dict(state_dict)
mlp4.set_state_dict(state_dict)
mlp5.set_state_dict(state_dict)
mlp6.set_state_dict(state_dict)
# fp32 offload
stage3_params = train_mlp(mlp1, use_pure_fp16=False)
stage3_params_offload = train_mlp(mlp2, use_pure_fp16=False, offload=True)
for i in range(len(stage3_params)):
np.testing.assert_allclose(
stage3_params[i].numpy(),
stage3_params_offload[i].numpy(),
rtol=1e-6,
atol=1e-8)
# fp16 offload
stage3_params = train_mlp(mlp3, use_pure_fp16=True)
stage3_params_offload = train_mlp(mlp4, use_pure_fp16=True, offload=True)
for i in range(len(stage3_params)):
np.testing.assert_allclose(
stage3_params[i].numpy(),
stage3_params_offload[i].numpy(),
rtol=1e-2,
atol=1e-2)
# fp32 accumulate grad offload
stage3_params = train_mlp(mlp5, use_pure_fp16=False, accumulate_grad=True)
stage3_params_offload = train_mlp(
mlp6,
use_pure_fp16=False,
accumulate_grad=True,
offload=True,
convert2cpu=True)
for i in range(len(stage3_params)):
np.testing.assert_allclose(
stage3_params[i].numpy(),
stage3_params_offload[i].numpy(),
rtol=1e-6,
atol=1e-8)
return
if __name__ == '__main__':
test_stage3_offload()
......@@ -23,10 +23,10 @@ from test_parallel_dygraph_dataparallel import TestMultipleGpus
class TestDygraphShardingStage2(TestMultipleGpus):
# check sharding logic as well as the accuracy with single mode
def test_dygraph_sharding_optimizer_stage2(self):
def test_dygraph_sharding_stage2(self):
self.run_mnist_2gpu('dygraph_sharding_stage2.py')
def test_dygraph_sharding_optimizer_stage2_offload(self):
def test_dygraph_sharding_stage2_offload(self):
self.run_mnist_2gpu('dygraph_sharding_stage2_offload.py')
......
......@@ -23,9 +23,12 @@ from test_parallel_dygraph_dataparallel import TestMultipleGpus
class TestDygraphShardingStage3(TestMultipleGpus):
# check sharding logic as well as the accuracy with single mode
def test_dygraph_sharding_optimizer_stage3(self):
def test_dygraph_sharding_stage3(self):
self.run_mnist_2gpu('dygraph_sharding_stage3.py')
def test_dygraph_sharding_stage3_offload(self):
self.run_mnist_2gpu('dygraph_sharding_stage3_offload.py')
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册