未验证 提交 4ef6b845 编写于 作者: Y Yi Liu 提交者: GitHub

adapte fleet api for localsgd and support nccl comm configuration in executor (#19443)

test=develop
上级 65c73684
......@@ -25,6 +25,7 @@ from paddle.fluid import compiler
import os
import sys
import six
class LambConfig(object):
......@@ -99,7 +100,6 @@ class DistributedStrategy(fluid.BuildStrategy):
self.use_local_sgd = False
self.use_dist_fc = False
self.local_sgd_config = None # LocalSGDConfig
self.dist_fc_config = None # DistFCConfig
self.mode = "nccl2" # or collective
self.collective_mode = None # local_sgd or grad_allreduce
......@@ -107,6 +107,9 @@ class DistributedStrategy(fluid.BuildStrategy):
self.exec_strategy = fluid.ExecutionStrategy()
# configurations below are used for unit test
self._ut4grad_allreduce = False
class CollectiveOpBasedOptimizer(DistributedOptimizer):
"""
......@@ -161,7 +164,7 @@ class CollectiveOptimizer(DistributedOptimizer):
return self._optimizer.apply_gradients(params_grads)
def _check_condition(self, name, **kwargs):
for k, v in kwargs.iterms():
for k, v in six.iteritems(kwargs):
if v is True:
assert False, "you can't use %s and %s together" % (name, k)
......@@ -170,12 +173,13 @@ class CollectiveOptimizer(DistributedOptimizer):
Check the conflict condtions.
"""
if strategy.use_local_sgd:
strategy.mode = "collective"
strategy.collective_mode = "local_sgd"
self._check_condition(
"use_local_sgd",
use_dgc=main_program._enable_dgc,
use_dist_fc=strategy.use_dist_fc,
use_lamb=main_program._use_lamb)
assert strategy.local_sgd_config is not None, "DistributedStrategy.local_sgd_config should be set"
if strategy.use_dist_fc:
self._check_condition(
......@@ -185,6 +189,14 @@ class CollectiveOptimizer(DistributedOptimizer):
use_lamb=main_program._use_lamb)
assert strategy.dist_fc_config is not None, "DistributedStrategy.dist_fc_config should be set"
if strategy._ut4grad_allreduce:
strategy.mode = "collective"
strategy.collective_mode = "grad_allreduce"
self._check_condition(
"_ut4grad_allreduce",
use_dgc=main_program._enable_dgc,
use_lamb=main_program._use_lamb)
if self._strategy.collective_mode=="local_sgd" \
or self._strategy.collective_mode == "grad_allreduce":
assert self._strategy.mode == "collective", \
......
......@@ -122,6 +122,10 @@ class TestDistRunnerBase(object):
dist_strategy.exec_strategy = exec_strategy
dist_strategy.fuse_memory_size = 1 #MB
dist_strategy.fuse_laryer_size = 1
if args.use_local_sgd:
dist_strategy.use_local_sgd = True
if args.ut4grad_allreduce:
dist_strategy._ut4grad_allreduce = True
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role)
......@@ -396,6 +400,8 @@ def runtime_main(test_class):
parser.add_argument('--enable_backward_deps', action='store_true')
parser.add_argument('--use_hallreduce', action='store_true')
parser.add_argument('--gpu_fleet_api', action='store_true')
parser.add_argument('--use_local_sgd', action='store_true')
parser.add_argument('--ut4grad_allreduce', action='store_true')
parser.add_argument(
'--hallreduce_inter_nranks', type=int, required=False, default=2)
parser.add_argument(
......@@ -478,6 +484,8 @@ class TestDistBase(unittest.TestCase):
self._nccl_comm_num = 1
self._enable_backward_deps = False
self._gpu_fleet_api = False
self._use_local_sgd = False
self._ut4grad_allreduce = False
self._use_hallreduce = False
self._setup_config()
self._after_setup_config()
......@@ -731,6 +739,10 @@ class TestDistBase(unittest.TestCase):
if self._gpu_fleet_api:
tr_cmd += " --gpu_fleet_api"
if self._use_local_sgd:
tr_cmd += " --use_local_sgd"
if self._ut4grad_allreduce:
tr_cmd += " --ut4grad_allreduce"
if os.getenv('WITH_COVERAGE', 'OFF') == 'ON':
env['COVERAGE_FILE'] = os.getenv('COVERAGE_FILE', '')
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
from test_dist_base import TestDistBase
class TestDistMnistLocalSGDFleetApi(TestDistBase):
def _setup_config(self):
self._sync_mode = True
self._use_reduce = False
self._use_reader_alloc = False
self._nccl2_mode = True
self._gpu_fleet_api = True
self._use_local_sgd = True
def test_dist_train(self):
import paddle.fluid as fluid
if fluid.core.is_compiled_with_cuda():
self.check_with_place("dist_mnist.py", delta=1e-5)
class TestDistMnistGradAllReduceFleetApi(TestDistBase):
def _setup_config(self):
self._sync_mode = True
self._use_reduce = False
self._use_reader_alloc = False
self._nccl2_mode = True
self._gpu_fleet_api = True
self._ut4grad_allreduce = True
def test_dist_train(self):
import paddle.fluid as fluid
if fluid.core.is_compiled_with_cuda():
self.check_with_place("dist_mnist.py", delta=1e-5)
if __name__ == "__main__":
unittest.main()
......@@ -278,10 +278,12 @@ class LocalSGD(Collective):
Collective._transpile_startup_program(self)
block = self.startup_program.global_block()
non_dist_params = []
for param in block.iter_parameters():
if param.is_distributed:
continue
if not param.is_distributed:
non_dist_params.append(param)
for param in non_dist_params:
snapshot = block.create_var(
name=self.snapshot_name(param.name),
shape=param.shape,
......
......@@ -334,9 +334,9 @@ class DistributeTranspiler(object):
transpiler = None
if collective_mode == 'grad_allreduce':
transpiler = collective.GradAllReduce()
transpiler = collective.GradAllReduce(self.config.nccl_comm_num)
elif collective_mode == 'local_sgd':
transpiler = collective.LocalSGD()
transpiler = collective.LocalSGD(self.config.nccl_comm_num)
else:
raise ValueError('invalid collective_mode: %s' % collective_mode)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册