未验证 提交 5a9f6889 编写于 作者: J JZ-LIANG 提交者: GitHub

[Sharding] add new features (#28568)

* add lars to fleet meta optimizer

* add lamb to proto

* add lamb to fleet meta optimizer

* fixed syntax bug

* fixed syntax bug

* fixed syntax error in lamb, add config setter of lamb in distributed_strategy

* trigger unitest to rerun

* add new unitest func for lamb

* revise unitest for lars and lamb

* revise dgc meta unitest

* revise lars document in distribute_strategy

* revise lars lamb document in distributed_strategy.py

* revise lars lamb document in distributed_strategy.py

* add weight decay exclude logic to lars

* restore optimzier.py

* restore optimizer.py as develop except lars

* add epsilon and exclude fn to distributed_sttrategy

* add lars epsilon

* revise unitest for fleet lars and lamb

* revise lars lamb unitest for CI coverage

* revise lars argument api

* revise lars argument api

* revise lars argument api

* revise api doc of lars

* fix op role

* add sharding save and add_sync_comm_for_test function

* add comm_analyse to utlis

* revise sharding_utils

* add sharding saving unittest

* revise sharding utils for unittest
上级 8c75b255
...@@ -11,13 +11,14 @@ ...@@ -11,13 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import paddle
from paddle.fluid import core from paddle.fluid import core
from functools import reduce from functools import reduce
from paddle.distributed.fleet.meta_optimizers.common import is_loss_grad_op from paddle.distributed.fleet.meta_optimizers.common import is_loss_grad_op
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
import re import re
import os
def check_broadcast(block): def check_broadcast(block):
...@@ -126,11 +127,25 @@ def check_allreduce_sum(block): ...@@ -126,11 +127,25 @@ def check_allreduce_sum(block):
return return
def get_valid_op_role(block, insert_idx):
"""
return OpRole.Forward or OpRole.Backward
"""
op_role = block.ops[insert_idx].attr('op_role')
if (insert_idx >= len(block.ops)) or (
op_role in [int(OpRole.Backward), int(OpRole.Optimize)]):
return OpRole.Backward
if op_role in [int(OpRole.Forward), int(OpRole.Loss)]:
return OpRole.Forward
return get_valid_op_role(block, insert_idx + 1)
def insert_sync_calc_op(block, insert_idx, calc_dep_vars): def insert_sync_calc_op(block, insert_idx, calc_dep_vars):
""" """
_insert_sync_calc_op _insert_sync_calc_op
""" """
op_role = block.ops[insert_idx].attr('op_role') op_role = get_valid_op_role(block, insert_idx)
block._insert_op_without_sync( block._insert_op_without_sync(
insert_idx, insert_idx,
type='c_sync_calc_stream', type='c_sync_calc_stream',
...@@ -144,7 +159,7 @@ def insert_sync_comm_ops(block, insert_idx, nrings, comm_dep_vars): ...@@ -144,7 +159,7 @@ def insert_sync_comm_ops(block, insert_idx, nrings, comm_dep_vars):
""" """
_insert_sync_comm_ops _insert_sync_comm_ops
""" """
op_role = block.ops[insert_idx].attr('op_role') op_role = get_valid_op_role(block, insert_idx)
for i in range(nrings): for i in range(nrings):
block._insert_op_without_sync( block._insert_op_without_sync(
insert_idx, insert_idx,
...@@ -160,7 +175,7 @@ def insert_fill_constant_ops(block, insert_idx, fill_constant_vars): ...@@ -160,7 +175,7 @@ def insert_fill_constant_ops(block, insert_idx, fill_constant_vars):
""" """
_add_fill_constant_ops _add_fill_constant_ops
""" """
op_role = block.ops[insert_idx].attr('op_role') op_role = get_valid_op_role(block, insert_idx)
for broadcast_name in fill_constant_vars: for broadcast_name in fill_constant_vars:
broadcast_var = block.var(broadcast_name) broadcast_var = block.var(broadcast_name)
block._insert_op_without_sync( block._insert_op_without_sync(
...@@ -180,7 +195,7 @@ def insert_cast_ops(block, insert_idx, cast_ops): ...@@ -180,7 +195,7 @@ def insert_cast_ops(block, insert_idx, cast_ops):
""" """
_add_cast_ops _add_cast_ops
""" """
op_role = block.ops[insert_idx].attr('op_role') op_role = get_valid_op_role(block, insert_idx)
for fp16_name, fp32_name in cast_ops.items(): for fp16_name, fp32_name in cast_ops.items():
block._insert_op_without_sync( block._insert_op_without_sync(
insert_idx, insert_idx,
...@@ -217,7 +232,7 @@ def insert_broadcast_ops(block, insert_idx, nrings, broadcast2root): ...@@ -217,7 +232,7 @@ def insert_broadcast_ops(block, insert_idx, nrings, broadcast2root):
_add_broadcast_ops _add_broadcast_ops
""" """
ring_id = -1 ring_id = -1
op_role = block.ops[insert_idx].attr('op_role') op_role = get_valid_op_role(block, insert_idx)
for broadcast_name, root_device in broadcast2root: for broadcast_name, root_device in broadcast2root:
ring_id = (ring_id + 1) % nrings ring_id = (ring_id + 1) % nrings
block._insert_op_without_sync( block._insert_op_without_sync(
...@@ -272,3 +287,115 @@ def insert_scale_loss_grad_ops(block, scale=1.0): ...@@ -272,3 +287,115 @@ def insert_scale_loss_grad_ops(block, scale=1.0):
outputs={'Out': loss_grad_var}, outputs={'Out': loss_grad_var},
attrs={'scale': scale, attrs={'scale': scale,
OP_ROLE_KEY: OpRole.Backward}) OP_ROLE_KEY: OpRole.Backward})
def comm_analyse(main_program):
"""
Analyse the parameter size that need to be broadcast/allreduce during sharding training
"""
reduce_vars = {}
broadcast_vars = {}
block = main_program.global_block()
for op in block.ops:
if op.type == "c_broadcast":
var_name = op.desc.input_arg_names()[0]
broadcast_vars[var_name] = get_var_size(block.var(var_name))
elif op.type == "c_allreduce_sum":
var_name = op.desc.input_arg_names()[0]
reduce_vars[var_name] = get_var_size(block.var(var_name))
varsize_count = {}
gap = 1
for k, v in broadcast_vars.items():
print("broadcast: {}: {} KB".format(k, v))
if (int(v / gap) in varsize_count):
varsize_count[int(v / gap)] += 1
else:
varsize_count[int(v / gap)] = 1
for k, v in reduce_vars.items():
print("allreduce: {}: {} KB".format(k, v))
if (int(v / gap) in varsize_count):
varsize_count[int(v / gap)] += 1
else:
varsize_count[int(v / gap)] = 1
with open("nccl_size.txt", 'w') as f:
sorted_varsize = sorted(varsize_count.items(), key=lambda x: x[0])
for varsize, count in sorted_varsize:
print("NCCL size {}~{} KB: {}".format(varsize, varsize + 1, count))
f.write("NCCL size {}~{} KB: {}\n".format(varsize, varsize + 1,
count))
def add_sync_comm_for_test(program, dist_strategy):
"""
When clone a test prog by clone from the sharding main prog,
part of the sync_comm op maybe be pruned by mistake, this function
add the sync_comm op for the test prog.
"""
#NOTE (liangjianzhong): only support one comm stream by now, use more than one
# comm streams will cause error. should be revise in future.
block = program.global_block()
not_sync_vars = set([])
for op in block.ops:
if op.type in ["c_broadcast", "c_allreduce"]:
for input_name in op.desc.input_arg_names():
not_sync_vars.add(input_name)
if op.type == "c_sync_comm_stream":
for input_name in op.desc.input_arg_names():
not_sync_vars.remove(input_name)
if not_sync_vars:
for nccl_id in range(dist_strategy.nccl_comm_num):
block.append_op(
type='c_sync_comm_stream',
inputs={'X': list(not_sync_vars)},
outputs={'Out': list(not_sync_vars)},
attrs={
'ring_id': nccl_id,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
})
return
def sharding_save_persistables(exe, dirname, main_program, filename=None):
"""
When use sharding, part of persistable vars are unique and are partitioned in different ranks,
and part of persistable vars are duplicated and exist in all the ranks with different values.
This function handles the model saving for sharding training.
"""
def is_opt_vars(var):
# NOTE(liangjianzhong): The checks should be updated when add new compatible optimizer
# now only Momentum and adam are compatible with sharding
checks = [
"_moment1_0", "_moment2_0", "_beta1_pow_acc_0", "_beta2_pow_acc_0",
"_velocity_0"
]
for check in checks:
if var.name.endswith(check):
return True
return False
def is_trainable(var):
return isinstance(var,
paddle.fluid.framework.Parameter) and var.trainable
def sharding_predicate(var):
return is_trainable(var) or is_opt_vars(var)
if int(os.environ.get('PADDLE_TRAINER_ID', 0)) == 0:
paddle.fluid.io.save_persistables(
exe, dirname, main_program=main_program, filename=None)
else:
paddle.fluid.io.save_vars(
exe,
dirname,
main_program=main_program,
predicate=sharding_predicate,
filename=None)
return
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import paddle
import paddle.fluid as fluid
from test_dist_base import TestDistRunnerBase, runtime_main
from dist_mnist import cnn_model
# from paddle.fluid.incubate.fleet.collective import fleet
import paddle.distributed.fleet as fleet
import paddle.distributed.fleet.base.role_maker as role_maker
from paddle.distributed.fleet.meta_optimizers.sharding.utils import sharding_save_persistables
import os
import six
import sys
import pickle
# Fix seed for test
fluid.default_startup_program().random_seed = 1
fluid.default_main_program().random_seed = 1
def runtime_main():
import paddle.distributed.fleet as fleet
# model definition
train_prog = paddle.fluid.Program()
startup_prog = paddle.fluid.Program()
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role)
with fluid.program_guard(train_prog, startup_prog):
with fluid.unique_name.guard():
input_x = paddle.fluid.layers.data(
name="x", shape=[32], dtype='float32')
input_y = paddle.fluid.layers.data(
name="y", shape=[1], dtype='int64')
fc_1 = paddle.fluid.layers.fc(input=input_x,
size=64,
act='tanh')
fc_2 = paddle.fluid.layers.fc(input=fc_1, size=256, act='tanh')
prediction = paddle.fluid.layers.fc(input=[fc_2],
size=2,
act='softmax')
cost = paddle.fluid.layers.cross_entropy(
input=prediction, label=input_y)
avg_cost = paddle.fluid.layers.mean(x=cost)
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.sharding = True
strategy.sharding_configs = {"fuse_broadcast_MB": 0.2}
optimizer = paddle.fluid.optimizer.Momentum(learning_rate=0.01, momentum=0.9)
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
optimizer.minimize(avg_cost)
# execution
device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
place = fluid.CUDAPlace(device_id)
exe = fluid.Executor(place)
exe.run(startup_prog)
dirname="./ut_sharding_save_model"
sharding_save_persistables(exe, dirname, main_program=train_prog, filename=None)
out_losses=[]
if six.PY2:
print(pickle.dumps(out_losses))
else:
sys.stdout.buffer.write(pickle.dumps(out_losses))
if __name__ == "__main__":
#NOTE(liangjianzhong): dist unittest should be imlpement using runtime_main in test_dist_base.py
# but the runtime_main in test_dist_base.py use the fleet, DistributedStrategy from
# paddle.fluid.incubate.fleet.collective which is not support by sharding (paddle.distributed.fleet).
# this should be update in future.
# runtime_main(TestDistMnist2x2)
runtime_main()
\ No newline at end of file
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import shutil
import os
import unittest
from test_dist_base import TestDistBase
import paddle
paddle.enable_static()
class TestDistMnistFleetSave(TestDistBase):
def _setup_config(self):
self._sync_mode = True
self._use_reduce = False
self._use_reader_alloc = False
self._nccl2_mode = True
self._gpu_fleet_api = True
self._sharding_save = True
self._enforce_place = "GPU"
def _rm_temp_files(self, dirname):
shutil.rmtree(dirname)
def _test_saved_files(self, dirname):
sharding_save_files = sorted(os.listdir(dirname))
check_files = ['fc_0.b_0', 'fc_0.b_0_velocity_0', 'fc_0.w_0', 'fc_0.w_0_velocity_0', 'fc_1.b_0',
'fc_1.b_0_velocity_0', 'fc_1.w_0', 'fc_1.w_0_velocity_0', 'fc_2.b_0',
'fc_2.b_0_velocity_0', 'fc_2.w_0', 'fc_2.w_0_velocity_0', 'learning_rate_0']
if sharding_save_files != check_files:
self._rm_temp_files(dirname)
raise ValueError("Test Failed.")
self._rm_temp_files(dirname)
return True
def check_with_place(self,
model_file,
delta=1e-3,
check_error_log=True,
need_envs={},
log_name=""):
required_envs = self._get_required_envs(check_error_log, need_envs)
tr0_losses, tr1_losses = self._run_cluster_nccl2(
model_file,
required_envs,
False,
check_error_log,
log_name=log_name)
dirname = './ut_sharding_save_model'
self._test_saved_files(dirname)
def test_dist_train(self):
import paddle.fluid as fluid
if fluid.core.is_compiled_with_cuda():
self.check_with_place("dist_sharding_save.py", delta=1e-5)
if __name__ == "__main__":
unittest.main()
...@@ -17,8 +17,11 @@ import paddle ...@@ -17,8 +17,11 @@ import paddle
import os import os
import paddle.distributed.fleet as fleet import paddle.distributed.fleet as fleet
import paddle.distributed.fleet.base.role_maker as role_maker import paddle.distributed.fleet.base.role_maker as role_maker
import paddle.fluid.core as core
import paddle.fluid as fluid
from fleet_meta_optimizer_base import TestFleetMetaOptimizer from fleet_meta_optimizer_base import TestFleetMetaOptimizer
from paddle.distributed.fleet.meta_optimizers.sharding.utils import add_sync_comm_for_test, sharding_save_persistables, comm_analyse
paddle.enable_static() paddle.enable_static()
...@@ -270,6 +273,25 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer): ...@@ -270,6 +273,25 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
'momentum' 'momentum'
]) ])
def test_sharding_clone_for_test(self):
train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program(
)
avg_cost, strategy = self.net(train_prog, startup_prog)
self.set_strategy(strategy, 'sharding')
self.optimizer(avg_cost, strategy, train_prog, startup_prog)
comm_analyse(train_prog)
test_prog = train_prog.clone(for_test=True)
add_sync_comm_for_test(test_prog, strategy)
ops = [op.type for op in test_prog.global_block().ops]
self.assertEqual(ops, ['fill_constant', 'fill_constant', 'fill_constant', 'c_sync_calc_stream', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream', 'mul',
'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'softmax',
'cross_entropy2', 'mean'])
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册