未验证 提交 327e5050 编写于 作者: B Baibaifan 提交者: GitHub

Integration sharding stage2 function (#38151)

上级 9e42fe9a
...@@ -16,21 +16,19 @@ ...@@ -16,21 +16,19 @@
#Commit: 8acbec718f3c70a6b9785470bb9e05cd84fc3f8e #Commit: 8acbec718f3c70a6b9785470bb9e05cd84fc3f8e
import copy import copy
import time
import logging import logging
import numpy as np import numpy as np
from math import inf
from itertools import chain from itertools import chain
from functools import reduce from functools import reduce
from collections import OrderedDict from collections import OrderedDict
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle import framework
from paddle.fluid import core from paddle.fluid import core
import paddle.distributed as dist import paddle.distributed as dist
from paddle.optimizer import Optimizer from paddle.optimizer import Optimizer
from paddle.fluid.clip import ClipGradByGlobalNorm from paddle.fluid.clip import ClipGradByGlobalNorm
from paddle.distributed.collective import _get_global_group
from ...utils.internal_storage import ParamStorage from ...utils.internal_storage import ParamStorage
from ...meta_parallel.sharding.sharding_utils import Type, device_guard, ShardingClipGrad from ...meta_parallel.sharding.sharding_utils import Type, device_guard, ShardingClipGrad
...@@ -59,14 +57,14 @@ class ShardingOptimizerStage2(Optimizer): ...@@ -59,14 +57,14 @@ class ShardingOptimizerStage2(Optimizer):
# Feature Notes: # Feature Notes:
# 1. Unified memory for parameters and parameters.grad to InternalStorage. # 1. Unified memory for parameters and parameters.grad to InternalStorage.
# 2. Support the segmentation of optimizer parameters and partial updating of parameters. # 2. Support the segmentation of optimizer parameters and partial updating of parameters.
# 3. Dynamically adjust training parameters and models # 3. Dynamically adjust training parameters and models.
# 4. Support offload function. # 4. Support offload function.
# 5. Support the establishment of independent communication groups. # 5. Support the establishment of independent communication groups.
# 6. Broadcast_fp16 is not supported now. # 6. Broadcast_fp16 is not supported now.
def __init__(self, def __init__(self,
params, params,
optim, optim,
group, group=None,
broadcast_fp16=False, broadcast_fp16=False,
offload=False, offload=False,
device="gpu", device="gpu",
...@@ -78,13 +76,16 @@ class ShardingOptimizerStage2(Optimizer): ...@@ -78,13 +76,16 @@ class ShardingOptimizerStage2(Optimizer):
self._dtype_rank_params = OrderedDict( self._dtype_rank_params = OrderedDict(
) # {dtype:[param1,param2]} device, rank, params ) # {dtype:[param1,param2]} device, rank, params
self._param2rank = {} self._param2rank = {}
self._segment_params = [] self.__segment_params = []
self._rank_buffer_size = {} # {dtype: {rank: numel+alignment}} self._rank_buffer_size = {} # {dtype: {rank: numel+alignment}}
self._param2align = {} # {param.name: align} self._param2align = {} # {param.name: align}
# Default information # Default information
self._optim_defaults = kw self._optim_defaults = kw
self._optim = optim self._optim = optim
self._ori_parameter_list = self._optim._parameter_list
self._ori_param_groups = self._optim._param_groups
assert hasattr(self._optim, "_master_weights" assert hasattr(self._optim, "_master_weights"
), "Must use optimizer with _master_weights attribute" ), "Must use optimizer with _master_weights attribute"
self._local_params = params self._local_params = params
...@@ -94,8 +95,8 @@ class ShardingOptimizerStage2(Optimizer): ...@@ -94,8 +95,8 @@ class ShardingOptimizerStage2(Optimizer):
filter(lambda x: x.trainable and x.dtype == Type.fp16.value, filter(lambda x: x.trainable and x.dtype == Type.fp16.value,
self._local_params))) > 0 self._local_params))) > 0
assert group is not None, "Distributed communication group is must be gived"
self.group = group self.group = group
group = _get_global_group() if group is None else group
self.world_size = group.nranks self.world_size = group.nranks
self.rank = group.rank self.rank = group.rank
...@@ -119,7 +120,7 @@ class ShardingOptimizerStage2(Optimizer): ...@@ -119,7 +120,7 @@ class ShardingOptimizerStage2(Optimizer):
self._master_params = {} self._master_params = {}
# Update optimizer parameters and adjust parameter storage and use according to rank. # Update optimizer parameters and adjust parameter storage and use according to rank.
self.update_opt_status() self._update_opt_status()
def _generate_master_params(self, trainable_params): def _generate_master_params(self, trainable_params):
if self.offload: if self.offload:
...@@ -137,7 +138,7 @@ class ShardingOptimizerStage2(Optimizer): ...@@ -137,7 +138,7 @@ class ShardingOptimizerStage2(Optimizer):
self._optim._master_weights[param.name] = paddle.cast( self._optim._master_weights[param.name] = paddle.cast(
param, Type.fp32.value) param, Type.fp32.value)
def update_opt_status(self): def _update_opt_status(self):
"""Update optimizer status and parameter storage information, and special functions to be developed. """Update optimizer status and parameter storage information, and special functions to be developed.
""" """
# func 1 # func 1
...@@ -147,12 +148,12 @@ class ShardingOptimizerStage2(Optimizer): ...@@ -147,12 +148,12 @@ class ShardingOptimizerStage2(Optimizer):
# Segement helpers # Segement helpers
def segment_params(self): def _segment_params(self):
""" """
Divide all optimizer parameters equally into rank. Divide all optimizer parameters equally into rank.
""" """
if len(self._segment_params) == 0: if len(self.__segment_params) == 0:
self._segment_params, param_lists = [ self.__segment_params, param_lists = [
[] for _ in range(self.world_size) [] for _ in range(self.world_size)
], [[] for _ in range(self.world_size)] ], [[] for _ in range(self.world_size)]
sizes = [0] * self.world_size sizes = [0] * self.world_size
...@@ -165,9 +166,8 @@ class ShardingOptimizerStage2(Optimizer): ...@@ -165,9 +166,8 @@ class ShardingOptimizerStage2(Optimizer):
sizes[rank] += np.prod(param.shape) if param.trainable else 0 sizes[rank] += np.prod(param.shape) if param.trainable else 0
for rank, params in enumerate(param_lists): for rank, params in enumerate(param_lists):
# param_group_rank = copy.copy(params) self.__segment_params[rank].extend(params)
self._segment_params[rank].extend(params) return self.__segment_params
return self._segment_params
@property @property
def local_params(self): def local_params(self):
...@@ -177,7 +177,7 @@ class ShardingOptimizerStage2(Optimizer): ...@@ -177,7 +177,7 @@ class ShardingOptimizerStage2(Optimizer):
def param2rank(self): def param2rank(self):
"""Map the params to the rank which owns them""" """Map the params to the rank which owns them"""
if len(self._param2rank) == 0: if len(self._param2rank) == 0:
for rank, params in enumerate(self.segment_params()): for rank, params in enumerate(self._segment_params()):
for param in params: for param in params:
self._param2rank[param.name] = rank self._param2rank[param.name] = rank
return self._param2rank return self._param2rank
...@@ -271,32 +271,31 @@ class ShardingOptimizerStage2(Optimizer): ...@@ -271,32 +271,31 @@ class ShardingOptimizerStage2(Optimizer):
""" """
if self.offload: if self.offload:
self._optim._parameter_list = [ params_list = list(self._master_params.values())
param for name, param in self._master_params.items()
]
else: else:
# Synchronize optimizer parameters for the current rank # Synchronize optimizer parameters for the current rank
if len(self.dtype_rank_params.keys( params_list = []
)) == 1 and Type.fp32.value in self.dtype_rank_params.keys(): for dtype in self.dtype_rank_params.keys():
self._optim._parameter_list = self.dtype_rank_params[ params_list.extend(self.dtype_rank_params[dtype][self.rank])
Type.fp32.value][self.rank]
elif len(self.dtype_rank_params.keys( params_name_list = list(map(lambda p: p.name, params_list))
)) == 1 and Type.fp16.value in self.dtype_rank_params.keys(): if not isinstance(self._optim._param_groups[0], dict):
self._optim._parameter_list = self.dtype_rank_params[ self._optim._parameter_list = params_list
Type.fp16.value][self.rank] self._optim._param_groups = params_list
else: else:
self._optim._parameter_list = self.dtype_rank_params[ for param_group in self._optim._param_groups:
Type.fp16.value][self.rank] + self.dtype_rank_params[ p_group = []
Type.fp32.value][self.rank] for param in param_group['params']:
if param.name in params_name_list:
p_group.append(params_list[params_name_list.index(
param.name)])
param_group['params'] = p_group
# Run the optimizer of the current rank step # Run the optimizer of the current rank step
if self.offload: if self.offload:
with device_guard(self.rank, self.offload_device): with device_guard(device=self.offload_device):
self._optim.step() self._optim.step()
for param in self._optim._parameter_list:
self._master_params[param.name].set_value(param)
dev_id = 0 if paddle.get_device() == "cpu" else int( dev_id = 0 if paddle.get_device() == "cpu" else int(
paddle.get_device().split(":")[1]) paddle.get_device().split(":")[1])
...@@ -312,10 +311,11 @@ class ShardingOptimizerStage2(Optimizer): ...@@ -312,10 +311,11 @@ class ShardingOptimizerStage2(Optimizer):
self._broadcast_params() self._broadcast_params()
# Return full parameters to optimizer parameters # Return full parameters to optimizer parameters
self._optim._parameter_list = self._local_params self._optim._parameter_list = self._ori_parameter_list
self._optim._param_groups = self._ori_param_groups
def clear_cache(self): def _clear_cache(self):
self._segment_params.clear() self.__segment_params.clear()
self._dtype_rank_params.clear() self._dtype_rank_params.clear()
self._param2rank.clear() self._param2rank.clear()
......
...@@ -24,10 +24,12 @@ import numpy as np ...@@ -24,10 +24,12 @@ import numpy as np
from itertools import chain from itertools import chain
from functools import reduce from functools import reduce
from collections import deque from collections import deque
from types import MethodType
import paddle import paddle
from paddle import nn from paddle import nn
import paddle.distributed as dist import paddle.distributed as dist
from paddle.distributed.collective import _get_global_group
from ...utils.internal_storage import GradStorage from ...utils.internal_storage import GradStorage
from ...meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2 import ShardingOptimizerStage2 from ...meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2 import ShardingOptimizerStage2
...@@ -57,7 +59,7 @@ class ShardingStage2(nn.Layer): ...@@ -57,7 +59,7 @@ class ShardingStage2(nn.Layer):
self, self,
layer, layer,
sharding_optimizer, sharding_optimizer,
group, group=None,
sync_buffers=False, sync_buffers=False,
pertrain_sync_models=True, pertrain_sync_models=True,
buffer_max_size=2**23, #8MB buffer_max_size=2**23, #8MB
...@@ -83,13 +85,12 @@ class ShardingStage2(nn.Layer): ...@@ -83,13 +85,12 @@ class ShardingStage2(nn.Layer):
self._accumulate_grads = accumulate_grads self._accumulate_grads = accumulate_grads
# Communication related attributes # Communication related attributes
assert group is not None, "Distributed communication group is must be gived"
self._group = group self._group = group
self._world_size_scaling = 1.0 / self._group.nranks group = _get_global_group() if group is None else group
assert self._group.nranks > 1, "Training must be distributed, ranks must be greater than 1" self._world_size_scaling = 1.0 / group.nranks
self._rank = self._group.rank assert group.nranks > 1, "Training must be distributed, ranks must be greater than 1"
self._rank = group.rank
self._global_root_rank = 0 # picking rank 0 as the reference self._global_root_rank = 0 # picking rank 0 as the reference
self._global_ranks = self._group.ranks
self._default_device = device self._default_device = device
# Global statistical parameters # Global statistical parameters
...@@ -112,8 +113,8 @@ class ShardingStage2(nn.Layer): ...@@ -112,8 +113,8 @@ class ShardingStage2(nn.Layer):
self._has_grad_storage = [] self._has_grad_storage = []
self._grad_storage_list = [] self._grad_storage_list = []
# offload # Offload
# TODO(haohongxiang): Now it's not supported for multi-optimizers using Offload strategy # TODO(haohongxiang): Now it's not be supported for multi-optimizers using Offload strategy
self._offload_optims = list( self._offload_optims = list(
filter(lambda optim: optim.offload, self._sharding_optimizers)) filter(lambda optim: optim.offload, self._sharding_optimizers))
if len(self._offload_optims) > 0: if len(self._offload_optims) > 0:
...@@ -134,6 +135,11 @@ class ShardingStage2(nn.Layer): ...@@ -134,6 +135,11 @@ class ShardingStage2(nn.Layer):
# Set tasks flow # Set tasks flow
self._tasks_flow = deque() self._tasks_flow = deque()
# Define optimizer step and clear_grad
if self._accumulate_grads:
self._redefine_opt_step()
self._redefine_opt_clear()
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
""" """
A wrapper for Sharding Stage2 layer. A wrapper for Sharding Stage2 layer.
...@@ -161,7 +167,7 @@ class ShardingStage2(nn.Layer): ...@@ -161,7 +167,7 @@ class ShardingStage2(nn.Layer):
return fw return fw
def clear_gradients(self): def _clear_gradients(self):
""" """
Set zero to the gradient of the optimizer's current rank trainable parameters. Set zero to the gradient of the optimizer's current rank trainable parameters.
""" """
...@@ -176,7 +182,7 @@ class ShardingStage2(nn.Layer): ...@@ -176,7 +182,7 @@ class ShardingStage2(nn.Layer):
if param.name in self._param_grads and param.grad is not None: if param.name in self._param_grads and param.grad is not None:
param.clear_gradient() param.clear_gradient()
def grad_scale(self): def _grad_scale(self):
""" """
Before the gradient accumulation, scale the gradient. Before the gradient accumulation, scale the gradient.
""" """
...@@ -287,9 +293,6 @@ class ShardingStage2(nn.Layer): ...@@ -287,9 +293,6 @@ class ShardingStage2(nn.Layer):
for grad_storage in self._grad_storage_list: for grad_storage in self._grad_storage_list:
grad_storage.reset_checked_in() grad_storage.reset_checked_in()
if not self._accumulate_grads:
self._grads_flipped = False
def _get_reduce_fn(self, index, param, dst_rank): def _get_reduce_fn(self, index, param, dst_rank):
""" """
There are two ways to reduce gradient. There are two ways to reduce gradient.
...@@ -412,7 +415,6 @@ class ShardingStage2(nn.Layer): ...@@ -412,7 +415,6 @@ class ShardingStage2(nn.Layer):
self._bw_hooks.pop().remove() self._bw_hooks.pop().remove()
# Go through the parameters, attach the hook # Go through the parameters, attach the hook
self._grad_accs = []
if not self.training: if not self.training:
return return
...@@ -500,9 +502,6 @@ class ShardingStage2(nn.Layer): ...@@ -500,9 +502,6 @@ class ShardingStage2(nn.Layer):
# Whether parameters trainability changed # Whether parameters trainability changed
trainability_changed = trainable_mask != self._trainable_mask trainability_changed = trainable_mask != self._trainable_mask
# The whole model is not trainable but we still have grad hooks
trainability_changed |= not self.training and len(self._bw_hooks) > 0
if trainability_changed: if trainability_changed:
logging.warning( logging.warning(
"Trainable params changed, because of eval/train mode or parameter freezing/unfreeze." "Trainable params changed, because of eval/train mode or parameter freezing/unfreeze."
...@@ -548,3 +547,25 @@ class ShardingStage2(nn.Layer): ...@@ -548,3 +547,25 @@ class ShardingStage2(nn.Layer):
format(rank_buffer_size[Type.fp32.value] / 2**18, model_size / 2 format(rank_buffer_size[Type.fp32.value] / 2**18, model_size / 2
**18)) **18))
return rank_buffer_size return rank_buffer_size
def _redefine_opt_step(self):
if not self._accumulate_grads:
return
grad_func = self._grad_scale
for opt in self._sharding_optimizers:
opt_step = opt.step
def _opt_step(self):
grad_func()
opt_step()
opt.step = MethodType(_opt_step, opt)
def _redefine_opt_clear(self):
clear_func = self._clear_gradients
def _opt_clear(self):
clear_func()
for opt in self._sharding_optimizers:
opt.clear_grad = MethodType(_opt_clear, opt)
...@@ -131,7 +131,7 @@ class ShardingClipGrad: ...@@ -131,7 +131,7 @@ class ShardingClipGrad:
@contextlib.contextmanager @contextlib.contextmanager
def device_guard(dev_id, device="cpu"): def device_guard(dev_id=0, device="cpu"):
origin_device = paddle.device.get_device() origin_device = paddle.device.get_device()
if device == "cpu": if device == "cpu":
paddle.set_device(device) paddle.set_device(device)
......
...@@ -125,7 +125,7 @@ def train_mlp(): ...@@ -125,7 +125,7 @@ def train_mlp():
oss_optimizer.step() oss_optimizer.step()
# oss_optimizer clear cache # oss_optimizer clear cache
oss_optimizer.clear_cache() oss_optimizer._clear_cache()
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -30,7 +30,7 @@ from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import Shar ...@@ -30,7 +30,7 @@ from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import Shar
seed = 2021 seed = 2021
epoch = 2 epoch = 2
batch_size = 32 batch_size = 32
linear_size = 10000 linear_size = 1000
strategy = fleet.DistributedStrategy() strategy = fleet.DistributedStrategy()
strategy.hybrid_configs = { strategy.hybrid_configs = {
...@@ -46,7 +46,7 @@ paddle.seed(seed) ...@@ -46,7 +46,7 @@ paddle.seed(seed)
class MLP(fluid.Layer): class MLP(fluid.Layer):
def __init__(self, linear_size=10000, param_attr=None, bias_attr=None): def __init__(self, linear_size=1000, param_attr=None, bias_attr=None):
super(MLP, self).__init__() super(MLP, self).__init__()
self._linear1 = Linear(linear_size, linear_size) self._linear1 = Linear(linear_size, linear_size)
...@@ -60,7 +60,7 @@ class MLP(fluid.Layer): ...@@ -60,7 +60,7 @@ class MLP(fluid.Layer):
return y return y
def reader_decorator(linear_size=10000): def reader_decorator(linear_size=1000):
def __reader__(): def __reader__():
for _ in range(100): for _ in range(100):
img = np.random.rand(linear_size).astype('float32') img = np.random.rand(linear_size).astype('float32')
...@@ -70,10 +70,12 @@ def reader_decorator(linear_size=10000): ...@@ -70,10 +70,12 @@ def reader_decorator(linear_size=10000):
return __reader__ return __reader__
def optimizer_setting(model, use_pure_fp16): def optimizer_setting(model, use_pure_fp16, opt_group=False):
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0) clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
optimizer = paddle.optimizer.AdamW( optimizer = paddle.optimizer.AdamW(
parameters=model.parameters(), parameters=[{
"params": model.parameters()
}] if opt_group else model.parameters(),
learning_rate=0.001, learning_rate=0.001,
weight_decay=0.00001, weight_decay=0.00001,
grad_clip=clip, grad_clip=clip,
...@@ -85,27 +87,32 @@ def optimizer_setting(model, use_pure_fp16): ...@@ -85,27 +87,32 @@ def optimizer_setting(model, use_pure_fp16):
def train_mlp(model, def train_mlp(model,
sharding_stage, sharding_stage,
use_pure_fp16=False, use_pure_fp16=False,
all_test=False, accumulate_grad=False,
accumulate_grad=False): opt_group=False):
if sharding_stage == "dp": if sharding_stage == "dp":
hcg = fleet.get_hybrid_communicate_group() hcg = fleet.get_hybrid_communicate_group()
group = hcg.get_check_parallel_group() group = hcg.get_check_parallel_group()
else: else:
group = paddle.distributed.new_group([0, 1]) group = paddle.distributed.new_group([0, 1])
optimizer = optimizer_setting(model=model, use_pure_fp16=use_pure_fp16) if opt_group:
optimizer = optimizer_setting(
if use_pure_fp16: model=model, use_pure_fp16=use_pure_fp16, opt_group=opt_group)
model = paddle.amp.decorate( else:
models=model, level='O2', save_dtype='float32') optimizer = optimizer_setting(model=model, use_pure_fp16=use_pure_fp16)
if sharding_stage == 2: if sharding_stage == 2:
optimizer = ShardingOptimizerStage2( optimizer = ShardingOptimizerStage2(
params=model.parameters(), optim=optimizer, group=group) params=model.parameters(), optim=optimizer, group=group)
if all_test: if accumulate_grad:
model = ShardingStage2( model = ShardingStage2(
model, optimizer, group=group, accumulate_grads=accumulate_grad) model,
optimizer,
group=group,
buffer_max_size=2**21,
accumulate_grads=accumulate_grad)
else: else:
model = ShardingStage2(model, optimizer, group=group) model = ShardingStage2(
model, optimizer, group=group, buffer_max_size=2**21)
else: else:
optimizer = fleet.distributed_optimizer(optimizer) optimizer = fleet.distributed_optimizer(optimizer)
model = fleet.distributed_model(model) model = fleet.distributed_model(model)
...@@ -132,29 +139,16 @@ def train_mlp(model, ...@@ -132,29 +139,16 @@ def train_mlp(model,
label.stop_gradient = True label.stop_gradient = True
img.stop_gradient = True img.stop_gradient = True
with paddle.amp.auto_cast(enable=use_pure_fp16, level='O2'): out = model(img)
out = model(img) loss = paddle.nn.functional.cross_entropy(input=out, label=label)
loss = paddle.nn.functional.cross_entropy(
input=out, label=label)
avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32)) avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32))
avg_loss.backward() avg_loss.backward()
if accumulate_grad and batch_id == 2: optimizer.step()
model.grad_scale() optimizer.clear_grad()
optimizer.step()
model.clear_gradients()
return model.parameters()
if not accumulate_grad:
optimizer.step()
if sharding_stage == 2:
model.clear_gradients()
else:
optimizer.clear_grad()
if all_test and batch_id == 2: if accumulate_grad and batch_id == 2:
return model.parameters() return model.parameters()
return model.parameters() return model.parameters()
...@@ -171,22 +165,19 @@ def test_dp_stage2(): ...@@ -171,22 +165,19 @@ def test_dp_stage2():
mlp2.set_state_dict(state_dict) mlp2.set_state_dict(state_dict)
mlp3.set_state_dict(state_dict) mlp3.set_state_dict(state_dict)
mlp4.set_state_dict(state_dict) mlp4.set_state_dict(state_dict)
dp_params = train_mlp(mlp1, sharding_stage="dp", use_pure_fp16=False) dp_params = train_mlp(
stage2_params = train_mlp(mlp2, sharding_stage=2, use_pure_fp16=False) mlp1, sharding_stage="dp", use_pure_fp16=False, opt_group=True)
stage2_params = train_mlp(
mlp2, sharding_stage=2, use_pure_fp16=False, opt_group=True)
for i in range(len(dp_params)): for i in range(len(dp_params)):
for j in range(len(stage2_params)): for j in range(len(stage2_params)):
if dp_params[i].name == stage2_params[j].name: if dp_params[i].name == stage2_params[j].name:
np.testing.assert_allclose( np.testing.assert_allclose(
dp_params[i].numpy(), stage2_params[j].numpy(), rtol=1e-6) dp_params[i].numpy(), stage2_params[j].numpy(), rtol=1e-6)
stage2_params = train_mlp( stage2_params = train_mlp(mlp3, sharding_stage=2)
mlp3, sharding_stage=2, use_pure_fp16=True, all_test=True)
stage2_accumulate_grad = train_mlp( stage2_accumulate_grad = train_mlp(
mlp4, mlp4, sharding_stage=2, accumulate_grad=True)
sharding_stage=2,
use_pure_fp16=True,
all_test=True,
accumulate_grad=True)
for i in range(len(stage2_params)): for i in range(len(stage2_params)):
for j in range(len(stage2_accumulate_grad)): for j in range(len(stage2_accumulate_grad)):
if stage2_params[i].name == stage2_accumulate_grad[j].name: if stage2_params[i].name == stage2_accumulate_grad[j].name:
......
...@@ -33,7 +33,7 @@ from dygraph_sharding_stage2 import MLP, reader_decorator, optimizer_setting ...@@ -33,7 +33,7 @@ from dygraph_sharding_stage2 import MLP, reader_decorator, optimizer_setting
seed = 2021 seed = 2021
epoch = 2 epoch = 2
batch_size = 32 batch_size = 32
linear_size = 8000 linear_size = 1000
np.random.seed(seed) np.random.seed(seed)
paddle.seed(seed) paddle.seed(seed)
...@@ -52,7 +52,12 @@ def train_mlp(model, offload=False): ...@@ -52,7 +52,12 @@ def train_mlp(model, offload=False):
optim=optimizer, optim=optimizer,
group=group, group=group,
offload=offload) offload=offload)
model = ShardingStage2(model, optimizer, group=group, accumulate_grads=True) model = ShardingStage2(
model,
optimizer,
group=group,
buffer_max_size=2**21,
accumulate_grads=True)
train_reader = paddle.batch( train_reader = paddle.batch(
reader_decorator(linear_size), batch_size=batch_size, drop_last=True) reader_decorator(linear_size), batch_size=batch_size, drop_last=True)
...@@ -81,10 +86,9 @@ def train_mlp(model, offload=False): ...@@ -81,10 +86,9 @@ def train_mlp(model, offload=False):
avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32)) avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32))
scaler.scale(avg_loss).backward() scaler.scale(avg_loss).backward()
model.grad_scale()
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
model.clear_gradients() optimizer.clear_grad()
for dtype in optimizer.param_storages: for dtype in optimizer.param_storages:
for dst_rank, param_storage in optimizer.param_storages[dtype].items(): for dst_rank, param_storage in optimizer.param_storages[dtype].items():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册