From 67ca9d45879f24bc974191dbb01b6d9c1069c833 Mon Sep 17 00:00:00 2001 From: Roc <30228238+sljlp@users.noreply.github.com> Date: Sat, 29 Oct 2022 22:58:04 +0800 Subject: [PATCH] [INCUBATE] Add dist save/load for sharding stage2 (#46908) --- .../unittests/collective/fleet/CMakeLists.txt | 9 + .../fleet/dygraph_dist_save_load.py | 419 ++++++++++++++++++ .../fleet/test_dygraph_dist_save_load.py | 38 ++ .../unittests/collective/fleet/testslist.csv | 3 +- .../incubate/distributed/utils/io/__init__.py | 16 + .../distributed/utils/io/dist_load.py | 120 +++++ .../distributed/utils/io/dist_save.py | 392 ++++++++++++++++ python/setup.py.in | 2 + 8 files changed, 998 insertions(+), 1 deletion(-) create mode 100644 python/paddle/fluid/tests/unittests/collective/fleet/dygraph_dist_save_load.py create mode 100644 python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_dist_save_load.py create mode 100644 python/paddle/incubate/distributed/utils/io/__init__.py create mode 100644 python/paddle/incubate/distributed/utils/io/dist_load.py create mode 100644 python/paddle/incubate/distributed/utils/io/dist_save.py diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/CMakeLists.txt b/python/paddle/fluid/tests/unittests/collective/fleet/CMakeLists.txt index b47e4b5b53..f853e96204 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/collective/fleet/CMakeLists.txt @@ -942,4 +942,13 @@ if(LOCAL_ALL_ARCH AND LOCAL_ALL_PLAT) py_test_modules( test_fleet_log MODULES test_fleet_log ENVS "http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python") + set_tests_properties(test_fleet_log PROPERTIES TIMEOUT "200" LABELS + "RUN_TYPE=DIST") +endif() +if((WITH_GPU) AND (LINUX)) + py_test_modules( + test_dygraph_dist_save_load MODULES test_dygraph_dist_save_load ENVS + "http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python") + set_tests_properties(test_dygraph_dist_save_load + PROPERTIES TIMEOUT "200" LABELS "RUN_TYPE=DIST") endif() diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_dist_save_load.py b/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_dist_save_load.py new file mode 100644 index 0000000000..0ade6b0cb7 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_dist_save_load.py @@ -0,0 +1,419 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import numpy as np +import tempfile +import paddle +import paddle.fluid as fluid +from paddle.fluid.dygraph.nn import Linear +from paddle.distributed import fleet +from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_optimizer_stage2 import ( + GroupShardedOptimizerStage2, +) +from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_stage2 import ( + GroupShardedStage2, +) + +import sys +import subprocess +import argparse + +from paddle import distributed as dist +from paddle.incubate.distributed.utils.io import save, load + +print(load) +epoch = 2 +linear_size = 1000 + + +class MLP(fluid.Layer): + def __init__(self, linear_size=1000, param_attr=None, bias_attr=None): + super(MLP, self).__init__() + + self._linear1 = Linear(linear_size, linear_size) + self._linear2 = Linear(linear_size, linear_size) + self._linear3 = Linear(linear_size, 10) + + def forward(self, inputs): + y = self._linear1(inputs) + y = self._linear2(y) + y = self._linear3(y) + return y + + +def reader_decorator(linear_size=1000): + def __reader__(): + for _ in range(100): + img = np.random.rand(linear_size).astype('float32') + label = np.ones(1).astype('int64') + yield img, label + + return __reader__ + + +def optimizer_setting(model, use_pure_fp16, opt_group=False): + clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0) + optimizer = paddle.optimizer.AdamW( + parameters=[ + { + "params": model.parameters(), + } + ] + if opt_group + else model.parameters(), + learning_rate=0.001, + weight_decay=0.00001, + grad_clip=clip, + multi_precision=use_pure_fp16, + ) + + return optimizer + + +def train_mlp( + model, + sharding_stage, + batch_size=100, + use_pure_fp16=False, + accumulate_grad=False, + opt_group=False, + save_model=False, + test_minimize=False, + opt_state=None, +): + if sharding_stage != "dp": + group = paddle.distributed.new_group([0, 1], backend="nccl") + if opt_group: + optimizer = optimizer_setting( + model=model, use_pure_fp16=use_pure_fp16, opt_group=opt_group + ) + else: + optimizer = optimizer_setting(model=model, use_pure_fp16=use_pure_fp16) + + if sharding_stage == 2: + optimizer = GroupShardedOptimizerStage2( + params=optimizer._parameter_list, optim=optimizer, group=group + ) + model = GroupShardedStage2( + model, optimizer, group=group, buffer_max_size=2**21 + ) + model._set_reduce_overlap(True) + optimizer._set_broadcast_overlap(True, model) + else: + model = paddle.DataParallel(model) + + # check optimizer.minimize() error + if test_minimize: + try: + optimizer.minimize() + except: + print( + "====== Find sharding_stage2_optimizer.minimize() error ======" + ) + return + + train_reader = paddle.batch( + reader_decorator(), batch_size=batch_size, drop_last=True + ) + + train_loader = paddle.io.DataLoader.from_generator( + capacity=32, + use_double_buffer=True, + iterable=True, + return_list=True, + use_multiprocess=True, + ) + train_loader.set_sample_list_generator(train_reader) + + if sharding_stage == 2: + model.to(device="gpu") + if opt_state is not None: + optimizer.set_state_dict(opt_state) + + for eop in range(epoch): + model.train() + + for batch_id, data in enumerate(train_loader()): + img, label = data + label.stop_gradient = True + img.stop_gradient = True + + out = model(img) + loss = paddle.nn.functional.cross_entropy(input=out, label=label) + + avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32)) + if batch_size == 20: + avg_loss = avg_loss / 5 + avg_loss.backward() + + if not accumulate_grad: + optimizer.step() + optimizer.clear_grad() + + if accumulate_grad: + optimizer.step() + optimizer.clear_grad() + + paddle.device.cuda.synchronize() + + if save_model: + return model, optimizer + return model.parameters() + + +def save_model(model, output_dir, **configs): + configs["save_model"] = True + model, opt = train_mlp(model, **configs) + + model_file = os.path.join( + output_dir, f"rank{dist.get_rank()}model.pdparams" + ) + opt_file = os.path.join(output_dir, f"rank{dist.get_rank()}model.pdopt") + + g_model_file = os.path.join( + output_dir, f"rank{dist.get_rank()}g_model.pdparams" + ) + g_opt_file = os.path.join(output_dir, f"rank{dist.get_rank()}g_model.pdopt") + + paddle.save(model.state_dict(), model_file) + paddle.save(opt.state_dict(), opt_file) + + save( + model.state_dict(), g_model_file, gather_to=[0, 1], state_type="params" + ) + save(opt.state_dict(), g_opt_file, gather_to=[0, 1], state_type="opt") + + +def load_mode(model, model_state_dict, output_param_path, **configs): + configs["save_model"] = False + model.set_state_dict(model_state_dict) + params = train_mlp(model, **configs) + paddle.save(params, output_param_path) + + +def step_check(path1, path2): + m1 = paddle.load(path1) + m2 = paddle.load(path2) + for v1, v2 in zip(m1, m2): + assert np.allclose(v1.numpy(), v2.numpy()) + print(f"value same: {v1.name}") + + +def step_save(strategy, output_dir, seed): + python_exe = sys.executable + # save data + os.makedirs(output_dir + "/logs", exist_ok=True) + filename = os.path.basename(__file__) + cmd = ( + f"{python_exe} -m paddle.distributed.launch --log_dir {output_dir}/logs" + f" --gpus 0,1 {filename} --cmd save --strategy {strategy} --output_dir {output_dir} --seed {seed}" + ) + p = subprocess.Popen(cmd.split()) + p.communicate() + assert p.poll() == 0 + + +def step_load( + saved_strategy, curent_strateggy, saved_dir, load_way, output_path, seed +): + python_exe = sys.executable + os.makedirs(f"{saved_dir}/load/logs", exist_ok=True) + filename = os.path.basename(__file__) + # load dp + cmd = ( + f"{python_exe} -m paddle.distributed.launch --log_dir {saved_dir}/load/logs" + f" --gpus 0,1 {filename} --cmd load --strategy {curent_strateggy} --output_dir {saved_dir} --load_dir {saved_dir}/{saved_strategy}/save --load_way {load_way}" + f" --output_param_path {output_path} --seed {seed}" + ) + p = subprocess.Popen(cmd.split()) + p.communicate() + assert p.poll() == 0 + + +def test_save_load(args): + + np.random.seed(args.seed) + paddle.seed(args.seed) + + if args.cmd == "main": + run_case(args) + return + + paddle.distributed.init_parallel_env() + strategy = fleet.DistributedStrategy() + if args.strategy == "dp": + strategy.hybrid_configs = { + "dp_degree": 2, + "mp_degree": 1, + "pp_degree": 1, + "sharding_degree": 1, + } + elif args.strategy == "sharding_stage2": + strategy.hybrid_configs = { + "dp_degree": 1, + "mp_degree": 1, + "pp_degree": 1, + "sharding_degree": 2, + } + else: + raise ValueError(f"Not supported strategy: {args.strategy}") + + fleet.init(is_collective=True, strategy=strategy) + fleet.set_log_level("DEBUG") + + mlp1 = MLP() + output_dir = os.path.join(args.output_dir, args.strategy, args.cmd) + os.makedirs(output_dir, exist_ok=True) + + if args.cmd.lower() == "save": + if args.strategy == "dp": + # DP VS stage2 + save_model( + mlp1, + output_dir, + sharding_stage="dp", + use_pure_fp16=False, + opt_group=False, + save_model=True, + ) + elif args.strategy == "sharding_stage2": + save_model( + mlp1, + output_dir, + sharding_stage=2, + use_pure_fp16=False, + opt_group=False, + save_model=True, + ) + else: + raise ValueError(f"Not supported {args.strategy}") + elif args.cmd.lower() == "load": + output_dir = args.load_dir + model_file = os.path.join( + output_dir, f"rank{dist.get_rank()}model.pdparams" + ) + opt_file = os.path.join(output_dir, f"rank{dist.get_rank()}model.pdopt") + g_model_file = os.path.join( + output_dir, f"rank{args.gather_to}g_model.pdparams" + ) + g_opt_file = os.path.join( + output_dir, f"rank{args.gather_to}g_model.pdopt" + ) + + if args.load_way == "load": + model_file = g_model_file + opt_file = g_opt_file + load_ = lambda x: eval(args.load_way)(x, place='cpu') + else: + load_ = eval(args.load_way) + + model = load_(model_file) + opt = load_(opt_file) + for k in opt.keys(): + print("opt k:", k) + if args.strategy == "dp": + load_mode( + mlp1, + model, + args.output_param_path, + sharding_stage="dp", + use_pure_fp16=False, + opt_group=False, + save_model=False, + opt_state=opt, + ) + elif args.strategy == "sharding_stage2": + load_mode( + mlp1, + model, + args.output_param_path, + sharding_stage=2, + use_pure_fp16=False, + opt_group=False, + save_model=False, + opt_state=opt, + ) + else: + raise ValueError(f"Not supported strategy {args.strategy}") + + else: + raise ValueError(f"Not supported cmd: {args.cmd}") + + +def run_case(args): + + saving_strategy = args.test_case.split(":")[0] + loading_strategy = args.test_case.split(":")[1] + + output_dir = tempfile.mkdtemp() + print("output dir:", output_dir) + os.makedirs(output_dir + "/load_save", exist_ok=True) + # save dp + step_save(saving_strategy, output_dir, args.seed) + # return + + # load dp + p1 = os.path.join(output_dir, "m1.pdparams") + p2 = os.path.join(output_dir, "m2.pdparams") + + step_load( + saving_strategy, + saving_strategy, + output_dir, + "paddle.load", + p1, + args.seed + 1, + ) + step_load( + saving_strategy, loading_strategy, output_dir, "load", p2, args.seed + 2 + ) + + # check + step_check(p1, p2) + + shutil.rmtree(output_dir) + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument( + "--cmd", default="main", choices=["main", "save", "load"] + ) + parser.add_argument( + "--strategy", required=False, choices=["dp", "sharding_stage2"] + ) + parser.add_argument( + "--load_way", choices=["paddle.load", "load"], required=False + ) + parser.add_argument("--load_dir", required=False) + parser.add_argument("--output_dir", required=False) + parser.add_argument("--output_param_path", required=False) + parser.add_argument( + "--test_case", + required=False, + choices=[ + "dp:dp", + "dp:sharding_stage2", + "sharding_stage2:dp", + "sharding_stage2:sharding_stage2", + ], + ) + parser.add_argument("--gather_to", required=False, default=0) + parser.add_argument("--seed", type=int, default=2022) + + args = parser.parse_args() + test_save_load(args) diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_dist_save_load.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_dist_save_load.py new file mode 100644 index 0000000000..18aac82f86 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_dist_save_load.py @@ -0,0 +1,38 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +import subprocess +import sys + + +def strategy_test(saving, loading, gather_to): + cmd = f"{sys.executable} dygraph_dist_save_load.py --test_case {saving}:{loading} --gather_to {gather_to}" + p = subprocess.Popen(cmd.split()) + p.communicate() + assert p.poll() == 0 + + +class TestDistSaveLoad(unittest.TestCase): + def test_dygraph_save_load_dp_sharding_stage2(self): + strategy_test("dp", "sharding_stage2", 0) + strategy_test("dp", "sharding_stage2", 1) + strategy_test("sharding_stage2", "dp", 1) + + +if __name__ == "__main__": + os.environ["FLAGS_enable_eager_mode"] = "1" + os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/testslist.csv b/python/paddle/fluid/tests/unittests/collective/fleet/testslist.csv index c7fa546322..15cfa81b51 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/testslist.csv +++ b/python/paddle/fluid/tests/unittests/collective/fleet/testslist.csv @@ -82,4 +82,5 @@ test_hdfs1,LINUX,,200,EXCLUSIVE:NIGHTLY,../../dist_test.sh,2,,http_proxy=;https_ test_hdfs2,LINUX,,200,EXCLUSIVE:NIGHTLY,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_hdfs3,LINUX,,200,EXCLUSIVE:NIGHTLY,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_fleet_checkpoint,LINUX,GPU;ROCM,200,EXCLUSIVE:NIGHTLY,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../.., -test_fleet_log,,,,DIST,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_fleet_log,,,200,DIST,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_dygraph_dist_save_load,LINUX,GPU,200,DIST,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../.., diff --git a/python/paddle/incubate/distributed/utils/io/__init__.py b/python/paddle/incubate/distributed/utils/io/__init__.py new file mode 100644 index 0000000000..7eacf695c7 --- /dev/null +++ b/python/paddle/incubate/distributed/utils/io/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .dist_save import save +from .dist_load import load diff --git a/python/paddle/incubate/distributed/utils/io/dist_load.py b/python/paddle/incubate/distributed/utils/io/dist_load.py new file mode 100644 index 0000000000..38907489c8 --- /dev/null +++ b/python/paddle/incubate/distributed/utils/io/dist_load.py @@ -0,0 +1,120 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.fluid.framework import dygraph_only +import paddle +import paddle.distributed as dist +from paddle.distributed import fleet +import re +import copy + + +@dygraph_only +def load(path, **configs): + """ + Load an object can be used in paddle from specified path. + The file is saved by distributed.save + + Note: + The file to load must be saved bu the API paddle.incubate.distributed.utils.io.save + + Args: + path(str|BytesIO) : The path/buffer to load the target object. Generally, the path is the target + file path. When loading state_dict from the saved result of the API used to save + the inference model, the path may be a file prefix or directory. + **configs (dict, optional): other load configuration options for compatibility. We do not + recommend using these configurations, they may be removed in the future. If not necessary, + DO NOT use them. Default None. + The following options are currently supported: + (1) place: where to place the loaded state dict. + If the state dict is too large, the palce should be set 'cpu'. + Note: + Other config value may cause some error.Please don't use any more config options. + Returns: + Object(Object): a target object can be used in paddle + + Examples: + import paddle + paddle.distributed.init_process_group(backend='nccl') + paddle.distributed.fleet.init(is_collective=True) + + model = build_model() + optimizer = build_optimizer(model) + + dist_model = paddle.distributed_optimizer(model) + dist_optimizer = paddle.distributed_optimizer(optimizer) + + + # load model state dict + model_state_dict = paddle.incubate.distributed.utils.io.load(path="path/to/load.pdparams") + dist_model.set_state_dict(model_state_dict) + + # load optimizer satte dict + optimizer_state_dict = paddle.incubate.distributed.utils.io.load(path="path/to/load.pdopt") + dist_optimizer.set_state_dict(optimizer_state_dict) + + """ + if dist.get_world_size() == 1: + return paddle.load(path, **configs) + + hcg = fleet.get_hybrid_communicate_group() + assert ( + hcg.get_model_parallel_world_size() == 1 + and hcg.get_pipe_parallel_world_size() == 1 + ), "Sharding and DP are supported only now" + + # assert ( + # "place" in configs + # ), "the arg place ('cpu' or 'gpu:0', 'gpus:1' ...)must be passed" + if "place" not in configs: + configs["place"] = "cpu" + place = configs["place"] + assert isinstance( + place, str + ), f"configs[place] must be a str, but this is a {type(place)}" + + assert re.search( + "^(cpu|gpu:[0-9]*)$", place + ), "configs[place] must be cpu, gpu:0, gpu:1 ..." + + return load_with_place(path, **configs) + + +def load_with_place(path, **configs): + place = configs["place"] + if place is None: + return paddle.load(path) + + origin_place = paddle.get_device() + paddle.set_device(place) + + configs = _remove_not_supported_itmes(configs) + state_dict = paddle.load(path, **configs) + + paddle.set_device(origin_place) + + return state_dict + + +def _remove_not_supported_itmes(configs): + __supported_by_load__ = [ + "model_filename", + "params_filename", + "return_numpy", + ] + _configs = copy.copy(configs) + for k in configs.keys(): + if k not in __supported_by_load__: + _configs.pop(k, None) + return _configs diff --git a/python/paddle/incubate/distributed/utils/io/dist_save.py b/python/paddle/incubate/distributed/utils/io/dist_save.py new file mode 100644 index 0000000000..363f54bcc6 --- /dev/null +++ b/python/paddle/incubate/distributed/utils/io/dist_save.py @@ -0,0 +1,392 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.distributed as dist +import paddle.distributed.fleet as fleet +import re +import paddle +from paddle.distributed.fleet.utils.log_util import logger +from paddle.fluid.framework import dygraph_only +import copy +import sys + +from paddle.distributed.fleet.utils.log_util import logger + +__all__ = ["save"] + + +@dygraph_only +def save(state_dict, path, **configs): + ''' + Save a state dict to the specified path in both distributed and single-card environment. + + Note: + Now supports saving ``state_dict`` of Layer/Optimizer, Tensor and nested structure containing Tensor, Program. + + Note: + Different from ``paddle.jit.save``, since the save result of ``paddle.save`` is a single file, + there is no need to distinguish multiple saved files by adding a suffix. The argument ``path`` + of ``paddle.save`` will be directly used as the saved file name instead of a prefix. + In order to unify the saved file name format, we recommend using the paddle standard suffix: + 1. for ``Layer.state_dict`` , recommend to use ``.pdparams`` ; + 2. for ``Optimizer.state_dict`` , recommend to use ``.pdopt`` . + For specific examples, please refer to API code examples. + + Args: + obj(Object) : The object to be saved. + path(str|BytesIO) : The path/buffer of the object to be saved. + If saved in the current directory, the input path string will be used as the file name. + protocol(int, optional): The protocol version of pickle module must be greater than 1 and less than 5. + Default: 4 + **configs(dict, optional): optional keyword arguments. The following options are currently supported: + (1)use_binary_format(bool): + To be used in paddle.save. When the saved object is static graph variable, you can specify ``use_binary_for_var``. + If True, save the file in the c++ binary format when saving a single static graph variable; otherwise, save it in pickle format. + Default: False + (2)gather_to(int|list|tuple|None): + To specify which global rank to save in.Defalut is None. + None value means distributed saving with no gathering to a single card. + (3)state_type(str): + Value can be 'params' or 'opt', specifying to save parametres or optimizer state. + (4)max_grouped_size(str|int): + To limit the max size(how many bits) a object group to be transfered a time. + If str, the format must be as num+'G/M/K', for example, 3G, 2K, 10M, etc. Default is 3G. + Returns: + None + Examples: + import paddle + paddle.distributed.init_process_group(backend='nccl') + paddle.distributed.fleet.init(is_collective=True) + + model = build_model() + optimizer = build_optimizer(model) + + dist_optimizer = paddle.distributed_optimizer(optimizer) + dist_model = paddle.distributed_optimizer(model) + + # gather params to rank 0 and then save + paddle.incubate.distributed.utils.io.save(model.state_dict(), path="path/to/save.pdparams", gather_to=[0], state_type="params") + + # save whoe params on all ranks + paddle.incubate.distributed.utils.io.save(model.state_dict(), path="path/to/save.pdparams", gather_to=[0,1], state_type="params") + + # save optimizer state dict on rank 0 + paddle.incubate.distributed.utils.io.save(optimizer.state_dict(), path="path/to/save.pdopt", gather=0, state_type="opt") + + ''' + + gather_to = configs.get("gather_to", None) + if dist.get_world_size() == 1 or gather_to is None: + configs = _remove_not_supported_conf(configs) + return paddle.save(state_dict, path, **configs) + + # gather_to is not None and world size > 1 + state_type = configs.get("state_type", None) + assert isinstance( + state_type, str + ), "must pass an arg state_type='params' or state_type='opt' to specify whether to save model state_dict or optimizer state_dict" + assert state_type in [ + "params", + "opt", + ], "must pass an arg state_type='params' or state_type='opt'" + + if re.search(f"{state_type}$", path) is None: + logger.warning( + f"You are saving {state_type}, while the path({path} does not end with {state_type})" + ) + + hcg = fleet.get_hybrid_communicate_group() + assert ( + hcg.get_model_parallel_world_size() == 1 + and hcg.get_pipe_parallel_world_size() == 1 + ), f"Only DP and Sharding is supported now. However, current MP={hcg.get_model_parallel_world_size()} , PP={hcg.get_pipe_parallel_world_size()}" + + sharding_group = hcg.get_sharding_parallel_group() + dp_group = hcg.get_data_parallel_group() + + if state_type == "params": + if dp_group.nranks > 1: + assert _same_keys( + state_dict, dp_group + ), "only sharding stage 1/2 and DP are supported now" + if sharding_group.nranks > 1: + assert _same_keys( + state_dict, sharding_group + ), "only sharding stage 1/2 and DP are supported now" + configs = _remove_not_supported_conf(configs) + return paddle.save(state_dict, path, **configs) + + # state_type == "opt" + if sharding_group.nranks == 1: + configs = _remove_not_supported_conf(configs) + return paddle.save(state_dict, path, **configs) + if _same_keys(state_dict, sharding_group): + return paddle.save(state_dict, path, **configs) + assert isinstance(gather_to, (list, tuple, int)) + if isinstance(gather_to, int): + gather_to = [gather_to] + max_size = configs.get("max_grouped_size", "3G") + try: + logger.info("state_dict_keys:" + str(state_dict.keys())) + gathered_state_dict = _gather_state_dict( + state_dict, gather_to, sharding_group, max_size=max_size + ) + logger.info("gathered_state_dict_keys:" + str(state_dict.keys())) + if dist.get_rank() in gather_to: + configs = _remove_not_supported_conf(configs) + paddle.save(gathered_state_dict, path, **configs) + except: + raise RuntimeError( + f'''Saving failed. Follwing are some suggestions: + 1) pass the param max_grouped_size to turn the grouped size smaller (current value of max_grouped_size is {max_size}) + 2) if sharding stage is 1, use paddle.save rather than paddle.distributed.save + 3) Concat the developers +''' + ) + + +def _state_dict_groups(state_dict, max_size): + """ + Description: + Generator of state dict groups to transfer.the size of each group is less than max_size. + """ + + # find the max size of a whole tensor + # now we only support to transfer at least one whole tensor + max_tensor_size = 0 + for k, v in state_dict.items(): + if max_tensor_size < sys.getsizeof(v) + sys.getsizeof(k): + max_tensor_size = sys.getsizeof(v) + sys.getsizeof(k) + + max_size = max(max_size, max_tensor_size) + logger.debug(f"max tensor size: {max_size}") + + state_group = dict() + k_list = list(state_dict.keys()) + index = 0 + bits = 0 + + # generate groups utils the end + while index < len(k_list): + bsize = sys.getsizeof(state_dict[k_list[index]]) + sys.getsizeof( + k_list[index] + ) + if bits + bsize >= max_size: + yield state_group + state_group = dict() + bits = 0 + + state_group[k_list[index]] = state_dict[k_list[index]] + index += 1 + bits += bsize + + if index == len(k_list) and bits > 0: + yield state_group + + +def all_empty(dict_list): + """ + Check if all items are empty + """ + for v in dict_list: + if len(v) > 0: + return False + return True + + +def _parse_mem_size_to_bits(max_size): + """ + Parse an integer or a mem size str to an integer + convert xxxG to xxx * 1024^3 + convert xxxM to xxx * 1024^2 + convert xxxK to xxx * 1024^1 + """ + assert isinstance(max_size, (int, str)) + if isinstance(max_size, str): + assert re.search( + "^[0-9]*[GMK]$", max_size + ), f"Wrong max_size 's format, the format ust be like 10K, 9M, 200G , etc, or an integer. However this is {max_size}" + num = int(max_size[:-1]) + if max_size[-1] == "G": + max_size = num * 1024**3 + elif max_size[-1] == "M": + max_size = num * 1024**2 + else: + max_size = num * 1024 + return max_size + + +def _gather_state_dict(state_dict, dst, group, max_size="3G"): + """ + Description: + Gather state dicts across all group ranks to dst, Depiring the same elements. including LR_Scheduler. + Args: + state_dict(dict): + local state dict + dst(int|list|tuple): + ranks the state dicts are gathered to + group(ProcessGroup): + group across which the state dicts are gathered + max_size(int|str): + The max limitation of the gathered tensor group size transformered a time. Default is 3G bits. + Each rank 's max tensor group before gathering is max_size // group.size + Returns: + Gathered state dict + """ + assert isinstance( + dst, (list, tuple, int) + ), "dst' type must be one of int, list and tuple" + if isinstance(dst, int): + dst = [dst] + + max_size = _parse_mem_size_to_bits(max_size) + max_size //= dist.get_world_size(group) + + logger.debug("len state_dict: len(state_dict)") + + state_dict_ = copy.copy(state_dict) + mw = None + has_mw = False + has_lr = False + + # Remove master_weights and LR_Scheduler to ensure that all the elements of the state dict are str->Tensor + if "master_weights" in state_dict_: + mw = state_dict_.pop("master_weights", None) + has_mw = True + if "LR_Scheduler" in state_dict_: + lr = state_dict_.pop("LR_Scheduler", None) + has_lr = True + + # Gather optimizer state_dict + output = _grouped_gather_data_dict(state_dict_, dst, group, max_size) + + # Gather master_weights if it exists + if isinstance(mw, dict): + masters = _grouped_gather_data_dict(mw, dst, group, max_size) + else: + assert mw is None, f"Wrong type of master weights . type: {type(mw)}" + + # assign master_weights and LR_Scheduler + # Because LR_Schedulers are same across group, it just needs to be reset + if has_mw: + output["master_weights"] = masters + if has_lr: + output["LR_Scheduler"] = lr + return output + + +def _grouped_gather_data_dict(state_data_dict, dst, group, max_size): + """ + Description: + Gather state data dict by groups. + Args: + state__data_dict(dict): + local dict to transfer.The state_data_dict only contains the mapping: str->paddle.Tensor + dst(int|list|tuple): + ranks the state dicts are gathered to + group(ProcessGroup): + group across which the state dicts are gathered + max_size(int|str): + The max limitation of the gathered tensor group size transformered a time. Default is 3G bits. + Each rank 's max tensor group before gathering is max_size // group.size + Returns: + Gatherd state_data_dict + + """ + numpy_dict = {} + logger.debug(f"len state_tict_ : {len(state_data_dict)}") + + for k, v in state_data_dict.items(): + try: + numpy_dict[k] = v.numpy() + except: + raise TypeError( + f"the object (type of {type(v)}) of '{k}' is neither tensor nor parameter" + ) + + total = 0 + output_state = dict() + + logger.info("start all gather ...") + # gather all state_dict by groups + for state in _state_dict_groups(numpy_dict, max_size): + s_list = [] + total += len(state) + logger.info(f"gen to gather: {total} / {len(numpy_dict)}") + dist.all_gather_object(s_list, state, group) + if dist.get_rank() in dst: + for s in s_list: + for k, v in s.items(): + logger.debug(f"gathered: {k}, {v.shape}") + output_state.update(s) + + logger.debug( + f"s list size: {sum(len(s) for s in s_list)} output: {len(output_state)}" + ) + + # Because each size of groups may be different, here we should wait all objetcs gatherd. + # The while block breaks until all objects from every rank are empty, which means all of the objects transforming is done. + while True: + s_list = [] + state = {} + logger.debug("while True") + dist.all_gather_object(s_list, state, group) + if all_empty(s_list): + break + if dist.get_rank() in dst: + for s in s_list: + for k, v in s.items(): + logger.debug(f"gathered: {k}, {v.shape}") + output_state.update(s) + logger.debug( + f"s list size: {sum(len(s) for s in s_list)} output: {len(output_state)}" + ) + + logger.debug("all gathered ...") + + if dist.get_rank() in dst: + # convert numpy.ndarray to Tensor in cpu palce + place = paddle.CPUPlace() + for k in output_state.keys(): + output_state[k] = paddle.to_tensor(output_state[k], place=place) + output_state[k].name = k + return output_state + return {} + + +def _same_keys(state_dict, group): + """ + Check whther all keys in each dict in the group are the same. + Used in sharding strategy to determine whether a dict needs to be gathered. + """ + keys = list(state_dict.keys()) + key_list = [] + logger.info(keys) + dist.all_gather_object(key_list, keys, group=group) + for k in key_list: + if not k == keys: + return False + return True + + +def _remove_not_supported_conf(configs): + """ + Remove the config values not supported by paddle.save + """ + __supported_by_save__ = ["use_binary_format"] + configs_ = copy.copy(configs) + for k in configs.keys(): + if k not in __supported_by_save__: + configs_.pop(k, None) + return configs_ diff --git a/python/setup.py.in b/python/setup.py.in index 92b75cd067..76daa99e4b 100755 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -383,6 +383,8 @@ packages=['paddle', 'paddle.incubate.optimizer.functional', 'paddle.incubate.autograd', 'paddle.incubate.distributed', + 'paddle.incubate.distributed.utils', + 'paddle.incubate.distributed.utils.io', 'paddle.incubate.distributed.fleet', 'paddle.incubate.distributed.models', 'paddle.incubate.distributed.models.moe', -- GitLab