From 20e19776b00616f8263ac1fc72957fbe966a960a Mon Sep 17 00:00:00 2001 From: Baibaifan <39549453+Baibaifan@users.noreply.github.com> Date: Thu, 2 Dec 2021 11:18:51 +0800 Subject: [PATCH] Add dygraph sharding stage2 (#37707) --- .../sharding_optimizer_stage2.py | 6 - .../meta_parallel/sharding/sharding_stage2.py | 505 ++++++++++++++++++ .../fluid/tests/unittests/CMakeLists.txt | 3 + .../unittests/dygraph_sharding_stage2.py | 204 +++++++ .../unittests/test_dygraph_sharding_stage2.py | 31 ++ 5 files changed, 743 insertions(+), 6 deletions(-) create mode 100644 python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py create mode 100644 python/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py create mode 100644 python/paddle/fluid/tests/unittests/test_dygraph_sharding_stage2.py diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py index 9595896188b..ba1b5222394 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py @@ -68,7 +68,6 @@ class ShardingOptimizerStage2(Optimizer): broadcast_fp16=False, offload=False, device="gpu", - accumulation_steps=None, **kw): super().__init__(optim._learning_rate, params, kw) @@ -86,7 +85,6 @@ class ShardingOptimizerStage2(Optimizer): self._optim = optim self._local_params = params self._default_device = device - self._accumulation_steps = accumulation_steps assert group is not None, "Distributed communication group is must be gived" self.group = group @@ -136,10 +134,6 @@ class ShardingOptimizerStage2(Optimizer): def local_params(self): return self._local_params - @property - def accumulation_steps(self): - return self._accumulation_steps - @property def param2rank(self): """Map the params to the rank which owns them""" diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py new file mode 100644 index 00000000000..8ac4a7e99c7 --- /dev/null +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py @@ -0,0 +1,505 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#Taken and modified for fairscale from: +# https://github.com/facebookresearch/fairscale/blob/main/fairscale/nn/data_parallel/sharded_ddp.py +#Commit: 8acbec718f3c70a6b9785470bb9e05cd84fc3f8e + +import os +import contextlib +import logging +import time +import functools +import numpy as np +from itertools import chain +from functools import reduce +from collections import deque + +import paddle +from paddle import nn +import paddle.distributed as dist + +from ...utils.internal_storage import GradStorage +from .sharding_utils import Taskflow, Type + + +def _trainable(param): + return param.trainable + + +class ShardingStage2(nn.Layer): + """ + A wrapper for Sharding Stage2 Layer in Dygraph. + .. warning: ShardingStage2 encapsulates the layer strategy and integrates it into the nn.Layer. + .. ZeRO: https://arxiv.org/pdf/1910.02054.pdf. + """ + + # TODO (Baibaifan) + # Feature Notes:: + # 1. Unified memory for param and param.grad to InternalStorage. + # 2. Divide param.grad according to rank to centrally apply for and release GPU memory. + # 3. Dynamically adjust training parameters and models。 + # 4. Support offload function. + # 5. Support the establishment of independent communication groups. + + def __init__( + self, + layer, + sharding_optimizer, + group, + sync_buffers=False, + pertrain_sync_models=True, + buffer_max_size=2**23, #8MB + auto_refresh_trainable=True, + device="gpu", + use_grad_storage=True, + accumulate_grads=False): + super().__init__() + + # training options + self._layer = layer + self._sharding_optimizers = [sharding_optimizer] if not isinstance( + sharding_optimizer, list) else sharding_optimizer + self._sync_buffers = sync_buffers + self._auto_refresh_trainable = auto_refresh_trainable + + # Gradient accumulation, Gradient flip + self._accumulate_grads = accumulate_grads + + # Communication related attributes + assert group is not None, "Distributed communication group is must be gived" + self._group = group + self._world_size_scaling = 1.0 / self._group.nranks + assert self._group.nranks > 1, "Training must be distributed, ranks must be greater than 1" + self._rank = self._group.rank + self._global_root_rank = 0 # picking rank 0 as the reference + self._global_ranks = self._group.ranks + self._default_device = device + + # Global statistical parameters + self._all_params = list( + chain( + * [optim.local_params for optim in self._sharding_optimizers])) + self._trainable_params = [] + self._grad_reduced = [] + self._trainable_param2rank = {} + self._trainable_param2align = {} + self._trainable_mask = list(map(_trainable, self._all_params)) + self._param_grads = [] + + # Set grad storage size & Display param sizes and model sizes + model_size = sum( + [np.prod(p.shape) for p in self._layer.parameters()]).item() + self._buffer_max_size = self._rank_buffer_size(buffer_max_size, + model_size) + self._use_grad_storage = use_grad_storage + self._grad_storages = {} # {dtype: {rank: GradStorage}} + self._has_grad_storage = [] + self._grad_storage_list = [] + + # Set backward pass hooks + self._bw_hooks = [] + + # Synchronous all ranks models + if pertrain_sync_models: + self._sync_params_and_buffers() + + # Set tasks flow + self._tasks_flow = deque() + + def forward(self, *inputs, **kwargs): + """ + A wrapper for Sharding Stage2 layer. + - Fresh trainable params or rebuild grad storage + - Sync layer's buffer params + - Clear all flags states + - Forward for origin layers + """ + + # Whether to need to reset trainable parameters + needs_fresh = len(self._bw_hooks) == 0 and self.training + + if self._auto_refresh_trainable: + needs_fresh |= self._detect_train_change() + + # Front hook + self._init_internal_storage(needs_fresh) + + # Sync layer's buffers state + if self._sync_buffers: + self.__sync_buffers() + + # Normal FW on the base model + fw = self._layer(*inputs, **kwargs) + + return fw + + def clear_gradients(self): + """ + Set zero to the gradient of the optimizer's current rank trainable parameters. + """ + # Release grad storages + for dtype in self._grad_storages.keys(): + if self._rank in self._grad_storages[dtype].keys(): + self._grad_storages[dtype][self._rank].buffer.zero_() + + # Release params + for param in self._trainable_params: + if param.name in self._param_grads and param.grad is not None: + param.clear_gradient() + + def grad_scale(self): + """ + Before the gradient accumulation, scale the gradient. + """ + # Scale grad storages + for dtype in self._grad_storages.keys(): + if self._rank in self._grad_storages[dtype].keys(): + self._grad_storages[dtype][self._rank].buffer.scale_( + scale=self._world_size_scaling) + + # Scale params + for param in self._trainable_params: + if param.name in self._param_grads and param.grad is not None: + param.grad.scale_(scale=self._world_size_scaling) + param._reset_grad_inplace_version() + + def _init_internal_storage(self, needs_fresh): + """ + Judge Fresh trainable params or rebuild grad storage. + """ + if needs_fresh: + self._fresh_trainable() + else: + self._build_grad_storages() + + # Clear all flags state + self._clear_counters() + + def to(self, device=None, dtype=None, blocking=True): + """ + Synchronously or asynchronously convert the data type of the layer, the device is not supported now. + """ + assert device == self._default_device, "New devices are not supported, because of the optimizer state is not sync" + + def _fresh_trainable(self): + """ Whether to update training parameters. """ + + # Make sure that this is not done while gradients are waiting to be reduced (if no_sync context for instance) + if reduce(lambda x, y: x or y, self._grad_reduced, False): + logging.warning("Grads waiting to be reduced.") + + self._trainable_params = list( + filter(lambda x: x.trainable, self._all_params)) + self._trainable_params.sort(key=lambda x: np.prod(x.shape)) + + self._trainable_param2rank = {} + for optim in self._sharding_optimizers: + # Need to be wrappered for Sharding Stage2 Optimizer + if len(optim.param_storages.keys()) == 0: + optim.update_opt_status() + + # Get the parameters split by the optimizer according to rank + for per_rank_params in optim.dtype_rank_params.values( + ): # all the params from all ranks + for params in per_rank_params: + for param in filter(lambda x: x.trainable, params): + self._trainable_param2rank[ + param.name] = optim.param2rank[param.name] + self._trainable_param2align[ + param.name] = optim._param2align[param.name] + + self._setup_use_grad_storage() + + # wait next func hook support + self._setup_backward_hooks() + + @paddle.no_grad() + def __sync_buffers(self): + """ + Sync all the param buffers from all ranks (exp: batch norm statistics). + """ + + for buffer in self._layer.buffers(include_sublayers=True): + dist.broadcast( + buffer, + self._global_root_rank, + self._group, + use_calc_stream=True) + # Multi stream operation will be supported later + dist.wait(tensor=buffer, group=self._group, use_calc_stream=True) + + def __getattr__(self, name): + """Forward missing attributes to wrapped layer.""" + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self._layer, name) + + @paddle.no_grad() + def _clear_counters(self): + """Reset all the grad reduce and call counters.""" + if self.training: + self._grad_reduced = [True for _ in self._trainable_params] + + if self._use_grad_storage: + for grad_storage in self._grad_storage_list: + grad_storage.reset_checked_in() + + if not self._accumulate_grads: + self._grads_flipped = False + + def _get_reduce_fn(self, index, param, dst_rank): + """ + There are two ways to reduce gradient. + - 1. Do not use use_grad_storage or exceeded buffer_max_size will be reduced separately. + - 2. Use grad_storage Reduce the storage to get the full gradient from different ranks. + """ + + if not self._use_grad_storage or not self._has_grad_storage[index]: + # Direct reduction + @paddle.no_grad() + def reduce(*_): + # Skip gradient reduction, do not change status information + if self._grad_reduced[index]: + assert param.grad is not None, "Parameter gradient cannot be None" + + # Change reduce information + self._grad_reduced[index] = False + if not self._accumulate_grads: + param.grad.scale_(scale=self._world_size_scaling) + param._reset_grad_inplace_version() + + # Clear the gradient that does not belong to the current rank through the callback function + def cleanup(): + if dst_rank != self._rank: + param.clear_gradient(False) + + # Synchronize the reduce parameter gradient + self._tasks_flow.append( + Taskflow( + task=dist.reduce( + tensor=param.grad, + dst=dst_rank, + group=self._group, + use_calc_stream=True), + callback=cleanup)) + + # Multi stream operation will be supported later + dist.wait( + tensor=param.grad, + group=self._group, + use_calc_stream=True) + + # Clear the task flow and trigger callback to clear the redundant gradient + self._clear_task_flow() + + else: + # Buffer reduction + @paddle.no_grad() + def reduce(*_): + # Skip gradient reduction, do not change status information + if self._grad_reduced[index]: + assert param.grad is not None, "Parameter gradient cannot be None" + + # Change reduce information + self._grad_reduced[index] = False + grad_storage = self._grad_storages[param.dtype][dst_rank] + grad_storage.params_checked_in += 1 + + if grad_storage.all_checked_in: + assert grad_storage.buffer is not None + + # Normalize all ranks grad_storage + if not self._accumulate_grads: + grad_storage.buffer.scale_( + scale=self._world_size_scaling) + + # Clearing up the grad_storage buffer + def cleanup(): + if dst_rank != self._rank: + for p in grad_storage._params: + p.clear_gradient(False) + p._gradient_set_empty(False) + + grad_storage.buffer.value().get_tensor()._clear( + ) + + # Reduce the bucket + grad_storage.sent = True + self._tasks_flow.append( + Taskflow( + task=dist.reduce( + tensor=grad_storage.buffer, + dst=grad_storage.destination, + group=self._group, + use_calc_stream=True), + callback=cleanup)) + + # Multi stream operation will be supported later + dist.wait( + tensor=grad_storage.buffer, + group=self._group, + use_calc_stream=True) + + # Clear the task flow and trigger callback to clear the redundant gradient + self._clear_task_flow() + + return reduce + + def _setup_backward_hooks(self): + """ + Set the backward hook to synchronize the gradients of all rank by reduce group ranks. + """ + + # Remove previous backward hooks + while len(self._bw_hooks) > 0: + self._bw_hooks.pop().remove() + + # Go through the parameters, attach the hook + self._grad_accs = [] + if not self.training: + return + + for index, param in enumerate(self._trainable_params): + dst_rank = self._trainable_param2rank[param.name] + + reduce_function = self._get_reduce_fn(index, param, dst_rank) + + self._bw_hooks.append( + param._register_backward_hook(reduce_function)) + + @paddle.no_grad() + def _sync_params_and_buffers(self): + """ + Sync all model states for all ranks + """ + + for t in self._layer.parameters(): + dist.broadcast( + t, + src=self._global_root_rank, + group=self._group, + use_calc_stream=True) + + # Multi stream operation will be supported later + dist.wait(tensor=t, group=self._group, use_calc_stream=True) + + def _setup_use_grad_storage(self): + """ + Integrate the parameters gradient into a continuous memory according to rank, and support the update of training parameters. + """ + + if not self._use_grad_storage: + return + + # According to parameters's numel sort, allocate memory of parameter gradient to continuous memory according to rank + self._grad_storages = {} + self._has_grad_storage = [False for _ in self._trainable_params] + + for index, param in enumerate(self._trainable_params): + dst_rank = self._trainable_param2rank[param.name] + + if param.dtype not in self._grad_storages.keys(): + self._grad_storages[param.dtype] = {} + + if dst_rank not in self._grad_storages[param.dtype].keys(): + self._grad_storages[param.dtype][dst_rank] = GradStorage( + self._buffer_max_size[param.dtype], + dtype=param.dtype, + device=self._default_device, + destination=dst_rank, + parm2align=self._trainable_param2align) + + # Criteria to decide whether this parameter is to be put in GradStorage + if self._grad_storages[param.dtype][dst_rank].can_add_grad_view( + param, self._trainable_param2align[param.name]): + self._grad_storages[param.dtype][dst_rank].add_grad( + param, self._trainable_param2align[param.name]) + self._has_grad_storage[index] = True + else: + self._param_grads.append(param.name) + print( + "Can not add param: {}, param's shape: {}, param align: {}, grad_storages fill: {}, ". + format(param.name, param.shape, self._trainable_param2align[ + param.name], self._grad_storages[param.dtype][dst_rank] + ._fill)) + + self._grad_storage_list = list( + chain(* [ + self._grad_storages[dtype].values() + for dtype in self._grad_storages.keys() + ])) + + def _clear_task_flow(self): + """Try to consume the previous tasks.""" + while len(self._tasks_flow) > 0: + task = self._tasks_flow.popleft() + if task.callback is not None: + task.callback() + + def _detect_train_change(self): + # Current trainable parameters + trainable_mask = list(map(_trainable, self._all_params)) + + # Whether parameters trainability changed + trainability_changed = trainable_mask != self._trainable_mask + + # The whole model is not trainable but we still have grad hooks + trainability_changed |= not self.training and len(self._bw_hooks) > 0 + + if trainability_changed: + logging.warning( + "Trainable params changed, because of eval/train mode or parameter freezing/unfreeze." + ) + self._trainable_mask = trainable_mask + + return trainability_changed + + def _build_grad_storages(self): + """ + Rebuild grad storages. + """ + # Rebuild fp16/fp32 grad storages + for dtype in self._grad_storages.keys(): + for dst_rank, grad_storage in self._grad_storages[dtype].items(): + if dst_rank != self._rank: + grad_storage.manumal_relase() + grad_storage.rebuild() + + def _rank_buffer_size(self, buffer_max_size, model_size): + """ + Generate the minimum buffer size for each rank & Display param sizes and model sizes. + """ + + # Initialize buffer size + rank_buffer_size = {} + for shard_opt in self._sharding_optimizers: + if shard_opt.rank_buffer_size: + for dtype in shard_opt.rank_buffer_size.keys(): + sizes = max(shard_opt.rank_buffer_size[dtype].values()) + rank_buffer_size[dtype] = min(sizes, buffer_max_size) + + if Type.fp16.value in rank_buffer_size.keys(): + # FP16 GradStorage and model size + print( + "====== FP16 GradStorage size: {:.2f}M parameters, Model size {:.2f}M parameters ======". + format(rank_buffer_size[Type.fp16.value] / 2**19, model_size / 2 + **19)) + if Type.fp32.value in rank_buffer_size.keys(): + # FP32 GradStorage and model size + print( + "====== FP32 GradStorage size: {:.2f}M parameters, Model size {:.2f}M parameters ======". + format(rank_buffer_size[Type.fp32.value] / 2**18, model_size / 2 + **18)) + return rank_buffer_size diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 099dadd6173..15f857f6087 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -33,6 +33,7 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_pipeline_parallel) list(APPEND DIST_TEST_OPS test_parallel_dygraph_tensor_parallel) list(APPEND DIST_TEST_OPS test_parallel_dygraph_sharding_parallel) list(APPEND DIST_TEST_OPS test_dygraph_sharding_optimizer_stage2) +list(APPEND DIST_TEST_OPS test_dygraph_sharding_stage2) list(APPEND DIST_TEST_OPS test_auto_parallel_parallelizer) list(APPEND DIST_TEST_OPS test_parallel_dygraph_mp_layers) list(APPEND DIST_TEST_OPS test_hybrid_parallel_inference_helper) @@ -244,6 +245,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM)) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_tensor_parallel) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_sharding_parallel) list(REMOVE_ITEM TEST_OPS test_dygraph_sharding_optimizer_stage2) + list(REMOVE_ITEM TEST_OPS test_dygraph_sharding_stage2) list(REMOVE_ITEM TEST_OPS test_auto_parallel_parallelizer) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_mp_layers) LIST(REMOVE_ITEM TEST_OPS test_imperative_auto_mixed_precision) @@ -1039,6 +1041,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL) set_tests_properties(test_parallel_dygraph_tensor_parallel PROPERTIES TIMEOUT 200) set_tests_properties(test_parallel_dygraph_sharding_parallel PROPERTIES TIMEOUT 120) set_tests_properties(test_dygraph_sharding_optimizer_stage2 PROPERTIES TIMEOUT 120) + set_tests_properties(test_dygraph_sharding_stage2 PROPERTIES TIMEOUT 120) set_tests_properties(test_auto_parallel_parallelizer PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_mp_layers PROPERTIES TIMEOUT 120) set_tests_properties(test_hybrid_parallel_inference_helper PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py new file mode 100644 index 00000000000..bc62d18c860 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py @@ -0,0 +1,204 @@ +# -*- coding: UTF-8 -*- + +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import argparse +import ast +import time +import paddle +import paddle.fluid as fluid +from paddle.fluid.dygraph.nn import Linear +from paddle.distributed import fleet +from paddle.fluid.dygraph import nn + +from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import DygraphShardingOptimizer +from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2 import ShardingOptimizerStage2 +from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import ShardingStage2 + +seed = 2021 +epoch = 2 +batch_size = 32 + +strategy = fleet.DistributedStrategy() +strategy.hybrid_configs = { + "dp_degree": 2, + "mp_degree": 1, + "pp_degree": 1, + "sharding_degree": 1 +} +fleet.init(is_collective=True, strategy=strategy) + +np.random.seed(seed) +paddle.seed(seed) + + +class MLP(fluid.Layer): + def __init__(self, param_attr=None, bias_attr=None): + super(MLP, self).__init__() + + self._linear1 = Linear(10000, 10000) + self._linear2 = Linear(10000, 10000) + self._linear3 = Linear(10000, 10) + + def forward(self, inputs): + y = self._linear1(inputs) + y = self._linear2(y) + y = self._linear3(y) + return y + + +def reader_decorator(): + def __reader__(): + for _ in range(100): + img = np.random.rand(10000).astype('float32') + label = np.ones(1).astype('int64') + yield img, label + + return __reader__ + + +def optimizer_setting(model, use_pure_fp16, stage=1): + clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0) + optimizer = paddle.optimizer.AdamW( + parameters=model.parameters(), + learning_rate=0.001, + weight_decay=0.00001, + grad_clip=clip, + multi_precision=use_pure_fp16) + + return optimizer + + +def train_mlp(model, + sharding_stage, + use_pure_fp16=False, + all_test=False, + accumulate_grad=False): + if sharding_stage == 1: + hcg = fleet.get_hybrid_communicate_group() + group = hcg.get_check_parallel_group() + else: + group = paddle.distributed.new_group([0, 1]) + optimizer = optimizer_setting( + model=model, use_pure_fp16=use_pure_fp16, stage=sharding_stage) + + if use_pure_fp16: + model, optimizer = paddle.amp.decorate( + models=model, + optimizers=optimizer, + level='O2', + save_dtype='float32') + + if sharding_stage == 2: + optimizer = ShardingOptimizerStage2( + params=model.parameters(), optim=optimizer, group=group) + if all_test: + model = ShardingStage2( + model, optimizer, group=group, accumulate_grads=accumulate_grad) + else: + model = ShardingStage2(model, optimizer, group=group) + else: + optimizer = fleet.distributed_optimizer(optimizer) + model = fleet.distributed_model(model) + + train_reader = paddle.batch( + reader_decorator(), batch_size=batch_size, drop_last=True) + + train_loader = paddle.io.DataLoader.from_generator( + capacity=32, + use_double_buffer=True, + iterable=True, + return_list=True, + use_multiprocess=True) + train_loader.set_sample_list_generator(train_reader) + + for eop in range(epoch): + model.train() + + for batch_id, data in enumerate(train_loader()): + img, label = data + label.stop_gradient = True + img.stop_gradient = True + + with paddle.amp.auto_cast(enable=use_pure_fp16, level='O2'): + out = model(img) + loss = paddle.nn.functional.cross_entropy( + input=out, label=label) + + avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32)) + avg_loss.backward() + + if accumulate_grad and batch_id == 2: + model.grad_scale() + optimizer.step() + model.clear_gradients() + return model.parameters() + + if not accumulate_grad: + optimizer.step() + + if sharding_stage == 2: + model.clear_gradients() + else: + optimizer.clear_grad() + + if all_test and batch_id == 2: + return model.parameters() + + if sharding_stage == 2: + model.to(device="gpu") + + return model.parameters() + + +def test_stage1_stage2(): + mlp = MLP() + state_dict = mlp.state_dict() + mlp1 = MLP() + mlp2 = MLP() + mlp3 = MLP() + mlp4 = MLP() + mlp1.set_state_dict(state_dict) + mlp2.set_state_dict(state_dict) + mlp3.set_state_dict(state_dict) + mlp4.set_state_dict(state_dict) + stage1_params = train_mlp(mlp, sharding_stage=1, use_pure_fp16=False) + stage2_params = train_mlp(mlp, sharding_stage=2, use_pure_fp16=False) + for i in range(len(stage1_params)): + np.testing.assert_allclose( + stage1_params[i].numpy(), stage2_params[i].numpy(), rtol=1e-6) + + stage2_params = train_mlp( + mlp3, sharding_stage=2, use_pure_fp16=True, all_test=True) + stage2_accumulate_grad = train_mlp( + mlp4, + sharding_stage=2, + use_pure_fp16=True, + all_test=True, + accumulate_grad=True) + for i in range(len(stage2_params)): + for j in range(len(stage2_accumulate_grad)): + if stage2_params[i].name == stage2_accumulate_grad[j].name: + np.testing.assert_allclose( + stage2_params[i].numpy(), + stage2_accumulate_grad[j].numpy(), + rtol=1e-6) + + return + + +if __name__ == '__main__': + test_stage1_stage2() diff --git a/python/paddle/fluid/tests/unittests/test_dygraph_sharding_stage2.py b/python/paddle/fluid/tests/unittests/test_dygraph_sharding_stage2.py new file mode 100644 index 00000000000..c5cf8c5d5ed --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dygraph_sharding_stage2.py @@ -0,0 +1,31 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle.fluid as fluid + +from test_parallel_dygraph_dataparallel import TestMultipleGpus + + +class TestDygraphShardingStage2(TestMultipleGpus): + + # check sharding logic as well as the accuracy with single mode + def test_dygraph_sharding_optimizer_stage2(self): + self.run_mnist_2gpu('dygraph_sharding_stage2.py') + + +if __name__ == "__main__": + unittest.main() -- GitLab