From f40ed5f421d64028c9781c2b77aeb2958327b090 Mon Sep 17 00:00:00 2001 From: Baibaifan <39549453+Baibaifan@users.noreply.github.com> Date: Wed, 9 Mar 2022 19:04:50 +0800 Subject: [PATCH] add_sharding_api (#40129) --- python/paddle/distributed/__init__.py | 1 + .../sharding_optimizer_stage2.py | 6 +- .../meta_parallel/sharding/sharding_stage2.py | 19 +- .../meta_parallel/sharding/sharding_stage3.py | 9 +- .../paddle/distributed/sharding/__init__.py | 17 ++ .../distributed/sharding/group_sharded.py | 211 ++++++++++++++++++ .../fluid/tests/unittests/CMakeLists.txt | 3 + .../unittests/dygraph_group_sharded_api.py | 147 ++++++++++++ .../unittests/dygraph_sharding_stage3.py | 8 +- .../test_dygraph_group_sharded_api.py | 31 +++ python/paddle/framework/io.py | 8 +- python/setup.py.in | 1 + 12 files changed, 437 insertions(+), 24 deletions(-) create mode 100644 python/paddle/distributed/sharding/__init__.py create mode 100644 python/paddle/distributed/sharding/group_sharded.py create mode 100644 python/paddle/fluid/tests/unittests/dygraph_group_sharded_api.py create mode 100644 python/paddle/fluid/tests/unittests/test_dygraph_group_sharded_api.py diff --git a/python/paddle/distributed/__init__.py b/python/paddle/distributed/__init__.py index fc299bc7b5..a0ae9bc29d 100644 --- a/python/paddle/distributed/__init__.py +++ b/python/paddle/distributed/__init__.py @@ -55,6 +55,7 @@ from paddle.fluid.dygraph.parallel import ParallelEnv # noqa: F401 from . import cloud_utils # noqa: F401 from . import utils # noqa: F401 +from .sharding import * # noqa: F401 __all__ = [ # noqa "spawn", diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py index 112c3887fc..a31f8bbfed 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py @@ -40,8 +40,6 @@ align = { Type.fp32.value: 4, } -__all__ = ["ShardingOptimizerStage2"] - class ShardingOptimizerStage2(Optimizer): """ @@ -136,7 +134,7 @@ class ShardingOptimizerStage2(Optimizer): # Update optimizer parameters and adjust parameter storage and use according to rank. self._update_opt_status() - @paddle.no_grad() + @paddle.autograd.no_grad() def _sync_params_and_buffers(self): """ Sync all model states for all ranks @@ -392,7 +390,7 @@ class ShardingOptimizerStage2(Optimizer): self._dtype_rank_params.clear() self._param2rank.clear() - @fluid.dygraph.no_grad + @paddle.autograd.no_grad() def _broadcast_params(self): """Broadcast the parameters of the current rank to each rank""" diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py index 392a7f3ac5..548f036067 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py @@ -63,8 +63,7 @@ class ShardingStage2(nn.Layer): sync_buffers=False, buffer_max_size=2**23, #8MB auto_refresh_trainable=True, - device="gpu", - use_grad_storage=True): + device="gpu"): super().__init__() # training options @@ -102,9 +101,10 @@ class ShardingStage2(nn.Layer): # Set grad storage size & Display param sizes and model sizes model_size = sum( [np.prod(p.shape) for p in self._layer.parameters()]).item() + assert buffer_max_size >= 0, "buffer_max_size must be GE than 0." self._buffer_max_size = self._rank_buffer_size(buffer_max_size, model_size) - self._use_grad_storage = use_grad_storage + self._use_grad_storage = buffer_max_size > 0 self._grad_storages = {} # {dtype: {rank: GradStorage}} self._has_grad_storage = [] self._grad_storage_list = [] @@ -255,7 +255,7 @@ class ShardingStage2(nn.Layer): # wait next func hook support self._setup_backward_hooks() - @paddle.no_grad() + @paddle.autograd.no_grad() def __sync_buffers(self): """ Sync all the param buffers from all ranks (exp: batch norm statistics). @@ -277,7 +277,7 @@ class ShardingStage2(nn.Layer): except AttributeError: return getattr(self._layer, name) - @paddle.no_grad() + @paddle.autograd.no_grad() def _clear_counters(self): """Reset all the grad reduce and call counters.""" if self.training: @@ -290,13 +290,13 @@ class ShardingStage2(nn.Layer): def _get_reduce_fn(self, index, param, dst_rank): """ There are two ways to reduce gradient. - - 1. Do not use use_grad_storage or exceeded buffer_max_size will be reduced separately. + - 1. Do not use self._use_grad_storage or exceeded buffer_max_size will be reduced separately. - 2. Use grad_storage Reduce the storage to get the full gradient from different ranks. """ if not self._use_grad_storage or not self._has_grad_storage[index]: # Direct reduction - @paddle.no_grad() + @paddle.autograd.no_grad() def reduce(*_): # Skip gradient reduction, do not change status information if self._grad_reduced[index]: @@ -336,7 +336,7 @@ class ShardingStage2(nn.Layer): else: # Buffer reduction - @paddle.no_grad() + @paddle.autograd.no_grad() def reduce(*_): # Skip gradient reduction, do not change status information if self._grad_reduced[index]: @@ -421,9 +421,6 @@ class ShardingStage2(nn.Layer): Integrate the parameters gradient into a continuous memory according to rank, and support the update of training parameters. """ - if not self._use_grad_storage: - return - # According to parameters's numel sort, allocate memory of parameter gradient to continuous memory according to rank self._grad_storages = {} self._has_grad_storage = [False for _ in self._trainable_params] diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py index de69836fdb..bcf63a54cc 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py @@ -84,6 +84,7 @@ class ShardingStage3(nn.Layer): self._offload = offload self._sync_comm = sync_comm # segmentation size + assert segment_size >= 0, "segment_size must be GE than 0." self._segment_size = segment_size global DEV @@ -158,7 +159,7 @@ class ShardingStage3(nn.Layer): self._redefine_opt_step() self._redefine_opt_clear() - @paddle.no_grad() + @paddle.autograd.no_grad() def _sync_params_and_buffers(self): """ Sync all model states for all ranks @@ -408,7 +409,7 @@ class ShardingStage3(nn.Layer): # register post forward hooks sub_layer.register_forward_post_hook(_forward_post_hook) - @paddle.no_grad() + @paddle.autograd.no_grad() def _sync_buffers(self): """ Sync all the param buffers from all ranks (exp: batch norm statistics). @@ -521,7 +522,7 @@ class ShardingStage3(nn.Layer): param._register_backward_hook(allreduce_function) def _get_allreduce_fn(self, param): - @paddle.no_grad() + @paddle.autograd.no_grad() def reduce(*_): if param.name in self._task_flow.full_grad.keys(): full_grad = self._task_flow.full_grad[param.name] @@ -840,7 +841,7 @@ def _allgather_buffer(trainable_params, return task_flow -@paddle.no_grad() +@paddle.autograd.no_grad() def _create_params_grad(trainable_params, param2buffer_size, task_flow): for param in trainable_params: if param.name in task_flow.full_grad.keys(): diff --git a/python/paddle/distributed/sharding/__init__.py b/python/paddle/distributed/sharding/__init__.py new file mode 100644 index 0000000000..d14e3dd099 --- /dev/null +++ b/python/paddle/distributed/sharding/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .group_sharded import group_sharded_parallel, save_group_sharded_model # noqa: F401 + +__all__ = ['group_sharded_parallel', 'save_group_sharded_model'] diff --git a/python/paddle/distributed/sharding/group_sharded.py b/python/paddle/distributed/sharding/group_sharded.py new file mode 100644 index 0000000000..2fdb20600f --- /dev/null +++ b/python/paddle/distributed/sharding/group_sharded.py @@ -0,0 +1,211 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import logging +from enum import Enum + +import paddle + +from paddle.optimizer import Optimizer +from paddle.distributed.utils import get_logger +from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2 import ShardingOptimizerStage2 +from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import ShardingStage2 +from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage3 import ShardingStage3 +from paddle.distributed.fleet.meta_parallel.sharding.sharding_utils import ShardingScaler + +logger_ = get_logger(logging.INFO) + + +def group_sharded_parallel(model, + optimizer, + level, + scaler=None, + group=None, + offload=False, + sync_buffers=False, + buffer_max_size=2**23, + segment_size=2**20, + sync_comm=False): + """ + Use this module to configure and wrap up the parameters of the group shared module. + + Args: + model (Layer): The layer to be wrapped with group_sharded_parallel. + optimizer (Optimizer): The optimizer to be wrapped with group_sharded_parallel. + level (str): The different level of the group sharded. Such as `os`, `os_g`, `p_g_os`. + scaler (GradScaler, optional): The scaler to be wrapped with group_sharded_parallel. Defaults to None. + group (Group, optional): The group instance. Defaults to None.d + offload (bool, optional): Whether to perform optimizer state and gradient transfer CPU. Defaults to False. + sync_buffers (bool, optional): Whether to broadcast model buffers. Defaults to False. + buffer_max_size (int, optional): The max size of the buffer used to integrate gradient in `os_g`. Defaults to 2**23. + segment_size (int, optional): The smallest size of parameter to be sharded in `p_g_os`. Defaults to 2**20. + sync_comm (bool, optional): Whether to use synchronous communication, only in `p_g_os` used. Defaults to False. + + Returns: + model: A wrapper for group sharded given model. + optimizer: A wrapper for group sharded given optimizer. + scaler: A wrapper for group sharded given scaler. + + Examples: + .. code-block:: python + + # required: distributed + import paddle + from paddle.fluid.dygraph.nn import Linear + from paddle.distributed import fleet + from paddle.distributed.sharding import group_sharded_parallel + + fleet.init(is_collective=True) + group = paddle.distributed.new_group([0, 1]) + model = Linear(1000, 1000) + + clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0) + optimizer = paddle.optimizer.AdamW(learning_rate=0.001, parameters=model.parameters(), weight_decay=0.00001, grad_clip=clip) + + # wrap sharding model, optimizer and scaler + model, optimizer, scaler = group_sharded_parallel(model, optimizer, "p_g", scaler=scaler) + + img, label = data + label.stop_gradient = True + img.stop_gradient = True + + out = model(img) + loss = paddle.nn.functional.cross_entropy(input=out, label=label) + + loss.backward() + optimizer.step() + optimizer.clear_grad() + """ + # check optition type + assert isinstance( + model, + paddle.nn.Layer), "The model must be the instance of paddle.nn.Layer." + assert isinstance( + optimizer, Optimizer + ), "The optimizer must be the instance of paddle.optimizer.Optimizer." + assert level in ['os', 'os_g', 'p_g_os' + ], "The level must be os, os_g or p_g_os." + + def check_dtype(param): + return param.dtype == paddle.float16 + + params_fp16 = filter(check_dtype, model.parameters()) + if scaler is None and len(params_fp16) > 0: + raise ValueError("Please enter the correct scaler.") + # convert model/optimizer/scaler + if level in ['os', 'os_g']: + logger_.info("*" * 30) + logger_.info("Sharded level os uses sharded level os_g achieved now.") + logger_.info("*" * 30) + optimizer = ShardingOptimizerStage2( + params=model.parameters(), + optim=optimizer, + group=group, + offload=offload) + model = ShardingStage2( + model, + optimizer, + group=group, + sync_buffers=sync_buffers, + buffer_max_size=buffer_max_size) + elif level == 'p_g_os': + model = ShardingStage3( + model, + optimizer=optimizer, + group=group, + sync_buffers=sync_buffers, + segment_size=segment_size, + offload=offload, + sync_comm=sync_comm) + else: + raise ValueError("Please enter the correct level.") + if params_fp16 and isinstance(scaler, paddle.amp.GradScaler): + scaler = ShardingScaler(scaler) + logger_.info("*" * 30) + logger_.info( + "If there is a communication hang using group sharded, please check whether the communication operations of each process are unified." + ) + logger_.info("*" * 30) + + return model, optimizer, scaler + + +def save_group_sharded_model(model, output, optimizer=None): + """ + Group sharded encapsulated model and optimizer state saving module. + + Args: + model (Layer): A wrapper for group sharded given model. + output (str): Save directory. + optimizer (Optimizer, optional): Group sharded encapsulated optimizer. Defaults to None. + + Examples: + .. code-block:: python + + # required: distributed + import paddle + from paddle.fluid.dygraph.nn import Linear + from paddle.distributed import fleet + from paddle.distributed.sharding import group_sharded_parallel, save_group_sharded_model + + fleet.init(is_collective=True) + group = paddle.distributed.new_group([0, 1]) + model = Linear(1000, 1000) + + clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0) + optimizer = paddle.optimizer.AdamW(learning_rate=0.001, parameters=model.parameters(), weight_decay=0.00001, grad_clip=clip) + + # wrap sharding model, optimizer and scaler + model, optimizer, scaler = group_sharded_parallel(model, optimizer, "p_g", scaler=scaler) + + img, label = data + label.stop_gradient = True + img.stop_gradient = True + + out = model(img) + loss = paddle.nn.functional.cross_entropy(input=out, label=label) + + loss.backward() + optimizer.step() + optimizer.clear_grad() + + # save model and optimizer state_dict + save_group_sharded_model(model, optimizer,output=output_dir) + """ + logger_.info( + "==========Begin to save group sharded model and optimizer==========") + assert not os.path.isfile( + output + ), "Saving directory ({}) should be a directory, not a file".format(output) + os.makedirs(output, exist_ok=True) + output_model = os.path.join(output, "model.pdmodel") + if isinstance(model, ShardingStage2): + paddle.save(model._layer.state_dict(), output_model) + elif isinstance(model, ShardingStage3): + convert2cpu = True if model._offload else False + model.get_all_parameters(convert2cpu=convert2cpu) + paddle.save(model._layer.state_dict(), output_model) + else: + raise ValueError( + "Please use the layer which is wrapped with group_sharded_parallel.") + + if optimizer is not None: + assert hasattr( + optimizer, "_optim" + ), "Please use the optimizer which is wrapped with group_sharded_parallel." + output_opt = os.path.join(output, "model.pdopt") + paddle.save(optimizer._optim.state_dict(), output_opt) + logger_.info( + "==========End to save group sharded model and optimizer==========") diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 5d861cddea..9b0c857576 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -47,6 +47,7 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_sharding_parallel) list(APPEND DIST_TEST_OPS test_dygraph_sharding_optimizer_stage2) list(APPEND DIST_TEST_OPS test_dygraph_sharding_stage2) list(APPEND DIST_TEST_OPS test_dygraph_sharding_stage3) +list(APPEND DIST_TEST_OPS test_dygraph_group_sharded_api) list(APPEND DIST_TEST_OPS test_auto_parallel_parallelizer) list(APPEND DIST_TEST_OPS test_parallel_dygraph_mp_layers) list(APPEND DIST_TEST_OPS test_hybrid_parallel_inference_helper) @@ -282,6 +283,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM)) list(REMOVE_ITEM TEST_OPS test_dygraph_sharding_optimizer_stage2) list(REMOVE_ITEM TEST_OPS test_dygraph_sharding_stage2) list(REMOVE_ITEM TEST_OPS test_dygraph_sharding_stage3) + list(REMOVE_ITEM TEST_OPS test_dygraph_group_sharded_api) list(REMOVE_ITEM TEST_OPS test_auto_parallel_parallelizer) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_mp_layers) LIST(REMOVE_ITEM TEST_OPS test_imperative_auto_mixed_precision) @@ -1123,6 +1125,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL) set_tests_properties(test_dygraph_sharding_optimizer_stage2 PROPERTIES TIMEOUT 120) set_tests_properties(test_dygraph_sharding_stage2 PROPERTIES TIMEOUT 120) set_tests_properties(test_dygraph_sharding_stage3 PROPERTIES TIMEOUT 120) + set_tests_properties(test_dygraph_group_sharded_api PROPERTIES TIMEOUT 120) set_tests_properties(test_auto_parallel_parallelizer PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_mp_layers PROPERTIES TIMEOUT 120) set_tests_properties(test_hybrid_parallel_inference_helper PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/dygraph_group_sharded_api.py b/python/paddle/fluid/tests/unittests/dygraph_group_sharded_api.py new file mode 100644 index 0000000000..d4832782c3 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_group_sharded_api.py @@ -0,0 +1,147 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import shutil +import tempfile +import numpy as np + +import paddle +import paddle.fluid as fluid +from paddle.fluid.dygraph.nn import Linear +from paddle.distributed import fleet +from paddle.fluid.dygraph import nn +from paddle.distributed.sharding import group_sharded_parallel, save_group_sharded_model + +epoch = 10 +paddle.seed(2022) +np.random.seed(2022) +base_lr = 0.1 +momentum_rate = 0.9 +l2_decay = 1e-4 +batch_size = 100 +fleet.init(is_collective=True) + + +class MLP(fluid.Layer): + def __init__(self, linear_size=1000, param_attr=None, bias_attr=None): + super(MLP, self).__init__() + + self._linear1 = Linear(linear_size, linear_size) + self._linear2 = Linear(linear_size, linear_size) + self._linear3 = Linear(linear_size, 10) + + def forward(self, inputs): + y = self._linear1(inputs) + y = self._linear2(y) + y = self._linear3(y) + return y + + +def reader_decorator(linear_size=1000): + def __reader__(): + for _ in range(100): + img = np.random.rand(linear_size).astype('float32') + label = np.ones(1).astype('int64') + yield img, label + + return __reader__ + + +def optimizer_setting(model, use_pure_fp16, opt_group=False): + clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0) + optimizer = paddle.optimizer.Momentum( + parameters=[{ + "params": list(model.parameters()) + }] if opt_group else list(model.parameters()), + learning_rate=0.001, + weight_decay=0.00001, + grad_clip=clip, + multi_precision=use_pure_fp16) + + return optimizer + + +def train_mlp(model, shard_level, use_pure_fp16, output_dir): + group = paddle.distributed.new_group([0, 1]) + + optimizer = optimizer_setting(model=model, use_pure_fp16=use_pure_fp16) + model = paddle.amp.decorate(models=model, level='O2', save_dtype='float32') + scaler = paddle.amp.GradScaler(init_loss_scaling=32768) + + model, optimizer, scaler = group_sharded_parallel( + model=model, optimizer=optimizer, level=shard_level, scaler=scaler) + + train_reader = paddle.batch( + reader_decorator(), batch_size=batch_size, drop_last=True) + + train_loader = paddle.io.DataLoader.from_generator( + capacity=32, + use_double_buffer=True, + iterable=True, + return_list=True, + use_multiprocess=True) + train_loader.set_sample_list_generator(train_reader) + + for eop in range(epoch): + model.train() + for batch_id, data in enumerate(train_loader()): + img, label = data + label.stop_gradient = True + img.stop_gradient = True + with paddle.amp.auto_cast(True, level='O2'): + out = model(img) + loss = paddle.nn.functional.cross_entropy( + input=out, label=label) + avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32)) + + if not use_pure_fp16: + avg_loss.backward() + optimizer.step() + else: + scaler.scale(avg_loss).backward() + scaler.step(optimizer) + scaler.update() + + optimizer.clear_grad() + + save_group_sharded_model(model, output=output_dir, optimizer=optimizer) + return model.parameters() + + +def test_sharding_api(): + mlp, mlp1, mlp2 = MLP(), MLP(), MLP() + state_dict = mlp.state_dict() + mlp1.set_state_dict(state_dict) + mlp2.set_state_dict(state_dict) + + output_dir = tempfile.mkdtemp() + + # fp16 + stage2_params = train_mlp( + mlp1, shard_level="os_g", use_pure_fp16=True, output_dir=output_dir) + stage3_params = train_mlp( + mlp2, shard_level="p_g_os", use_pure_fp16=True, output_dir=output_dir) + + for i in range(len(stage3_params)): + np.testing.assert_allclose( + stage2_params[i].numpy(), + stage3_params[i].numpy(), + rtol=1e-4, + atol=1e-3) + shutil.rmtree(output_dir) + + +if __name__ == '__main__': + test_sharding_api() diff --git a/python/paddle/fluid/tests/unittests/dygraph_sharding_stage3.py b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage3.py index 6b755cf4c2..bbbcb621fd 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_sharding_stage3.py +++ b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage3.py @@ -83,7 +83,7 @@ def train_mlp(model, accumulate_grad=False, batch_size=100, opt_group=False, - recompute=False, + sync_comm=False, test_minimize=False): group = paddle.distributed.new_group([0, 1]) if opt_group: @@ -104,7 +104,7 @@ def train_mlp(model, model, optimizer, group=group, buffer_max_size=2**21) elif sharding_stage == 3: model = ShardingStage3( - model, optimizer=optimizer, group=group, sync_comm=recompute) + model, optimizer=optimizer, group=group, sync_comm=sync_comm) # check optimizer.minimize() error if test_minimize: @@ -225,7 +225,7 @@ def test_stage2_stage3(): rtol=1e-4, atol=1e-3) - # fp16 recompute + # fp16 sync_comm stage3_params = train_mlp( mlp7, sharding_stage=3, use_pure_fp16=True, opt_group=False) stage3_params_re = train_mlp( @@ -233,7 +233,7 @@ def test_stage2_stage3(): sharding_stage=3, use_pure_fp16=True, opt_group=False, - recompute=True) + sync_comm=True) for i in range(len(stage3_params)): np.testing.assert_allclose( stage3_params[i].numpy(), stage3_params_re[i].numpy(), rtol=1e-6) diff --git a/python/paddle/fluid/tests/unittests/test_dygraph_group_sharded_api.py b/python/paddle/fluid/tests/unittests/test_dygraph_group_sharded_api.py new file mode 100644 index 0000000000..7c296c7e40 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dygraph_group_sharded_api.py @@ -0,0 +1,31 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle.fluid as fluid + +from test_parallel_dygraph_dataparallel import TestMultipleGpus + + +class TestDygraphGroupSharded(TestMultipleGpus): + + # check group sharded logic as well as the accuracy with single mode + def test_dygraph_group_sharded(self): + self.run_mnist_2gpu('dygraph_group_sharded_api.py') + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/framework/io.py b/python/paddle/framework/io.py index 94b8bd29b2..f2d41b5e9b 100644 --- a/python/paddle/framework/io.py +++ b/python/paddle/framework/io.py @@ -46,6 +46,10 @@ def _build_saved_state_dict(state_dict): if value.type == core.VarDesc.VarType.VOCAB: save_dict[key] = value.value().get_map_tensor() else: + if not value.value().get_tensor()._is_initialized(): + raise ValueError( + "The saved tensor is not initialized. If you used group sharded, please use save_group_sharded_model." + ) save_dict[key] = value.numpy() name_table[key] = value.name else: @@ -466,7 +470,9 @@ def _parse_load_result(obj, return_numpy): def _save_lod_tensor(tensor, file_name): if not tensor._is_initialized(): - raise ValueError("The saved tensor is not initialized.") + raise ValueError( + "The saved tensor is not initialized. If you used group sharded, please use save_group_sharded_model firstly." + ) if _is_file_path(file_name): _seek = core.save_lod_tensor(tensor, file_name) # '_seek' is the end position of this tensor in the file. diff --git a/python/setup.py.in b/python/setup.py.in index 118f617361..3ce22892b6 100755 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -280,6 +280,7 @@ packages=['paddle', 'paddle.incubate.nn', 'paddle.incubate.passes', 'paddle.distribution', + 'paddle.distributed.sharding', 'paddle.distributed.fleet', 'paddle.distributed.fleet.base', 'paddle.distributed.fleet.elastic', -- GitLab