From 5af64631bfc0d05ff31176eb6ee09b11bfa9aeed Mon Sep 17 00:00:00 2001 From: Baibaifan <39549453+Baibaifan@users.noreply.github.com> Date: Thu, 25 Nov 2021 16:30:24 +0800 Subject: [PATCH] Add InternalStorage and add ShardingOptimizerStage2 (#37489) --- .../dygraph_optimizer/__init__.py | 2 +- .../dygraph_sharding_optimizer.py | 276 +++++++++++++++++- .../fleet/meta_parallel/sharding/__init__.py | 15 + .../meta_parallel/sharding/sharding_utils.py | 93 ++++++ .../fleet/utils/internal_storage.py | 242 +++++++++++++++ .../fluid/tests/unittests/CMakeLists.txt | 3 + .../dygraph_sharding_optimizer_stage2.py | 134 +++++++++ .../test_dygraph_sharding_optimizer_stage2.py | 31 ++ 8 files changed, 785 insertions(+), 11 deletions(-) create mode 100644 python/paddle/distributed/fleet/meta_parallel/sharding/__init__.py create mode 100644 python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py create mode 100644 python/paddle/distributed/fleet/utils/internal_storage.py create mode 100644 python/paddle/fluid/tests/unittests/dygraph_sharding_optimizer_stage2.py create mode 100644 python/paddle/fluid/tests/unittests/test_dygraph_sharding_optimizer_stage2.py diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/__init__.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/__init__.py index 28260d7aa18..f7b346e5228 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/__init__.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/__init__.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and from .hybrid_parallel_optimizer import HybridParallelOptimizer from .hybrid_parallel_gradscaler import HybridParallelGradScaler -from .dygraph_sharding_optimizer import DygraphShardingOptimizer +# from .dygraph_sharding_optimizer import DygraphShardingOptimizer __all__ = [] diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py index 4bddde6b5b6..9512c43425b 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,16 +11,35 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +#Taken and modified for fairscale from: +# https://github.com/facebookresearch/fairscale/blob/main/fairscale/optim/oss.py +#Commit: 8acbec718f3c70a6b9785470bb9e05cd84fc3f8e -###### +import numpy as np +from itertools import chain from functools import reduce +from collections import OrderedDict import paddle from paddle import framework +import paddle.distributed as dist +from paddle.optimizer import Optimizer + from ...utils.log_util import logger +from ...utils.internal_storage import ParamStorage +from ...meta_parallel.sharding.sharding_utils import Type + +# CUDA alignment 256 bytes +alignment = {"gpu": 256, } +align = { + Type.fp16.value: 2, + Type.fp32.value: 4, +} + +__all__ = ["ShardingOptimizerStage2"] -def _is_trainable(param: paddle.Tensor) -> bool: +def _is_trainable(param): return not param.stop_gradient @@ -41,13 +60,8 @@ class DygraphShardingOptimizer(object): # 3. dynamic trainable params, which is the case bewteen pretraining and finetuning # 4. option to choose fuse comm (more GPU MEM need) or un-fuse comm - def __init__( - self, - hcg, - user_defined_strategy, - params, - inner_optimizer_class, - **inner_optimizer_kargs, ): + def __init__(self, hcg, user_defined_strategy, params, + inner_optimizer_class, **inner_optimizer_kargs): if not isinstance(params, list): raise TypeError( @@ -196,3 +210,245 @@ class DygraphShardingOptimizer(object): def __getattr__(self, item): return getattr(self._inner_optimizer, item) + + +class ShardingOptimizerStage2(Optimizer): + """ + A wrapper for Sharding Stage2 Optimizer in Dygraph. + + .. warning: ShardingOptimizer encapsulates the optimization strategy and integrates it into the optimizer. + + .. ZeRO: 1.https://arxiv.org/pdf/1910.02054.pdf 2.https://arxiv.org/pdf/1910.02054.pdf. + + """ + + # TODO (Baibaifan) + # Feature Notes: + # 1. Unified memory for parameters and parameters.grad to InternalStorage. + # 2. Support the segmentation of optimizer parameters and partial updating of parameters. + # 3. Dynamically adjust training parameters and models。 + # 4. Support offload function. + # 5. Support the establishment of independent communication groups. + # 6. Broadcast_fp16 is not supported now. + def __init__(self, + params, + optim, + group, + broadcast_fp16=False, + offload=False, + device="gpu", + accumulation_steps=None, + **kw): + + super().__init__(optim._learning_rate, params, kw) + + # Segmentation information + self._dtype_rank_params = OrderedDict( + ) # {dtype:[param1,param2]} device, rank, params + self._param2rank = {} + self._segment_params = [] + self._rank_buffer_size = {} # {dtype: {rank: numel+alignment}} + self._param2align = {} # {param.name: align} + + # Default information + self._optim_defaults = kw + self._optim = optim + self._local_params = params + self._default_device = device + self._accumulation_steps = accumulation_steps + + assert group is not None, "Distributed communication group is must be gived" + self.group = group + self.world_size = group.nranks + self.rank = group.rank + + self.broadcast_fp16 = broadcast_fp16 + self.param_storages = {} # {dtype: {rank: InternalStorage}} + self.offload = offload # Using for offload + + # Update optimizer parameters and adjust parameter storage and use according to rank. + self.update_opt_status() + + def update_opt_status(self): + """Update optimizer status and parameter storage information, and special functions to be developed. + """ + # func 1 + self._integration_params() + + # fun 2 TODO + + # Segement helpers + + def segment_params(self): + """ + Divide all optimizer parameters equally into rank. + """ + if len(self._segment_params) == 0: + self._segment_params, param_lists = [ + [] for _ in range(self.world_size) + ], [[] for _ in range(self.world_size)] + sizes = [0] * self.world_size + for param in self._local_params: + # Add this param to rank with smallest size. + rank = sizes.index(min(sizes)) + param_lists[rank].append(param) + + # Statistical real numels + sizes[rank] += np.prod(param.shape) if param.trainable else 0 + + for rank, params in enumerate(param_lists): + # param_group_rank = copy.copy(params) + self._segment_params[rank].extend(params) + return self._segment_params + + @property + def local_params(self): + return self._local_params + + @property + def accumulation_steps(self): + return self._accumulation_steps + + @property + def param2rank(self): + """Map the params to the rank which owns them""" + if len(self._param2rank) == 0: + for rank, params in enumerate(self.segment_params()): + for param in params: + self._param2rank[param.name] = rank + return self._param2rank + + @property + def dtype_rank_params(self): + """ + Divide the parameters into groups according to rank and dtype. + """ + if len(self._dtype_rank_params) == 0: + # Assign the parameters of each rank according to the type + for param in self._local_params: + if param.dtype not in self._dtype_rank_params.keys(): + self._dtype_rank_params[ + param.dtype] = [[] for _ in range(self.world_size)] + self._dtype_rank_params[param.dtype][self.param2rank[ + param.name]].append(param) + + # Sort per rank params by size + for dtype in self._dtype_rank_params.keys(): + for rank_params in self._dtype_rank_params[dtype]: + rank_params.sort(key=lambda x: np.prod(x.shape)) + + return self._dtype_rank_params + + @property + def rank_buffer_size(self): + """ + Count the memory size of the parameters corresponding to rank under the corresponding dtype. + """ + # CUDA alignment 256 bytes + if len(self._rank_buffer_size) == 0: + for dtype in self.dtype_rank_params.keys(): + if dtype not in self._rank_buffer_size.keys(): + self._rank_buffer_size[dtype] = {} + for dst_rank, per_rank_params in enumerate( + self.dtype_rank_params[dtype]): + if dst_rank not in self._rank_buffer_size[dtype].keys(): + self._rank_buffer_size[dtype][dst_rank] = 0 + for param in per_rank_params: + if not param.trainable: + continue + size = np.prod(param.shape) * align[dtype] + remaining = size % alignment[self._default_device] + ali = 0 if remaining == 0 else alignment[ + self._default_device] - remaining + align_ = ali // align[dtype] + self._rank_buffer_size[dtype][dst_rank] += np.prod( + param.shape) + align_ + self._param2align[param.name] = align_ + + return self._rank_buffer_size + + def _integration_params(self): + """ + Integrate the parameters into a continuous memory according to rank, and support the update of training parameters. + """ + + for dtype, per_rank_params in self.dtype_rank_params.items(): + if dtype not in self.param_storages.keys(): + self.param_storages[dtype] = {} + + for dst_rank, params in enumerate(per_rank_params): + if len(params) > 0: + + # Merge all the trainable params in a single InternalStorage + trainable_params = list( + filter(lambda x: x.trainable, params)) + if trainable_params: + param_storage = ParamStorage( + size=self.rank_buffer_size[dtype][dst_rank], + dtype=dtype, + device=self._default_device) + + param_storage.add_rank_params(trainable_params, + self._param2align) + self.param_storages[dtype][dst_rank] = param_storage + + # Clear the InternalStorage keys which are not in use anymore + dtype_in_use = list(self.dtype_rank_params.keys()) + dtype_to_pop = list( + filter(lambda x: x not in dtype_in_use, self.param_storages.keys())) + for d in dtype_to_pop: + self.param_storages.pop(d) + + def step(self): + """ + A wrapper for Optimizer's step function to finish the update operation of the optimizer. + """ + + # Synchronize optimizer parameters for the current rank + if len(self.dtype_rank_params.keys( + )) == 1 and Type.fp32.value in self.dtype_rank_params.keys(): + self._optim._parameter_list = self.dtype_rank_params[ + Type.fp32.value][self.rank] + elif len(self.dtype_rank_params.keys( + )) == 1 and Type.fp16.value in self.dtype_rank_params.keys(): + self._optim._parameter_list = self.dtype_rank_params[ + Type.fp16.value][self.rank] + else: + self._optim._parameter_list = self.dtype_rank_params[ + Type.fp16.value][self.rank] + self.dtype_rank_params[ + Type.fp32.value][self.rank] + + # Run the optimizer of the current rank step + self._optim.step() + + # Synchronize all the updated shards in between the ranks + self._broadcast_params() + + # Return full parameters to optimizer parameters + self._optim._parameter_list = self._local_params + + def clear_cache(self): + self._segment_params.clear() + self._dtype_rank_params.clear() + self._param2rank.clear() + + @paddle.no_grad() + def _broadcast_params(self): + """Broadcast the parameters of the current rank to each rank""" + + assert self._default_device == "gpu", "Only supported gpu" + + # Exchange all the shards with the other ranks + for dtype_per_rank in self.param_storages.values(): + for dst_rank, internal_storage in dtype_per_rank.items(): + dist.broadcast( + tensor=internal_storage.buffer, + src=dst_rank, + group=self.group, + use_calc_stream=True) + + # Multi stream operation will be supported later + dist.wait( + tensor=internal_storage.buffer, + group=self.group, + use_calc_stream=True) diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/__init__.py b/python/paddle/distributed/fleet/meta_parallel/sharding/__init__.py new file mode 100644 index 00000000000..845879ba38f --- /dev/null +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .sharding_utils import GpuInfo diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py new file mode 100644 index 00000000000..4cf40005c1f --- /dev/null +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py @@ -0,0 +1,93 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import contextlib +from collections import abc +from enum import Enum +from math import inf + +import paddle +import paddle.distributed as dist +from paddle.fluid import core + +# Set global device id +global dev_id +if core.is_compiled_with_cuda(): + dev_id = int(os.environ.get('FLAGS_selected_gpus', 0)) +elif core.is_compiled_with_npu(): + dev_id = int(os.environ.get('FLAGS_selected_npus', 0)) +else: + raise ValueError("This device doesn't support.") + + +class Taskflow: + """ + Task flows, one way linked list for task acquisition. + """ + + def __init__(self, task, callback): + self.task = task + self.callback = callback + + +class Type(Enum): + """ + Type of trainable parameters + """ + fp16 = paddle.float16 + fp32 = paddle.float32 + + +def GpuInfo(fn): + """ + Displays GPU usage information before and after the function。 + """ + + def used(*args, **kw): + # Before using + b_info = os.popen("nvidia-smi -i {} | grep MiB".format(str( + dev_id))).read() + before_info = (int(b_info.split()[8][:-3]), + int(b_info.split()[10][:-3])) + print( + "====== Current device {} ====== Total has {} MiB, Has used {} MiB ======". + format(str(dev_id), str(before_info[1]), str(before_info[0]))) + result = fn(*args, **kw) + # After using + a_info = os.popen("nvidia-smi -i {} | grep MiB".format(str( + dev_id))).read() + after_info = (int(a_info.split()[8][:-3]), int(a_info.split()[10][:-3])) + print( + "====== Current device {} ====== Total has {} MiB, Has used {} MiB, Self use {} MiB ======". + format( + str(dev_id), + str(after_info[1]), + str(after_info[0]), str(after_info[0] - before_info[0]))) + return result + + return used + + +@contextlib.contextmanager +def device_guard(dev_id, device="cpu"): + origin_device = paddle.device.get_device() + if device == "cpu": + paddle.set_device(device) + elif device == "gpu": + paddle.set_device("gpu:{}".format(dev_id)) + try: + yield + finally: + paddle.set_device(origin_device) diff --git a/python/paddle/distributed/fleet/utils/internal_storage.py b/python/paddle/distributed/fleet/utils/internal_storage.py new file mode 100644 index 00000000000..96947221f31 --- /dev/null +++ b/python/paddle/distributed/fleet/utils/internal_storage.py @@ -0,0 +1,242 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#Taken and modified for fairscale from: +# https://github.com/facebookresearch/fairscale/blob/main/fairscale/nn/misc/param_bucket.py +#Commit: 8acbec718f3c70a6b9785470bb9e05cd84fc3f8e + +import os +import time +import numpy as np + +import paddle +from paddle.fluid import core +from ..meta_parallel.sharding.sharding_utils import Type, device_guard + +# Set global device id +global dev_id +if core.is_compiled_with_cuda(): + dev_id = int(os.environ.get('FLAGS_selected_gpus', 0)) +elif core.is_compiled_with_npu(): + dev_id = int(os.environ.get('FLAGS_selected_npus', 0)) +else: + raise ValueError("This device doesn't support.") + + +class InternalStorage: + """ + This is a basic class, which is responsible for consolidating the basic storage tensor. + + """ + + # Support integration parameter tensor + def __init__(self, size, dtype, device, convert_cpu=False): + self._params = [] + self._param_ids = [] + self._fill = 0 + self._device = device + self._dtype = dtype + + # The actual flat tensor + size = [size] if isinstance(size, int) else size + if convert_cpu: + value = np.zeros( + size, + dtype=np.float16) if Type.fp16.value == dtype else np.zeros( + size, dtype=np.float32) + self.buffer = core.VarBase(value=value, place=core.CPUPlace()) + else: + self.buffer = paddle.zeros(size, dtype=dtype) + + +class ParamStorage(InternalStorage): + """ + This is a basic class to simplify the handling of parameter InternalStorages. + """ + + def __init__(self, size, dtype, device): + super().__init__(size, dtype, device, convert_cpu=True) + self.param2align = None + + @paddle.no_grad() + def add_rank_params(self, trainable_params, param2align): + """ + Add new parameters to the InternalStorage. Params becomes a view of this InternalStorage buffer. + """ + + assert all([ + id(param) not in self._param_ids for param in trainable_params + ]), "The same param cannot be checked in twice" + assert self.buffer is not None + + self.param2align = param2align + + cpu_param_shape = list() + for param in trainable_params: + p_shape = self._add_param_as_view(param, param2align[param.name]) + cpu_param_shape.append(p_shape) + + # buffer covert from cpu to cuda + self.buffer = self.buffer.cuda(dev_id) + self._fill = 0 + + for idx, param in enumerate(trainable_params): + self._convert_buffer(param, cpu_param_shape[idx], + param2align[param.name]) + self._params.append(param) + self._param_ids.append(id(param)) + + @paddle.no_grad() + def _add_param_as_view(self, param, align): + + assert ( + param.dtype == self.buffer.dtype + ), "Different types for the InternalStorage and the param, cannot proceed: {} - {}".format( + param.dtype, self.buffer.dtype) + + var_end = self._fill + np.prod(param.shape) + offset = var_end + align + assert offset <= np.prod(self.buffer.shape) + + p_shape = param.shape + + origin_state = param.stop_gradient + param.stop_gradient = True + param.flatten_() + param.stop_gradient = origin_state + + # Copy the current param value + with device_guard(dev_id, "cpu"): + tmp_var = core.VarBase(tensor=self.buffer._slice(self._fill, + var_end)) + param_cpu = param.cpu() + param.value().get_tensor()._clear() + tmp_var.set_value(param_cpu) + + self._fill = offset + return p_shape + + @paddle.no_grad() + def _convert_buffer(self, param, p_shape, align): + + var_end = self._fill + np.prod(p_shape) + offset = var_end + align + assert offset <= np.prod(self.buffer.shape) + + # Convert the param value + tmp_tensor = self.buffer._slice(self._fill, var_end) + param.value().get_tensor()._share_data_with(tmp_tensor) + param.value().get_tensor()._set_dims(p_shape) + + self._fill = offset + + +class GradStorage(InternalStorage): + """ + This is a basic class to simplify the handling of gradient InternalStorages + """ + + def __init__(self, size, dtype, device, destination, parm2align): + if isinstance(size, np.int64): + size = size.tolist() + super().__init__(size, dtype, device) + + self._max_size = size + self._release = False + + self.params_checked_in = 0 + self.destination = destination + self._parm2align = parm2align + self.sent = False + + def reset_checked_in(self): + """ Reset the counter of the parameter grads which have been checked in + """ + self.params_checked_in = 0 + self.sent = False + + @property + def all_checked_in(self): + """ Judge all the expected gradient check-in happened """ + return len(self._params) == self.params_checked_in + + def can_add_grad_view(self, param, align): + """ Is there enough InternalStorage to add this parameter gradient, and whether this param have already checked in. + """ + return self._fill + np.prod( + param.shape) + align <= self._max_size and id( + param) not in self._param_ids + + @paddle.no_grad() + def add_grad(self, param, align): + """ + Add a new parameter gradient to the InternalStorage. Param.grad becomes a view of this InternalStorage buffer. + """ + + assert id( + param + ) not in self._param_ids, "The same gradients cannot be checked in twice" + + self._add_grad_as_view(param, align) + self._params.append(param) + self._param_ids.append(id(param)) + + @paddle.no_grad() + def manumal_relase(self): + """ + Release the buffer from InternalStorage. The InternalStorage will need to be rebuilt before use. + """ + if not self._release: + for p in self._params: + if p.grad is not None: + p.clear_gradient(False) + p._gradient_set_empty(False) + + self.buffer = None + self._fill = 0 + self.params_checked_in = 0 + self._release = True + + @paddle.no_grad() + def rebuild(self): + """ + Given the parameter gradients which have been registered previously, rebuild the whole InternalStorage. + """ + assert len(self._params) > 0 + + if self._release: + self.buffer = paddle.zeros( + [self._max_size], dtype=self._params[0].dtype) + + for p in self._params: + self._add_grad_as_view(p, self._parm2align[p.name]) + + self._release = False + + @paddle.no_grad() + def _add_grad_as_view(self, param, align): + assert np.prod( + self.buffer.shape + ) > 0, "Cannot add a gradient to a released InternalStorage, please rebuild" + assert param.dtype == self.buffer.dtype + + grad_end = self._fill + np.prod(param.shape) + offset = grad_end + align + assert offset <= np.prod(self.buffer.shape) + + # Copy the current grad value to InternalStorage + assert self._device == "gpu" + tmp_var = core.VarBase(self.buffer._slice(self._fill, grad_end)) + param._copy_gradient_from(tmp_var) + tmp_var.value().get_tensor()._clear() + self._fill = offset diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 4698b1dcb27..2b49119c4fb 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -32,6 +32,7 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_dataparallel) list(APPEND DIST_TEST_OPS test_parallel_dygraph_pipeline_parallel) list(APPEND DIST_TEST_OPS test_parallel_dygraph_tensor_parallel) list(APPEND DIST_TEST_OPS test_parallel_dygraph_sharding_parallel) +list(APPEND DIST_TEST_OPS test_dygraph_sharding_optimizer_stage2) list(APPEND DIST_TEST_OPS test_auto_parallel_parallelizer) list(APPEND DIST_TEST_OPS test_parallel_dygraph_mp_layers) list(APPEND DIST_TEST_OPS test_hybrid_parallel_inference_helper) @@ -242,6 +243,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM)) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_pipeline_parallel) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_tensor_parallel) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_sharding_parallel) + list(REMOVE_ITEM TEST_OPS test_dygraph_sharding_optimizer_stage2) list(REMOVE_ITEM TEST_OPS test_auto_parallel_parallelizer) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_mp_layers) LIST(REMOVE_ITEM TEST_OPS test_imperative_auto_mixed_precision) @@ -1036,6 +1038,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL) set_tests_properties(test_parallel_dygraph_pipeline_parallel PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_tensor_parallel PROPERTIES TIMEOUT 200) set_tests_properties(test_parallel_dygraph_sharding_parallel PROPERTIES TIMEOUT 120) + set_tests_properties(test_dygraph_sharding_optimizer_stage2 PROPERTIES TIMEOUT 120) set_tests_properties(test_auto_parallel_parallelizer PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_mp_layers PROPERTIES TIMEOUT 120) set_tests_properties(test_hybrid_parallel_inference_helper PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/dygraph_sharding_optimizer_stage2.py b/python/paddle/fluid/tests/unittests/dygraph_sharding_optimizer_stage2.py new file mode 100644 index 00000000000..7a5ec28dd1a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_sharding_optimizer_stage2.py @@ -0,0 +1,134 @@ +# -*- coding: UTF-8 -*- + +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import numpy as np +import argparse +import ast +import time +import paddle +import paddle.fluid as fluid +from paddle.fluid.dygraph.nn import Linear +from paddle.distributed import fleet + +from paddle.distributed.fleet.meta_parallel.sharding.sharding_utils import GpuInfo +from paddle.distributed.fleet.utils.internal_storage import GradStorage +from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ShardingOptimizerStage2 + +base_lr = 0.1 +momentum_rate = 0.9 +l2_decay = 1e-4 + +epoch = 100 +batch_size = 32 +class_dim = 102 + + +class MLP(fluid.Layer): + def __init__(self, param_attr=None, bias_attr=None): + super(MLP, self).__init__() + + self._linear1 = Linear(10, 10) + self._linear2 = Linear(10, 10) + + def forward(self, inputs): + y = self._linear1(inputs) + y = self._linear2(y) + return y + + +def reader_decorator(): + def __reader__(): + for _ in range(100): + img = np.random.rand(10).astype('float32') + label = np.ones(1).astype('int64') + yield img, label + + return __reader__ + + +def optimizer_setting(parameter_list=None): + optimizer = paddle.optimizer.Momentum( + learning_rate=base_lr, + momentum=momentum_rate, + weight_decay=paddle.regularizer.L2Decay(l2_decay), + parameters=parameter_list) + return optimizer + + +@GpuInfo +def train_mlp(): + fleet.init(is_collective=True) + group = paddle.distributed.new_group([0, 1]) + + mlp = MLP() + + optimizer = optimizer_setting(parameter_list=mlp.parameters()) + oss_optimizer = ShardingOptimizerStage2( + params=mlp.parameters(), optim=optimizer, group=group) + # cover grad_storage code + trainable_param2align = dict() + for p in mlp.parameters(): + trainable_param2align[p.name] = 0 + grad_storage = GradStorage( + 10000, + dtype=paddle.float32, + device="gpu", + destination=0, + parm2align=trainable_param2align) + for p in mlp.parameters(): + grad_storage.can_add_grad_view(p, trainable_param2align[p.name]) + grad_storage.add_grad(p, trainable_param2align[p.name]) + grad_storage.manumal_relase() + grad_storage.rebuild() + grad_storage.reset_checked_in() + + train_reader = paddle.batch( + reader_decorator(), batch_size=batch_size, drop_last=True) + + train_loader = paddle.io.DataLoader.from_generator( + capacity=32, + use_double_buffer=True, + iterable=True, + return_list=True, + use_multiprocess=True) + train_loader.set_sample_list_generator(train_reader) + + for eop in range(epoch): + mlp.train() + + for batch_id, data in enumerate(train_loader()): + img, label = data + label.stop_gradient = True + img.stop_gradient = True + + out = mlp(img) + loss = paddle.nn.functional.cross_entropy(input=out, label=label) + avg_loss = paddle.mean(x=loss) + acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1) + acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5) + + dy_out = avg_loss.numpy() + + avg_loss.backward() + oss_optimizer.step() + + # oss_optimizer clear cache + oss_optimizer.clear_cache() + + +if __name__ == '__main__': + train_mlp() diff --git a/python/paddle/fluid/tests/unittests/test_dygraph_sharding_optimizer_stage2.py b/python/paddle/fluid/tests/unittests/test_dygraph_sharding_optimizer_stage2.py new file mode 100644 index 00000000000..deb180a2fe1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dygraph_sharding_optimizer_stage2.py @@ -0,0 +1,31 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle.fluid as fluid + +from test_parallel_dygraph_dataparallel import TestMultipleGpus + + +class TestDygraphShardingOptimizerStage2(TestMultipleGpus): + + # check sharding logic as well as the accuracy with single mode + def test_dygraph_sharding_optimizer_stage2(self): + self.run_mnist_2gpu('dygraph_sharding_optimizer_stage2.py') + + +if __name__ == "__main__": + unittest.main() -- GitLab