未验证 提交 b7a05057 编写于 作者: P pangengzheng 提交者: GitHub

support sharding parallel (#54634)

* support sharding parallel

* fix name

* fix

* update

* test amp for sharding

---------

Co-authored-by: pangengzheng <pangengzheng.baidu.com>
上级 ab73b8c6
...@@ -3096,7 +3096,7 @@ void StackInferMeta(const std::vector<const MetaTensor*>& x, ...@@ -3096,7 +3096,7 @@ void StackInferMeta(const std::vector<const MetaTensor*>& x,
rank, rank,
axis)); axis));
if (axis < 0) axis += (rank + 1); if (axis < 0) axis += (rank + 1);
auto vec = phi::vectorize<int>(out_dim); auto vec = phi::vectorize<int64_t>(out_dim);
vec.insert(vec.begin() + axis, input_dims.size()); vec.insert(vec.begin() + axis, input_dims.size());
out->set_dims(phi::make_ddim(vec)); out->set_dims(phi::make_ddim(vec));
out->set_dtype(x.at(0)->dtype()); out->set_dtype(x.at(0)->dtype());
......
...@@ -477,7 +477,7 @@ def _set_multi_precision(optimizer, multi_precision): ...@@ -477,7 +477,7 @@ def _set_multi_precision(optimizer, multi_precision):
) )
optimizer = ( optimizer = (
optimizer._inner_optimizer optimizer._inner_opt
if isinstance(optimizer, DygraphShardingOptimizer) if isinstance(optimizer, DygraphShardingOptimizer)
else optimizer else optimizer
) )
......
...@@ -187,6 +187,11 @@ class HybridCommunicateGroup: ...@@ -187,6 +187,11 @@ class HybridCommunicateGroup:
"data" "data"
) )
(
self.sharding_check_group,
self.sharding_check_comm_group,
) = self._set_check_group("sharding")
# create p2p group # create p2p group
self.is_first_stage = self.stage_id == 0 self.is_first_stage = self.stage_id == 0
self.is_last_stage = self.stage_id == (self._pp_degree - 1) self.is_last_stage = self.stage_id == (self._pp_degree - 1)
...@@ -428,8 +433,11 @@ class HybridCommunicateGroup: ...@@ -428,8 +433,11 @@ class HybridCommunicateGroup:
return self._sharding_comm_group.ranks[0] return self._sharding_comm_group.ranks[0]
# check parallel group # check parallel group
def get_check_parallel_group(self): def get_check_parallel_group(self, sharding=False):
return self._check_comm_group if sharding:
return self.sharding_check_comm_group
else:
return self._check_comm_group
def get_rank_from_stage(self, stage_id, **kwargs): def get_rank_from_stage(self, stage_id, **kwargs):
return self._topo.get_rank_from_stage( return self._topo.get_rank_from_stage(
......
...@@ -43,53 +43,49 @@ class DygraphShardingOptimizer: ...@@ -43,53 +43,49 @@ class DygraphShardingOptimizer:
# 3. dynamic trainable params, which is the case bewteen pretraining and finetuning # 3. dynamic trainable params, which is the case bewteen pretraining and finetuning
# 4. option to choose fuse comm (more GPU MEM need) or un-fuse comm # 4. option to choose fuse comm (more GPU MEM need) or un-fuse comm
def __init__( def __init__(self, optimizer, hcg):
self, # TODO(pangengzheng): support param_groups
hcg, if isinstance(optimizer._parameter_list[0], dict):
user_defined_strategy,
params,
inner_optimizer_class,
**inner_optimizer_kargs
):
if not isinstance(params, list):
raise TypeError( raise TypeError(
"`parameters` argument given to the DygraphShardingOptimizer should be " "Do not support param_groups now, please set optimizer._parameter_list as a list of Parameter"
"an iterable of paddle Tensors, but got argument type is `{}`.".format(
type(params)
)
) )
self._parameter_list = params if not hasattr(optimizer, '_apply_optimize') or not callable(
self._reference_is_trainable_params = list( optimizer._apply_optimize
map(_is_trainable, self._parameter_list) ):
) raise ValueError(
"the optimzier object should have _apply_optimize function"
self._inner_optimizer_class = inner_optimizer_class )
self._inner_optimizer_kargs = inner_optimizer_kargs # the self._parameter_list holds the whole model paramters
self._parameter_list = optimizer._parameter_list
# sharding parallel information self._inner_opt = optimizer
# TODO better way to get the hcg & user_defined_strategy
self._hcg = hcg self._hcg = hcg
self._user_defined_strategy = user_defined_strategy
self._sharding_world_size = self._hcg.get_sharding_parallel_world_size() self._sharding_world_size = self._hcg.get_sharding_parallel_world_size()
self._sharding_rank = self._hcg.get_sharding_parallel_rank() self._sharding_rank = self._hcg.get_sharding_parallel_rank()
# logic partitioning self._rank2params = self._partition_parameters()
self._build_sharding_mapping() self._param2rank = self._map_param_to_rank()
# actually create opt ops self._set_inner_opt_attr(
self._buid_inner_optimizer() '_parameter_list', self._rank2params[self._sharding_rank]
)
self._set_inner_opt_attr(
'_param_groups', self._rank2params[self._sharding_rank]
)
def clear_grad(self): def clear_grad(self, set_to_zero=True):
""" """
should clear grad for all parameters in model should clear grad for all parameters in model
""" """
for p in self._parameter_list: for p in self._parameter_list:
if not p.stop_gradient: if hasattr(p, "main_grad") and p.main_grad is not None:
p.clear_gradient() assert p._grad_ivar() is None
if set_to_zero:
def _build_sharding_mapping(self): p.main_grad.zero_()
self._rank2params = self._partition_parameters() else:
self._param2rank = self._map_param_to_rank() p.main_grad._clear()
p.main_grad = None
elif not hasattr(p, "main_grad"):
p.clear_gradient(set_to_zero)
def _partition_parameters(self): def _partition_parameters(self):
""" """
...@@ -132,14 +128,35 @@ class DygraphShardingOptimizer: ...@@ -132,14 +128,35 @@ class DygraphShardingOptimizer:
mapping[param.name] = rank mapping[param.name] = rank
return mapping return mapping
def _buid_inner_optimizer(self): def reduce_gradients(self, parameter_list, hcg):
# we rely on the inner opt to determine whether a parameter is stop_gradient or not: # TODO merge grad / nrank with dp
# create moment logger.debug("sharding start gradients sync")
# update related ops: clip, regular, opt with framework.no_grad():
self._inner_optimizer = self._inner_optimizer_class( sharding_nrank = hcg.get_sharding_parallel_group().nranks
parameters=self._rank2params[self._sharding_rank], for param in parameter_list:
**self._inner_optimizer_kargs g_var = None
) if param.trainable and (param._grad_ivar() is not None):
g_var = param._grad_ivar()
if param.trainable and hasattr(param, "main_grad"):
assert (
param._grad_ivar() is None
), "param.grad should be None when using main_grad"
g_var = param.main_grad
if g_var is not None:
g_var.scale_(1.0 / sharding_nrank)
param_rank = self._param2rank[param.name]
paddle.distributed.all_reduce(
g_var,
group=hcg.get_sharding_parallel_group(),
sync_op=True,
)
# TODO(pangengzheng): change to reduce operation when there is no diff in calculating global norm values in HybridParallelClipGrad compared to dp.
# paddle.distributed.reduce(
# g_var,
# dst=hcg.get_sharding_parallel_group().ranks[param_rank],
# group=hcg.get_sharding_parallel_group(),
# sync_op=True,
# )
def _sharding_sync_parameters(self): def _sharding_sync_parameters(self):
""" """
...@@ -180,7 +197,7 @@ class DygraphShardingOptimizer: ...@@ -180,7 +197,7 @@ class DygraphShardingOptimizer:
self._rank2params[self._sharding_rank], self._rank2params[self._sharding_rank],
) )
) )
result = self._inner_optimizer.minimize( result = self._inner_opt.minimize(
loss, startup_program, parameters, no_grad_set loss, startup_program, parameters, no_grad_set
) )
...@@ -192,19 +209,92 @@ class DygraphShardingOptimizer: ...@@ -192,19 +209,92 @@ class DygraphShardingOptimizer:
def step(self): def step(self):
# TODO Check whether the model trainable param changed and update state accordingly # TODO Check whether the model trainable param changed and update state accordingly
# actually updating # hack to grad_clip all parameters,
self._inner_optimizer.step() # otherwise the self._inner_opt will only grad_clip the self._rank2params[self._sharding_rank] params
# TODO(pangengzheng): remove the hacked grad_clip codes here when there is no diff in calculating global norm values in HybridParallelClipGrad compared to dp.
origin_clip = self._inner_opt._grad_clip
if not isinstance(self._parameter_list[0], dict):
params_grads = []
for param in self._parameter_list:
if (
hasattr(param, "regularizer")
and param.regularizer is not None
):
raise ValueError(
"param {} should not has the regularizer attribute".format(
param.name
)
)
if param.stop_gradient:
continue
grad_var = param._grad_ivar()
if hasattr(param, "main_grad") and param.main_grad is not None:
grad_var = param.main_grad
params_grads.append((param, grad_var))
if hasattr(self._inner_opt._grad_clip, 'not_sharding_stage1'):
self._inner_opt._grad_clip.not_sharding_stage1 = False
params_grads = self._inner_opt._grad_clip(params_grads)
# set inner_opt._grad_clip None to avoid repeatedly grad_clip gradients inside inner_opt._apply_optimize
self._set_inner_opt_attr('_grad_clip', None)
update_param_names = [
p.name for p in self._rank2params[self._sharding_rank]
]
update_params_grads = [
(p, g) for p, g in params_grads if p.name in update_param_names
]
self._apply_optimize(
loss=None,
startup_program=None,
params_grads=update_params_grads,
)
# restore the grad clip
self._set_inner_opt_attr('_grad_clip', origin_clip)
# sync parameters across sharding ranks # sync parameters across sharding ranks
self._sharding_sync_parameters() self._sharding_sync_parameters()
# TODO is it a good way to make _grad_clip a property @framework.dygraph_only
@property def set_state_dict(self, state_dict):
def _grad_clip(self): inner_state = {}
assert ( parameters = self._rank2params[self._sharding_rank]
self._inner_optimizer is not None
), "inner opt of sharding is not initiliazed." if "LR_Scheduler" in state_dict:
return self._inner_optimizer._grad_clip inner_state["LR_Scheduler"] = state_dict.pop("LR_Scheduler")
if "master_weights" in state_dict:
master = state_dict.pop("master_weights")
inner_state["master_weights"] = {}
for p in parameters:
for k, v in master.items():
if p.name == k:
v.name = self._inner_opt._gen_master_weight_var_name(p)
inner_state["master_weights"][k] = v
for p in parameters:
for k, v in state_dict.items():
if p.name in k:
inner_state[k] = v
self._inner_opt.set_state_dict(inner_state)
def _set_inner_opt_attr(self, attr_name, value):
inner_opt = self._inner_opt
inner_opt_name = '_inner_opt'
if not isinstance(attr_name, str):
raise TypeError(
"attr_name should be str type, but is {}".format(
type(attr_name)
)
)
while hasattr(inner_opt, attr_name):
setattr(inner_opt, attr_name, value)
if (
hasattr(inner_opt, inner_opt_name)
and getattr(inner_opt, inner_opt_name, None) is not None
):
inner_opt = getattr(inner_opt, inner_opt_name, None)
else:
break
def __getattr__(self, item): def __getattr__(self, item):
return getattr(self._inner_optimizer, item) return getattr(self._inner_opt, item)
...@@ -18,13 +18,19 @@ import paddle ...@@ -18,13 +18,19 @@ import paddle
from paddle import framework from paddle import framework
from paddle.autograd import no_grad from paddle.autograd import no_grad
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
DygraphShardingOptimizer,
)
from paddle.distributed.fleet.utils.hybrid_parallel_util import (
obtain_optimizer_parameters_list,
)
from paddle.framework import core from paddle.framework import core
from paddle.nn import ClipGradByGlobalNorm, clip from paddle.nn import ClipGradByGlobalNorm, clip
from ...base.topology import ParallelMode from ...base.topology import ParallelMode
from ...utils.hybrid_parallel_util import ( from ...utils.hybrid_parallel_util import (
fused_allreduce_gradients, fused_allreduce_gradients,
sharding_reduce_gradients, unwrap_optimizer,
) )
from ...utils.log_util import logger from ...utils.log_util import logger
from ...utils.mix_precision_utils import MixPrecisionOptimizer from ...utils.mix_precision_utils import MixPrecisionOptimizer
...@@ -32,24 +38,11 @@ from ...utils.mix_precision_utils import MixPrecisionOptimizer ...@@ -32,24 +38,11 @@ from ...utils.mix_precision_utils import MixPrecisionOptimizer
__all__ = [] __all__ = []
def _obtain_optimizer_parameters_list(optimizer):
if getattr(optimizer, '_param_groups', None) and isinstance(
optimizer._param_groups[0], dict
):
parameters_list = []
for group in optimizer._param_groups:
for param in group['params']:
parameters_list.append(param)
else:
parameters_list = list(optimizer._parameter_list)
return parameters_list
class HybridParallelClipGrad: class HybridParallelClipGrad:
def __init__(self, clip, hcg): def __init__(self, clip, hcg):
self._clip = clip self._clip = clip
self._hcg = hcg self._hcg = hcg
self.not_sharding_stage1 = True
@no_grad() @no_grad()
def _dygraph_clip(self, params_grads): def _dygraph_clip(self, params_grads):
...@@ -166,8 +159,15 @@ class HybridParallelClipGrad: ...@@ -166,8 +159,15 @@ class HybridParallelClipGrad:
# add all reduce to get global norm of distributed params_and_grads # add all reduce to get global norm of distributed params_and_grads
if self._hcg.get_model_parallel_world_size() > 1: if self._hcg.get_model_parallel_world_size() > 1:
sharding_flag = False
if (
self._hcg.get_sharding_parallel_world_size() > 1
and self._hcg.get_data_parallel_world_size() == 1
):
sharding_flag = True
paddle.distributed.all_reduce( paddle.distributed.all_reduce(
global_norm_var_dist, group=self._hcg.get_check_parallel_group() global_norm_var_dist,
group=self._hcg.get_check_parallel_group(sharding_flag),
) )
# add all reduce to get global norm of non-distributed params_and_grads in groups of pp # add all reduce to get global norm of non-distributed params_and_grads in groups of pp
...@@ -179,7 +179,11 @@ class HybridParallelClipGrad: ...@@ -179,7 +179,11 @@ class HybridParallelClipGrad:
# In Sharding mode, param and grad is mapping different rank in optimizer. # In Sharding mode, param and grad is mapping different rank in optimizer.
# ClipGradByGlobalNorm need allreduce to get globol norm # ClipGradByGlobalNorm need allreduce to get globol norm
if self._hcg.get_sharding_parallel_world_size() > 1: # TODO(pangengzheng): remove the self.not_sharding_stage1 flag when there is no diff in calculating global norm values in HybridParallelClipGrad compared to dp.
if (
self._hcg.get_sharding_parallel_world_size() > 1
and self.not_sharding_stage1
):
paddle.distributed.all_reduce( paddle.distributed.all_reduce(
global_norm_var_not_dist, global_norm_var_not_dist,
group=self._hcg.get_sharding_parallel_group(), group=self._hcg.get_sharding_parallel_group(),
...@@ -238,6 +242,10 @@ class HybridParallelClipGrad: ...@@ -238,6 +242,10 @@ class HybridParallelClipGrad:
class HybridParallelOptimizer: class HybridParallelOptimizer:
# adapter wrapper for optimizer # adapter wrapper for optimizer
def __init__(self, optimizer, hcg, strategy): def __init__(self, optimizer, hcg, strategy):
# Note: Only sharding stage 1 is considered in HybridParallelOptimizer.
# The sharding stage2 and stage3 optimizers are invoked in other api.
if hcg.get_sharding_parallel_world_size() > 1:
optimizer = DygraphShardingOptimizer(optimizer, hcg)
self._inner_opt = optimizer self._inner_opt = optimizer
self._strategy = strategy self._strategy = strategy
self._hcg = hcg self._hcg = hcg
...@@ -263,15 +271,11 @@ class HybridParallelOptimizer: ...@@ -263,15 +271,11 @@ class HybridParallelOptimizer:
"or Sharding, the grad clip of original optimizer will be changed." "or Sharding, the grad clip of original optimizer will be changed."
) )
inner_opt = ( inner_opt = unwrap_optimizer(
self._inner_opt._inner_optimizer self._inner_opt,
if self._sharding_enable (MixPrecisionOptimizer, DygraphShardingOptimizer),
else self._inner_opt
) )
if isinstance(inner_opt, MixPrecisionOptimizer):
inner_opt = inner_opt._inner_opt
if ( if (
inner_opt._parameter_list inner_opt._parameter_list
and not isinstance(inner_opt._parameter_list[0], dict) and not isinstance(inner_opt._parameter_list[0], dict)
...@@ -415,9 +419,10 @@ class HybridParallelOptimizer: ...@@ -415,9 +419,10 @@ class HybridParallelOptimizer:
@no_grad() @no_grad()
@framework.dygraph_only @framework.dygraph_only
def step(self): def step(self):
parameters_list = _obtain_optimizer_parameters_list(self._inner_opt) parameters_list = obtain_optimizer_parameters_list(self._inner_opt)
if self._sharding_enable: if self._sharding_enable:
sharding_reduce_gradients(list(parameters_list), self._hcg) assert isinstance(self._inner_opt, DygraphShardingOptimizer)
self._inner_opt.reduce_gradients(list(parameters_list), self._hcg)
if self._dp_enable: if self._dp_enable:
fused_allreduce_gradients(list(parameters_list), self._hcg) fused_allreduce_gradients(list(parameters_list), self._hcg)
...@@ -433,12 +438,13 @@ class HybridParallelOptimizer: ...@@ -433,12 +438,13 @@ class HybridParallelOptimizer:
parameter_list = ( parameter_list = (
parameters parameters
if parameters if parameters
else _obtain_optimizer_parameters_list(self._inner_opt) else obtain_optimizer_parameters_list(self._inner_opt)
) )
# Here sharding should use global parameter list # Here sharding should use global parameter list
if self._sharding_enable: if self._sharding_enable:
sharding_reduce_gradients(list(parameter_list), self._hcg) assert isinstance(self._inner_opt, DygraphShardingOptimizer)
self._inner_opt.reduce_gradients(list(parameter_list), self._hcg)
if self._dp_enable: if self._dp_enable:
fused_allreduce_gradients(list(parameter_list), self._hcg) fused_allreduce_gradients(list(parameter_list), self._hcg)
......
...@@ -246,17 +246,21 @@ class FusedCommBuffer: ...@@ -246,17 +246,21 @@ class FusedCommBuffer:
def _comm_grads(self): def _comm_grads(self):
assert self._all_params_checked_in assert self._all_params_checked_in
if self._act == HOOK_ACTION.ALL_REDUCE: # Note: after sharding change to reduce operation here also need to be updated
task = paddle.distributed.all_reduce( # if self._act == HOOK_ACTION.ALL_REDUCE:
self.grad_storage, group=self._comm_group, sync_op=False # task = paddle.distributed.all_reduce(
) # self.grad_storage, group=self._comm_group, sync_op=False
elif self._act == HOOK_ACTION.REDUCE: # )
task = paddle.distributed.reduce( # elif self._act == HOOK_ACTION.REDUCE:
self.grad_storage, # task = paddle.distributed.reduce(
dst=self._dst, # self.grad_storage,
group=self._comm_group, # dst=self._dst,
sync_op=False, # group=self._comm_group,
) # sync_op=False,
# )
task = paddle.distributed.all_reduce(
self.grad_storage, group=self._comm_group, sync_op=False
)
self._task = task self._task = task
@imperative_base.no_grad @imperative_base.no_grad
......
...@@ -29,6 +29,20 @@ from .log_util import logger ...@@ -29,6 +29,20 @@ from .log_util import logger
__all__ = [] __all__ = []
def obtain_optimizer_parameters_list(optimizer):
if getattr(optimizer, '_param_groups', None) and isinstance(
optimizer._param_groups[0], dict
):
parameters_list = []
for group in optimizer._param_groups:
for param in group['params']:
parameters_list.append(param)
else:
parameters_list = list(optimizer._parameter_list)
return parameters_list
def _apply_collective_grads(parameters, comm_group, bucket_size, scale=None): def _apply_collective_grads(parameters, comm_group, bucket_size, scale=None):
grad_var_set = set() grad_var_set = set()
grad_vars = [] grad_vars = []
...@@ -230,30 +244,6 @@ def fused_allreduce_gradients(parameter_list, hcg): ...@@ -230,30 +244,6 @@ def fused_allreduce_gradients(parameter_list, hcg):
fused_allreduce_gradients_with_group(parameter_list, data_parallel_group) fused_allreduce_gradients_with_group(parameter_list, data_parallel_group)
def sharding_reduce_gradients(parameter_list, hcg):
# TODO allreduce --> reduce
# TODO merge grad / nrank with dp
logger.debug("sharding start gradients sync")
with framework.no_grad():
sharding_nrank = hcg.get_sharding_parallel_group().nranks
for param in parameter_list:
g_var = None
if param.trainable and (param._grad_ivar() is not None):
g_var = param._grad_ivar()
if param.trainable and hasattr(param, "main_grad"):
assert (
param._grad_ivar() is None
), "param.grad should be None when using main_grad"
g_var = param.main_grad
if g_var is not None:
g_var.scale_(1.0 / sharding_nrank)
paddle.distributed.all_reduce(
g_var,
group=hcg.get_sharding_parallel_group(),
sync_op=True,
)
def broadcast_sharding_parameters(model, hcg): def broadcast_sharding_parameters(model, hcg):
# TODO TO save memory, use un-fused broadcast to avoid potentional OOM # TODO TO save memory, use un-fused broadcast to avoid potentional OOM
logger.debug("sharding start init parameters sync") logger.debug("sharding start init parameters sync")
...@@ -262,3 +252,10 @@ def broadcast_sharding_parameters(model, hcg): ...@@ -262,3 +252,10 @@ def broadcast_sharding_parameters(model, hcg):
sync_params_buffers( sync_params_buffers(
model, sharding_parallel_group, src_rank, is_model_parallel=False model, sharding_parallel_group, src_rank, is_model_parallel=False
) )
def unwrap_optimizer(optimizer, optimizer_instances=()):
_inner_opt = optimizer
while isinstance(_inner_opt, optimizer_instances):
_inner_opt = _inner_opt._inner_opt
return _inner_opt
...@@ -21,6 +21,9 @@ import numpy as np ...@@ -21,6 +21,9 @@ import numpy as np
import paddle import paddle
from paddle import _legacy_C_ops, nn from paddle import _legacy_C_ops, nn
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.fleet.utils.hybrid_parallel_util import (
obtain_optimizer_parameters_list,
)
from paddle.fluid import framework from paddle.fluid import framework
from paddle.fluid.dygraph import base as imperative_base from paddle.fluid.dygraph import base as imperative_base
from paddle.fluid.dygraph import to_variable from paddle.fluid.dygraph import to_variable
...@@ -93,20 +96,7 @@ class MixPrecisionLayer(nn.Layer): ...@@ -93,20 +96,7 @@ class MixPrecisionLayer(nn.Layer):
class MixPrecisionOptimizer: class MixPrecisionOptimizer:
def __init__(self, optimizer): def __init__(self, optimizer):
self._inner_opt = optimizer self._inner_opt = optimizer
self._parameter_list = self._obtain_optimizer_parameters_list() self._parameter_list = obtain_optimizer_parameters_list(optimizer)
def _obtain_optimizer_parameters_list(self):
if getattr(self._inner_opt, '_param_groups', None) and isinstance(
self._inner_opt._param_groups[0], dict
):
parameters_list = []
for group in self._inner_opt._param_groups:
for param in group['params']:
parameters_list.append(param)
else:
parameters_list = list(self._inner_opt._parameter_list)
return parameters_list
@imperative_base.no_grad @imperative_base.no_grad
@framework.dygraph_only @framework.dygraph_only
......
...@@ -700,8 +700,7 @@ class Optimizer: ...@@ -700,8 +700,7 @@ class Optimizer:
else: else:
assert isinstance(self.helper, LayerHelper) assert isinstance(self.helper, LayerHelper)
var_name = param.name + "_fp32_master" var_name = self._gen_master_weight_var_name(param)
var_name = unique_name.generate(var_name)
var = paddle.static.create_global_var( var = paddle.static.create_global_var(
name=var_name, name=var_name,
shape=param.shape, shape=param.shape,
...@@ -722,6 +721,10 @@ class Optimizer: ...@@ -722,6 +721,10 @@ class Optimizer:
self._master_weights[param.name] = var self._master_weights[param.name] = var
return var return var
def _gen_master_weight_var_name(self, param):
var_name = param.name + "_fp32_master"
return unique_name.generate(var_name)
def _create_master_grad(self, grad): def _create_master_grad(self, grad):
assert self._is_dtype_fp16_or_bf16(grad.dtype) assert self._is_dtype_fp16_or_bf16(grad.dtype)
if grad.name in self._master_grads: if grad.name in self._master_grads:
......
...@@ -23,6 +23,10 @@ from paddle.distributed import fleet ...@@ -23,6 +23,10 @@ from paddle.distributed import fleet
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
DygraphShardingOptimizer, DygraphShardingOptimizer,
) )
from paddle.distributed.fleet.utils.mix_precision_utils import (
MixPrecisionLayer,
MixPrecisionOptimizer,
)
vocab_size = 20 vocab_size = 20
hidden_size = 10 hidden_size = 10
...@@ -210,47 +214,24 @@ class TestDistMPTraning(unittest.TestCase): ...@@ -210,47 +214,24 @@ class TestDistMPTraning(unittest.TestCase):
optimizer.clear_grad() optimizer.clear_grad()
return loss return loss
def build_optimizer( def build_optimizer(self, model, strategy=None, Optimizer="adam"):
self, model, strategy=None, is_sharding=True, Optimizer="adam"
):
clip = paddle.nn.ClipGradByGlobalNorm(0.5) clip = paddle.nn.ClipGradByGlobalNorm(0.5)
if Optimizer == "adam": if Optimizer == "adam":
if is_sharding: optimizer = paddle.optimizer.AdamW(
optimizer = DygraphShardingOptimizer( parameters=model.parameters(),
hcg=fleet.get_hybrid_communicate_group(), learning_rate=0.001,
user_defined_strategy=strategy, weight_decay=0.00001,
params=model.parameters(), grad_clip=clip,
inner_optimizer_class=paddle.optimizer.AdamW, )
learning_rate=0.001,
weight_decay=0.00001,
grad_clip=clip,
)
else:
optimizer = paddle.optimizer.AdamW(
parameters=model.parameters(),
learning_rate=0.001,
weight_decay=0.00001,
grad_clip=clip,
)
else: else:
if is_sharding: optimizer = paddle.optimizer.Momentum(
optimizer = DygraphShardingOptimizer( learning_rate=0.001,
hcg=fleet.get_hybrid_communicate_group(), parameters=model.parameters(),
user_defined_strategy=strategy, grad_clip=clip,
params=model.parameters(), )
inner_optimizer_class=paddle.optimizer.Momentum,
learning_rate=0.001,
grad_clip=clip,
)
else:
optimizer = paddle.optimizer.Momentum(
learning_rate=0.001,
parameters=model.parameters(),
grad_clip=clip,
)
return optimizer return optimizer
def build_model_optimizer(self, Optimizer="adam"): def build_model_optimizer(self, Optimizer="adam", amp_level=None):
hcg = fleet.get_hybrid_communicate_group() hcg = fleet.get_hybrid_communicate_group()
word_size = hcg.get_model_parallel_world_size() word_size = hcg.get_model_parallel_world_size()
sharding_id = hcg.get_sharding_parallel_rank() sharding_id = hcg.get_sharding_parallel_rank()
...@@ -266,11 +247,8 @@ class TestDistMPTraning(unittest.TestCase): ...@@ -266,11 +247,8 @@ class TestDistMPTraning(unittest.TestCase):
optimizer_a = self.build_optimizer( optimizer_a = self.build_optimizer(
model_a, model_a,
strategy=self.strategy, strategy=self.strategy,
is_sharding=True,
Optimizer=Optimizer, Optimizer=Optimizer,
) )
model_a = fleet.distributed_model(model_a)
optimizer_a = fleet.distributed_optimizer(optimizer_a)
model_b = SimpleDPNet( model_b = SimpleDPNet(
vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2
...@@ -278,15 +256,23 @@ class TestDistMPTraning(unittest.TestCase): ...@@ -278,15 +256,23 @@ class TestDistMPTraning(unittest.TestCase):
optimizer_b = self.build_optimizer( optimizer_b = self.build_optimizer(
model_b, model_b,
strategy=self.strategy, strategy=self.strategy,
is_sharding=False,
Optimizer=Optimizer, Optimizer=Optimizer,
) )
if amp_level is not None and amp_level == "O2":
model_a = MixPrecisionLayer(model_a)
optimizer_a = MixPrecisionOptimizer(optimizer_a)
model_b = MixPrecisionLayer(model_b)
optimizer_b = MixPrecisionOptimizer(optimizer_b)
model_a = fleet.distributed_model(model_a)
optimizer_a = fleet.distributed_optimizer(optimizer_a)
return model_a, optimizer_a, model_b, optimizer_b return model_a, optimizer_a, model_b, optimizer_b
def sharding_model(self, Optimizer, sharded_accumulators): def sharding_model(self, Optimizer, sharded_accumulators, amp_level=None):
model_a, optimizer_a, model_b, optimizer_b = self.build_model_optimizer( model_a, optimizer_a, model_b, optimizer_b = self.build_model_optimizer(
Optimizer=Optimizer Optimizer=Optimizer, amp_level=amp_level
) )
self.assertTrue( self.assertTrue(
...@@ -296,9 +282,7 @@ class TestDistMPTraning(unittest.TestCase): ...@@ -296,9 +282,7 @@ class TestDistMPTraning(unittest.TestCase):
for idx in range(STEPS): for idx in range(STEPS):
if idx == 2 and paddle.distributed.get_rank() == 0: if idx == 2 and paddle.distributed.get_rank() == 0:
self.assertTrue( self.assertTrue(
set( set(optimizer_a._inner_opt._inner_opt.state_dict().keys())
optimizer_a._inner_opt._inner_optimizer.state_dict().keys()
)
== sharded_accumulators == sharded_accumulators
) )
...@@ -352,6 +336,19 @@ class TestDistMPTraning(unittest.TestCase): ...@@ -352,6 +336,19 @@ class TestDistMPTraning(unittest.TestCase):
Optimizer="Momentum", sharded_accumulators=sharded_accumulators Optimizer="Momentum", sharded_accumulators=sharded_accumulators
) )
def test_sharding_momentum_amp(self):
sharded_accumulators = {
'linear_12.w_0_velocity_0',
'linear_13.b_0_velocity_0',
'linear_14.b_0_velocity_0',
'embedding_4.w_0_velocity_0',
}
self.sharding_model(
Optimizer="Momentum",
sharded_accumulators=sharded_accumulators,
amp_level="O2",
)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import random
import unittest
import numpy as np
import paddle
from paddle.distributed import fleet
from paddle.distributed.fleet.utils.mix_precision_utils import (
MixPrecisionOptimizer,
)
vocab_size = 20
hidden_size = 10
inner_size = 8
output_size = 10
seq_length = 2
batch_size = 4
STEPS = 10
class SimpleDPNet(paddle.nn.Layer):
def __init__(
self, vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2
):
super().__init__()
self.linear1 = paddle.nn.Linear(
hidden_size,
inner_size,
weight_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Assign(np_fc1)
),
bias_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.0)
),
)
self.linear2 = paddle.nn.Linear(
inner_size,
hidden_size,
weight_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Assign(np_fc2)
),
bias_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.0)
),
)
self.linear3 = paddle.nn.Linear(
hidden_size,
output_size,
weight_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.0)
),
bias_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.0)
),
)
self.embedding = paddle.nn.Embedding(
vocab_size,
hidden_size,
weight_attr=paddle.nn.initializer.Constant(value=0.5),
)
def forward(self, x):
x = self.embedding(x)
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
x = paddle.matmul(x, self.embedding.weight, transpose_y=True)
return x
class TestDistShardingTraining(unittest.TestCase):
def setUp(self):
random.seed(2021)
np.random.seed(2021)
paddle.seed(2021)
self.strategy = fleet.DistributedStrategy()
self.strategy.hybrid_configs = {
"sharding_degree": 2,
"dp_degree": 1,
"mp_degree": 1,
"pp_degree": 1,
}
fleet.init(is_collective=True, strategy=self.strategy)
self.data = [
np.random.randint(
0,
vocab_size,
(
batch_size,
seq_length,
),
)
for _ in range(STEPS)
]
def build_adam_optimizer(self, model, lr=0.001):
clip = paddle.nn.ClipGradByGlobalNorm(0.5)
optimizer = paddle.optimizer.AdamW(
parameters=model.parameters(),
learning_rate=lr,
weight_decay=0.00001,
grad_clip=clip,
)
return optimizer
def test_set_state_dict(self):
np_fc1 = np.random.random_sample((hidden_size, inner_size))
np_fc2 = np.random.random_sample((inner_size, hidden_size))
model = SimpleDPNet(
vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2
)
init_lr = 0.001
init_lr_scheduler = paddle.optimizer.lr.CosineAnnealingDecay(
learning_rate=init_lr, T_max=1
)
local_optimizer = self.build_adam_optimizer(model, init_lr_scheduler)
dist_optimizer = fleet.distributed_optimizer(local_optimizer)
# prepare state_dict
state_dict = {}
# lr_scheduler
base_lr = 0.1
lr_scheduler = paddle.optimizer.lr.CosineAnnealingDecay(
learning_rate=base_lr, T_max=1
)
state_dict["LR_Scheduler"] = lr_scheduler.state_dict()
# master_weights and accumulators
state_dict["master_weights"] = {}
all_param_names = []
accumulator_names = ["moment1", "moment2"]
#
local_params = dist_optimizer._rank2params[
dist_optimizer._sharding_rank
]
local_param_names = [p.name for p in local_params]
local_acc_names = []
other_acc_names = []
for p in model.parameters():
var_name = dist_optimizer._gen_master_weight_var_name(p)
var = paddle.static.create_global_var(
name=var_name,
shape=p.shape,
value=0,
dtype='float32',
persistable=True,
)
var = paddle.randn(shape=var.shape, dtype=var.dtype, name=var.name)
state_dict["master_weights"][p.name] = var
# accumulator
for name in accumulator_names:
acc_name = p.name + '_' + name
state_dict[acc_name] = paddle.randn(
shape=var.shape, dtype=var.dtype, name=acc_name
)
if p.name in local_param_names:
local_acc_names.append(acc_name)
else:
other_acc_names.append(acc_name)
all_param_names.append(p.name)
# test api
tmp_state_dict = copy.deepcopy(state_dict)
dist_optimizer.set_state_dict(state_dict)
# check result
other_param_names = [
p_name
for p_name in all_param_names
if p_name not in local_param_names
]
inner_opt = dist_optimizer._inner_opt
self.assertEqual(inner_opt._learning_rate.last_lr, base_lr)
assert hasattr(inner_opt, "_master_weights")
for p_name, weight in inner_opt._master_weights.items():
assert p_name in local_param_names
assert p_name not in other_param_names
assert p_name in tmp_state_dict["master_weights"]
np.testing.assert_array_almost_equal(
weight.numpy(), tmp_state_dict["master_weights"][p_name].numpy()
)
for acc_name, val in inner_opt._accumulators_holder.items():
assert acc_name in local_acc_names
assert acc_name not in other_acc_names
assert acc_name in tmp_state_dict
np.testing.assert_array_almost_equal(
val.numpy(), tmp_state_dict[acc_name].numpy()
)
def test_clear_grad(self):
np_fc1 = np.random.random_sample((hidden_size, inner_size))
np_fc2 = np.random.random_sample((inner_size, hidden_size))
model = SimpleDPNet(
vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2
)
local_optimizer = self.build_adam_optimizer(model)
dist_optimizer = fleet.distributed_optimizer(local_optimizer)
tmp_parameter_list = []
for p in dist_optimizer._inner_opt._parameter_list:
main_grad = paddle.randn(shape=p.shape, dtype=p.dtype, name=p.name)
p.main_grad = main_grad
tmp_parameter_list.append(p)
assert hasattr(
dist_optimizer._inner_opt._parameter_list[0], "main_grad"
)
# test set_to_zero True
dist_optimizer._inner_opt.clear_grad(set_to_zero=True)
for p in dist_optimizer._inner_opt._parameter_list:
np.testing.assert_array_almost_equal(
p.main_grad.numpy(), np.zeros(p.main_grad.numpy().shape)
)
# test set_to_zero False
dist_optimizer._inner_opt.clear_grad(set_to_zero=False)
for p in dist_optimizer._inner_opt._parameter_list:
self.assertTrue(p.main_grad is None)
def test_set_inner_opt_attr(self):
np_fc1 = np.random.random_sample((hidden_size, inner_size))
np_fc2 = np.random.random_sample((inner_size, hidden_size))
model = SimpleDPNet(
vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2
)
local_optimizer = self.build_adam_optimizer(model)
local_optimizer = MixPrecisionOptimizer(local_optimizer)
dist_optimizer = fleet.distributed_optimizer(local_optimizer)
sharding_opt = dist_optimizer._inner_opt
sharding_opt._set_inner_opt_attr('_parameter_list', 123)
self.assertTrue(hasattr(sharding_opt._inner_opt, '_parameter_list'))
self.assertTrue(
hasattr(sharding_opt._inner_opt._inner_opt, '_parameter_list')
)
self.assertEqual(sharding_opt._inner_opt._parameter_list, 123)
self.assertEqual(
sharding_opt._inner_opt._inner_opt._parameter_list, 123
)
sharding_opt._set_inner_opt_attr('_param_groups', 123)
self.assertTrue(hasattr(sharding_opt._inner_opt, '_param_groups'))
self.assertTrue(
hasattr(sharding_opt._inner_opt._inner_opt, '_param_groups')
)
self.assertEqual(sharding_opt._inner_opt._param_groups, 123)
self.assertEqual(sharding_opt._inner_opt._inner_opt._param_groups, 123)
# test bad case
try:
sharding_opt._set_inner_opt_attr(123, 123)
self.assertTrue(False)
except:
pass
if __name__ == "__main__":
unittest.main()
...@@ -22,6 +22,9 @@ class TestHybridParallel(TestMultipleGpus): ...@@ -22,6 +22,9 @@ class TestHybridParallel(TestMultipleGpus):
def test_hybrid_parallel_sharding_logic(self): def test_hybrid_parallel_sharding_logic(self):
self.run_mnist_2gpu('hybrid_parallel_sharding_model.py') self.run_mnist_2gpu('hybrid_parallel_sharding_model.py')
def test_hybrid_parallel_sharding_state_dict(self):
self.run_mnist_2gpu('hybrid_parallel_sharding_state_dict.py')
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册