未验证 提交 974676bc 编写于 作者: P pangengzheng 提交者: GitHub

support sharding stage1 (#54069)

* support sharding stage1

* fix unittest

* format

* pass sharded sharding params_and_grads to inner_opt apply_pptimize

* change sharding gradient allreduce to reduce

* support save state_dict adptively and support sharding with mp

* fix sharding test

* test set_state_dict

* add more unit test

* fix global norm of mp case

* polish

* hack to calculate global norm in order to remove diff in calculating global norm values in HybridParallelClipGrad compared to dp

* remove print
上级 7df043ec
......@@ -2867,7 +2867,7 @@ void StackInferMeta(const std::vector<const MetaTensor*>& x,
rank,
axis));
if (axis < 0) axis += (rank + 1);
auto vec = phi::vectorize<int>(out_dim);
auto vec = phi::vectorize<int64_t>(out_dim);
vec.insert(vec.begin() + axis, input_dims.size());
out->set_dims(phi::make_ddim(vec));
out->set_dtype(x.at(0)->dtype());
......
......@@ -553,7 +553,7 @@ def _set_multi_precision(optimizer, multi_precision):
)
optimizer = (
optimizer._inner_optimizer
optimizer._inner_opt
if isinstance(optimizer, DygraphShardingOptimizer)
else optimizer
)
......
......@@ -183,6 +183,11 @@ class HybridCommunicateGroup:
"data"
)
(
self.sharding_check_group,
self.sharding_check_comm_group,
) = self._set_check_group("sharding")
# create p2p group
self.is_first_stage = self.stage_id == 0
self.is_last_stage = self.stage_id == (self._pp_degree - 1)
......@@ -376,8 +381,11 @@ class HybridCommunicateGroup:
return self._sharding_comm_group.ranks[0]
# check parallel group
def get_check_parallel_group(self):
return self._check_comm_group
def get_check_parallel_group(self, sharding=False):
if sharding:
return self.sharding_check_comm_group
else:
return self._check_comm_group
def get_rank_from_stage(self, stage_id, **kwargs):
return self._topo.get_rank_from_stage(
......
......@@ -18,6 +18,7 @@ from functools import reduce
import paddle
from paddle import framework
from paddle.fluid.dygraph import base as imperative_base
from ...utils.log_util import logger
......@@ -43,55 +44,51 @@ class DygraphShardingOptimizer:
# 3. dynamic trainable params, which is the case bewteen pretraining and finetuning
# 4. option to choose fuse comm (more GPU MEM need) or un-fuse comm
def __init__(
self,
hcg,
user_defined_strategy,
params,
inner_optimizer_class,
**inner_optimizer_kargs
):
if not isinstance(params, list):
def __init__(self, optimizer, hcg):
# TODO(pangengzheng): support param_groups
if isinstance(optimizer._parameter_list[0], dict):
raise TypeError(
"`parameters` argument given to the DygraphShardingOptimizer should be "
"an iterable of paddle Tensors, but got argument type is `{}`.".format(
type(params)
)
"Do not support param_groups now, please set optimizer._parameter_list as a list of Parameter"
)
if not hasattr(optimizer, '_apply_optimize') or not callable(
optimizer._apply_optimize
):
raise ValueError(
"the optimzier object should have _apply_optimize function"
)
self._parameter_list = params
self._reference_is_trainable_params = list(
map(_is_trainable, self._parameter_list)
)
self._inner_optimizer_class = inner_optimizer_class
self._inner_optimizer_kargs = inner_optimizer_kargs
# sharding parallel information
# TODO better way to get the hcg & user_defined_strategy
# the self._parameter_list holds the whole model paramters
self._parameter_list = optimizer._parameter_list
self._inner_opt = optimizer
self._hcg = hcg
self._user_defined_strategy = user_defined_strategy
self._sharding_world_size = self._hcg.get_sharding_parallel_world_size()
self._sharding_rank = self._hcg.get_sharding_parallel_rank()
# logic partitioning
self._build_sharding_mapping()
self._rank2params = self._partition_parameters()
self._param2rank = self._map_param_to_rank()
# actually create opt ops
self._buid_inner_optimizer()
self._set_inner_opt_attr(
'_parameter_list', self._rank2params[self._sharding_rank]
)
self._set_inner_opt_attr(
'_param_groups', self._rank2params[self._sharding_rank]
)
def clear_grad(self):
def clear_grad(self, set_to_zero=True):
"""
should clear grad for all parameters in model
"""
#
for p in self._parameter_list:
if not p.stop_gradient:
p.clear_gradient()
def _build_sharding_mapping(self):
self._rank2params = self._partition_parameters()
self._param2rank = self._map_param_to_rank()
if hasattr(p, "main_grad") and p.main_grad is not None:
assert p._grad_ivar() is None
if set_to_zero:
p.main_grad.zero_()
else:
p.main_grad._clear()
p.main_grad = None
elif not hasattr(p, "main_grad"):
p.clear_gradient(set_to_zero)
def _partition_parameters(self):
"""
......@@ -134,14 +131,35 @@ class DygraphShardingOptimizer:
mapping[param.name] = rank
return mapping
def _buid_inner_optimizer(self):
# we rely on the inner opt to determine whether a parameter is stop_gradient or not:
# create moment
# update related ops: clip, regular, opt
self._inner_optimizer = self._inner_optimizer_class(
parameters=self._rank2params[self._sharding_rank],
**self._inner_optimizer_kargs
)
def reduce_gradients(self, parameter_list, hcg):
# TODO merge grad / nrank with dp
logger.debug("sharding start gradients sync")
with framework.no_grad():
sharding_nrank = hcg.get_sharding_parallel_group().nranks
for param in parameter_list:
g_var = None
if param.trainable and (param._grad_ivar() is not None):
g_var = param._grad_ivar()
if param.trainable and hasattr(param, "main_grad"):
assert (
param._grad_ivar() is None
), "param.grad should be None when using main_grad"
g_var = param.main_grad
if g_var is not None:
g_var.scale_(1.0 / sharding_nrank)
param_rank = self._param2rank[param.name]
paddle.distributed.all_reduce(
g_var,
group=hcg.get_sharding_parallel_group(),
sync_op=True,
)
# TODO(pangengzheng): change to reduce operation when there is no diff in calculating global norm values in HybridParallelClipGrad compared to dp.
# paddle.distributed.reduce(
# g_var,
# dst=hcg.get_sharding_parallel_group().ranks[param_rank],
# group=hcg.get_sharding_parallel_group(),
# sync_op=True,
# )
def _sharding_sync_parameters(self):
"""
......@@ -149,7 +167,6 @@ class DygraphShardingOptimizer:
"""
# TODO speed up this functional
logger.debug("sharding start sync parameters")
with framework.no_grad():
# TODO detach not need (?)
for rank, params in self._rank2params.items():
......@@ -172,7 +189,6 @@ class DygraphShardingOptimizer:
def minimize(
self, loss, startup_program=None, parameters=None, no_grad_set=None
):
# NOTE in dygraph mode, the only different between step and minimize is that minimize
# allow user to customize the parameters for updating on each step
......@@ -183,7 +199,7 @@ class DygraphShardingOptimizer:
self._rank2params[self._sharding_rank],
)
)
result = self._inner_optimizer.minimize(
result = self._inner_opt.minimize(
loss, startup_program, parameters, no_grad_set
)
......@@ -192,22 +208,97 @@ class DygraphShardingOptimizer:
return result
@imperative_base.no_grad
@framework.dygraph_only
def step(self):
# TODO Check whether the model trainable param changed and update state accordingly
# actually updating
self._inner_optimizer.step()
# hack to grad_clip all parameters,
# otherwise the self._inner_opt will only grad_clip the self._rank2params[self._sharding_rank] params
# TODO(pangengzheng): remove the hacked grad_clip codes here when there is no diff in calculating global norm values in HybridParallelClipGrad compared to dp.
origin_clip = self._inner_opt._grad_clip
if not isinstance(self._parameter_list[0], dict):
params_grads = []
for param in self._parameter_list:
if (
hasattr(param, "regularizer")
and param.regularizer is not None
):
raise ValueError(
"param {} should not has the regularizer attribute".format(
param.name
)
)
if param.stop_gradient:
continue
grad_var = param._grad_ivar()
if hasattr(param, "main_grad") and param.main_grad is not None:
grad_var = param.main_grad
params_grads.append((param, grad_var))
if hasattr(self._inner_opt._grad_clip, 'not_sharding_stage1'):
self._inner_opt._grad_clip.not_sharding_stage1 = False
params_grads = self._inner_opt._grad_clip(params_grads)
# set inner_opt._grad_clip None to avoid repeatedly grad_clip gradients inside inner_opt._apply_optimize
self._set_inner_opt_attr('_grad_clip', None)
update_param_names = [
p.name for p in self._rank2params[self._sharding_rank]
]
update_params_grads = [
(p, g) for p, g in params_grads if p.name in update_param_names
]
self._apply_optimize(
loss=None,
startup_program=None,
params_grads=update_params_grads,
)
# restore the grad clip
self._set_inner_opt_attr('_grad_clip', origin_clip)
# sync parameters across sharding ranks
self._sharding_sync_parameters()
# TODO is it a good way to make _grad_clip a property
@property
def _grad_clip(self):
assert (
self._inner_optimizer is not None
), "inner opt of sharding is not initiliazed."
return self._inner_optimizer._grad_clip
@framework.dygraph_only
def set_state_dict(self, state_dict):
inner_state = {}
parameters = self._rank2params[self._sharding_rank]
if "LR_Scheduler" in state_dict:
inner_state["LR_Scheduler"] = state_dict.pop("LR_Scheduler")
if "master_weights" in state_dict:
master = state_dict.pop("master_weights")
inner_state["master_weights"] = {}
for p in parameters:
for k, v in master.items():
if p.name == k:
v.name = self._inner_opt._gen_master_weight_var_name(p)
inner_state["master_weights"][k] = v
for p in parameters:
for k, v in state_dict.items():
if p.name in k:
inner_state[k] = v
self._inner_opt.set_state_dict(inner_state)
def _set_inner_opt_attr(self, attr_name, value):
inner_opt = self._inner_opt
inner_opt_name = '_inner_opt'
if not isinstance(attr_name, str):
raise TypeError(
"attr_name should be str type, but is {}".format(
type(attr_name)
)
)
while hasattr(inner_opt, attr_name):
setattr(inner_opt, attr_name, value)
if (
hasattr(inner_opt, inner_opt_name)
and getattr(inner_opt, inner_opt_name, None) is not None
):
inner_opt = getattr(inner_opt, inner_opt_name, None)
else:
break
def __getattr__(self, item):
return getattr(self._inner_optimizer, item)
return getattr(self._inner_opt, item)
......@@ -17,13 +17,19 @@ import paddle
from paddle import framework
from paddle.autograd import no_grad
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
DygraphShardingOptimizer,
)
from paddle.distributed.fleet.utils.hybrid_parallel_util import (
obtain_optimizer_parameters_list,
)
from paddle.framework import core
from paddle.nn import ClipGradByGlobalNorm, clip
from ...base.topology import ParallelMode
from ...utils.hybrid_parallel_util import (
fused_allreduce_gradients,
sharding_reduce_gradients,
unwrap_optimizer,
)
from ...utils.log_util import logger
from ...utils.mix_precision_utils import MixPrecisionOptimizer
......@@ -31,24 +37,11 @@ from ...utils.mix_precision_utils import MixPrecisionOptimizer
__all__ = []
def _obtain_optimizer_parameters_list(optimizer):
if getattr(optimizer, '_param_groups', None) and isinstance(
optimizer._param_groups[0], dict
):
parameters_list = []
for group in optimizer._param_groups:
for param in group['params']:
parameters_list.append(param)
else:
parameters_list = list(optimizer._parameter_list)
return parameters_list
class HybridParallelClipGrad:
def __init__(self, clip, hcg):
self._clip = clip
self._hcg = hcg
self.not_sharding_stage1 = True
@no_grad()
def _dygraph_clip(self, params_grads):
......@@ -169,8 +162,15 @@ class HybridParallelClipGrad:
# add all reduce to get global norm of distributed params_and_grads
if self._hcg.get_model_parallel_world_size() > 1:
sharding_flag = False
if (
self._hcg.get_sharding_parallel_world_size() > 1
and self._hcg.get_data_parallel_world_size() == 1
):
sharding_flag = True
paddle.distributed.all_reduce(
global_norm_var_dist, group=self._hcg.get_check_parallel_group()
global_norm_var_dist,
group=self._hcg.get_check_parallel_group(sharding_flag),
)
# add all reduce to get global norm of non-distributed params_and_grads in groups of pp
......@@ -182,7 +182,11 @@ class HybridParallelClipGrad:
# In Sharding mode, param and grad is mapping different rank in optimizer.
# ClipGradByGlobalNorm need allreduce to get globol norm
if self._hcg.get_sharding_parallel_world_size() > 1:
# TODO(pangengzheng): remove the self.not_sharding_stage1 flag when there is no diff in calculating global norm values in HybridParallelClipGrad compared to dp.
if (
self._hcg.get_sharding_parallel_world_size() > 1
and self.not_sharding_stage1
):
paddle.distributed.all_reduce(
global_norm_var_not_dist,
group=self._hcg.get_sharding_parallel_group(),
......@@ -236,6 +240,10 @@ class HybridParallelClipGrad:
class HybridParallelOptimizer:
# adapter wrapper for optimizer
def __init__(self, optimizer, hcg, strategy):
# Note: Only sharding stage 1 is considered in HybridParallelOptimizer.
# The sharding stage2 and stage3 optimizers are invoked in other api.
if hcg.get_sharding_parallel_world_size() > 1:
optimizer = DygraphShardingOptimizer(optimizer, hcg)
self._inner_opt = optimizer
self._strategy = strategy
self._hcg = hcg
......@@ -260,16 +268,11 @@ class HybridParallelOptimizer:
"While using ClipGradByGlobalNorm in TensorParallel, PipelineParallel "
"or Sharding, the grad clip of original optimizer will be changed."
)
inner_opt = (
self._inner_opt._inner_optimizer
if self._sharding_enable
else self._inner_opt
inner_opt = unwrap_optimizer(
self._inner_opt,
(MixPrecisionOptimizer, DygraphShardingOptimizer),
)
if isinstance(inner_opt, MixPrecisionOptimizer):
inner_opt = inner_opt._inner_opt
if (
inner_opt._parameter_list
and not isinstance(inner_opt._parameter_list[0], dict)
......@@ -413,9 +416,10 @@ class HybridParallelOptimizer:
@no_grad()
@framework.dygraph_only
def step(self):
parameters_list = _obtain_optimizer_parameters_list(self._inner_opt)
parameters_list = obtain_optimizer_parameters_list(self._inner_opt)
if self._sharding_enable:
sharding_reduce_gradients(list(parameters_list), self._hcg)
assert isinstance(self._inner_opt, DygraphShardingOptimizer)
self._inner_opt.reduce_gradients(list(parameters_list), self._hcg)
if self._dp_enable:
fused_allreduce_gradients(list(parameters_list), self._hcg)
......@@ -435,7 +439,8 @@ class HybridParallelOptimizer:
# Here sharding should use global parameter list
if self._sharding_enable:
sharding_reduce_gradients(list(parameter_list), self._hcg)
assert isinstance(self._inner_opt, DygraphShardingOptimizer)
self._inner_opt.reduce_gradients(list(parameter_list), self._hcg)
if self._dp_enable:
fused_allreduce_gradients(list(parameter_list), self._hcg)
......
......@@ -29,6 +29,20 @@ from .log_util import logger
__all__ = []
def obtain_optimizer_parameters_list(optimizer):
if getattr(optimizer, '_param_groups', None) and isinstance(
optimizer._param_groups[0], dict
):
parameters_list = []
for group in optimizer._param_groups:
for param in group['params']:
parameters_list.append(param)
else:
parameters_list = list(optimizer._parameter_list)
return parameters_list
def _apply_collective_grads(parameters, comm_group, bucket_size, scale=None):
grad_var_set = set()
grad_vars = []
......@@ -224,31 +238,6 @@ def fused_allreduce_gradients(parameter_list, hcg):
fused_allreduce_gradients_with_group(parameter_list, data_parallel_group)
def sharding_reduce_gradients(parameter_list, hcg):
# TODO allreduce --> reduce
# TODO merge grad / nrank with dp
logger.debug("sharding start gradients sync")
with framework.no_grad():
sharding_nrank = hcg.get_sharding_parallel_group().nranks
for param in parameter_list:
g_var = None
if param.trainable and (param._grad_ivar() is not None):
g_var = param._grad_ivar()
if param.trainable and hasattr(param, "main_grad"):
assert (
param._grad_ivar() is None
), "param.grad should be None when using main_grad"
g_var = param.main_grad
if g_var is not None:
g_var.scale_(1.0 / sharding_nrank)
paddle.distributed.all_reduce(
g_var,
group=hcg.get_sharding_parallel_group(),
sync_op=True,
)
def broadcast_sharding_parameters(model, hcg):
# TODO TO save memory, use un-fused broadcast to avoid potentional OOM
logger.debug("sharding start init parameters sync")
......@@ -257,3 +246,10 @@ def broadcast_sharding_parameters(model, hcg):
sync_params_buffers(
model, sharding_parallel_group, src_rank, is_model_parallel=False
)
def unwrap_optimizer(optimizer, optimizer_instances=()):
_inner_opt = optimizer
while isinstance(_inner_opt, optimizer_instances):
_inner_opt = _inner_opt._inner_opt
return _inner_opt
......@@ -21,6 +21,9 @@ import numpy as np
import paddle
from paddle import _legacy_C_ops, nn
from paddle.distributed import fleet
from paddle.distributed.fleet.utils.hybrid_parallel_util import (
obtain_optimizer_parameters_list,
)
from paddle.fluid import framework
from paddle.fluid.dygraph import base as imperative_base
from paddle.fluid.dygraph import to_variable
......@@ -95,20 +98,7 @@ class MixPrecisionLayer(nn.Layer):
class MixPrecisionOptimizer:
def __init__(self, optimizer):
self._inner_opt = optimizer
self._parameter_list = self._obtain_optimizer_parameters_list()
def _obtain_optimizer_parameters_list(self):
if getattr(self._inner_opt, '_param_groups', None) and isinstance(
self._inner_opt._param_groups[0], dict
):
parameters_list = []
for group in self._inner_opt._param_groups:
for param in group['params']:
parameters_list.append(param)
else:
parameters_list = list(self._inner_opt._parameter_list)
return parameters_list
self._parameter_list = obtain_optimizer_parameters_list(optimizer)
@imperative_base.no_grad
@framework.dygraph_only
......
......@@ -212,44 +212,21 @@ class TestDistMPTraning(unittest.TestCase):
optimizer.clear_grad()
return loss
def build_optimizer(
self, model, strategy=None, is_sharding=True, Optimizer="adam"
):
def build_optimizer(self, model, strategy=None, Optimizer="adam"):
clip = paddle.nn.ClipGradByGlobalNorm(0.5)
if Optimizer == "adam":
if is_sharding:
optimizer = DygraphShardingOptimizer(
hcg=fleet.get_hybrid_communicate_group(),
user_defined_strategy=strategy,
params=model.parameters(),
inner_optimizer_class=paddle.optimizer.AdamW,
learning_rate=0.001,
weight_decay=0.00001,
grad_clip=clip,
)
else:
optimizer = paddle.optimizer.AdamW(
parameters=model.parameters(),
learning_rate=0.001,
weight_decay=0.00001,
grad_clip=clip,
)
optimizer = paddle.optimizer.AdamW(
parameters=model.parameters(),
learning_rate=0.001,
weight_decay=0.00001,
grad_clip=clip,
)
else:
if is_sharding:
optimizer = DygraphShardingOptimizer(
hcg=fleet.get_hybrid_communicate_group(),
user_defined_strategy=strategy,
params=model.parameters(),
inner_optimizer_class=paddle.optimizer.Momentum,
learning_rate=0.001,
grad_clip=clip,
)
else:
optimizer = paddle.optimizer.Momentum(
learning_rate=0.001,
parameters=model.parameters(),
grad_clip=clip,
)
optimizer = paddle.optimizer.Momentum(
learning_rate=0.001,
parameters=model.parameters(),
grad_clip=clip,
)
return optimizer
def build_model_optimizer(self, Optimizer="adam"):
......@@ -268,7 +245,6 @@ class TestDistMPTraning(unittest.TestCase):
optimizer_a = self.build_optimizer(
model_a,
strategy=self.strategy,
is_sharding=True,
Optimizer=Optimizer,
)
model_a = fleet.distributed_model(model_a)
......@@ -280,7 +256,6 @@ class TestDistMPTraning(unittest.TestCase):
optimizer_b = self.build_optimizer(
model_b,
strategy=self.strategy,
is_sharding=False,
Optimizer=Optimizer,
)
......@@ -299,9 +274,7 @@ class TestDistMPTraning(unittest.TestCase):
if idx == 2 and paddle.distributed.get_rank() == 0:
self.assertTrue(
set(
optimizer_a._inner_opt._inner_optimizer.state_dict().keys()
)
set(optimizer_a._inner_opt._inner_opt.state_dict().keys())
== sharded_accumulators
)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import random
import unittest
import numpy as np
import paddle
from paddle.distributed import fleet
from paddle.distributed.fleet.utils.mix_precision_utils import (
MixPrecisionOptimizer,
)
vocab_size = 20
hidden_size = 10
inner_size = 8
output_size = 10
seq_length = 2
batch_size = 4
STEPS = 10
class SimpleDPNet(paddle.nn.Layer):
def __init__(
self, vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2
):
super().__init__()
self.linear1 = paddle.nn.Linear(
hidden_size,
inner_size,
weight_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Assign(np_fc1)
),
bias_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.0)
),
)
self.linear2 = paddle.nn.Linear(
inner_size,
hidden_size,
weight_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Assign(np_fc2)
),
bias_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.0)
),
)
self.linear3 = paddle.nn.Linear(
hidden_size,
output_size,
weight_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.0)
),
bias_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.0)
),
)
self.embedding = paddle.nn.Embedding(
vocab_size,
hidden_size,
weight_attr=paddle.nn.initializer.Constant(value=0.5),
)
def forward(self, x):
x = self.embedding(x)
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
x = paddle.matmul(x, self.embedding.weight, transpose_y=True)
return x
class TestDistShardingTraining(unittest.TestCase):
def setUp(self):
random.seed(2021)
np.random.seed(2021)
paddle.seed(2021)
self.strategy = fleet.DistributedStrategy()
self.strategy.hybrid_configs = {
"sharding_degree": 2,
"dp_degree": 1,
"mp_degree": 1,
"pp_degree": 1,
}
fleet.init(is_collective=True, strategy=self.strategy)
self.data = [
np.random.randint(
0,
vocab_size,
(
batch_size,
seq_length,
),
)
for _ in range(STEPS)
]
def build_adam_optimizer(self, model, lr=0.001):
clip = paddle.nn.ClipGradByGlobalNorm(0.5)
optimizer = paddle.optimizer.AdamW(
parameters=model.parameters(),
learning_rate=lr,
weight_decay=0.00001,
grad_clip=clip,
)
return optimizer
def test_set_state_dict(self):
np_fc1 = np.random.random_sample((hidden_size, inner_size))
np_fc2 = np.random.random_sample((inner_size, hidden_size))
model = SimpleDPNet(
vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2
)
init_lr = 0.001
init_lr_scheduler = paddle.optimizer.lr.CosineAnnealingDecay(
learning_rate=init_lr, T_max=1
)
local_optimizer = self.build_adam_optimizer(model, init_lr_scheduler)
dist_optimizer = fleet.distributed_optimizer(local_optimizer)
# prepare state_dict
state_dict = {}
# lr_scheduler
base_lr = 0.1
lr_scheduler = paddle.optimizer.lr.CosineAnnealingDecay(
learning_rate=base_lr, T_max=1
)
state_dict["LR_Scheduler"] = lr_scheduler.state_dict()
# master_weights and accumulators
state_dict["master_weights"] = {}
all_param_names = []
accumulator_names = ["moment1", "moment2"]
#
local_params = dist_optimizer._rank2params[
dist_optimizer._sharding_rank
]
local_param_names = [p.name for p in local_params]
local_acc_names = []
other_acc_names = []
for p in model.parameters():
var_name = dist_optimizer._gen_master_weight_var_name(p)
var = paddle.static.create_global_var(
name=var_name,
shape=p.shape,
value=0,
dtype='float32',
persistable=True,
)
var = paddle.randn(shape=var.shape, dtype=var.dtype, name=var.name)
state_dict["master_weights"][p.name] = var
# accumulator
for name in accumulator_names:
acc_name = p.name + '_' + name
state_dict[acc_name] = paddle.randn(
shape=var.shape, dtype=var.dtype, name=acc_name
)
if p.name in local_param_names:
local_acc_names.append(acc_name)
else:
other_acc_names.append(acc_name)
all_param_names.append(p.name)
# test api
tmp_state_dict = copy.deepcopy(state_dict)
dist_optimizer.set_state_dict(state_dict)
# check result
other_param_names = [
p_name
for p_name in all_param_names
if p_name not in local_param_names
]
inner_opt = dist_optimizer._inner_opt
self.assertEqual(inner_opt._learning_rate.last_lr, base_lr)
assert hasattr(inner_opt, "_master_weights")
for p_name, weight in inner_opt._master_weights.items():
assert p_name in local_param_names
assert p_name not in other_param_names
assert p_name in tmp_state_dict["master_weights"]
np.testing.assert_array_almost_equal(
weight.numpy(), tmp_state_dict["master_weights"][p_name].numpy()
)
for acc_name, val in inner_opt._accumulators_holder.items():
assert acc_name in local_acc_names
assert acc_name not in other_acc_names
assert acc_name in tmp_state_dict
np.testing.assert_array_almost_equal(
val.numpy(), tmp_state_dict[acc_name].numpy()
)
def test_clear_grad(self):
np_fc1 = np.random.random_sample((hidden_size, inner_size))
np_fc2 = np.random.random_sample((inner_size, hidden_size))
model = SimpleDPNet(
vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2
)
local_optimizer = self.build_adam_optimizer(model)
dist_optimizer = fleet.distributed_optimizer(local_optimizer)
tmp_parameter_list = []
for p in dist_optimizer._inner_opt._parameter_list:
main_grad = paddle.randn(shape=p.shape, dtype=p.dtype, name=p.name)
p.main_grad = main_grad
tmp_parameter_list.append(p)
assert hasattr(
dist_optimizer._inner_opt._parameter_list[0], "main_grad"
)
# test set_to_zero True
dist_optimizer._inner_opt.clear_grad(set_to_zero=True)
for p in dist_optimizer._inner_opt._parameter_list:
np.testing.assert_array_almost_equal(
p.main_grad.numpy(), np.zeros(p.main_grad.numpy().shape)
)
# test set_to_zero False
dist_optimizer._inner_opt.clear_grad(set_to_zero=False)
for p in dist_optimizer._inner_opt._parameter_list:
self.assertTrue(p.main_grad is None)
def test_set_inner_opt_attr(self):
np_fc1 = np.random.random_sample((hidden_size, inner_size))
np_fc2 = np.random.random_sample((inner_size, hidden_size))
model = SimpleDPNet(
vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2
)
local_optimizer = self.build_adam_optimizer(model)
local_optimizer = MixPrecisionOptimizer(local_optimizer)
dist_optimizer = fleet.distributed_optimizer(local_optimizer)
sharding_opt = dist_optimizer._inner_opt
sharding_opt._set_inner_opt_attr('_parameter_list', 123)
self.assertTrue(hasattr(sharding_opt._inner_opt, '_parameter_list'))
self.assertTrue(
hasattr(sharding_opt._inner_opt._inner_opt, '_parameter_list')
)
self.assertEqual(sharding_opt._inner_opt._parameter_list, 123)
self.assertEqual(
sharding_opt._inner_opt._inner_opt._parameter_list, 123
)
sharding_opt._set_inner_opt_attr('_param_groups', 123)
self.assertTrue(hasattr(sharding_opt._inner_opt, '_param_groups'))
self.assertTrue(
hasattr(sharding_opt._inner_opt._inner_opt, '_param_groups')
)
self.assertEqual(sharding_opt._inner_opt._param_groups, 123)
self.assertEqual(sharding_opt._inner_opt._inner_opt._param_groups, 123)
# test bad case
try:
sharding_opt._set_inner_opt_attr(123, 123)
self.assertTrue(False)
except:
pass
if __name__ == "__main__":
unittest.main()
......@@ -23,6 +23,9 @@ class TestHybridParallel(TestMultipleGpus):
def test_hybrid_parallel_sharding_logic(self):
self.run_mnist_2gpu('hybrid_parallel_sharding_model.py')
def test_hybrid_parallel_sharding_state_dict(self):
self.run_mnist_2gpu('hybrid_parallel_sharding_state_dict.py')
if __name__ == "__main__":
unittest.main()
......@@ -649,8 +649,7 @@ class Optimizer:
else:
assert isinstance(self.helper, LayerHelper)
var_name = param.name + "_fp32_master"
var_name = unique_name.generate(var_name)
var_name = self._gen_master_weight_var_name(param)
var = paddle.static.create_global_var(
name=var_name,
shape=param.shape,
......@@ -671,6 +670,10 @@ class Optimizer:
self._master_weights[param.name] = var
return var
def _gen_master_weight_var_name(self, param):
var_name = param.name + "_fp32_master"
return unique_name.generate(var_name)
def _create_accumulators(self, block, parameters):
"""Create all accumulators needed by the parameters
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册