diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py index a31f8bbfed0c919ce3e848c5e5f66407f76d5a20..a2c741667ed77f7ffbd2e650ff162c2762a1ced6 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py @@ -25,10 +25,9 @@ from collections import OrderedDict import paddle import paddle.fluid as fluid from paddle.fluid import core -import paddle.distributed as dist from paddle.optimizer import Optimizer from paddle.fluid.clip import ClipGradByGlobalNorm -from paddle.distributed.collective import _get_global_group +from paddle.distributed.collective import _get_global_group, new_group, broadcast, wait from ...utils.internal_storage import ParamStorage, GradStorage from ...meta_parallel.sharding.sharding_utils import Type, device_guard, ShardingClipGrad @@ -91,8 +90,8 @@ class ShardingOptimizerStage2(Optimizer): filter(lambda x: x.trainable and x.dtype == Type.fp16.value, self._local_params))) > 0 - self.group = dist.new_group(_get_global_group() - .ranks) if group is None else group + self.group = new_group(_get_global_group() + .ranks) if group is None else group self.world_size = self.group.nranks self.rank = self.group.rank @@ -141,14 +140,14 @@ class ShardingOptimizerStage2(Optimizer): """ for p in self._local_params: - dist.broadcast( + broadcast( p, src=self._global_root_rank, group=self.group, use_calc_stream=True) # Multi stream operation will be supported later - dist.wait(tensor=p, group=self.group, use_calc_stream=True) + wait(tensor=p, group=self.group, use_calc_stream=True) def _generate_master_params(self, trainable_params): if self.offload: @@ -385,6 +384,12 @@ class ShardingOptimizerStage2(Optimizer): raise RuntimeError( "optimizer.minimize() not support now, please use optimizer.step()") + def set_state_dict(self, state_dict): + self._optim.set_state_dict(state_dict) + + def state_dict(self): + return self._optim.state_dict() + def _clear_cache(self): self.__segment_params.clear() self._dtype_rank_params.clear() @@ -399,14 +404,14 @@ class ShardingOptimizerStage2(Optimizer): # Exchange all the shards with the other ranks for dtype_per_rank in self.param_storages.values(): for dst_rank, internal_storage in dtype_per_rank.items(): - dist.broadcast( + broadcast( tensor=internal_storage.buffer, src=self.group.ranks[dst_rank], group=self.group, use_calc_stream=True) # Multi stream operation will be supported later - dist.wait( + wait( tensor=internal_storage.buffer, group=self.group, use_calc_stream=True) diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py index 548f036067eba702f9958cf3a10175b41a798e26..c6f05023e6138597dfd906cd53854b70231d6130 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py @@ -28,7 +28,7 @@ from types import MethodType import paddle from paddle import nn -import paddle.distributed as dist +from paddle.distributed import collective as dist from paddle.distributed.collective import _get_global_group from ...utils.internal_storage import GradStorage @@ -158,6 +158,17 @@ class ShardingStage2(nn.Layer): return fw + def set_state_dict(self, state_dict, use_structured_name=True): + self._layer.set_state_dict( + state_dict, use_structured_name=use_structured_name) + + def state_dict(self, + destination=None, + include_sublayers=True, + structured_name_prefix=""): + return self._layer.state_dict( + destination=None, include_sublayers=True, structured_name_prefix="") + def _clear_gradients(self): """ Set zero to the gradient of the optimizer's current rank trainable parameters. diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py index bcf63a54cc4ec4b6f8be818c1059bc8fb186e1a4..9886ca4e2deace4c625ead51852841e7c761be21 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py @@ -20,7 +20,6 @@ import logging import functools import numpy as np from itertools import chain -from functools import reduce from types import MethodType from collections import deque, OrderedDict @@ -28,9 +27,9 @@ import paddle from paddle import nn from paddle.autograd import PyLayer import paddle.fluid.core as core -import paddle.distributed as dist from paddle.fluid.framework import ParamBase from paddle.fluid.clip import ClipGradByGlobalNorm +from paddle.distributed import collective as dist from paddle.distributed.collective import _get_global_group from .sharding_utils import Type, ShardingClipGrad, device_guard @@ -249,6 +248,17 @@ class ShardingStage3(nn.Layer): return fw + def set_state_dict(self, state_dict, use_structured_name=True): + self._layer.set_state_dict( + state_dict, use_structured_name=use_structured_name) + + def state_dict(self, + destination=None, + include_sublayers=True, + structured_name_prefix=""): + return self._layer.state_dict( + destination=None, include_sublayers=True, structured_name_prefix="") + def _handle_unslice_params(self): buffer_size = dict() buffer_size[Type.fp32.value] = 0 @@ -523,7 +533,7 @@ class ShardingStage3(nn.Layer): def _get_allreduce_fn(self, param): @paddle.autograd.no_grad() - def reduce(*_): + def allreduce_(*_): if param.name in self._task_flow.full_grad.keys(): full_grad = self._task_flow.full_grad[param.name] # Only support sync allreduce current rank's layer now @@ -573,7 +583,7 @@ class ShardingStage3(nn.Layer): if self._offload: param.fw_storage = _device2cpu(param.fw_storage, True) - return reduce + return allreduce_ def _param2align(self, param): # CUDA alignment 256 bytes diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py index 0a42b993d5bf2387d1110ae5478b6162ce175483..89b59254e5b9105a55c68f3ef871396de1bd9199 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py @@ -21,7 +21,6 @@ import numpy as np from types import MethodType import paddle -import paddle.distributed as dist from paddle import _C_ops from paddle.fluid import core from paddle.fluid import layers diff --git a/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py index 06935e212c3cb19519c558042cea3210910a8975..fb01fd46c0d28455067fd44139e177a81b25566a 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py +++ b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py @@ -14,8 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import shutil import numpy as np import argparse +import tempfile import ast import time import paddle @@ -88,7 +91,8 @@ def train_mlp(model, batch_size=100, use_pure_fp16=False, accumulate_grad=False, - opt_group=False): + opt_group=False, + save_model=False): if sharding_stage == "dp": hcg = fleet.get_hybrid_communicate_group() group = hcg.get_check_parallel_group() @@ -147,6 +151,9 @@ def train_mlp(model, if accumulate_grad: optimizer.step() optimizer.clear_grad() + + if save_model: + return model, optimizer return model.parameters() @@ -158,11 +165,13 @@ def test_dp_stage2(): mlp3 = MLP() mlp4 = MLP() mlp5 = MLP() + mlp6 = MLP() mlp1.set_state_dict(state_dict) mlp2.set_state_dict(state_dict) mlp3.set_state_dict(state_dict) mlp4.set_state_dict(state_dict) mlp5.set_state_dict(state_dict) + mlp6.set_state_dict(state_dict) # DP VS stage2 dp_params = train_mlp( @@ -186,10 +195,29 @@ def test_dp_stage2(): # stage2 param list VS param group stage2_params = train_mlp( - mlp2, sharding_stage=2, use_pure_fp16=False, opt_group=True) + mlp5, sharding_stage=2, use_pure_fp16=False, opt_group=True) for i in range(len(dp_params)): np.testing.assert_allclose( dp_params[i].numpy(), stage2_params[i].numpy(), rtol=1e-6) + + # save/load model + output_dir = tempfile.mkdtemp() + model_file = os.path.join(output_dir, "model.pdmodel") + optimizer_file = os.path.join(output_dir, "model.pdopt") + model_stage2, optimizer_stage2 = train_mlp( + mlp6, + sharding_stage=2, + use_pure_fp16=False, + opt_group=False, + save_model=True) + paddle.save(model_stage2.state_dict(), model_file) + paddle.save(optimizer_stage2.state_dict(), optimizer_file) + m_state_dict = paddle.load(model_file) + opt_state_dict = paddle.load(optimizer_file) + model_stage2.set_state_dict(m_state_dict) + optimizer_stage2.set_state_dict(opt_state_dict) + shutil.rmtree(output_dir) + return diff --git a/python/paddle/fluid/tests/unittests/dygraph_sharding_stage3.py b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage3.py index bbbcb621fd466edbb496de6c36714ee7f982ca0a..82821cd7ee644b5209a594e9a43de7636cdd4958 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_sharding_stage3.py +++ b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage3.py @@ -14,6 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import shutil +import tempfile import numpy as np import argparse import ast @@ -84,7 +87,8 @@ def train_mlp(model, batch_size=100, opt_group=False, sync_comm=False, - test_minimize=False): + test_minimize=False, + save_model=False): group = paddle.distributed.new_group([0, 1]) if opt_group: optimizer = optimizer_setting( @@ -162,12 +166,15 @@ def train_mlp(model, optimizer.clear_grad() if sharding_stage == 3: model.get_all_parameters() + + if save_model: + return model, optimizer return model.parameters() def test_stage2_stage3(): - mlp, mlp1, mlp2, mlp3, mlp4, mlp5, mlp6, mlp7, mlp8, mlp9 = MLP(), MLP( - ), MLP(), MLP(), MLP(), MLP(), MLP(), MLP(), MLP(), MLP() + mlp, mlp1, mlp2, mlp3, mlp4, mlp5, mlp6, mlp7, mlp8, mlp9, mlp10 = MLP( + ), MLP(), MLP(), MLP(), MLP(), MLP(), MLP(), MLP(), MLP(), MLP(), MLP() state_dict = mlp.state_dict() mlp1.set_state_dict(state_dict) mlp2.set_state_dict(state_dict) @@ -178,6 +185,7 @@ def test_stage2_stage3(): mlp7.set_state_dict(state_dict) mlp8.set_state_dict(state_dict) mlp9.set_state_dict(state_dict) + mlp10.set_state_dict(state_dict) # fp32 stage2_params = train_mlp( @@ -238,9 +246,27 @@ def test_stage2_stage3(): np.testing.assert_allclose( stage3_params[i].numpy(), stage3_params_re[i].numpy(), rtol=1e-6) + # save/load model + output_dir = tempfile.mkdtemp() + model_file = os.path.join(output_dir, "model.pdmodel") + optimizer_file = os.path.join(output_dir, "model.pdopt") + model_stage3, optimizer_stage3 = train_mlp( + mlp9, + sharding_stage=3, + use_pure_fp16=False, + opt_group=False, + save_model=True) + paddle.save(model_stage3.state_dict(), model_file) + paddle.save(optimizer_stage3.state_dict(), optimizer_file) + m_state_dict = paddle.load(model_file) + opt_state_dict = paddle.load(optimizer_file) + model_stage3.set_state_dict(m_state_dict) + optimizer_stage3.set_state_dict(opt_state_dict) + shutil.rmtree(output_dir) + # check optimizer.minimize() error train_mlp( - mlp9, + mlp10, sharding_stage=3, use_pure_fp16=False, opt_group=False,