未验证 提交 67ca9d45 编写于 作者: R Roc 提交者: GitHub

[INCUBATE] Add dist save/load for sharding stage2 (#46908)

上级 c036c5c0
......@@ -942,4 +942,13 @@ if(LOCAL_ALL_ARCH AND LOCAL_ALL_PLAT)
py_test_modules(
test_fleet_log MODULES test_fleet_log ENVS
"http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_fleet_log PROPERTIES TIMEOUT "200" LABELS
"RUN_TYPE=DIST")
endif()
if((WITH_GPU) AND (LINUX))
py_test_modules(
test_dygraph_dist_save_load MODULES test_dygraph_dist_save_load ENVS
"http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_dygraph_dist_save_load
PROPERTIES TIMEOUT "200" LABELS "RUN_TYPE=DIST")
endif()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import numpy as np
import tempfile
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Linear
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_optimizer_stage2 import (
GroupShardedOptimizerStage2,
)
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_stage2 import (
GroupShardedStage2,
)
import sys
import subprocess
import argparse
from paddle import distributed as dist
from paddle.incubate.distributed.utils.io import save, load
print(load)
epoch = 2
linear_size = 1000
class MLP(fluid.Layer):
def __init__(self, linear_size=1000, param_attr=None, bias_attr=None):
super(MLP, self).__init__()
self._linear1 = Linear(linear_size, linear_size)
self._linear2 = Linear(linear_size, linear_size)
self._linear3 = Linear(linear_size, 10)
def forward(self, inputs):
y = self._linear1(inputs)
y = self._linear2(y)
y = self._linear3(y)
return y
def reader_decorator(linear_size=1000):
def __reader__():
for _ in range(100):
img = np.random.rand(linear_size).astype('float32')
label = np.ones(1).astype('int64')
yield img, label
return __reader__
def optimizer_setting(model, use_pure_fp16, opt_group=False):
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
optimizer = paddle.optimizer.AdamW(
parameters=[
{
"params": model.parameters(),
}
]
if opt_group
else model.parameters(),
learning_rate=0.001,
weight_decay=0.00001,
grad_clip=clip,
multi_precision=use_pure_fp16,
)
return optimizer
def train_mlp(
model,
sharding_stage,
batch_size=100,
use_pure_fp16=False,
accumulate_grad=False,
opt_group=False,
save_model=False,
test_minimize=False,
opt_state=None,
):
if sharding_stage != "dp":
group = paddle.distributed.new_group([0, 1], backend="nccl")
if opt_group:
optimizer = optimizer_setting(
model=model, use_pure_fp16=use_pure_fp16, opt_group=opt_group
)
else:
optimizer = optimizer_setting(model=model, use_pure_fp16=use_pure_fp16)
if sharding_stage == 2:
optimizer = GroupShardedOptimizerStage2(
params=optimizer._parameter_list, optim=optimizer, group=group
)
model = GroupShardedStage2(
model, optimizer, group=group, buffer_max_size=2**21
)
model._set_reduce_overlap(True)
optimizer._set_broadcast_overlap(True, model)
else:
model = paddle.DataParallel(model)
# check optimizer.minimize() error
if test_minimize:
try:
optimizer.minimize()
except:
print(
"====== Find sharding_stage2_optimizer.minimize() error ======"
)
return
train_reader = paddle.batch(
reader_decorator(), batch_size=batch_size, drop_last=True
)
train_loader = paddle.io.DataLoader.from_generator(
capacity=32,
use_double_buffer=True,
iterable=True,
return_list=True,
use_multiprocess=True,
)
train_loader.set_sample_list_generator(train_reader)
if sharding_stage == 2:
model.to(device="gpu")
if opt_state is not None:
optimizer.set_state_dict(opt_state)
for eop in range(epoch):
model.train()
for batch_id, data in enumerate(train_loader()):
img, label = data
label.stop_gradient = True
img.stop_gradient = True
out = model(img)
loss = paddle.nn.functional.cross_entropy(input=out, label=label)
avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32))
if batch_size == 20:
avg_loss = avg_loss / 5
avg_loss.backward()
if not accumulate_grad:
optimizer.step()
optimizer.clear_grad()
if accumulate_grad:
optimizer.step()
optimizer.clear_grad()
paddle.device.cuda.synchronize()
if save_model:
return model, optimizer
return model.parameters()
def save_model(model, output_dir, **configs):
configs["save_model"] = True
model, opt = train_mlp(model, **configs)
model_file = os.path.join(
output_dir, f"rank{dist.get_rank()}model.pdparams"
)
opt_file = os.path.join(output_dir, f"rank{dist.get_rank()}model.pdopt")
g_model_file = os.path.join(
output_dir, f"rank{dist.get_rank()}g_model.pdparams"
)
g_opt_file = os.path.join(output_dir, f"rank{dist.get_rank()}g_model.pdopt")
paddle.save(model.state_dict(), model_file)
paddle.save(opt.state_dict(), opt_file)
save(
model.state_dict(), g_model_file, gather_to=[0, 1], state_type="params"
)
save(opt.state_dict(), g_opt_file, gather_to=[0, 1], state_type="opt")
def load_mode(model, model_state_dict, output_param_path, **configs):
configs["save_model"] = False
model.set_state_dict(model_state_dict)
params = train_mlp(model, **configs)
paddle.save(params, output_param_path)
def step_check(path1, path2):
m1 = paddle.load(path1)
m2 = paddle.load(path2)
for v1, v2 in zip(m1, m2):
assert np.allclose(v1.numpy(), v2.numpy())
print(f"value same: {v1.name}")
def step_save(strategy, output_dir, seed):
python_exe = sys.executable
# save data
os.makedirs(output_dir + "/logs", exist_ok=True)
filename = os.path.basename(__file__)
cmd = (
f"{python_exe} -m paddle.distributed.launch --log_dir {output_dir}/logs"
f" --gpus 0,1 {filename} --cmd save --strategy {strategy} --output_dir {output_dir} --seed {seed}"
)
p = subprocess.Popen(cmd.split())
p.communicate()
assert p.poll() == 0
def step_load(
saved_strategy, curent_strateggy, saved_dir, load_way, output_path, seed
):
python_exe = sys.executable
os.makedirs(f"{saved_dir}/load/logs", exist_ok=True)
filename = os.path.basename(__file__)
# load dp
cmd = (
f"{python_exe} -m paddle.distributed.launch --log_dir {saved_dir}/load/logs"
f" --gpus 0,1 {filename} --cmd load --strategy {curent_strateggy} --output_dir {saved_dir} --load_dir {saved_dir}/{saved_strategy}/save --load_way {load_way}"
f" --output_param_path {output_path} --seed {seed}"
)
p = subprocess.Popen(cmd.split())
p.communicate()
assert p.poll() == 0
def test_save_load(args):
np.random.seed(args.seed)
paddle.seed(args.seed)
if args.cmd == "main":
run_case(args)
return
paddle.distributed.init_parallel_env()
strategy = fleet.DistributedStrategy()
if args.strategy == "dp":
strategy.hybrid_configs = {
"dp_degree": 2,
"mp_degree": 1,
"pp_degree": 1,
"sharding_degree": 1,
}
elif args.strategy == "sharding_stage2":
strategy.hybrid_configs = {
"dp_degree": 1,
"mp_degree": 1,
"pp_degree": 1,
"sharding_degree": 2,
}
else:
raise ValueError(f"Not supported strategy: {args.strategy}")
fleet.init(is_collective=True, strategy=strategy)
fleet.set_log_level("DEBUG")
mlp1 = MLP()
output_dir = os.path.join(args.output_dir, args.strategy, args.cmd)
os.makedirs(output_dir, exist_ok=True)
if args.cmd.lower() == "save":
if args.strategy == "dp":
# DP VS stage2
save_model(
mlp1,
output_dir,
sharding_stage="dp",
use_pure_fp16=False,
opt_group=False,
save_model=True,
)
elif args.strategy == "sharding_stage2":
save_model(
mlp1,
output_dir,
sharding_stage=2,
use_pure_fp16=False,
opt_group=False,
save_model=True,
)
else:
raise ValueError(f"Not supported {args.strategy}")
elif args.cmd.lower() == "load":
output_dir = args.load_dir
model_file = os.path.join(
output_dir, f"rank{dist.get_rank()}model.pdparams"
)
opt_file = os.path.join(output_dir, f"rank{dist.get_rank()}model.pdopt")
g_model_file = os.path.join(
output_dir, f"rank{args.gather_to}g_model.pdparams"
)
g_opt_file = os.path.join(
output_dir, f"rank{args.gather_to}g_model.pdopt"
)
if args.load_way == "load":
model_file = g_model_file
opt_file = g_opt_file
load_ = lambda x: eval(args.load_way)(x, place='cpu')
else:
load_ = eval(args.load_way)
model = load_(model_file)
opt = load_(opt_file)
for k in opt.keys():
print("opt k:", k)
if args.strategy == "dp":
load_mode(
mlp1,
model,
args.output_param_path,
sharding_stage="dp",
use_pure_fp16=False,
opt_group=False,
save_model=False,
opt_state=opt,
)
elif args.strategy == "sharding_stage2":
load_mode(
mlp1,
model,
args.output_param_path,
sharding_stage=2,
use_pure_fp16=False,
opt_group=False,
save_model=False,
opt_state=opt,
)
else:
raise ValueError(f"Not supported strategy {args.strategy}")
else:
raise ValueError(f"Not supported cmd: {args.cmd}")
def run_case(args):
saving_strategy = args.test_case.split(":")[0]
loading_strategy = args.test_case.split(":")[1]
output_dir = tempfile.mkdtemp()
print("output dir:", output_dir)
os.makedirs(output_dir + "/load_save", exist_ok=True)
# save dp
step_save(saving_strategy, output_dir, args.seed)
# return
# load dp
p1 = os.path.join(output_dir, "m1.pdparams")
p2 = os.path.join(output_dir, "m2.pdparams")
step_load(
saving_strategy,
saving_strategy,
output_dir,
"paddle.load",
p1,
args.seed + 1,
)
step_load(
saving_strategy, loading_strategy, output_dir, "load", p2, args.seed + 2
)
# check
step_check(p1, p2)
shutil.rmtree(output_dir)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
"--cmd", default="main", choices=["main", "save", "load"]
)
parser.add_argument(
"--strategy", required=False, choices=["dp", "sharding_stage2"]
)
parser.add_argument(
"--load_way", choices=["paddle.load", "load"], required=False
)
parser.add_argument("--load_dir", required=False)
parser.add_argument("--output_dir", required=False)
parser.add_argument("--output_param_path", required=False)
parser.add_argument(
"--test_case",
required=False,
choices=[
"dp:dp",
"dp:sharding_stage2",
"sharding_stage2:dp",
"sharding_stage2:sharding_stage2",
],
)
parser.add_argument("--gather_to", required=False, default=0)
parser.add_argument("--seed", type=int, default=2022)
args = parser.parse_args()
test_save_load(args)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import subprocess
import sys
def strategy_test(saving, loading, gather_to):
cmd = f"{sys.executable} dygraph_dist_save_load.py --test_case {saving}:{loading} --gather_to {gather_to}"
p = subprocess.Popen(cmd.split())
p.communicate()
assert p.poll() == 0
class TestDistSaveLoad(unittest.TestCase):
def test_dygraph_save_load_dp_sharding_stage2(self):
strategy_test("dp", "sharding_stage2", 0)
strategy_test("dp", "sharding_stage2", 1)
strategy_test("sharding_stage2", "dp", 1)
if __name__ == "__main__":
os.environ["FLAGS_enable_eager_mode"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
unittest.main()
......@@ -82,4 +82,5 @@ test_hdfs1,LINUX,,200,EXCLUSIVE:NIGHTLY,../../dist_test.sh,2,,http_proxy=;https_
test_hdfs2,LINUX,,200,EXCLUSIVE:NIGHTLY,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_hdfs3,LINUX,,200,EXCLUSIVE:NIGHTLY,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_fleet_checkpoint,LINUX,GPU;ROCM,200,EXCLUSIVE:NIGHTLY,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_fleet_log,,,,DIST,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_fleet_log,,,200,DIST,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_dygraph_dist_save_load,LINUX,GPU,200,DIST,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../..,
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .dist_save import save
from .dist_load import load
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid.framework import dygraph_only
import paddle
import paddle.distributed as dist
from paddle.distributed import fleet
import re
import copy
@dygraph_only
def load(path, **configs):
"""
Load an object can be used in paddle from specified path.
The file is saved by distributed.save
Note:
The file to load must be saved bu the API paddle.incubate.distributed.utils.io.save
Args:
path(str|BytesIO) : The path/buffer to load the target object. Generally, the path is the target
file path. When loading state_dict from the saved result of the API used to save
the inference model, the path may be a file prefix or directory.
**configs (dict, optional): other load configuration options for compatibility. We do not
recommend using these configurations, they may be removed in the future. If not necessary,
DO NOT use them. Default None.
The following options are currently supported:
(1) place: where to place the loaded state dict.
If the state dict is too large, the palce should be set 'cpu'.
Note:
Other config value may cause some error.Please don't use any more config options.
Returns:
Object(Object): a target object can be used in paddle
Examples:
import paddle
paddle.distributed.init_process_group(backend='nccl')
paddle.distributed.fleet.init(is_collective=True)
model = build_model()
optimizer = build_optimizer(model)
dist_model = paddle.distributed_optimizer(model)
dist_optimizer = paddle.distributed_optimizer(optimizer)
# load model state dict
model_state_dict = paddle.incubate.distributed.utils.io.load(path="path/to/load.pdparams")
dist_model.set_state_dict(model_state_dict)
# load optimizer satte dict
optimizer_state_dict = paddle.incubate.distributed.utils.io.load(path="path/to/load.pdopt")
dist_optimizer.set_state_dict(optimizer_state_dict)
"""
if dist.get_world_size() == 1:
return paddle.load(path, **configs)
hcg = fleet.get_hybrid_communicate_group()
assert (
hcg.get_model_parallel_world_size() == 1
and hcg.get_pipe_parallel_world_size() == 1
), "Sharding and DP are supported only now"
# assert (
# "place" in configs
# ), "the arg place ('cpu' or 'gpu:0', 'gpus:1' ...)must be passed"
if "place" not in configs:
configs["place"] = "cpu"
place = configs["place"]
assert isinstance(
place, str
), f"configs[place] must be a str, but this is a {type(place)}"
assert re.search(
"^(cpu|gpu:[0-9]*)$", place
), "configs[place] must be cpu, gpu:0, gpu:1 ..."
return load_with_place(path, **configs)
def load_with_place(path, **configs):
place = configs["place"]
if place is None:
return paddle.load(path)
origin_place = paddle.get_device()
paddle.set_device(place)
configs = _remove_not_supported_itmes(configs)
state_dict = paddle.load(path, **configs)
paddle.set_device(origin_place)
return state_dict
def _remove_not_supported_itmes(configs):
__supported_by_load__ = [
"model_filename",
"params_filename",
"return_numpy",
]
_configs = copy.copy(configs)
for k in configs.keys():
if k not in __supported_by_load__:
_configs.pop(k, None)
return _configs
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.distributed as dist
import paddle.distributed.fleet as fleet
import re
import paddle
from paddle.distributed.fleet.utils.log_util import logger
from paddle.fluid.framework import dygraph_only
import copy
import sys
from paddle.distributed.fleet.utils.log_util import logger
__all__ = ["save"]
@dygraph_only
def save(state_dict, path, **configs):
'''
Save a state dict to the specified path in both distributed and single-card environment.
Note:
Now supports saving ``state_dict`` of Layer/Optimizer, Tensor and nested structure containing Tensor, Program.
Note:
Different from ``paddle.jit.save``, since the save result of ``paddle.save`` is a single file,
there is no need to distinguish multiple saved files by adding a suffix. The argument ``path``
of ``paddle.save`` will be directly used as the saved file name instead of a prefix.
In order to unify the saved file name format, we recommend using the paddle standard suffix:
1. for ``Layer.state_dict`` , recommend to use ``.pdparams`` ;
2. for ``Optimizer.state_dict`` , recommend to use ``.pdopt`` .
For specific examples, please refer to API code examples.
Args:
obj(Object) : The object to be saved.
path(str|BytesIO) : The path/buffer of the object to be saved.
If saved in the current directory, the input path string will be used as the file name.
protocol(int, optional): The protocol version of pickle module must be greater than 1 and less than 5.
Default: 4
**configs(dict, optional): optional keyword arguments. The following options are currently supported:
(1)use_binary_format(bool):
To be used in paddle.save. When the saved object is static graph variable, you can specify ``use_binary_for_var``.
If True, save the file in the c++ binary format when saving a single static graph variable; otherwise, save it in pickle format.
Default: False
(2)gather_to(int|list|tuple|None):
To specify which global rank to save in.Defalut is None.
None value means distributed saving with no gathering to a single card.
(3)state_type(str):
Value can be 'params' or 'opt', specifying to save parametres or optimizer state.
(4)max_grouped_size(str|int):
To limit the max size(how many bits) a object group to be transfered a time.
If str, the format must be as num+'G/M/K', for example, 3G, 2K, 10M, etc. Default is 3G.
Returns:
None
Examples:
import paddle
paddle.distributed.init_process_group(backend='nccl')
paddle.distributed.fleet.init(is_collective=True)
model = build_model()
optimizer = build_optimizer(model)
dist_optimizer = paddle.distributed_optimizer(optimizer)
dist_model = paddle.distributed_optimizer(model)
# gather params to rank 0 and then save
paddle.incubate.distributed.utils.io.save(model.state_dict(), path="path/to/save.pdparams", gather_to=[0], state_type="params")
# save whoe params on all ranks
paddle.incubate.distributed.utils.io.save(model.state_dict(), path="path/to/save.pdparams", gather_to=[0,1], state_type="params")
# save optimizer state dict on rank 0
paddle.incubate.distributed.utils.io.save(optimizer.state_dict(), path="path/to/save.pdopt", gather=0, state_type="opt")
'''
gather_to = configs.get("gather_to", None)
if dist.get_world_size() == 1 or gather_to is None:
configs = _remove_not_supported_conf(configs)
return paddle.save(state_dict, path, **configs)
# gather_to is not None and world size > 1
state_type = configs.get("state_type", None)
assert isinstance(
state_type, str
), "must pass an arg state_type='params' or state_type='opt' to specify whether to save model state_dict or optimizer state_dict"
assert state_type in [
"params",
"opt",
], "must pass an arg state_type='params' or state_type='opt'"
if re.search(f"{state_type}$", path) is None:
logger.warning(
f"You are saving {state_type}, while the path({path} does not end with {state_type})"
)
hcg = fleet.get_hybrid_communicate_group()
assert (
hcg.get_model_parallel_world_size() == 1
and hcg.get_pipe_parallel_world_size() == 1
), f"Only DP and Sharding is supported now. However, current MP={hcg.get_model_parallel_world_size()} , PP={hcg.get_pipe_parallel_world_size()}"
sharding_group = hcg.get_sharding_parallel_group()
dp_group = hcg.get_data_parallel_group()
if state_type == "params":
if dp_group.nranks > 1:
assert _same_keys(
state_dict, dp_group
), "only sharding stage 1/2 and DP are supported now"
if sharding_group.nranks > 1:
assert _same_keys(
state_dict, sharding_group
), "only sharding stage 1/2 and DP are supported now"
configs = _remove_not_supported_conf(configs)
return paddle.save(state_dict, path, **configs)
# state_type == "opt"
if sharding_group.nranks == 1:
configs = _remove_not_supported_conf(configs)
return paddle.save(state_dict, path, **configs)
if _same_keys(state_dict, sharding_group):
return paddle.save(state_dict, path, **configs)
assert isinstance(gather_to, (list, tuple, int))
if isinstance(gather_to, int):
gather_to = [gather_to]
max_size = configs.get("max_grouped_size", "3G")
try:
logger.info("state_dict_keys:" + str(state_dict.keys()))
gathered_state_dict = _gather_state_dict(
state_dict, gather_to, sharding_group, max_size=max_size
)
logger.info("gathered_state_dict_keys:" + str(state_dict.keys()))
if dist.get_rank() in gather_to:
configs = _remove_not_supported_conf(configs)
paddle.save(gathered_state_dict, path, **configs)
except:
raise RuntimeError(
f'''Saving failed. Follwing are some suggestions:
1) pass the param max_grouped_size to turn the grouped size smaller (current value of max_grouped_size is {max_size})
2) if sharding stage is 1, use paddle.save rather than paddle.distributed.save
3) Concat the developers
'''
)
def _state_dict_groups(state_dict, max_size):
"""
Description:
Generator of state dict groups to transfer.the size of each group is less than max_size.
"""
# find the max size of a whole tensor
# now we only support to transfer at least one whole tensor
max_tensor_size = 0
for k, v in state_dict.items():
if max_tensor_size < sys.getsizeof(v) + sys.getsizeof(k):
max_tensor_size = sys.getsizeof(v) + sys.getsizeof(k)
max_size = max(max_size, max_tensor_size)
logger.debug(f"max tensor size: {max_size}")
state_group = dict()
k_list = list(state_dict.keys())
index = 0
bits = 0
# generate groups utils the end
while index < len(k_list):
bsize = sys.getsizeof(state_dict[k_list[index]]) + sys.getsizeof(
k_list[index]
)
if bits + bsize >= max_size:
yield state_group
state_group = dict()
bits = 0
state_group[k_list[index]] = state_dict[k_list[index]]
index += 1
bits += bsize
if index == len(k_list) and bits > 0:
yield state_group
def all_empty(dict_list):
"""
Check if all items are empty
"""
for v in dict_list:
if len(v) > 0:
return False
return True
def _parse_mem_size_to_bits(max_size):
"""
Parse an integer or a mem size str to an integer
convert xxxG to xxx * 1024^3
convert xxxM to xxx * 1024^2
convert xxxK to xxx * 1024^1
"""
assert isinstance(max_size, (int, str))
if isinstance(max_size, str):
assert re.search(
"^[0-9]*[GMK]$", max_size
), f"Wrong max_size 's format, the format ust be like 10K, 9M, 200G , etc, or an integer. However this is {max_size}"
num = int(max_size[:-1])
if max_size[-1] == "G":
max_size = num * 1024**3
elif max_size[-1] == "M":
max_size = num * 1024**2
else:
max_size = num * 1024
return max_size
def _gather_state_dict(state_dict, dst, group, max_size="3G"):
"""
Description:
Gather state dicts across all group ranks to dst, Depiring the same elements. including LR_Scheduler.
Args:
state_dict(dict):
local state dict
dst(int|list|tuple):
ranks the state dicts are gathered to
group(ProcessGroup):
group across which the state dicts are gathered
max_size(int|str):
The max limitation of the gathered tensor group size transformered a time. Default is 3G bits.
Each rank 's max tensor group before gathering is max_size // group.size
Returns:
Gathered state dict
"""
assert isinstance(
dst, (list, tuple, int)
), "dst' type must be one of int, list and tuple"
if isinstance(dst, int):
dst = [dst]
max_size = _parse_mem_size_to_bits(max_size)
max_size //= dist.get_world_size(group)
logger.debug("len state_dict: len(state_dict)")
state_dict_ = copy.copy(state_dict)
mw = None
has_mw = False
has_lr = False
# Remove master_weights and LR_Scheduler to ensure that all the elements of the state dict are str->Tensor
if "master_weights" in state_dict_:
mw = state_dict_.pop("master_weights", None)
has_mw = True
if "LR_Scheduler" in state_dict_:
lr = state_dict_.pop("LR_Scheduler", None)
has_lr = True
# Gather optimizer state_dict
output = _grouped_gather_data_dict(state_dict_, dst, group, max_size)
# Gather master_weights if it exists
if isinstance(mw, dict):
masters = _grouped_gather_data_dict(mw, dst, group, max_size)
else:
assert mw is None, f"Wrong type of master weights . type: {type(mw)}"
# assign master_weights and LR_Scheduler
# Because LR_Schedulers are same across group, it just needs to be reset
if has_mw:
output["master_weights"] = masters
if has_lr:
output["LR_Scheduler"] = lr
return output
def _grouped_gather_data_dict(state_data_dict, dst, group, max_size):
"""
Description:
Gather state data dict by groups.
Args:
state__data_dict(dict):
local dict to transfer.The state_data_dict only contains the mapping: str->paddle.Tensor
dst(int|list|tuple):
ranks the state dicts are gathered to
group(ProcessGroup):
group across which the state dicts are gathered
max_size(int|str):
The max limitation of the gathered tensor group size transformered a time. Default is 3G bits.
Each rank 's max tensor group before gathering is max_size // group.size
Returns:
Gatherd state_data_dict
"""
numpy_dict = {}
logger.debug(f"len state_tict_ : {len(state_data_dict)}")
for k, v in state_data_dict.items():
try:
numpy_dict[k] = v.numpy()
except:
raise TypeError(
f"the object (type of {type(v)}) of '{k}' is neither tensor nor parameter"
)
total = 0
output_state = dict()
logger.info("start all gather ...")
# gather all state_dict by groups
for state in _state_dict_groups(numpy_dict, max_size):
s_list = []
total += len(state)
logger.info(f"gen to gather: {total} / {len(numpy_dict)}")
dist.all_gather_object(s_list, state, group)
if dist.get_rank() in dst:
for s in s_list:
for k, v in s.items():
logger.debug(f"gathered: {k}, {v.shape}")
output_state.update(s)
logger.debug(
f"s list size: {sum(len(s) for s in s_list)} output: {len(output_state)}"
)
# Because each size of groups may be different, here we should wait all objetcs gatherd.
# The while block breaks until all objects from every rank are empty, which means all of the objects transforming is done.
while True:
s_list = []
state = {}
logger.debug("while True")
dist.all_gather_object(s_list, state, group)
if all_empty(s_list):
break
if dist.get_rank() in dst:
for s in s_list:
for k, v in s.items():
logger.debug(f"gathered: {k}, {v.shape}")
output_state.update(s)
logger.debug(
f"s list size: {sum(len(s) for s in s_list)} output: {len(output_state)}"
)
logger.debug("all gathered ...")
if dist.get_rank() in dst:
# convert numpy.ndarray to Tensor in cpu palce
place = paddle.CPUPlace()
for k in output_state.keys():
output_state[k] = paddle.to_tensor(output_state[k], place=place)
output_state[k].name = k
return output_state
return {}
def _same_keys(state_dict, group):
"""
Check whther all keys in each dict in the group are the same.
Used in sharding strategy to determine whether a dict needs to be gathered.
"""
keys = list(state_dict.keys())
key_list = []
logger.info(keys)
dist.all_gather_object(key_list, keys, group=group)
for k in key_list:
if not k == keys:
return False
return True
def _remove_not_supported_conf(configs):
"""
Remove the config values not supported by paddle.save
"""
__supported_by_save__ = ["use_binary_format"]
configs_ = copy.copy(configs)
for k in configs.keys():
if k not in __supported_by_save__:
configs_.pop(k, None)
return configs_
......@@ -383,6 +383,8 @@ packages=['paddle',
'paddle.incubate.optimizer.functional',
'paddle.incubate.autograd',
'paddle.incubate.distributed',
'paddle.incubate.distributed.utils',
'paddle.incubate.distributed.utils.io',
'paddle.incubate.distributed.fleet',
'paddle.incubate.distributed.models',
'paddle.incubate.distributed.models.moe',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册