未验证 提交 aec6e8a9 编写于 作者: M minghaoBD 提交者: GitHub

[Paddle-ASP]Asp sharding (#37725)

上级 9c1167cf
......@@ -53,6 +53,7 @@ class ShardingOptimizer(MetaOptimizerBase):
"AMPOptimizer",
"LarsOptimizer",
"LambOptimizer",
"ASPOptimizer",
# "ModelParallelOptimizer",
# "PipelineOptimizer",
]
......
......@@ -16,12 +16,17 @@
Functions for Auto SParsity (ASP) training and inference.
"""
import os
import copy
import numpy as np
import paddle
from paddle.fluid import global_scope, program_guard, layers
from paddle.fluid.initializer import ConstantInitializer
from paddle.fluid.contrib import sparsity
from paddle.fluid import core
OpRole = core.op_proto_and_checker_maker.OpRole
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
__all__ = [
'decorate', 'prune_model', 'set_excluded_layers', 'reset_excluded_layers'
......@@ -150,7 +155,8 @@ def prune_model(main_program=None,
n=2,
m=4,
mask_algo='mask_1d',
with_mask=True):
with_mask=True,
sharding=False):
r"""
Pruning parameters of supported layers in :attr:`main_program` via
specified mask generation function given by :attr:`mask_algo`. This
......@@ -173,6 +179,7 @@ def prune_model(main_program=None,
mask_algo (string, optional): The function name to generate spase mask. Default is `mask_1d`.
The vaild inputs should be one of 'mask_1d', 'mask_2d_greedy' and 'mask_2d_best'.
with_mask (bool, optional): To prune mask Variables related to parameters or not. Ture is purning also, False is not. Defalut is True.
sharding (bool, optional): Whether to turn on sharding (model parallel) during training. Please consider turning it ON when encountering OOM using sharding. Default is False.
Returns:
dictionary: A dictionary with key: `parameter name` (string) and value: its corresponding mask Variable.
Examples:
......@@ -214,8 +221,12 @@ def prune_model(main_program=None,
# Must call `exe.run(startup_program)` first before calling `sparsity.prune_model`
sparsity.prune_model(main_program, mask_algo='mask_2d_best')
"""
device = paddle.device.get_device()
place = paddle.set_device(device)
if sharding:
gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
place = paddle.CUDAPlace(gpu_id)
else:
device = paddle.device.get_device()
place = paddle.set_device(device)
MaskAlgo_mapping = {
'mask_1d': sparsity.MaskAlgo.MASK_1D,
......@@ -528,8 +539,11 @@ class ASPHelper(object):
'Y': asp_info.mask_vars[param_grad[0].name]
},
outputs={'Out': param_grad[0]},
attrs={'axis': -1,
'use_mkldnn': False})
attrs={
'axis': -1,
'use_mkldnn': False,
OP_ROLE_KEY: OpRole.Optimize
})
class OptimizerWithSparsityGuarantee(object):
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2021 NVIDIA Corporation. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.distributed.fleet as fleet
import paddle.distributed.fleet.base.role_maker as role_maker
import unittest
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
import os
import sys
from paddle.static import sparsity
from paddle.fluid.contrib.sparsity.asp import ASPHelper
import numpy as np
cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES')
if cuda_visible_devices is None or cuda_visible_devices == "":
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
else:
os.environ['CUDA_VISIBLE_DEVICES'] = cuda_visible_devices.split(',')[0]
paddle.enable_static()
class TestFleetWithASPSharding(unittest.TestCase):
def setUp(self):
os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36213"
os.environ["PADDLE_CURRENT_ENDPOINTS"] = "127.0.0.1:36213"
os.environ["PADDLE_TRAINERS_NUM"] = "1"
os.environ["PADDLE_TRAINER_ID"] = "0"
os.environ['FLAGS_enable_parallel_graph'] = "0"
os.environ['FLAGS_fraction_of_gpu_memory_to_use'] = "0.1"
os.environ['FLAGS_sync_nccl_allreduce'] = "1"
os.environ['FLAGS_eager_delete_tensor_gb'] = "0"
os.environ['FLAGS_fuse_parameter_memory_size'] = "32"
os.environ['FLAGS_fuse_parameter_groups_size'] = "50"
os.environ['FLAGS_check_nan_inf'] = "0"
def net(self, main_prog, startup_prog):
with fluid.program_guard(main_prog, startup_prog):
input_x = paddle.static.data(
name="x", shape=[-1, 32], dtype='float32')
input_y = paddle.static.data(name="y", shape=[-1, 1], dtype='int64')
fc_1 = fluid.layers.fc(input=input_x, size=64, act='tanh')
fc_2 = fluid.layers.fc(input=fc_1, size=64, act='tanh')
fc_3 = fluid.layers.fc(input=fc_2, size=64, act='tanh')
fc_4 = fluid.layers.fc(input=fc_3, size=64, act='tanh')
prediction = fluid.layers.fc(input=fc_4, size=2, act='softmax')
cost = fluid.layers.cross_entropy(input=prediction, label=input_y)
avg_cost = paddle.mean(x=cost)
dist_strategy = paddle.distributed.fleet.DistributedStrategy()
dist_strategy.sharding = True
dist_strategy.sharding_configs = {
"sharding_segment_strategy": "segment_broadcast_MB",
"segment_broadcast_MB": 32,
"segment_anchors": None,
"sharding_degree": 8,
"mp_degree": 1,
"hybrid_dp": False,
"gradient_merge_acc_step": 1
}
dist_strategy.nccl_comm_num = 1
dist_strategy.asp = True
return avg_cost, dist_strategy, input_x, input_y
def test_with_asp_sharding(self):
if sys.platform == 'win32':
return
print(sys.platform)
fleet.init(is_collective=True)
train_prog, startup_prog = fluid.Program(), fluid.Program()
avg_cost, strategy, input_x, input_y = self.net(train_prog,
startup_prog)
with fluid.program_guard(train_prog, startup_prog):
optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.01)
optimizer = fleet.distributed_optimizer(
optimizer, strategy=strategy)
optimizer.minimize(avg_cost)
if paddle.fluid.is_compiled_with_cuda():
place = fluid.CUDAPlace(
int(os.environ.get('FLAGS_selected_gpus', 0)))
else:
place = fluid.CPUPlace()
exe = fluid.Executor(place)
feeder = fluid.DataFeeder(feed_list=[input_x, input_y], place=place)
exe.run(startup_prog)
sparsity.prune_model(train_prog, sharding=True)
data = (np.random.randn(64, 32), np.random.randint(2, size=(64, 1)))
exe.run(train_prog, feed=feeder.feed([data]))
for param in train_prog.global_block().all_parameters():
if ASPHelper._is_supported_layer(train_prog, param.name):
mat = np.array(fluid.global_scope().find_var(param.name)
.get_tensor())
self.assertTrue(
paddle.fluid.contrib.sparsity.check_sparsity(
mat.T, n=2, m=4))
if __name__ == "__main__":
unittest.main()
......@@ -229,5 +229,7 @@ class TestFleetMetaOptimizer(unittest.TestCase):
"micro_batch_size": 2,
"accumulate_steps": 4,
}
elif name == 'asp':
strategy.asp = True
else:
raise NotImplementedError()
......@@ -190,6 +190,53 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
'momentum', 'momentum'
])
def test_sharding_amp_asp_optimizer(self):
train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program(
)
avg_cost, strategy = self.net(train_prog, startup_prog)
self.set_strategy(strategy, 'sharding')
self.set_strategy(strategy, 'amp')
self.set_strategy(strategy, 'asp')
self.optimizer(avg_cost, strategy, train_prog, startup_prog)
ops = [op.type for op in avg_cost.block.ops]
vars = [x.name for x in train_prog.list_vars()]
parameters = [
x.name for x in train_prog.list_vars() if x.persistable == True
]
self.assertIn('@BroadCast', ''.join(vars))
self.assertIn('cast', ops)
self.assertIn('check_finite_and_unscale', ops)
self.assertEqual(
set(parameters),
set([
'fc_2.b_0', 'num_good_steps_0', 'fc_2.w_0', 'loss_scaling_0',
'num_bad_steps_0', 'fc_2.w_0_velocity_0', 'fc_2.w_0_asp_mask',
'learning_rate_0', 'fc_1.b_0', 'fc_1.w_0_asp_mask',
'fc_0.w_0_asp_mask', 'fc_1.b_0_velocity_0',
'fc_2.b_0_velocity_0'
]))
self.assertEqual(ops, [
'cast', 'cast', 'cast', 'fill_constant', 'fill_constant',
'fill_constant', 'c_sync_calc_stream', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_sync_comm_stream', 'cast', 'mul', 'elementwise_add', 'cast',
'tanh', 'cast', 'mul', 'elementwise_add', 'cast', 'tanh', 'cast',
'mul', 'elementwise_add', 'softmax', 'cast', 'cross_entropy2',
'mean', 'elementwise_mul', 'fill_constant', 'elementwise_mul_grad',
'mean_grad', 'cross_entropy_grad2', 'cast', 'softmax_grad',
'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast',
'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast',
'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream',
'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum',
'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream', 'cast',
'cast', 'cast', 'check_finite_and_unscale', 'cast',
'c_allreduce_max', 'cast', 'update_loss_scaling', 'momentum',
'momentum', 'momentum', 'elementwise_mul'
])
def test_sharding_weight_decay(self):
train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program(
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册