未验证 提交 46823104 编写于 作者: B Baibaifan 提交者: GitHub

Add sharding stage3 offload (#38989)

上级 f4623876
......@@ -30,7 +30,6 @@ from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage3 import Shar
from paddle.distributed.fleet.meta_parallel.sharding.sharding_utils import ShardingScaler
epoch = 10
batch_size = 32
paddle.seed(2021)
np.random.seed(2021)
base_lr = 0.1
......@@ -66,10 +65,10 @@ def reader_decorator(linear_size=1000):
def optimizer_setting(model, use_pure_fp16, opt_group=False):
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
optimizer = paddle.optimizer.AdamW(
optimizer = paddle.optimizer.Momentum(
parameters=[{
"params": model.parameters()
}] if opt_group else model.parameters(),
"params": list(model.parameters())
}] if opt_group else list(model.parameters()),
learning_rate=0.001,
weight_decay=0.00001,
grad_clip=clip,
......@@ -82,6 +81,7 @@ def train_mlp(model,
sharding_stage,
use_pure_fp16=False,
accumulate_grad=False,
batch_size=100,
opt_group=False,
recompute=False):
group = paddle.distributed.new_group([0, 1])
......@@ -104,10 +104,14 @@ def train_mlp(model,
optimizer,
group=group,
buffer_max_size=2**21,
accumulate_grads=accumulate_grad)
accumulate_grads=batch_size == 20)
elif sharding_stage == 3:
model = ShardingStage3(
model, optimizer=optimizer, group=group, sync_comm=recompute)
model,
optimizer=optimizer,
group=group,
accumulate_grads=batch_size == 20,
sync_comm=recompute)
train_reader = paddle.batch(
reader_decorator(), batch_size=batch_size, drop_last=True)
......@@ -131,21 +135,22 @@ def train_mlp(model,
loss = paddle.nn.functional.cross_entropy(
input=out, label=label)
avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32))
if not use_pure_fp16:
avg_loss.backward()
else:
scaler.scale(avg_loss).backward()
if not accumulate_grad:
if not use_pure_fp16:
avg_loss.backward()
optimizer.step()
else:
scaler.scale(avg_loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.clear_grad()
if accumulate_grad:
if not use_pure_fp16:
avg_loss.backward()
optimizer.step()
else:
scaler.scale(avg_loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.clear_grad()
......@@ -168,48 +173,50 @@ def test_stage2_stage3():
mlp8.set_state_dict(state_dict)
# fp32
stage2_params = train_mlp(
mlp1, sharding_stage=2, use_pure_fp16=False, opt_group=True)
mlp1, sharding_stage=2, use_pure_fp16=False, opt_group=False)
stage3_params = train_mlp(
mlp2, sharding_stage=3, use_pure_fp16=False, opt_group=True)
mlp2, sharding_stage=3, use_pure_fp16=False, opt_group=False)
for i in range(len(stage2_params)):
for j in range(len(stage3_params)):
if stage2_params[i].name == stage3_params[j].name:
np.testing.assert_allclose(
stage2_params[i].numpy(),
stage3_params[j].numpy(),
rtol=1e-6)
np.testing.assert_allclose(
stage2_params[i].numpy(),
stage3_params[i].numpy(),
rtol=1e-6,
atol=1e-6)
# fp32 accumulate grad
stage2_params = train_mlp(
stage3_params = train_mlp(
mlp3,
sharding_stage=2,
sharding_stage=3,
use_pure_fp16=False,
accumulate_grad=True,
opt_group=True)
stage3_params = train_mlp(
stage3_params_add = train_mlp(
mlp4,
sharding_stage=3,
use_pure_fp16=False,
accumulate_grad=True,
batch_size=20,
opt_group=True)
for i in range(len(stage2_params)):
for j in range(len(stage3_params)):
if stage2_params[i].name == stage3_params[j].name:
np.testing.assert_allclose(
stage2_params[i].numpy(),
stage3_params[j].numpy(),
rtol=1e-6)
for i in range(len(stage3_params)):
np.testing.assert_allclose(
stage3_params[i].numpy(),
stage3_params_add[i].numpy(),
rtol=1e-6,
atol=1e-6)
# fp16
stage2_params = train_mlp(
mlp5, sharding_stage=2, use_pure_fp16=True, opt_group=False)
stage3_params = train_mlp(
mlp6, sharding_stage=3, use_pure_fp16=True, opt_group=False)
for i in range(len(stage2_params)):
for j in range(len(stage3_params)):
if stage2_params[i].name == stage3_params[j].name:
np.testing.assert_allclose(
stage2_params[i].numpy(),
stage3_params[j].numpy(),
rtol=1e-6)
np.testing.assert_allclose(
stage2_params[i].numpy(),
stage3_params[i].numpy(),
rtol=1e-4,
atol=1e-4)
# fp16 recompute
stage3_params = train_mlp(
mlp7, sharding_stage=3, use_pure_fp16=True, opt_group=False)
......@@ -220,12 +227,8 @@ def test_stage2_stage3():
opt_group=False,
recompute=True)
for i in range(len(stage3_params)):
for j in range(len(stage3_params_re)):
if stage3_params[i].name == stage3_params_re[j].name:
np.testing.assert_allclose(
stage3_params[i].numpy(),
stage3_params_re[j].numpy(),
rtol=1e-6)
np.testing.assert_allclose(
stage3_params[i].numpy(), stage3_params_re[i].numpy(), rtol=1e-6)
return
......
# -*- coding: UTF-8 -*-
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import argparse
import ast
import time
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Linear
from paddle.distributed import fleet
from paddle.fluid.dygraph import nn
from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage3 import ShardingStage3
from paddle.distributed.fleet.meta_parallel.sharding.sharding_utils import ShardingScaler
epoch = 10
batch_size = 32
paddle.seed(2022)
np.random.seed(2022)
base_lr = 0.1
momentum_rate = 0.9
l2_decay = 1e-4
fleet.init(is_collective=True)
class MLP(fluid.Layer):
def __init__(self, linear_size=1000, param_attr=None, bias_attr=None):
super(MLP, self).__init__()
self._linear1 = Linear(linear_size, linear_size)
self._linear2 = Linear(linear_size, linear_size)
self._linear3 = Linear(linear_size, 10)
def forward(self, inputs):
y = self._linear1(inputs)
y = self._linear2(y)
y = self._linear3(y)
return y
def reader_decorator(linear_size=1000):
def __reader__():
for _ in range(100):
img = np.random.rand(linear_size).astype('float32')
label = np.ones(1).astype('int64')
yield img, label
return __reader__
def optimizer_setting(model, use_pure_fp16, opt_group=False):
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
optimizer = paddle.optimizer.AdamW(
parameters=[{
"params": model.parameters()
}] if opt_group else model.parameters(),
learning_rate=0.001,
weight_decay=0.00001,
grad_clip=clip,
multi_precision=use_pure_fp16)
return optimizer
def train_mlp(model,
use_pure_fp16=False,
accumulate_grad=False,
offload=False,
convert2cpu=False):
group = paddle.distributed.new_group([0, 1])
optimizer = optimizer_setting(model=model, use_pure_fp16=use_pure_fp16)
if use_pure_fp16:
model = paddle.amp.decorate(
models=model, level='O2', save_dtype='float32')
scaler = paddle.amp.GradScaler(init_loss_scaling=32768)
scaler = ShardingScaler(scaler)
model = ShardingStage3(
model, optimizer=optimizer, group=group, offload=offload)
train_reader = paddle.batch(
reader_decorator(), batch_size=batch_size, drop_last=True)
train_loader = paddle.io.DataLoader.from_generator(
capacity=32,
use_double_buffer=True,
iterable=True,
return_list=True,
use_multiprocess=True)
train_loader.set_sample_list_generator(train_reader)
for eop in range(epoch):
model.train()
for batch_id, data in enumerate(train_loader()):
img, label = data
label.stop_gradient = True
img.stop_gradient = True
with paddle.amp.auto_cast(True, level='O2'):
out = model(img)
loss = paddle.nn.functional.cross_entropy(
input=out, label=label)
avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32))
if not use_pure_fp16:
avg_loss.backward()
else:
scaler.scale(avg_loss).backward()
if not accumulate_grad:
if not use_pure_fp16:
optimizer.step()
else:
scaler.step(optimizer)
scaler.update()
optimizer.clear_grad()
if accumulate_grad:
if not use_pure_fp16:
optimizer.step()
else:
scaler.step(optimizer)
scaler.update()
optimizer.clear_grad()
if not convert2cpu:
model.get_all_parameters()
else:
model.get_all_parameters(convert2cpu)
return model.parameters()
def test_stage3_offload():
mlp, mlp1, mlp2, mlp3, mlp4, mlp5, mlp6 = MLP(), MLP(), MLP(), MLP(), MLP(
), MLP(), MLP()
state_dict = mlp.state_dict()
mlp1.set_state_dict(state_dict)
mlp2.set_state_dict(state_dict)
mlp3.set_state_dict(state_dict)
mlp4.set_state_dict(state_dict)
mlp5.set_state_dict(state_dict)
mlp6.set_state_dict(state_dict)
# fp32 offload
stage3_params = train_mlp(mlp1, use_pure_fp16=False)
stage3_params_offload = train_mlp(mlp2, use_pure_fp16=False, offload=True)
for i in range(len(stage3_params)):
np.testing.assert_allclose(
stage3_params[i].numpy(),
stage3_params_offload[i].numpy(),
rtol=1e-6,
atol=1e-8)
# fp16 offload
stage3_params = train_mlp(mlp3, use_pure_fp16=True)
stage3_params_offload = train_mlp(mlp4, use_pure_fp16=True, offload=True)
for i in range(len(stage3_params)):
np.testing.assert_allclose(
stage3_params[i].numpy(),
stage3_params_offload[i].numpy(),
rtol=1e-2,
atol=1e-2)
# fp32 accumulate grad offload
stage3_params = train_mlp(mlp5, use_pure_fp16=False, accumulate_grad=True)
stage3_params_offload = train_mlp(
mlp6,
use_pure_fp16=False,
accumulate_grad=True,
offload=True,
convert2cpu=True)
for i in range(len(stage3_params)):
np.testing.assert_allclose(
stage3_params[i].numpy(),
stage3_params_offload[i].numpy(),
rtol=1e-6,
atol=1e-8)
return
if __name__ == '__main__':
test_stage3_offload()
......@@ -23,10 +23,10 @@ from test_parallel_dygraph_dataparallel import TestMultipleGpus
class TestDygraphShardingStage2(TestMultipleGpus):
# check sharding logic as well as the accuracy with single mode
def test_dygraph_sharding_optimizer_stage2(self):
def test_dygraph_sharding_stage2(self):
self.run_mnist_2gpu('dygraph_sharding_stage2.py')
def test_dygraph_sharding_optimizer_stage2_offload(self):
def test_dygraph_sharding_stage2_offload(self):
self.run_mnist_2gpu('dygraph_sharding_stage2_offload.py')
......
......@@ -23,9 +23,12 @@ from test_parallel_dygraph_dataparallel import TestMultipleGpus
class TestDygraphShardingStage3(TestMultipleGpus):
# check sharding logic as well as the accuracy with single mode
def test_dygraph_sharding_optimizer_stage3(self):
def test_dygraph_sharding_stage3(self):
self.run_mnist_2gpu('dygraph_sharding_stage3.py')
def test_dygraph_sharding_stage3_offload(self):
self.run_mnist_2gpu('dygraph_sharding_stage3_offload.py')
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册