未验证 提交 f40ed5f4 编写于 作者: B Baibaifan 提交者: GitHub

add_sharding_api (#40129)

上级 1defc8f3
...@@ -55,6 +55,7 @@ from paddle.fluid.dygraph.parallel import ParallelEnv # noqa: F401 ...@@ -55,6 +55,7 @@ from paddle.fluid.dygraph.parallel import ParallelEnv # noqa: F401
from . import cloud_utils # noqa: F401 from . import cloud_utils # noqa: F401
from . import utils # noqa: F401 from . import utils # noqa: F401
from .sharding import * # noqa: F401
__all__ = [ # noqa __all__ = [ # noqa
"spawn", "spawn",
......
...@@ -40,8 +40,6 @@ align = { ...@@ -40,8 +40,6 @@ align = {
Type.fp32.value: 4, Type.fp32.value: 4,
} }
__all__ = ["ShardingOptimizerStage2"]
class ShardingOptimizerStage2(Optimizer): class ShardingOptimizerStage2(Optimizer):
""" """
...@@ -136,7 +134,7 @@ class ShardingOptimizerStage2(Optimizer): ...@@ -136,7 +134,7 @@ class ShardingOptimizerStage2(Optimizer):
# Update optimizer parameters and adjust parameter storage and use according to rank. # Update optimizer parameters and adjust parameter storage and use according to rank.
self._update_opt_status() self._update_opt_status()
@paddle.no_grad() @paddle.autograd.no_grad()
def _sync_params_and_buffers(self): def _sync_params_and_buffers(self):
""" """
Sync all model states for all ranks Sync all model states for all ranks
...@@ -392,7 +390,7 @@ class ShardingOptimizerStage2(Optimizer): ...@@ -392,7 +390,7 @@ class ShardingOptimizerStage2(Optimizer):
self._dtype_rank_params.clear() self._dtype_rank_params.clear()
self._param2rank.clear() self._param2rank.clear()
@fluid.dygraph.no_grad @paddle.autograd.no_grad()
def _broadcast_params(self): def _broadcast_params(self):
"""Broadcast the parameters of the current rank to each rank""" """Broadcast the parameters of the current rank to each rank"""
......
...@@ -63,8 +63,7 @@ class ShardingStage2(nn.Layer): ...@@ -63,8 +63,7 @@ class ShardingStage2(nn.Layer):
sync_buffers=False, sync_buffers=False,
buffer_max_size=2**23, #8MB buffer_max_size=2**23, #8MB
auto_refresh_trainable=True, auto_refresh_trainable=True,
device="gpu", device="gpu"):
use_grad_storage=True):
super().__init__() super().__init__()
# training options # training options
...@@ -102,9 +101,10 @@ class ShardingStage2(nn.Layer): ...@@ -102,9 +101,10 @@ class ShardingStage2(nn.Layer):
# Set grad storage size & Display param sizes and model sizes # Set grad storage size & Display param sizes and model sizes
model_size = sum( model_size = sum(
[np.prod(p.shape) for p in self._layer.parameters()]).item() [np.prod(p.shape) for p in self._layer.parameters()]).item()
assert buffer_max_size >= 0, "buffer_max_size must be GE than 0."
self._buffer_max_size = self._rank_buffer_size(buffer_max_size, self._buffer_max_size = self._rank_buffer_size(buffer_max_size,
model_size) model_size)
self._use_grad_storage = use_grad_storage self._use_grad_storage = buffer_max_size > 0
self._grad_storages = {} # {dtype: {rank: GradStorage}} self._grad_storages = {} # {dtype: {rank: GradStorage}}
self._has_grad_storage = [] self._has_grad_storage = []
self._grad_storage_list = [] self._grad_storage_list = []
...@@ -255,7 +255,7 @@ class ShardingStage2(nn.Layer): ...@@ -255,7 +255,7 @@ class ShardingStage2(nn.Layer):
# wait next func hook support # wait next func hook support
self._setup_backward_hooks() self._setup_backward_hooks()
@paddle.no_grad() @paddle.autograd.no_grad()
def __sync_buffers(self): def __sync_buffers(self):
""" """
Sync all the param buffers from all ranks (exp: batch norm statistics). Sync all the param buffers from all ranks (exp: batch norm statistics).
...@@ -277,7 +277,7 @@ class ShardingStage2(nn.Layer): ...@@ -277,7 +277,7 @@ class ShardingStage2(nn.Layer):
except AttributeError: except AttributeError:
return getattr(self._layer, name) return getattr(self._layer, name)
@paddle.no_grad() @paddle.autograd.no_grad()
def _clear_counters(self): def _clear_counters(self):
"""Reset all the grad reduce and call counters.""" """Reset all the grad reduce and call counters."""
if self.training: if self.training:
...@@ -290,13 +290,13 @@ class ShardingStage2(nn.Layer): ...@@ -290,13 +290,13 @@ class ShardingStage2(nn.Layer):
def _get_reduce_fn(self, index, param, dst_rank): def _get_reduce_fn(self, index, param, dst_rank):
""" """
There are two ways to reduce gradient. There are two ways to reduce gradient.
- 1. Do not use use_grad_storage or exceeded buffer_max_size will be reduced separately. - 1. Do not use self._use_grad_storage or exceeded buffer_max_size will be reduced separately.
- 2. Use grad_storage Reduce the storage to get the full gradient from different ranks. - 2. Use grad_storage Reduce the storage to get the full gradient from different ranks.
""" """
if not self._use_grad_storage or not self._has_grad_storage[index]: if not self._use_grad_storage or not self._has_grad_storage[index]:
# Direct reduction # Direct reduction
@paddle.no_grad() @paddle.autograd.no_grad()
def reduce(*_): def reduce(*_):
# Skip gradient reduction, do not change status information # Skip gradient reduction, do not change status information
if self._grad_reduced[index]: if self._grad_reduced[index]:
...@@ -336,7 +336,7 @@ class ShardingStage2(nn.Layer): ...@@ -336,7 +336,7 @@ class ShardingStage2(nn.Layer):
else: else:
# Buffer reduction # Buffer reduction
@paddle.no_grad() @paddle.autograd.no_grad()
def reduce(*_): def reduce(*_):
# Skip gradient reduction, do not change status information # Skip gradient reduction, do not change status information
if self._grad_reduced[index]: if self._grad_reduced[index]:
...@@ -421,9 +421,6 @@ class ShardingStage2(nn.Layer): ...@@ -421,9 +421,6 @@ class ShardingStage2(nn.Layer):
Integrate the parameters gradient into a continuous memory according to rank, and support the update of training parameters. Integrate the parameters gradient into a continuous memory according to rank, and support the update of training parameters.
""" """
if not self._use_grad_storage:
return
# According to parameters's numel sort, allocate memory of parameter gradient to continuous memory according to rank # According to parameters's numel sort, allocate memory of parameter gradient to continuous memory according to rank
self._grad_storages = {} self._grad_storages = {}
self._has_grad_storage = [False for _ in self._trainable_params] self._has_grad_storage = [False for _ in self._trainable_params]
......
...@@ -84,6 +84,7 @@ class ShardingStage3(nn.Layer): ...@@ -84,6 +84,7 @@ class ShardingStage3(nn.Layer):
self._offload = offload self._offload = offload
self._sync_comm = sync_comm self._sync_comm = sync_comm
# segmentation size # segmentation size
assert segment_size >= 0, "segment_size must be GE than 0."
self._segment_size = segment_size self._segment_size = segment_size
global DEV global DEV
...@@ -158,7 +159,7 @@ class ShardingStage3(nn.Layer): ...@@ -158,7 +159,7 @@ class ShardingStage3(nn.Layer):
self._redefine_opt_step() self._redefine_opt_step()
self._redefine_opt_clear() self._redefine_opt_clear()
@paddle.no_grad() @paddle.autograd.no_grad()
def _sync_params_and_buffers(self): def _sync_params_and_buffers(self):
""" """
Sync all model states for all ranks Sync all model states for all ranks
...@@ -408,7 +409,7 @@ class ShardingStage3(nn.Layer): ...@@ -408,7 +409,7 @@ class ShardingStage3(nn.Layer):
# register post forward hooks # register post forward hooks
sub_layer.register_forward_post_hook(_forward_post_hook) sub_layer.register_forward_post_hook(_forward_post_hook)
@paddle.no_grad() @paddle.autograd.no_grad()
def _sync_buffers(self): def _sync_buffers(self):
""" """
Sync all the param buffers from all ranks (exp: batch norm statistics). Sync all the param buffers from all ranks (exp: batch norm statistics).
...@@ -521,7 +522,7 @@ class ShardingStage3(nn.Layer): ...@@ -521,7 +522,7 @@ class ShardingStage3(nn.Layer):
param._register_backward_hook(allreduce_function) param._register_backward_hook(allreduce_function)
def _get_allreduce_fn(self, param): def _get_allreduce_fn(self, param):
@paddle.no_grad() @paddle.autograd.no_grad()
def reduce(*_): def reduce(*_):
if param.name in self._task_flow.full_grad.keys(): if param.name in self._task_flow.full_grad.keys():
full_grad = self._task_flow.full_grad[param.name] full_grad = self._task_flow.full_grad[param.name]
...@@ -840,7 +841,7 @@ def _allgather_buffer(trainable_params, ...@@ -840,7 +841,7 @@ def _allgather_buffer(trainable_params,
return task_flow return task_flow
@paddle.no_grad() @paddle.autograd.no_grad()
def _create_params_grad(trainable_params, param2buffer_size, task_flow): def _create_params_grad(trainable_params, param2buffer_size, task_flow):
for param in trainable_params: for param in trainable_params:
if param.name in task_flow.full_grad.keys(): if param.name in task_flow.full_grad.keys():
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .group_sharded import group_sharded_parallel, save_group_sharded_model # noqa: F401
__all__ = ['group_sharded_parallel', 'save_group_sharded_model']
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import logging
from enum import Enum
import paddle
from paddle.optimizer import Optimizer
from paddle.distributed.utils import get_logger
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2 import ShardingOptimizerStage2
from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import ShardingStage2
from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage3 import ShardingStage3
from paddle.distributed.fleet.meta_parallel.sharding.sharding_utils import ShardingScaler
logger_ = get_logger(logging.INFO)
def group_sharded_parallel(model,
optimizer,
level,
scaler=None,
group=None,
offload=False,
sync_buffers=False,
buffer_max_size=2**23,
segment_size=2**20,
sync_comm=False):
"""
Use this module to configure and wrap up the parameters of the group shared module.
Args:
model (Layer): The layer to be wrapped with group_sharded_parallel.
optimizer (Optimizer): The optimizer to be wrapped with group_sharded_parallel.
level (str): The different level of the group sharded. Such as `os`, `os_g`, `p_g_os`.
scaler (GradScaler, optional): The scaler to be wrapped with group_sharded_parallel. Defaults to None.
group (Group, optional): The group instance. Defaults to None.d
offload (bool, optional): Whether to perform optimizer state and gradient transfer CPU. Defaults to False.
sync_buffers (bool, optional): Whether to broadcast model buffers. Defaults to False.
buffer_max_size (int, optional): The max size of the buffer used to integrate gradient in `os_g`. Defaults to 2**23.
segment_size (int, optional): The smallest size of parameter to be sharded in `p_g_os`. Defaults to 2**20.
sync_comm (bool, optional): Whether to use synchronous communication, only in `p_g_os` used. Defaults to False.
Returns:
model: A wrapper for group sharded given model.
optimizer: A wrapper for group sharded given optimizer.
scaler: A wrapper for group sharded given scaler.
Examples:
.. code-block:: python
# required: distributed
import paddle
from paddle.fluid.dygraph.nn import Linear
from paddle.distributed import fleet
from paddle.distributed.sharding import group_sharded_parallel
fleet.init(is_collective=True)
group = paddle.distributed.new_group([0, 1])
model = Linear(1000, 1000)
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
optimizer = paddle.optimizer.AdamW(learning_rate=0.001, parameters=model.parameters(), weight_decay=0.00001, grad_clip=clip)
# wrap sharding model, optimizer and scaler
model, optimizer, scaler = group_sharded_parallel(model, optimizer, "p_g", scaler=scaler)
img, label = data
label.stop_gradient = True
img.stop_gradient = True
out = model(img)
loss = paddle.nn.functional.cross_entropy(input=out, label=label)
loss.backward()
optimizer.step()
optimizer.clear_grad()
"""
# check optition type
assert isinstance(
model,
paddle.nn.Layer), "The model must be the instance of paddle.nn.Layer."
assert isinstance(
optimizer, Optimizer
), "The optimizer must be the instance of paddle.optimizer.Optimizer."
assert level in ['os', 'os_g', 'p_g_os'
], "The level must be os, os_g or p_g_os."
def check_dtype(param):
return param.dtype == paddle.float16
params_fp16 = filter(check_dtype, model.parameters())
if scaler is None and len(params_fp16) > 0:
raise ValueError("Please enter the correct scaler.")
# convert model/optimizer/scaler
if level in ['os', 'os_g']:
logger_.info("*" * 30)
logger_.info("Sharded level os uses sharded level os_g achieved now.")
logger_.info("*" * 30)
optimizer = ShardingOptimizerStage2(
params=model.parameters(),
optim=optimizer,
group=group,
offload=offload)
model = ShardingStage2(
model,
optimizer,
group=group,
sync_buffers=sync_buffers,
buffer_max_size=buffer_max_size)
elif level == 'p_g_os':
model = ShardingStage3(
model,
optimizer=optimizer,
group=group,
sync_buffers=sync_buffers,
segment_size=segment_size,
offload=offload,
sync_comm=sync_comm)
else:
raise ValueError("Please enter the correct level.")
if params_fp16 and isinstance(scaler, paddle.amp.GradScaler):
scaler = ShardingScaler(scaler)
logger_.info("*" * 30)
logger_.info(
"If there is a communication hang using group sharded, please check whether the communication operations of each process are unified."
)
logger_.info("*" * 30)
return model, optimizer, scaler
def save_group_sharded_model(model, output, optimizer=None):
"""
Group sharded encapsulated model and optimizer state saving module.
Args:
model (Layer): A wrapper for group sharded given model.
output (str): Save directory.
optimizer (Optimizer, optional): Group sharded encapsulated optimizer. Defaults to None.
Examples:
.. code-block:: python
# required: distributed
import paddle
from paddle.fluid.dygraph.nn import Linear
from paddle.distributed import fleet
from paddle.distributed.sharding import group_sharded_parallel, save_group_sharded_model
fleet.init(is_collective=True)
group = paddle.distributed.new_group([0, 1])
model = Linear(1000, 1000)
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
optimizer = paddle.optimizer.AdamW(learning_rate=0.001, parameters=model.parameters(), weight_decay=0.00001, grad_clip=clip)
# wrap sharding model, optimizer and scaler
model, optimizer, scaler = group_sharded_parallel(model, optimizer, "p_g", scaler=scaler)
img, label = data
label.stop_gradient = True
img.stop_gradient = True
out = model(img)
loss = paddle.nn.functional.cross_entropy(input=out, label=label)
loss.backward()
optimizer.step()
optimizer.clear_grad()
# save model and optimizer state_dict
save_group_sharded_model(model, optimizer,output=output_dir)
"""
logger_.info(
"==========Begin to save group sharded model and optimizer==========")
assert not os.path.isfile(
output
), "Saving directory ({}) should be a directory, not a file".format(output)
os.makedirs(output, exist_ok=True)
output_model = os.path.join(output, "model.pdmodel")
if isinstance(model, ShardingStage2):
paddle.save(model._layer.state_dict(), output_model)
elif isinstance(model, ShardingStage3):
convert2cpu = True if model._offload else False
model.get_all_parameters(convert2cpu=convert2cpu)
paddle.save(model._layer.state_dict(), output_model)
else:
raise ValueError(
"Please use the layer which is wrapped with group_sharded_parallel.")
if optimizer is not None:
assert hasattr(
optimizer, "_optim"
), "Please use the optimizer which is wrapped with group_sharded_parallel."
output_opt = os.path.join(output, "model.pdopt")
paddle.save(optimizer._optim.state_dict(), output_opt)
logger_.info(
"==========End to save group sharded model and optimizer==========")
...@@ -47,6 +47,7 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_sharding_parallel) ...@@ -47,6 +47,7 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_sharding_parallel)
list(APPEND DIST_TEST_OPS test_dygraph_sharding_optimizer_stage2) list(APPEND DIST_TEST_OPS test_dygraph_sharding_optimizer_stage2)
list(APPEND DIST_TEST_OPS test_dygraph_sharding_stage2) list(APPEND DIST_TEST_OPS test_dygraph_sharding_stage2)
list(APPEND DIST_TEST_OPS test_dygraph_sharding_stage3) list(APPEND DIST_TEST_OPS test_dygraph_sharding_stage3)
list(APPEND DIST_TEST_OPS test_dygraph_group_sharded_api)
list(APPEND DIST_TEST_OPS test_auto_parallel_parallelizer) list(APPEND DIST_TEST_OPS test_auto_parallel_parallelizer)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_mp_layers) list(APPEND DIST_TEST_OPS test_parallel_dygraph_mp_layers)
list(APPEND DIST_TEST_OPS test_hybrid_parallel_inference_helper) list(APPEND DIST_TEST_OPS test_hybrid_parallel_inference_helper)
...@@ -282,6 +283,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM)) ...@@ -282,6 +283,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM))
list(REMOVE_ITEM TEST_OPS test_dygraph_sharding_optimizer_stage2) list(REMOVE_ITEM TEST_OPS test_dygraph_sharding_optimizer_stage2)
list(REMOVE_ITEM TEST_OPS test_dygraph_sharding_stage2) list(REMOVE_ITEM TEST_OPS test_dygraph_sharding_stage2)
list(REMOVE_ITEM TEST_OPS test_dygraph_sharding_stage3) list(REMOVE_ITEM TEST_OPS test_dygraph_sharding_stage3)
list(REMOVE_ITEM TEST_OPS test_dygraph_group_sharded_api)
list(REMOVE_ITEM TEST_OPS test_auto_parallel_parallelizer) list(REMOVE_ITEM TEST_OPS test_auto_parallel_parallelizer)
list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_mp_layers) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_mp_layers)
LIST(REMOVE_ITEM TEST_OPS test_imperative_auto_mixed_precision) LIST(REMOVE_ITEM TEST_OPS test_imperative_auto_mixed_precision)
...@@ -1123,6 +1125,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL) ...@@ -1123,6 +1125,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL)
set_tests_properties(test_dygraph_sharding_optimizer_stage2 PROPERTIES TIMEOUT 120) set_tests_properties(test_dygraph_sharding_optimizer_stage2 PROPERTIES TIMEOUT 120)
set_tests_properties(test_dygraph_sharding_stage2 PROPERTIES TIMEOUT 120) set_tests_properties(test_dygraph_sharding_stage2 PROPERTIES TIMEOUT 120)
set_tests_properties(test_dygraph_sharding_stage3 PROPERTIES TIMEOUT 120) set_tests_properties(test_dygraph_sharding_stage3 PROPERTIES TIMEOUT 120)
set_tests_properties(test_dygraph_group_sharded_api PROPERTIES TIMEOUT 120)
set_tests_properties(test_auto_parallel_parallelizer PROPERTIES TIMEOUT 120) set_tests_properties(test_auto_parallel_parallelizer PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_dygraph_mp_layers PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_mp_layers PROPERTIES TIMEOUT 120)
set_tests_properties(test_hybrid_parallel_inference_helper PROPERTIES TIMEOUT 120) set_tests_properties(test_hybrid_parallel_inference_helper PROPERTIES TIMEOUT 120)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import shutil
import tempfile
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Linear
from paddle.distributed import fleet
from paddle.fluid.dygraph import nn
from paddle.distributed.sharding import group_sharded_parallel, save_group_sharded_model
epoch = 10
paddle.seed(2022)
np.random.seed(2022)
base_lr = 0.1
momentum_rate = 0.9
l2_decay = 1e-4
batch_size = 100
fleet.init(is_collective=True)
class MLP(fluid.Layer):
def __init__(self, linear_size=1000, param_attr=None, bias_attr=None):
super(MLP, self).__init__()
self._linear1 = Linear(linear_size, linear_size)
self._linear2 = Linear(linear_size, linear_size)
self._linear3 = Linear(linear_size, 10)
def forward(self, inputs):
y = self._linear1(inputs)
y = self._linear2(y)
y = self._linear3(y)
return y
def reader_decorator(linear_size=1000):
def __reader__():
for _ in range(100):
img = np.random.rand(linear_size).astype('float32')
label = np.ones(1).astype('int64')
yield img, label
return __reader__
def optimizer_setting(model, use_pure_fp16, opt_group=False):
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
optimizer = paddle.optimizer.Momentum(
parameters=[{
"params": list(model.parameters())
}] if opt_group else list(model.parameters()),
learning_rate=0.001,
weight_decay=0.00001,
grad_clip=clip,
multi_precision=use_pure_fp16)
return optimizer
def train_mlp(model, shard_level, use_pure_fp16, output_dir):
group = paddle.distributed.new_group([0, 1])
optimizer = optimizer_setting(model=model, use_pure_fp16=use_pure_fp16)
model = paddle.amp.decorate(models=model, level='O2', save_dtype='float32')
scaler = paddle.amp.GradScaler(init_loss_scaling=32768)
model, optimizer, scaler = group_sharded_parallel(
model=model, optimizer=optimizer, level=shard_level, scaler=scaler)
train_reader = paddle.batch(
reader_decorator(), batch_size=batch_size, drop_last=True)
train_loader = paddle.io.DataLoader.from_generator(
capacity=32,
use_double_buffer=True,
iterable=True,
return_list=True,
use_multiprocess=True)
train_loader.set_sample_list_generator(train_reader)
for eop in range(epoch):
model.train()
for batch_id, data in enumerate(train_loader()):
img, label = data
label.stop_gradient = True
img.stop_gradient = True
with paddle.amp.auto_cast(True, level='O2'):
out = model(img)
loss = paddle.nn.functional.cross_entropy(
input=out, label=label)
avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32))
if not use_pure_fp16:
avg_loss.backward()
optimizer.step()
else:
scaler.scale(avg_loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.clear_grad()
save_group_sharded_model(model, output=output_dir, optimizer=optimizer)
return model.parameters()
def test_sharding_api():
mlp, mlp1, mlp2 = MLP(), MLP(), MLP()
state_dict = mlp.state_dict()
mlp1.set_state_dict(state_dict)
mlp2.set_state_dict(state_dict)
output_dir = tempfile.mkdtemp()
# fp16
stage2_params = train_mlp(
mlp1, shard_level="os_g", use_pure_fp16=True, output_dir=output_dir)
stage3_params = train_mlp(
mlp2, shard_level="p_g_os", use_pure_fp16=True, output_dir=output_dir)
for i in range(len(stage3_params)):
np.testing.assert_allclose(
stage2_params[i].numpy(),
stage3_params[i].numpy(),
rtol=1e-4,
atol=1e-3)
shutil.rmtree(output_dir)
if __name__ == '__main__':
test_sharding_api()
...@@ -83,7 +83,7 @@ def train_mlp(model, ...@@ -83,7 +83,7 @@ def train_mlp(model,
accumulate_grad=False, accumulate_grad=False,
batch_size=100, batch_size=100,
opt_group=False, opt_group=False,
recompute=False, sync_comm=False,
test_minimize=False): test_minimize=False):
group = paddle.distributed.new_group([0, 1]) group = paddle.distributed.new_group([0, 1])
if opt_group: if opt_group:
...@@ -104,7 +104,7 @@ def train_mlp(model, ...@@ -104,7 +104,7 @@ def train_mlp(model,
model, optimizer, group=group, buffer_max_size=2**21) model, optimizer, group=group, buffer_max_size=2**21)
elif sharding_stage == 3: elif sharding_stage == 3:
model = ShardingStage3( model = ShardingStage3(
model, optimizer=optimizer, group=group, sync_comm=recompute) model, optimizer=optimizer, group=group, sync_comm=sync_comm)
# check optimizer.minimize() error # check optimizer.minimize() error
if test_minimize: if test_minimize:
...@@ -225,7 +225,7 @@ def test_stage2_stage3(): ...@@ -225,7 +225,7 @@ def test_stage2_stage3():
rtol=1e-4, rtol=1e-4,
atol=1e-3) atol=1e-3)
# fp16 recompute # fp16 sync_comm
stage3_params = train_mlp( stage3_params = train_mlp(
mlp7, sharding_stage=3, use_pure_fp16=True, opt_group=False) mlp7, sharding_stage=3, use_pure_fp16=True, opt_group=False)
stage3_params_re = train_mlp( stage3_params_re = train_mlp(
...@@ -233,7 +233,7 @@ def test_stage2_stage3(): ...@@ -233,7 +233,7 @@ def test_stage2_stage3():
sharding_stage=3, sharding_stage=3,
use_pure_fp16=True, use_pure_fp16=True,
opt_group=False, opt_group=False,
recompute=True) sync_comm=True)
for i in range(len(stage3_params)): for i in range(len(stage3_params)):
np.testing.assert_allclose( np.testing.assert_allclose(
stage3_params[i].numpy(), stage3_params_re[i].numpy(), rtol=1e-6) stage3_params[i].numpy(), stage3_params_re[i].numpy(), rtol=1e-6)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle.fluid as fluid
from test_parallel_dygraph_dataparallel import TestMultipleGpus
class TestDygraphGroupSharded(TestMultipleGpus):
# check group sharded logic as well as the accuracy with single mode
def test_dygraph_group_sharded(self):
self.run_mnist_2gpu('dygraph_group_sharded_api.py')
if __name__ == "__main__":
unittest.main()
...@@ -46,6 +46,10 @@ def _build_saved_state_dict(state_dict): ...@@ -46,6 +46,10 @@ def _build_saved_state_dict(state_dict):
if value.type == core.VarDesc.VarType.VOCAB: if value.type == core.VarDesc.VarType.VOCAB:
save_dict[key] = value.value().get_map_tensor() save_dict[key] = value.value().get_map_tensor()
else: else:
if not value.value().get_tensor()._is_initialized():
raise ValueError(
"The saved tensor is not initialized. If you used group sharded, please use save_group_sharded_model."
)
save_dict[key] = value.numpy() save_dict[key] = value.numpy()
name_table[key] = value.name name_table[key] = value.name
else: else:
...@@ -466,7 +470,9 @@ def _parse_load_result(obj, return_numpy): ...@@ -466,7 +470,9 @@ def _parse_load_result(obj, return_numpy):
def _save_lod_tensor(tensor, file_name): def _save_lod_tensor(tensor, file_name):
if not tensor._is_initialized(): if not tensor._is_initialized():
raise ValueError("The saved tensor is not initialized.") raise ValueError(
"The saved tensor is not initialized. If you used group sharded, please use save_group_sharded_model firstly."
)
if _is_file_path(file_name): if _is_file_path(file_name):
_seek = core.save_lod_tensor(tensor, file_name) _seek = core.save_lod_tensor(tensor, file_name)
# '_seek' is the end position of this tensor in the file. # '_seek' is the end position of this tensor in the file.
......
...@@ -280,6 +280,7 @@ packages=['paddle', ...@@ -280,6 +280,7 @@ packages=['paddle',
'paddle.incubate.nn', 'paddle.incubate.nn',
'paddle.incubate.passes', 'paddle.incubate.passes',
'paddle.distribution', 'paddle.distribution',
'paddle.distributed.sharding',
'paddle.distributed.fleet', 'paddle.distributed.fleet',
'paddle.distributed.fleet.base', 'paddle.distributed.fleet.base',
'paddle.distributed.fleet.elastic', 'paddle.distributed.fleet.elastic',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册