未验证 提交 20e19776 编写于 作者: B Baibaifan 提交者: GitHub

Add dygraph sharding stage2 (#37707)

上级 29ebf621
...@@ -68,7 +68,6 @@ class ShardingOptimizerStage2(Optimizer): ...@@ -68,7 +68,6 @@ class ShardingOptimizerStage2(Optimizer):
broadcast_fp16=False, broadcast_fp16=False,
offload=False, offload=False,
device="gpu", device="gpu",
accumulation_steps=None,
**kw): **kw):
super().__init__(optim._learning_rate, params, kw) super().__init__(optim._learning_rate, params, kw)
...@@ -86,7 +85,6 @@ class ShardingOptimizerStage2(Optimizer): ...@@ -86,7 +85,6 @@ class ShardingOptimizerStage2(Optimizer):
self._optim = optim self._optim = optim
self._local_params = params self._local_params = params
self._default_device = device self._default_device = device
self._accumulation_steps = accumulation_steps
assert group is not None, "Distributed communication group is must be gived" assert group is not None, "Distributed communication group is must be gived"
self.group = group self.group = group
...@@ -136,10 +134,6 @@ class ShardingOptimizerStage2(Optimizer): ...@@ -136,10 +134,6 @@ class ShardingOptimizerStage2(Optimizer):
def local_params(self): def local_params(self):
return self._local_params return self._local_params
@property
def accumulation_steps(self):
return self._accumulation_steps
@property @property
def param2rank(self): def param2rank(self):
"""Map the params to the rank which owns them""" """Map the params to the rank which owns them"""
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#Taken and modified for fairscale from:
# https://github.com/facebookresearch/fairscale/blob/main/fairscale/nn/data_parallel/sharded_ddp.py
#Commit: 8acbec718f3c70a6b9785470bb9e05cd84fc3f8e
import os
import contextlib
import logging
import time
import functools
import numpy as np
from itertools import chain
from functools import reduce
from collections import deque
import paddle
from paddle import nn
import paddle.distributed as dist
from ...utils.internal_storage import GradStorage
from .sharding_utils import Taskflow, Type
def _trainable(param):
return param.trainable
class ShardingStage2(nn.Layer):
"""
A wrapper for Sharding Stage2 Layer in Dygraph.
.. warning: ShardingStage2 encapsulates the layer strategy and integrates it into the nn.Layer.
.. ZeRO: https://arxiv.org/pdf/1910.02054.pdf.
"""
# TODO (Baibaifan)
# Feature Notes::
# 1. Unified memory for param and param.grad to InternalStorage.
# 2. Divide param.grad according to rank to centrally apply for and release GPU memory.
# 3. Dynamically adjust training parameters and models。
# 4. Support offload function.
# 5. Support the establishment of independent communication groups.
def __init__(
self,
layer,
sharding_optimizer,
group,
sync_buffers=False,
pertrain_sync_models=True,
buffer_max_size=2**23, #8MB
auto_refresh_trainable=True,
device="gpu",
use_grad_storage=True,
accumulate_grads=False):
super().__init__()
# training options
self._layer = layer
self._sharding_optimizers = [sharding_optimizer] if not isinstance(
sharding_optimizer, list) else sharding_optimizer
self._sync_buffers = sync_buffers
self._auto_refresh_trainable = auto_refresh_trainable
# Gradient accumulation, Gradient flip
self._accumulate_grads = accumulate_grads
# Communication related attributes
assert group is not None, "Distributed communication group is must be gived"
self._group = group
self._world_size_scaling = 1.0 / self._group.nranks
assert self._group.nranks > 1, "Training must be distributed, ranks must be greater than 1"
self._rank = self._group.rank
self._global_root_rank = 0 # picking rank 0 as the reference
self._global_ranks = self._group.ranks
self._default_device = device
# Global statistical parameters
self._all_params = list(
chain(
* [optim.local_params for optim in self._sharding_optimizers]))
self._trainable_params = []
self._grad_reduced = []
self._trainable_param2rank = {}
self._trainable_param2align = {}
self._trainable_mask = list(map(_trainable, self._all_params))
self._param_grads = []
# Set grad storage size & Display param sizes and model sizes
model_size = sum(
[np.prod(p.shape) for p in self._layer.parameters()]).item()
self._buffer_max_size = self._rank_buffer_size(buffer_max_size,
model_size)
self._use_grad_storage = use_grad_storage
self._grad_storages = {} # {dtype: {rank: GradStorage}}
self._has_grad_storage = []
self._grad_storage_list = []
# Set backward pass hooks
self._bw_hooks = []
# Synchronous all ranks models
if pertrain_sync_models:
self._sync_params_and_buffers()
# Set tasks flow
self._tasks_flow = deque()
def forward(self, *inputs, **kwargs):
"""
A wrapper for Sharding Stage2 layer.
- Fresh trainable params or rebuild grad storage
- Sync layer's buffer params
- Clear all flags states
- Forward for origin layers
"""
# Whether to need to reset trainable parameters
needs_fresh = len(self._bw_hooks) == 0 and self.training
if self._auto_refresh_trainable:
needs_fresh |= self._detect_train_change()
# Front hook
self._init_internal_storage(needs_fresh)
# Sync layer's buffers state
if self._sync_buffers:
self.__sync_buffers()
# Normal FW on the base model
fw = self._layer(*inputs, **kwargs)
return fw
def clear_gradients(self):
"""
Set zero to the gradient of the optimizer's current rank trainable parameters.
"""
# Release grad storages
for dtype in self._grad_storages.keys():
if self._rank in self._grad_storages[dtype].keys():
self._grad_storages[dtype][self._rank].buffer.zero_()
# Release params
for param in self._trainable_params:
if param.name in self._param_grads and param.grad is not None:
param.clear_gradient()
def grad_scale(self):
"""
Before the gradient accumulation, scale the gradient.
"""
# Scale grad storages
for dtype in self._grad_storages.keys():
if self._rank in self._grad_storages[dtype].keys():
self._grad_storages[dtype][self._rank].buffer.scale_(
scale=self._world_size_scaling)
# Scale params
for param in self._trainable_params:
if param.name in self._param_grads and param.grad is not None:
param.grad.scale_(scale=self._world_size_scaling)
param._reset_grad_inplace_version()
def _init_internal_storage(self, needs_fresh):
"""
Judge Fresh trainable params or rebuild grad storage.
"""
if needs_fresh:
self._fresh_trainable()
else:
self._build_grad_storages()
# Clear all flags state
self._clear_counters()
def to(self, device=None, dtype=None, blocking=True):
"""
Synchronously or asynchronously convert the data type of the layer, the device is not supported now.
"""
assert device == self._default_device, "New devices are not supported, because of the optimizer state is not sync"
def _fresh_trainable(self):
""" Whether to update training parameters. """
# Make sure that this is not done while gradients are waiting to be reduced (if no_sync context for instance)
if reduce(lambda x, y: x or y, self._grad_reduced, False):
logging.warning("Grads waiting to be reduced.")
self._trainable_params = list(
filter(lambda x: x.trainable, self._all_params))
self._trainable_params.sort(key=lambda x: np.prod(x.shape))
self._trainable_param2rank = {}
for optim in self._sharding_optimizers:
# Need to be wrappered for Sharding Stage2 Optimizer
if len(optim.param_storages.keys()) == 0:
optim.update_opt_status()
# Get the parameters split by the optimizer according to rank
for per_rank_params in optim.dtype_rank_params.values(
): # all the params from all ranks
for params in per_rank_params:
for param in filter(lambda x: x.trainable, params):
self._trainable_param2rank[
param.name] = optim.param2rank[param.name]
self._trainable_param2align[
param.name] = optim._param2align[param.name]
self._setup_use_grad_storage()
# wait next func hook support
self._setup_backward_hooks()
@paddle.no_grad()
def __sync_buffers(self):
"""
Sync all the param buffers from all ranks (exp: batch norm statistics).
"""
for buffer in self._layer.buffers(include_sublayers=True):
dist.broadcast(
buffer,
self._global_root_rank,
self._group,
use_calc_stream=True)
# Multi stream operation will be supported later
dist.wait(tensor=buffer, group=self._group, use_calc_stream=True)
def __getattr__(self, name):
"""Forward missing attributes to wrapped layer."""
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self._layer, name)
@paddle.no_grad()
def _clear_counters(self):
"""Reset all the grad reduce and call counters."""
if self.training:
self._grad_reduced = [True for _ in self._trainable_params]
if self._use_grad_storage:
for grad_storage in self._grad_storage_list:
grad_storage.reset_checked_in()
if not self._accumulate_grads:
self._grads_flipped = False
def _get_reduce_fn(self, index, param, dst_rank):
"""
There are two ways to reduce gradient.
- 1. Do not use use_grad_storage or exceeded buffer_max_size will be reduced separately.
- 2. Use grad_storage Reduce the storage to get the full gradient from different ranks.
"""
if not self._use_grad_storage or not self._has_grad_storage[index]:
# Direct reduction
@paddle.no_grad()
def reduce(*_):
# Skip gradient reduction, do not change status information
if self._grad_reduced[index]:
assert param.grad is not None, "Parameter gradient cannot be None"
# Change reduce information
self._grad_reduced[index] = False
if not self._accumulate_grads:
param.grad.scale_(scale=self._world_size_scaling)
param._reset_grad_inplace_version()
# Clear the gradient that does not belong to the current rank through the callback function
def cleanup():
if dst_rank != self._rank:
param.clear_gradient(False)
# Synchronize the reduce parameter gradient
self._tasks_flow.append(
Taskflow(
task=dist.reduce(
tensor=param.grad,
dst=dst_rank,
group=self._group,
use_calc_stream=True),
callback=cleanup))
# Multi stream operation will be supported later
dist.wait(
tensor=param.grad,
group=self._group,
use_calc_stream=True)
# Clear the task flow and trigger callback to clear the redundant gradient
self._clear_task_flow()
else:
# Buffer reduction
@paddle.no_grad()
def reduce(*_):
# Skip gradient reduction, do not change status information
if self._grad_reduced[index]:
assert param.grad is not None, "Parameter gradient cannot be None"
# Change reduce information
self._grad_reduced[index] = False
grad_storage = self._grad_storages[param.dtype][dst_rank]
grad_storage.params_checked_in += 1
if grad_storage.all_checked_in:
assert grad_storage.buffer is not None
# Normalize all ranks grad_storage
if not self._accumulate_grads:
grad_storage.buffer.scale_(
scale=self._world_size_scaling)
# Clearing up the grad_storage buffer
def cleanup():
if dst_rank != self._rank:
for p in grad_storage._params:
p.clear_gradient(False)
p._gradient_set_empty(False)
grad_storage.buffer.value().get_tensor()._clear(
)
# Reduce the bucket
grad_storage.sent = True
self._tasks_flow.append(
Taskflow(
task=dist.reduce(
tensor=grad_storage.buffer,
dst=grad_storage.destination,
group=self._group,
use_calc_stream=True),
callback=cleanup))
# Multi stream operation will be supported later
dist.wait(
tensor=grad_storage.buffer,
group=self._group,
use_calc_stream=True)
# Clear the task flow and trigger callback to clear the redundant gradient
self._clear_task_flow()
return reduce
def _setup_backward_hooks(self):
"""
Set the backward hook to synchronize the gradients of all rank by reduce group ranks.
"""
# Remove previous backward hooks
while len(self._bw_hooks) > 0:
self._bw_hooks.pop().remove()
# Go through the parameters, attach the hook
self._grad_accs = []
if not self.training:
return
for index, param in enumerate(self._trainable_params):
dst_rank = self._trainable_param2rank[param.name]
reduce_function = self._get_reduce_fn(index, param, dst_rank)
self._bw_hooks.append(
param._register_backward_hook(reduce_function))
@paddle.no_grad()
def _sync_params_and_buffers(self):
"""
Sync all model states for all ranks
"""
for t in self._layer.parameters():
dist.broadcast(
t,
src=self._global_root_rank,
group=self._group,
use_calc_stream=True)
# Multi stream operation will be supported later
dist.wait(tensor=t, group=self._group, use_calc_stream=True)
def _setup_use_grad_storage(self):
"""
Integrate the parameters gradient into a continuous memory according to rank, and support the update of training parameters.
"""
if not self._use_grad_storage:
return
# According to parameters's numel sort, allocate memory of parameter gradient to continuous memory according to rank
self._grad_storages = {}
self._has_grad_storage = [False for _ in self._trainable_params]
for index, param in enumerate(self._trainable_params):
dst_rank = self._trainable_param2rank[param.name]
if param.dtype not in self._grad_storages.keys():
self._grad_storages[param.dtype] = {}
if dst_rank not in self._grad_storages[param.dtype].keys():
self._grad_storages[param.dtype][dst_rank] = GradStorage(
self._buffer_max_size[param.dtype],
dtype=param.dtype,
device=self._default_device,
destination=dst_rank,
parm2align=self._trainable_param2align)
# Criteria to decide whether this parameter is to be put in GradStorage
if self._grad_storages[param.dtype][dst_rank].can_add_grad_view(
param, self._trainable_param2align[param.name]):
self._grad_storages[param.dtype][dst_rank].add_grad(
param, self._trainable_param2align[param.name])
self._has_grad_storage[index] = True
else:
self._param_grads.append(param.name)
print(
"Can not add param: {}, param's shape: {}, param align: {}, grad_storages fill: {}, ".
format(param.name, param.shape, self._trainable_param2align[
param.name], self._grad_storages[param.dtype][dst_rank]
._fill))
self._grad_storage_list = list(
chain(* [
self._grad_storages[dtype].values()
for dtype in self._grad_storages.keys()
]))
def _clear_task_flow(self):
"""Try to consume the previous tasks."""
while len(self._tasks_flow) > 0:
task = self._tasks_flow.popleft()
if task.callback is not None:
task.callback()
def _detect_train_change(self):
# Current trainable parameters
trainable_mask = list(map(_trainable, self._all_params))
# Whether parameters trainability changed
trainability_changed = trainable_mask != self._trainable_mask
# The whole model is not trainable but we still have grad hooks
trainability_changed |= not self.training and len(self._bw_hooks) > 0
if trainability_changed:
logging.warning(
"Trainable params changed, because of eval/train mode or parameter freezing/unfreeze."
)
self._trainable_mask = trainable_mask
return trainability_changed
def _build_grad_storages(self):
"""
Rebuild grad storages.
"""
# Rebuild fp16/fp32 grad storages
for dtype in self._grad_storages.keys():
for dst_rank, grad_storage in self._grad_storages[dtype].items():
if dst_rank != self._rank:
grad_storage.manumal_relase()
grad_storage.rebuild()
def _rank_buffer_size(self, buffer_max_size, model_size):
"""
Generate the minimum buffer size for each rank & Display param sizes and model sizes.
"""
# Initialize buffer size
rank_buffer_size = {}
for shard_opt in self._sharding_optimizers:
if shard_opt.rank_buffer_size:
for dtype in shard_opt.rank_buffer_size.keys():
sizes = max(shard_opt.rank_buffer_size[dtype].values())
rank_buffer_size[dtype] = min(sizes, buffer_max_size)
if Type.fp16.value in rank_buffer_size.keys():
# FP16 GradStorage and model size
print(
"====== FP16 GradStorage size: {:.2f}M parameters, Model size {:.2f}M parameters ======".
format(rank_buffer_size[Type.fp16.value] / 2**19, model_size / 2
**19))
if Type.fp32.value in rank_buffer_size.keys():
# FP32 GradStorage and model size
print(
"====== FP32 GradStorage size: {:.2f}M parameters, Model size {:.2f}M parameters ======".
format(rank_buffer_size[Type.fp32.value] / 2**18, model_size / 2
**18))
return rank_buffer_size
...@@ -33,6 +33,7 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_pipeline_parallel) ...@@ -33,6 +33,7 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_pipeline_parallel)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_tensor_parallel) list(APPEND DIST_TEST_OPS test_parallel_dygraph_tensor_parallel)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_sharding_parallel) list(APPEND DIST_TEST_OPS test_parallel_dygraph_sharding_parallel)
list(APPEND DIST_TEST_OPS test_dygraph_sharding_optimizer_stage2) list(APPEND DIST_TEST_OPS test_dygraph_sharding_optimizer_stage2)
list(APPEND DIST_TEST_OPS test_dygraph_sharding_stage2)
list(APPEND DIST_TEST_OPS test_auto_parallel_parallelizer) list(APPEND DIST_TEST_OPS test_auto_parallel_parallelizer)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_mp_layers) list(APPEND DIST_TEST_OPS test_parallel_dygraph_mp_layers)
list(APPEND DIST_TEST_OPS test_hybrid_parallel_inference_helper) list(APPEND DIST_TEST_OPS test_hybrid_parallel_inference_helper)
...@@ -244,6 +245,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM)) ...@@ -244,6 +245,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM))
list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_tensor_parallel) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_tensor_parallel)
list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_sharding_parallel) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_sharding_parallel)
list(REMOVE_ITEM TEST_OPS test_dygraph_sharding_optimizer_stage2) list(REMOVE_ITEM TEST_OPS test_dygraph_sharding_optimizer_stage2)
list(REMOVE_ITEM TEST_OPS test_dygraph_sharding_stage2)
list(REMOVE_ITEM TEST_OPS test_auto_parallel_parallelizer) list(REMOVE_ITEM TEST_OPS test_auto_parallel_parallelizer)
list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_mp_layers) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_mp_layers)
LIST(REMOVE_ITEM TEST_OPS test_imperative_auto_mixed_precision) LIST(REMOVE_ITEM TEST_OPS test_imperative_auto_mixed_precision)
...@@ -1039,6 +1041,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL) ...@@ -1039,6 +1041,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL)
set_tests_properties(test_parallel_dygraph_tensor_parallel PROPERTIES TIMEOUT 200) set_tests_properties(test_parallel_dygraph_tensor_parallel PROPERTIES TIMEOUT 200)
set_tests_properties(test_parallel_dygraph_sharding_parallel PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_sharding_parallel PROPERTIES TIMEOUT 120)
set_tests_properties(test_dygraph_sharding_optimizer_stage2 PROPERTIES TIMEOUT 120) set_tests_properties(test_dygraph_sharding_optimizer_stage2 PROPERTIES TIMEOUT 120)
set_tests_properties(test_dygraph_sharding_stage2 PROPERTIES TIMEOUT 120)
set_tests_properties(test_auto_parallel_parallelizer PROPERTIES TIMEOUT 120) set_tests_properties(test_auto_parallel_parallelizer PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_dygraph_mp_layers PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_mp_layers PROPERTIES TIMEOUT 120)
set_tests_properties(test_hybrid_parallel_inference_helper PROPERTIES TIMEOUT 120) set_tests_properties(test_hybrid_parallel_inference_helper PROPERTIES TIMEOUT 120)
......
# -*- coding: UTF-8 -*-
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import argparse
import ast
import time
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Linear
from paddle.distributed import fleet
from paddle.fluid.dygraph import nn
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import DygraphShardingOptimizer
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2 import ShardingOptimizerStage2
from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import ShardingStage2
seed = 2021
epoch = 2
batch_size = 32
strategy = fleet.DistributedStrategy()
strategy.hybrid_configs = {
"dp_degree": 2,
"mp_degree": 1,
"pp_degree": 1,
"sharding_degree": 1
}
fleet.init(is_collective=True, strategy=strategy)
np.random.seed(seed)
paddle.seed(seed)
class MLP(fluid.Layer):
def __init__(self, param_attr=None, bias_attr=None):
super(MLP, self).__init__()
self._linear1 = Linear(10000, 10000)
self._linear2 = Linear(10000, 10000)
self._linear3 = Linear(10000, 10)
def forward(self, inputs):
y = self._linear1(inputs)
y = self._linear2(y)
y = self._linear3(y)
return y
def reader_decorator():
def __reader__():
for _ in range(100):
img = np.random.rand(10000).astype('float32')
label = np.ones(1).astype('int64')
yield img, label
return __reader__
def optimizer_setting(model, use_pure_fp16, stage=1):
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
optimizer = paddle.optimizer.AdamW(
parameters=model.parameters(),
learning_rate=0.001,
weight_decay=0.00001,
grad_clip=clip,
multi_precision=use_pure_fp16)
return optimizer
def train_mlp(model,
sharding_stage,
use_pure_fp16=False,
all_test=False,
accumulate_grad=False):
if sharding_stage == 1:
hcg = fleet.get_hybrid_communicate_group()
group = hcg.get_check_parallel_group()
else:
group = paddle.distributed.new_group([0, 1])
optimizer = optimizer_setting(
model=model, use_pure_fp16=use_pure_fp16, stage=sharding_stage)
if use_pure_fp16:
model, optimizer = paddle.amp.decorate(
models=model,
optimizers=optimizer,
level='O2',
save_dtype='float32')
if sharding_stage == 2:
optimizer = ShardingOptimizerStage2(
params=model.parameters(), optim=optimizer, group=group)
if all_test:
model = ShardingStage2(
model, optimizer, group=group, accumulate_grads=accumulate_grad)
else:
model = ShardingStage2(model, optimizer, group=group)
else:
optimizer = fleet.distributed_optimizer(optimizer)
model = fleet.distributed_model(model)
train_reader = paddle.batch(
reader_decorator(), batch_size=batch_size, drop_last=True)
train_loader = paddle.io.DataLoader.from_generator(
capacity=32,
use_double_buffer=True,
iterable=True,
return_list=True,
use_multiprocess=True)
train_loader.set_sample_list_generator(train_reader)
for eop in range(epoch):
model.train()
for batch_id, data in enumerate(train_loader()):
img, label = data
label.stop_gradient = True
img.stop_gradient = True
with paddle.amp.auto_cast(enable=use_pure_fp16, level='O2'):
out = model(img)
loss = paddle.nn.functional.cross_entropy(
input=out, label=label)
avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32))
avg_loss.backward()
if accumulate_grad and batch_id == 2:
model.grad_scale()
optimizer.step()
model.clear_gradients()
return model.parameters()
if not accumulate_grad:
optimizer.step()
if sharding_stage == 2:
model.clear_gradients()
else:
optimizer.clear_grad()
if all_test and batch_id == 2:
return model.parameters()
if sharding_stage == 2:
model.to(device="gpu")
return model.parameters()
def test_stage1_stage2():
mlp = MLP()
state_dict = mlp.state_dict()
mlp1 = MLP()
mlp2 = MLP()
mlp3 = MLP()
mlp4 = MLP()
mlp1.set_state_dict(state_dict)
mlp2.set_state_dict(state_dict)
mlp3.set_state_dict(state_dict)
mlp4.set_state_dict(state_dict)
stage1_params = train_mlp(mlp, sharding_stage=1, use_pure_fp16=False)
stage2_params = train_mlp(mlp, sharding_stage=2, use_pure_fp16=False)
for i in range(len(stage1_params)):
np.testing.assert_allclose(
stage1_params[i].numpy(), stage2_params[i].numpy(), rtol=1e-6)
stage2_params = train_mlp(
mlp3, sharding_stage=2, use_pure_fp16=True, all_test=True)
stage2_accumulate_grad = train_mlp(
mlp4,
sharding_stage=2,
use_pure_fp16=True,
all_test=True,
accumulate_grad=True)
for i in range(len(stage2_params)):
for j in range(len(stage2_accumulate_grad)):
if stage2_params[i].name == stage2_accumulate_grad[j].name:
np.testing.assert_allclose(
stage2_params[i].numpy(),
stage2_accumulate_grad[j].numpy(),
rtol=1e-6)
return
if __name__ == '__main__':
test_stage1_stage2()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle.fluid as fluid
from test_parallel_dygraph_dataparallel import TestMultipleGpus
class TestDygraphShardingStage2(TestMultipleGpus):
# check sharding logic as well as the accuracy with single mode
def test_dygraph_sharding_optimizer_stage2(self):
self.run_mnist_2gpu('dygraph_sharding_stage2.py')
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册