From 0449b8412f56c8a7e4b9c79390a46eae59c52799 Mon Sep 17 00:00:00 2001 From: Roc <30228238+sljlp@users.noreply.github.com> Date: Thu, 3 Nov 2022 19:26:39 +0800 Subject: [PATCH] Save dygraph model for auto inference (#47463) --- .../distributed/fleet/layers/mpu/mp_layers.py | 9 + python/paddle/fluid/dygraph/layers.py | 12 + .../unittests/collective/fleet/CMakeLists.txt | 8 + .../fleet/dygraph_save_for_auto_infer.py | 472 ++++++++++++++++++ .../fleet/test_dygraph_save_for_auto_infer.py | 49 ++ .../unittests/collective/fleet/testslist.csv | 1 + .../incubate/distributed/utils/io/__init__.py | 2 +- .../distributed/utils/io/dist_save.py | 4 +- .../distributed/utils/io/save_for_auto.py | 357 +++++++++++++ 9 files changed, 911 insertions(+), 3 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/collective/fleet/dygraph_save_for_auto_infer.py create mode 100644 python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_save_for_auto_infer.py create mode 100644 python/paddle/incubate/distributed/utils/io/save_for_auto.py diff --git a/python/paddle/distributed/fleet/layers/mpu/mp_layers.py b/python/paddle/distributed/fleet/layers/mpu/mp_layers.py index 38eb8da785..8224d2a7b9 100644 --- a/python/paddle/distributed/fleet/layers/mpu/mp_layers.py +++ b/python/paddle/distributed/fleet/layers/mpu/mp_layers.py @@ -144,6 +144,8 @@ class VocabParallelEmbedding(Layer): ) self.weight.is_distributed = True if self.is_mp else False + if self.weight.is_distributed: + setattr(self.weight, "split_axis", 0) def forward(self, x): if self.is_mp: @@ -276,6 +278,9 @@ class ColumnParallelLinear(Layer): self.weight.is_distributed = True if self.is_mp else False + if self.weight.is_distributed: + setattr(self.weight, "split_axis", 1) + if has_bias: # initialize bias to zero like Megatron self.bias = self.create_parameter( @@ -285,6 +290,8 @@ class ColumnParallelLinear(Layer): is_bias=True, ) self.bias.is_distributed = True if self.is_mp else False + if self.bias.is_distributed: + setattr(self.bias, "split_axis", 0) else: self.bias = None @@ -437,6 +444,8 @@ class RowParallelLinear(Layer): ) self.weight.is_distributed = True if self.is_mp else False + if self.weight.is_distributed: + setattr(self.weight, "split_axis", 0) if has_bias: self.bias = self.create_parameter( diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index ef824d6a2e..5187f9ae72 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -62,6 +62,17 @@ _first_cap_re = re.compile('(.)([A-Z][a-z]+)') _all_cap_re = re.compile('([a-z])([A-Z])') +def _scope_dist2single(dist_scope): + mapping = { + "row_parallel_linear": "linear", + "column_parallel_linear": "linear", + "vocab_parallel_embedding": "embedding", + # "parallel_cross_entropy": "cross_entropy", while mp_layer has parallel_cross_entropy, + # but there is no parameters so the mapping of parallel_cross_entropy is not neccessary. + } + return mapping.get(dist_scope, dist_scope) + + def _convert_camel_to_snake(name): s1 = _first_cap_re.sub(r'\1_\2', name) return _all_cap_re.sub(r'\1_\2', s1).lower() @@ -137,6 +148,7 @@ class Layer(object): self.training = True if name_scope is None: name_scope = _convert_camel_to_snake(self.__class__.__name__) + name_scope = _scope_dist2single(name_scope) self._full_name = unique_name.generate(name_scope) self._helper = LayerObjectHelper(self._full_name) self._built = False diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/CMakeLists.txt b/python/paddle/fluid/tests/unittests/collective/fleet/CMakeLists.txt index f853e96204..a0252356b6 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/collective/fleet/CMakeLists.txt @@ -952,3 +952,11 @@ if((WITH_GPU) AND (LINUX)) set_tests_properties(test_dygraph_dist_save_load PROPERTIES TIMEOUT "200" LABELS "RUN_TYPE=DIST") endif() +if((WITH_GPU) AND (LINUX)) + py_test_modules( + test_dygraph_save_for_auto_infer MODULES test_dygraph_save_for_auto_infer + ENVS + "http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python") + set_tests_properties(test_dygraph_save_for_auto_infer + PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST") +endif() diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_save_for_auto_infer.py b/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_save_for_auto_infer.py new file mode 100644 index 0000000000..04a0038a6a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_save_for_auto_infer.py @@ -0,0 +1,472 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import numpy as np +import tempfile +import paddle +import paddle.fluid as fluid +from paddle.fluid.dygraph.nn import Linear, Embedding +from paddle.distributed import fleet +from paddle.distributed.fleet.layers.mpu.mp_layers import ( + RowParallelLinear, + ColumnParallelLinear, + VocabParallelEmbedding, +) + +from paddle.distributed.auto_parallel import engine + +from paddle.distributed.sharding.group_sharded import group_sharded_parallel +from paddle.distributed.fleet.meta_parallel.parallel_layers.pp_layers import ( + PipelineLayer, + LayerDesc, +) + +import sys +import subprocess +import argparse +import copy +from paddle import distributed as dist + +from paddle.distributed.utils.log_utils import get_logger + +from paddle.fluid.dataloader.dataset import IterableDataset + +from paddle.incubate.distributed.utils.io import save_for_auto_inference + +logger = get_logger("INFO", __file__) + + +epoch = 2 +linear_size = 1000 + + +class MLP_pipe(PipelineLayer): + def __init__( + self, + embedding_size=1000, + linear_size=1000, + param_attr=None, + bias_attr=None, + ): + desc = [ + LayerDesc( + VocabParallelEmbedding, + num_embeddings=embedding_size, + embedding_dim=linear_size, + ), + LayerDesc( + RowParallelLinear, + in_features=linear_size, + out_features=linear_size, + has_bias=True, + ), + LayerDesc( + ColumnParallelLinear, + in_features=linear_size, + out_features=linear_size, + gather_output=True, + has_bias=True, + ), + LayerDesc(Linear, input_dim=linear_size, output_dim=10), + ] + super(MLP_pipe, self).__init__( + desc, + num_stages=2, + loss_fn=paddle.nn.CrossEntropyLoss(), + topology=fleet.get_hybrid_communicate_group()._topo, + ) + + +class MLP_Hybrid(fluid.Layer): + def __init__( + self, + embedding_size=1000, + linear_size=1000, + param_attr=None, + bias_attr=None, + ): + super(MLP_Hybrid, self).__init__() + self.embedding = VocabParallelEmbedding(embedding_size, linear_size) + self._linear1 = RowParallelLinear( + linear_size, linear_size, has_bias=True, input_is_parallel=True + ) + self._linear2 = ColumnParallelLinear( + linear_size, linear_size, gather_output=True, has_bias=True + ) + self._linear3 = Linear(linear_size, 10) + + def forward(self, src): + inputs = self.embedding(src) + # slice for a bug in row parallel linear + mp_group = ( + fleet.get_hybrid_communicate_group().get_model_parallel_group() + ) + step = inputs.shape[-1] // mp_group.nranks + mp_rank = dist.get_rank(mp_group) + mp_rank = mp_rank if mp_rank >= 0 else 0 + inputs = inputs[..., step * mp_rank : step * mp_rank + step] + + y = self._linear1(inputs) + y = self._linear2(y) + y = self._linear3(y) + return y + + +class MLP(fluid.Layer): + def __init__( + self, + embedding_size=1000, + linear_size=1000, + param_attr=None, + bias_attr=None, + ): + super(MLP, self).__init__() + self.embedding = Embedding((embedding_size, linear_size)) + self._linear1 = Linear(linear_size, linear_size) + self._linear2 = Linear(linear_size, linear_size) + self._linear3 = Linear(linear_size, 10) + + def forward(self, src): + inputs = self.embedding(src) + y = self._linear1(inputs) + y = self._linear2(y) + y = self._linear3(y) + return y + + +def gen_uniq_random_numbers(low, high, size, seed): + assert np.prod(size) <= high - low + pool = list(range(low, high)) + data = np.zeros(size).astype("int32").reshape(-1) + np.random.seed(10245) + for i in range(np.prod(size)): + pos = int(np.random.randint(0, len(pool))) + data[i] = pool[pos] + pool.remove(pool[pos]) + np.random.seed(seed) + return data.reshape(size) + + +class RangeIterableDataset(IterableDataset): + def __init__( + self, data_path, ebd=1000, start=0, end=100, linear_size=1000, seed=1024 + ): + self.start = start + self.end = end + self.img = gen_uniq_random_numbers(0, 1000, (100, 1), seed) + + def __iter__(self): + for idx in range(self.start, self.end): + label = np.ones(1).astype('int32') + yield self.img[idx], label + + +def optimizer_setting(args, model): + optimizer = paddle.optimizer.SGD( + learning_rate=0.0 if args.strategy == "static" else 0.01, + parameters=model.parameters(), + weight_decay=0.01, + ) + + return optimizer + + +def train_mlp(args, model, loss, opt_state=None, save_model=False): + optimizer = optimizer_setting(args, model=model) + + if args.strategy in ["mp", "dp", "pp"]: + model = fleet.distributed_model(model) + optimizer = fleet.distributed_optimizer(optimizer) + elif args.strategy == "sharding_stage2": + model, optimizer, _ = wrap_sharding_2_3( + model, optimizer, None, False, 2 + ) + elif args.strategy == "sharding_stage3": + model, optimizer, _ = wrap_sharding_2_3( + model, optimizer, None, False, 3 + ) + elif args.strategy != "single": + raise ValueError(f"not supported strategy: {args.strategy}") + + dataset = RangeIterableDataset( + data_path=os.path.join(args.output_dir, "data.npy"), seed=args.seed + ) + + train_loader = paddle.io.DataLoader(dataset, batch_size=100, drop_last=True) + + if dist.get_world_size() > 1: + pp_degree = ( + fleet.get_hybrid_communicate_group().get_pipe_parallel_world_size() + ) + else: + pp_degree = 0 + + model.train() + for epo in range(epoch): + for step, data in enumerate(train_loader()): + img, label = data + label.stop_gradient = True + img.stop_gradient = True + if pp_degree <= 1: + out = model(img) + avg_loss = loss(out, label) + paddle.device.cuda.synchronize() + avg_loss.backward() + optimizer.step() + else: + avg_loss = model.train_batch(data, optimizer) + + model.eval() + print("=============== predict in dygraph mode =================") + for step, data in enumerate(train_loader()): + img, label = data + if pp_degree <= 1: + out = model(img) + out = out.numpy() + else: + out = model.eval_batch(data) + out = np.array(out) + + paddle.device.cuda.synchronize() + if save_model: + return model, optimizer, out + + return None + + +def train_mlp_static(args, model, loss, opt_state=None, save_model=False): + optimizer = optimizer_setting(args, model=model) + model = engine.Engine(model, loss=loss, optimizer=optimizer, strategy=None) + + dataset = RangeIterableDataset( + data_path=os.path.join(args.output_dir, "data.npy"), seed=args.seed + ) + model.load(os.path.join(args.load_dir, "saved"), load_optimizer=False) + model.fit(dataset, epochs=1) + model.save(os.path.join(args.output_dir, "static_save")) + paddle.device.cuda.synchronize() + print("=============== predict in static mode =================") + out = model.predict(dataset, verbose=1000) + + if save_model: + return model, optimizer + return out + + +def step_check(output_dir): + p1 = os.path.join(output_dir, "static.npy") + p2 = os.path.join(output_dir, "dygraph.npy") + m1 = np.load(p1).reshape(-1) + m2 = np.load(p2).reshape(-1) + try: + assert np.allclose(m1, m2, rtol=1e-5, atol=1e-6) + except: + diff = m1 - m2 + logger.error(f"max diff{diff.max()}, min diff: {diff.min()}") + logger.error(f"{m1[:10]}") + logger.error(f"{m2[:10]}") + raise ValueError("diff is too large") + + +def step_save(strategy, output_dir, seed): + python_exe = sys.executable + # save data + os.makedirs(output_dir + "/logs", exist_ok=True) + filename = os.path.basename(__file__) + if strategy != "single": + cmd = ( + f"{python_exe} -m paddle.distributed.launch --log_dir {output_dir}/logs" + f" --gpus 0,1 {filename} --cmd save --strategy {strategy} --output_dir {output_dir} --seed {seed}" + ) + else: + cmd = f"{python_exe} {filename} --cmd save --strategy {strategy} --output_dir {output_dir} --seed {seed}" + + logger.info(f"exe: {cmd}") + p = subprocess.Popen(cmd.split()) + p.communicate() + assert p.poll() == 0 + + +def step_load(curent_strateggy, saved_dir, seed): + python_exe = sys.executable + os.makedirs(f"{saved_dir}/load/logs", exist_ok=True) + filename = os.path.basename(__file__) + # load dp + cmd = ( + f"{python_exe} -m paddle.distributed.launch --log_dir {saved_dir}/load/logs" + f" --gpus 0 {filename} --cmd load --strategy {curent_strateggy} --output_dir {saved_dir} --load_dir {saved_dir} --seed {seed}" + ) + logger.info(f"exe: {cmd}") + env = copy.copy(os.environ) + env["CUDA_VISIBLE_DEVICES"] = "0" + p = subprocess.Popen(cmd.split(), env=env) + p.communicate() + assert p.poll() == 0 + + +def wrap_sharding_2_3(model, optimizer, scaler, sharding_offload, stage): + group = fleet.get_hybrid_communicate_group().get_sharding_parallel_group() + level = "p_g_os" if stage == 3 else "os_g" + return group_sharded_parallel( + model=model, + optimizer=optimizer, + level=level, + scaler=scaler, + group=group, + offload=sharding_offload, + ) + + +def test_save_load(args): + + np.random.seed(args.seed) + paddle.seed(args.seed) + + if args.cmd == "main": + run_case(args) + return + + paddle.distributed.init_parallel_env() + strategy = fleet.DistributedStrategy() + if args.strategy == "dp": + strategy.hybrid_configs = { + "dp_degree": 2, + "mp_degree": 1, + "pp_degree": 1, + "sharding_degree": 1, + } + elif args.strategy in ["sharding_stage2", "sharding_stage3"]: + strategy.hybrid_configs = { + "dp_degree": 1, + "mp_degree": 1, + "pp_degree": 1, + "sharding_degree": 2, + } + elif args.strategy == "mp": + strategy.hybrid_configs = { + "dp_degree": 1, + "mp_degree": 2, + "pp_degree": 1, + "sharding_degree": 1, + } + elif args.strategy == "pp": + strategy.hybrid_configs = { + "dp_degree": 1, + "mp_degree": 1, + "pp_degree": 2, + "sharding_degree": 1, + } + strategy.pipeline_configs = { + "accumulate_steps": 10, + "micro_batch_size": 10, + } + elif args.strategy == "static": + paddle.enable_static() + elif args.strategy != "single": + raise ValueError(f"Not supported strategy: {args.strategy}") + + loss = paddle.nn.CrossEntropyLoss() + + fleet.set_log_level("INFO") + if dist.get_world_size() <= 1: + mlp1 = MLP() + if args.strategy == "static": + out_static = train_mlp_static(args, mlp1, loss, save_model=False) + np.save(os.path.join(args.output_dir, "static.npy"), out_static) + else: + model, _, out_dygraph = train_mlp(args, mlp1, loss, save_model=True) + np.save(os.path.join(args.output_dir, "dygraph.npy"), out_dygraph) + else: + fleet.init(is_collective=True, strategy=strategy) + pp_group = ( + fleet.get_hybrid_communicate_group().get_pipe_parallel_group() + ) + if pp_group.nranks > 1: + mlp1 = MLP_pipe() + else: + mlp1 = MLP_Hybrid() + model, _, out_dygraph = train_mlp(args, mlp1, loss, save_model=True) + if ( + dist.get_world_size() == 0 + or dist.get_rank() == dist.get_world_size() - 1 + ): + np.save(os.path.join(args.output_dir, "dygraph.npy"), out_dygraph) + + if args.cmd == "save": + save_for_auto_inference(os.path.join(args.output_dir, "saved"), model) + + +def run_case(args): + + saving_strategy = args.test_case.split(":")[0] + loading_strategy = args.test_case.split(":")[1] + + output_dir = tempfile.mkdtemp() + if os.path.isdir(output_dir): + shutil.rmtree(output_dir) + os.makedirs(output_dir, exist_ok=True) + try: + step_save(saving_strategy, output_dir, args.seed) + step_load(loading_strategy, output_dir, args.seed + 1) + step_check(output_dir) + except Exception as e: + shutil.rmtree(output_dir) + raise RuntimeError(f"Test failed.\n {e.__str__()}") + shutil.rmtree(output_dir) + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument( + "--cmd", default="main", choices=["main", "save", "load"] + ) + parser.add_argument( + "--strategy", + required=False, + choices=[ + "single", + "dp", + "mp", + "pp", + "sharding_stage2", + "sharding_stage3", + "static", + ], + ) + parser.add_argument( + "--load_way", choices=["paddle.load", "load"], required=False + ) + parser.add_argument("--load_dir", required=False) + parser.add_argument("--output_dir", required=False) + parser.add_argument("--output_param_path", required=False) + parser.add_argument( + "--test_case", + required=False, + choices=[ + "dp:static", + "mp:static", + "pp:static", + "sharding_stage2:static", + "sharding_stage3:static", + "single:static", + ], + ) + parser.add_argument("--gather_to", required=False, default=0) + parser.add_argument("--seed", type=int, default=2022) + + args = parser.parse_args() + test_save_load(args) diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_save_for_auto_infer.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_save_for_auto_infer.py new file mode 100644 index 0000000000..db0a3ed92a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_save_for_auto_infer.py @@ -0,0 +1,49 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +import subprocess +import sys + + +def strategy_test(saving, seed=1024, loading="static"): + cmd = f"{sys.executable} dygraph_save_for_auto_infer.py --test_case {saving}:{loading} --cmd main --seed {seed}" + p = subprocess.Popen(cmd.split()) + p.communicate() + assert p.poll() == 0 + + +class TestHybrid(unittest.TestCase): + def test_dygraph_save_load_dp_sharding_stage2(self): + strategy_test("dp") + strategy_test("mp") + strategy_test("pp") + + +class TestSharding(unittest.TestCase): + def test_dygraph_save_load_dp_sharding_stage2(self): + strategy_test("sharding_stage2") + strategy_test("sharding_stage3") + + +class TestSingleCard(unittest.TestCase): + def test_dygraph_save_load_dp_sharding_stage2(self): + strategy_test("single") + + +if __name__ == "__main__": + os.environ["FLAGS_enable_eager_mode"] = "1" + os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/testslist.csv b/python/paddle/fluid/tests/unittests/collective/fleet/testslist.csv index 14284a1205..4dfe1c35e8 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/testslist.csv +++ b/python/paddle/fluid/tests/unittests/collective/fleet/testslist.csv @@ -84,3 +84,4 @@ test_hdfs3,LINUX,,200,EXCLUSIVE:NIGHTLY,../../dist_test.sh,2,,http_proxy=;https_ test_fleet_checkpoint,LINUX,GPU;ROCM,200,EXCLUSIVE:NIGHTLY,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_fleet_log,,,200,DIST,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_dygraph_dist_save_load,LINUX,GPU,200,DIST,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_dygraph_save_for_auto_infer,LINUX,GPU,300,DIST,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../.., diff --git a/python/paddle/incubate/distributed/utils/io/__init__.py b/python/paddle/incubate/distributed/utils/io/__init__.py index 7eacf695c7..de970a1339 100644 --- a/python/paddle/incubate/distributed/utils/io/__init__.py +++ b/python/paddle/incubate/distributed/utils/io/__init__.py @@ -12,5 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .dist_save import save +from .dist_save import save, save_for_auto_inference from .dist_load import load diff --git a/python/paddle/incubate/distributed/utils/io/dist_save.py b/python/paddle/incubate/distributed/utils/io/dist_save.py index 363f54bcc6..2244aa974f 100644 --- a/python/paddle/incubate/distributed/utils/io/dist_save.py +++ b/python/paddle/incubate/distributed/utils/io/dist_save.py @@ -20,10 +20,10 @@ from paddle.distributed.fleet.utils.log_util import logger from paddle.fluid.framework import dygraph_only import copy import sys - +from .save_for_auto import save_for_auto_inference from paddle.distributed.fleet.utils.log_util import logger -__all__ = ["save"] +__all__ = ["save", "save_for_auto_inference"] @dygraph_only diff --git a/python/paddle/incubate/distributed/utils/io/save_for_auto.py b/python/paddle/incubate/distributed/utils/io/save_for_auto.py new file mode 100644 index 0000000000..30b1ac0c9b --- /dev/null +++ b/python/paddle/incubate/distributed/utils/io/save_for_auto.py @@ -0,0 +1,357 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.distributed as dist +import paddle.distributed.fleet as fleet +import re +import paddle +from paddle.distributed.fleet.utils.log_util import logger +import os +import pickle +from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_stage3 import ( + GroupShardedStage3, +) +from paddle.fluid.framework import dygraph_only +import copy + +import numpy as np + +__all__ = ["save_for_auto_inference"] + + +@dygraph_only +def save_for_auto_inference(path_prefix, dist_model, cvt2cpu=False): + """ + Description: + Save model parameters for auto parallel inference. + Supporting dp + mp + pp + sharding(stage1), dp + sharding stage2-3. + MoE not sdupported till MoE is supported in auto parallel mode. + Args: + path_prefix: path prefix to save + If `path_preifx` ends with path sepreator, + the path is processed as a directory and parameters will be saved in it, + automatically named saved_parameters. + Otherwisw, the parameters will be saved with name + path_preifx_dist{global_rank}.pdparams and path_preifx_dist{global_rank}.pdattrs + + dist_model: + model in distributed modeß + cvt2cpu: wheather to move parameters to CPU when using sharding stage 3. + The var is invalid if not using sharding stage 3. + Returns: + None + Examples: + dist_model = build_distributed_model() + + path_prefix = "path/to/save_infer" + + save_for_auto_inference(path_prefix, dist_model=dist_model, original_model=single_model, cvt2cpu=False) + + Outputs: + path/to/save_infer_dist0.pdparams path/to/save_infer_dist1.pdparams path/to/save_infer_dist2.pdparams ... + path/to/save_infer_dist0.pdattr path/to/save_infer_dist1.pdattr path/to/save_infer_dist2.pdattr ... + + """ + + save_dir, basename_prefix = _get_abs_saved_prefix(path_prefix) + + if isinstance(dist_model, GroupShardedStage3): + dist_model.get_all_parameters(cvt2cpu) + + wrapped_dict = _get_wrapped_dist_state_dict(dist_model.state_dict()) + global_rank = paddle.distributed.get_rank() + + # save parameters + paddle.save( + wrapped_dict, + os.path.join(save_dir, f"{basename_prefix}_dist{global_rank}.pdparams"), + ) + + # save attributes + _save_param_attr( + wrapped_dict, + os.path.join(save_dir, f"{basename_prefix}_dist{global_rank}.pdattr"), + ) + + # unset dims mapping after saving attrs + for _, dist_param in wrapped_dict.items(): + _unset_dims_mapping(dist_param) + + +def _is_first_used(param): + return not hasattr(param, "is_firstly_shared") or param.is_firstly_shared + + +def _get_all_ranks_of_pp(pp_rank, dp_degree, mp_degree, pp_degree): + """ + Description: + get all global ranks involving given pp_rank + """ + + process_group = [] + + world_size = dp_degree * mp_degree * pp_degree + + for i in range(dp_degree): + for k in range(mp_degree): + process_group.append( + i * world_size // dp_degree + + pp_rank * world_size // dp_degree // pp_degree + + k + ) + return process_group + + +def _save_param_attr(state_dict_, path, dims_mapping_dict=None): + """ + Description: + save params' attr dict + Args: + state_dict_: + state for which to save attrs, when the state is optimzier state, the master and LRScheduler will be reomoved. + path: + path to save + dims_mapping_dict: + Dims mapping dict, mapping from parameter name in state_dict_ to dims_mapping. + If parameter in state_dict_ has attribute 'dims_mapping', the dims_mapping is ignored. + If parameter has no attribute 'dims_mapping', the dims mapping must contains the parameter's name. + """ + state_dict = copy.copy(state_dict_) + + # remove master_weights and LRScheduler, which needs no parameter attributes to save + state_dict.pop("master_weights", None) + state_dict.pop("LR_Scheduler", None) + + if dims_mapping_dict is not None: + assert isinstance( + dims_mapping_dict, dict + ), "dims_mapping_dict must be an instance of dict" + for k in state_dict.keys(): + assert ( + k in dims_mapping_dict + ), f"param {k} cannot find dims mapping in dims_mapping_dict" + if dist.get_world_size() > 1: + hcg = fleet.get_hybrid_communicate_group() + dp_degree = hcg.get_data_parallel_world_size() + mp_degree = hcg.get_model_parallel_world_size() + pp_degree = hcg.get_pipe_parallel_world_size() + sharding_degree = hcg.get_sharding_parallel_world_size() + dp_degree = dp_degree * sharding_degree + + pp_group = hcg.get_pipe_parallel_group() + else: + pp_degree = 1 + dp_degree = 1 + mp_degree = 1 + pp_group = None + hcg = None + + logger.debug(f"dp degree * sharding degree : {dp_degree}") + logger.debug(f"mp degree: {mp_degree}") + logger.debug(f"pp degree: {pp_degree}") + + pp_rank = dist.get_rank(pp_group) + + # Why condition 'pp_rank < 0' exists? + # Because if pp_degree = 1, pp_rank is set -1 + pp_rank = 0 if pp_rank <= 0 else pp_rank + + if dist.get_world_size() > 1: + process_group = _get_all_ranks_of_pp( + pp_rank, dp_degree, mp_degree, pp_degree + ) + else: + process_group = [0] + + attr_dict = {} + for k, v in state_dict.items(): + dims = len(v.shape) + logger.debug(f"shape: , {k}, {dims}") + attr_d = { + "process_shape": [dp_degree, mp_degree] if hcg else [1], + "process_group": process_group, + "dims_mapping": v.dims_mapping + if hasattr(v, "dims_mapping") + else [-1 for _ in v.shape], + } + attr_dict[k] = attr_d + + with open(path, "wb") as f: + pickle.dump(attr_dict, f) + + +def _unset_dims_mapping(param): + if hasattr(param, "dims_mapping"): + delattr(param, "dims_mapping") + + +def _get_dims_mapping(dist_parameter, mp_group): + """ + Description: + return the sliting mapping: + {tensor_name: spiting_strategy} + Args: + dist_parameters(list): distributed model parameters + mp_group(ProcessGroup): Model Parallel communication group + Return: + The sliting mapping + Examples: + spliting_strategy's format (-1, -1, -1, 0), meaing the dims + of the tennsor is 4 and it is splited along the first strategy axis in mesh + + Mesh Examples: (2, 4) means dp=2, mp=4 + + """ + + import numpy as np + + dist_shape = np.array(dist_parameter.shape) + if hasattr(dist_parameter, "split_axis"): + aixs = getattr(dist_parameter, "split_axis") + mapping = [-1 for _ in dist_shape] + mapping[aixs] = 1 + logger.debug( + f"{dist_parameter.name} has attr split_axis: mapping: {mapping}" + ) + else: + mapping = [-1 for _ in dist_shape] + logger.debug(f"normal parameter: {dist_parameter.name}") + return mapping + + +def _get_abs_saved_prefix(path_prefix): + """ + Description: + Get absolute dir path and basename prefix of path_prefix, with making path_prefix's directories. + If path_prefix is a directory name, basename is set 'saved_parameters'. + If path_prefix is a file name, basename is extracted from path_prefix. + Args: + path_prefix: str + Return: + (dirpath: str, basename: str) + """ + abs_prefix = os.path.abspath(path_prefix) + if abs_prefix[-1] == os.path.sep: + save_dir = abs_prefix + basename_prefix = "saved_parameters" + else: + save_dir = os.path.dirname(abs_prefix) + basename_prefix = os.path.basename(abs_prefix) + os.makedirs(save_dir, exist_ok=True) + return save_dir, basename_prefix + + +def _name_mapping_dist2single(state_dict, pp_group): + + key_list = [] + param_keys = [ + v.name + for _, v in state_dict.items() + if isinstance(v, paddle.Tensor) and _is_first_used(v) + ] + + if pp_group.nranks == 1: + return {k: k for k in param_keys} + + dist.all_gather_object(key_list, param_keys, pp_group) + + # find how many a op in a each pp: + # {"linear:"[0, 2,0,1,1,...]} + param_types = {} + + matcher = re.compile(r"^\w+_\d+(?=\.)") + + for pp, keys in enumerate(key_list): + param_type_idx = {} + for k in keys: + matched = matcher.search(k) + logger.debug(f"matched: {k}: {matched}") + assert ( + matched is not None + ), f"the name of param, '{k}', is not satisfyied the format 'name_idx.xxx'" + name_idx = k[matched.start() : matched.end()] + logger.debug(f"get param_type_idx: {name_idx}") + + if name_idx in param_type_idx: + continue + + name = "_".join(name_idx.split("_")[:-1]) + idx = int(name_idx.split("_")[-1]) + param_type_idx.update({name_idx: (name, idx)}) + if name not in param_types: + param_types[name] = [0] * pp_group.nranks + param_types[name][pp] += 1 + + # check if continous + types_idx = {} + for _, v in param_type_idx.items(): + if v[0] not in types_idx: + types_idx.update({v[0]: [v[1]]}) + else: + types_idx[v[0]].append(v[1]) + for k, v in types_idx.items(): + assert v == list( + range(v[0], v[-1] + 1) + ), f"{k} is not continous: {v}" + + logger.debug(f"param type: {param_types}") + + # analyse starting index + for k in param_types.keys(): + param_types[k] = np.cumsum([0] + param_types[k][:-1]) + + logger.debug(f"params type: {param_types}") + + name_mapping = {} + pp_rank = dist.get_rank(pp_group) + for k in key_list[pp_rank]: + matched = matcher.search(k) + name_idx = k[matched.start() : matched.end()] + name = "_".join(name_idx.split("_")[:-1]) + idx = int(name_idx.split("_")[-1]) + logger.debug(f"idx: {idx}") + + new_idx = param_types[name][pp_rank] + idx + logger.debug(f"new idx: {new_idx}") + new_name_idx = name + "_" + str(new_idx) + name_mapping[k] = new_name_idx + k[matched.end() :] + + return name_mapping + + +def _get_wrapped_dist_state_dict(dist_state_dict): + + wrapped_state_dict = dict() + if dist.get_world_size() <= 1: + for _, v in dist_state_dict.items(): + wrapped_state_dict[v.name] = v + return wrapped_state_dict + + hcg = fleet.get_hybrid_communicate_group() + + pp_group = hcg.get_pipe_parallel_group() + mp_group = hcg.get_model_parallel_group() + logger.debug("execute _name_mapping_dist2single") + + name_mapping = _name_mapping_dist2single(dist_state_dict, pp_group) + for _, v in dist_state_dict.items(): + if not _is_first_used(v): + logger.debug(f"not first used : {v.name}") + continue + wrapped_state_dict[name_mapping[v.name]] = v + setattr(v, "dims_mapping", _get_dims_mapping(v, mp_group)) + logger.debug( + f"saving param: {v.name} -> {name_mapping[v.name]} shape: {v.shape}" + ) + return wrapped_state_dict -- GitLab