From 29d8781240753bbe2c00a970288cc8fc18c8d422 Mon Sep 17 00:00:00 2001 From: gongweibao Date: Mon, 12 Aug 2019 11:48:49 +0800 Subject: [PATCH] Polish fleet API to support cuda collective mode and nccl2 mode. (#18966) Polish fleet API to support cuda collective mode and nccl2 mode --- .../framework/details/all_reduce_op_handle.cc | 1 + .../operators/distributed/grpc/grpc_client.cc | 2 +- paddle/fluid/platform/device_context.cc | 6 +- python/paddle/fluid/framework.py | 2 + .../fluid/incubate/fleet/base/fleet_base.py | 8 + .../fluid/incubate/fleet/base/role_maker.py | 30 +- .../incubate/fleet/collective/__init__.py | 301 ++++++++++-------- .../distribute_transpiler/__init__.py | 8 + python/paddle/fluid/optimizer.py | 1 + .../fluid/tests/unittests/CMakeLists.txt | 28 +- .../fluid/tests/unittests/dist_mnist.py | 12 +- .../fluid/tests/unittests/test_dist_base.py | 87 ++++- .../unittests/test_dist_mnist_fleetapi.py | 35 ++ .../fluid/transpiler/distribute_transpiler.py | 4 +- 14 files changed, 363 insertions(+), 162 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_dist_mnist_fleetapi.py diff --git a/paddle/fluid/framework/details/all_reduce_op_handle.cc b/paddle/fluid/framework/details/all_reduce_op_handle.cc index f806a4fa847..e2a0097cb1c 100644 --- a/paddle/fluid/framework/details/all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/all_reduce_op_handle.cc @@ -22,6 +22,7 @@ // asynchronous nccl allreduce or synchronous issue: // https://github.com/PaddlePaddle/Paddle/issues/15049 +// If you want to change this default value, why?(gongwb) DEFINE_bool( sync_nccl_allreduce, true, "If set true, will call `cudaStreamSynchronize(nccl_stream)`" diff --git a/paddle/fluid/operators/distributed/grpc/grpc_client.cc b/paddle/fluid/operators/distributed/grpc/grpc_client.cc index 8504110c6e9..d06d4b63b60 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_client.cc +++ b/paddle/fluid/operators/distributed/grpc/grpc_client.cc @@ -449,7 +449,7 @@ void GRPCClient::Proceed() { // destructed at this moment. if (FLAGS_v >= 3) { std::string msg("GRPCClient Proceed end"); - fwrite(msg.c_str(), msg.length(), 1, stdout); + fwrite(msg.c_str(), msg.length(), 1, stderr); } } diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index c9ce7ed12e4..f8099c7e515 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -32,8 +32,10 @@ platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) { auto it = device_contexts_.find(place); if (it == device_contexts_.end()) { PADDLE_THROW( - "Place %s is not supported, Please re-compile with WITH_GPU " - "option", + "Place %s is not supported, Please check that your paddle compiles " + "with WITH_GPU " + "option or check that your train process hold the correct gpu_id if " + "you use Executor", place); } return it->second.get().get(); diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index d899363cbae..1b496614209 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -2848,6 +2848,8 @@ class Program(object): # use Deep gradient comrepssion or not self._enable_dgc = False + self._use_lamb = False + self._nccl_comm_num = 1 self._use_hierarchical_allreduce = False self._hierarchical_allreduce_inter_nranks = 0 diff --git a/python/paddle/fluid/incubate/fleet/base/fleet_base.py b/python/paddle/fluid/incubate/fleet/base/fleet_base.py index ac9b0f23276..658d971a731 100644 --- a/python/paddle/fluid/incubate/fleet/base/fleet_base.py +++ b/python/paddle/fluid/incubate/fleet/base/fleet_base.py @@ -232,6 +232,14 @@ class Fleet(object): def save_persistables(self, executor, dirname, main_program=None): pass + @abc.abstractmethod + def node_num(self): + pass + + @abc.abstractmethod + def node_id(self): + pass + class DistributedOptimizer(object): """ diff --git a/python/paddle/fluid/incubate/fleet/base/role_maker.py b/python/paddle/fluid/incubate/fleet/base/role_maker.py index e775250af97..ff99a912533 100644 --- a/python/paddle/fluid/incubate/fleet/base/role_maker.py +++ b/python/paddle/fluid/incubate/fleet/base/role_maker.py @@ -350,7 +350,7 @@ class PaddleCloudRoleMaker(RoleMakerBase): for i, ip in enumerate(self.pserver_ips.split(",")): eplist.append(':'.join([ip, ports[i]])) self.endpoints = ",".join(eplist) - self._trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "1")) + self._trainers_num = int(os.getenv("PADDLE_TRAINERS_NUM", "1")) # ip of current node, either a worker or a pserver current_ip = os.getenv("POD_IP", "") if current_ip == "": @@ -380,11 +380,31 @@ class PaddleCloudRoleMaker(RoleMakerBase): assert (self._training_role == "TRAINER") self._worker_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS") self._current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT") - if self._worker_endpoints: - self._worker_endpoints = self._worker_endpoints.split(",") - self._num_trainers = len(self._worker_endpoints) + assert self._worker_endpoints is not None, "can't find PADDLE_TRAINER_ENDPOINTS" + self._worker_endpoints = self._worker_endpoints.split(",") + self._trainers_num = len(self._worker_endpoints) + + self._node_ips = self._get_node_ips_from_endpoints( + self._worker_endpoints) + self._node_ip = self._current_endpoint.split(":")[0].strip() + + self._node_num = len(self._node_ips) + self._node_id = self._node_ips.index(self._node_ip) self._role_is_generated = True + def _get_node_ips_from_endpoints(self, endpoints): + ss = set() + ips = [] + for ep in endpoints: + ip = ep.split(":")[0].strip() + if ip not in ss: + ss.add(ip) + ips.append(ip) + else: + continue + + return ips + def get_pserver_endpoints(self): if not self._role_is_generated: self.generate_role() @@ -418,7 +438,7 @@ class PaddleCloudRoleMaker(RoleMakerBase): def worker_num(self): if not self._role_is_generated: self.generate_role() - return self._trainers + return self._trainers_num class UserDefinedRoleMaker(RoleMakerBase): diff --git a/python/paddle/fluid/incubate/fleet/collective/__init__.py b/python/paddle/fluid/incubate/fleet/collective/__init__.py index 4c72c9636a4..6f67ecc4a9a 100644 --- a/python/paddle/fluid/incubate/fleet/collective/__init__.py +++ b/python/paddle/fluid/incubate/fleet/collective/__init__.py @@ -21,60 +21,20 @@ from paddle.fluid.incubate.fleet.base.fleet_base import Fleet from paddle.fluid.incubate.fleet.base.fleet_base import Mode from paddle.fluid.incubate.fleet.base.fleet_base import DistributedOptimizer +from paddle.fluid import compiler -class DistributedStrategy(object): +import os +import sys + + +class LambConfig(object): def __init__(self): - # precision configs - self.use_fp16 = False - self.use_fp32 = True - # algorithmic communication - self.local_sgd = False - self.dgc = False - # communication topology configs - self.h_allreduce = False - - def build(self): - self.strategy_map = {} - # make sure we set single precision config True - if self.use_fp32 and self.use_fp16: - self.use_fp16 = False - # make sure we set single algorithmic communication True - if self.local_sgd and self.dgc: - self.local_sgd = False - self.strategy_map["fp16"] = self.use_fp16 - self.strategy_map["fp32"] = self.use_fp32 - self.strategy_map["localsgd"] = self.local_sgd - self.strategy_map["dgc"] = self.dgc - self.strategy_map["h_allreduce"] = self.h_allreduce - - -class DistributedOptimizerFactory(object): + pass + + +class DistFCConfig(object): def __init__(self): - self.strategy_to_optimizer_map() - - def strategy_to_optimizer_map(self): - pattern = {} - pattern["fp16"] = ["FP16SGDOptimizer", "FP16LocalSGDOptimizer"] - pattern["fp32"] = ["FP32SGDOptimizer", "FP32LocalSGDOptimizer"] - pattern["localsgd"] = ["FP16LocalSGDOptimizer", "FP32LocalSGDOptimizer"] - pattern["h_allreduce"] = [ - "FP32SGDOptimizer", - "FP32LocalSGDOptimizer", - "FP16SGDOptimizer", - "FP16LocalSGDOptimizer", - ] - self.pattern = pattern - - def create_by_strategy(self, optimizer, strategy): - if strategy == None: - strategy = DistributedStrategy() - strategy.build() - strategy_list = [] - for key in strategy.strategy_map: - if strategy.strategy_map[key]: - strategy_list.append(self.pattern[key]) - classname = list(set.intersection(*map(set, strategy_list)))[0] - return globals()[classname](optimizer, strategy) + pass class Collective(Fleet): @@ -82,6 +42,10 @@ class Collective(Fleet): super(Collective, self).__init__(Mode.COLLECTIVE) self._local_ip = 0 + self.startup_program = None + self._origin_program = None + self.main_program = None + def init_worker(self): logging.warn( "You should not call 'init_worker' method for collective mode.") @@ -103,10 +67,8 @@ class Collective(Fleet): "You should not call 'stop_worker' method for collective mode.") def distributed_optimizer(self, optimizer, strategy=None): - optimizer_factory = DistributedOptimizerFactory() - self._optimizer = \ - optimizer_factory.create_by_strategy(optimizer, strategy) + CollectiveOptimizer(optimizer, strategy) return self._optimizer def save_inference_model(self, @@ -117,16 +79,56 @@ class Collective(Fleet): main_program=None, export_for_deployment=True): io.save_inference_model(dirname, feeded_var_names, target_vars, - self._executor, main_program, None, None, + executor, main_program, None, None, export_for_deployment) def save_persistables(self, executor, dirname, main_program=None): - io.save_persistables(self._executor, dirname, main_program, None) + io.save_persistables(executor, dirname, main_program, None) + + def node_num(self): + return self._role_maker._node_num + + def node_id(self): + return self._role_maker._node_id fleet = Collective() +class DistributedStrategy(fluid.BuildStrategy): + """ + Init function of DistributedStrategy + """ + + def __init__(self): + super(DistributedStrategy, self).__init__() + self.fuse_memory_size = -1 + self.fuse_layer_size = 1 + + self.use_local_sgd = False + self.use_dist_fc = False + + self.local_sgd_config = None # LocalSGDConfig + self.dist_fc_config = None # DistFCConfig + self.mode = "nccl2" # or collective + self.collective_mode = None # local_sgd or grad_allreduce + + self.nccl_comm_num = 2 + + self.exec_strategy = fluid.ExecutionStrategy() + sync_allreduce = os.getenv("FLAGS_sync_nccl_allreduce") + if sync_allreduce == "0": + self._exec_strategy.num_threads = self.nccl_comm_num + 1 + if sef.use_hierarchical_allreduce: + self._exec_strategy.num_threads = 2 * self.nccl_comm_num + 1 + if self._exec_strategy.num_threads > 4: + print( + sys.stderr, + "WARNING: if you use use_hierarchical_allreduce or " + "with multi nccl comm, please set FLAGS_sync_nccl_allreduce = 0" + ) + + class CollectiveOpBasedOptimizer(DistributedOptimizer): """ Collective Operator Base Class For Distributed Optimizer @@ -134,6 +136,9 @@ class CollectiveOpBasedOptimizer(DistributedOptimizer): """ def __init__(self, optimizer, strategy=None): + assert isinstance( + strategy, + DistributedStrategy), "strategy must be DistributedStrategy" super(CollectiveOpBasedOptimizer, self).__init__(optimizer, strategy) def backward(self, @@ -149,69 +154,6 @@ class CollectiveOpBasedOptimizer(DistributedOptimizer): return self._optimizer.apply_gradients(params_grads) -class FP16SGDOptimizer(CollectiveOpBasedOptimizer): - """ - do all reduce within every minibatch - """ - - def __init__(self, optimizer, strategy=None): - super(FP16SGDOptimizer, self).__init__(optimizer, strategy) - - def minimize(self, - loss, - startup_program=None, - parameter_list=None, - no_grad_set=None): - pass - - -class FP32LocalSGDOptimizer(CollectiveOpBasedOptimizer): - def __init__(self, optimizer, strategy=None): - super(FP32LocalSGDOptimizer, self).__init__(optimizer, strategy) - - def minimize(self, - loss, - startup_program=None, - parameter_list=None, - no_grad_set=None): - opts, param_and_grads = self._optimizer.minimize(loss) - config = fluid.DistributeTranspilerConfig() - config.mode = 'collective' - config.collective_mode = 'local_sgd' - t = fluid.DistributeTranspiler(config=config) - t.transpile( - trainer_id=fleet.worker_index(), - trainers=fleet.worker_endpoints(), - current_endpoint=fleet.worker_endpoints()[fleet.worker_index()], - startup_program=startup_program, - program=loss.block.program) - return opts, param_and_grads - - -class FP32SGDOptimizer(CollectiveOpBasedOptimizer): - def __init__(self, optimizer, strategy=None): - super(FP32SGDOptimizer, self).__init__(optimizer, strategy) - - def minimize(self, - loss, - startup_program=None, - parameter_list=None, - no_grad_set=None): - opts, param_and_grads = self._optimizer.minimize(loss) - config = fluid.DistributeTranspilerConfig() - config.mode = 'collective' - config.collective_mode = 'grad_allreduce' - t = fluid.DistributeTranspiler(config=config) - - t.transpile( - trainer_id=fleet.worker_index(), - trainers=fleet.worker_endpoints(), - current_endpoint=fleet.worker_endpoints()[fleet.worker_index()], - startup_program=startup_program, - program=loss.block.program) - return opts, param_and_grads - - class CollectiveOptimizer(DistributedOptimizer): """ DistributedOptimizer is a wrapper for paddle.fluid.optimizer @@ -223,9 +165,9 @@ class CollectiveOptimizer(DistributedOptimizer): training. """ - def __init__(self, optimizer, strategy=None): + def __init__(self, optimizer, strategy=DistributedStrategy()): super(CollectiveOptimizer, self).__init__(optimizer, strategy) - self.strategy = strategy + self.print_config = False def backward(self, loss, @@ -239,6 +181,95 @@ class CollectiveOptimizer(DistributedOptimizer): def apply_gradients(self, params_grads): return self._optimizer.apply_gradients(params_grads) + def _check_condition(self, name, **kwargs): + for k, v in kwargs.iterms(): + if v is True: + assert False, "you can't use %s and %s together" % (name, k) + + def _check_collective_mode(self, main_program, optimizer, strategy): + """ + Check the conflict condtions. + """ + if strategy.use_local_sgd: + self._check_condition( + "use_local_sgd", + use_dgc=main_program._enable_dgc, + use_dist_fc=strategy.use_dist_fc, + use_lamb=main_program._use_lamb) + assert strategy.local_sgd_config is not None, "DistributedStrategy.local_sgd_config should be set" + + if strategy.use_dist_fc: + self._check_condition( + "use_dist_fc", + use_dgc=main_program._enable_dgc, + use_local_sgd=strategy.use_local_sgd, + use_lamb=main_program._use_lamb) + assert strategy.dist_fc_config is not None, "DistributedStrategy.dist_fc_config should be set" + + if self._strategy.collective_mode=="local_sgd" \ + or self._strategy.collective_mode == "grad_allreduce": + assert self._strategy.mode == "collective", \ + "local_sgd and grad_allreduce can be used under collective mode" + + def _transpile(self, startup_program, main_program): + """ + Transpile the programs to distributed programs. And add the variables. + """ + if self._strategy.fuse_all_reduce_ops: + os.environ[ + 'FLAGS_fuse_parameter_memory_size'] = self.fuse_memory_size + os.environ[ + 'FLAGS_fuse_parameter_groups_size'] = self.fuse_layer_size + + worker_endpoints = fleet.worker_endpoints() + trainer_id = fleet.worker_index() + current_endpoint = fleet.worker_endpoints()[trainer_id] + worker_endpoints_env = ','.join(worker_endpoints) + trainers_num = fleet.worker_num() + + if self.print_config: + print("worker_endpoints:{} trainers_num:{} current_endpoint:{} \ + trainer_id:{}".format(worker_endpoints, trainers_num, + current_endpoint, trainer_id)) + + # call transpiler + config = dist_transpiler.DistributeTranspilerConfig() + config.mode = self._strategy.mode + config.collective_mode = self._strategy.collective_mode + + config.nccl_comm_num = self._strategy.nccl_comm_num + config.use_hierarchical_allreduce = self._strategy.use_hierarchical_allreduce + config.hierarchical_allreduce_inter_nranks = self._strategy.hierarchical_allreduce_inter_nranks + + t = dist_transpiler.DistributeTranspiler(config=config) + t.transpile( + trainer_id=trainer_id, + trainers=worker_endpoints_env, + startup_program=startup_program, + program=main_program, + current_endpoint=current_endpoint) + + def _try_to_compile(self, startup_program, main_program): + self._transpile(startup_program, main_program) + + if self._strategy.mode == "collective": + return main_program + + self._strategy.num_trainers = fleet.worker_num() + self._strategy.trainer_id = fleet.worker_index() + self._strategy.trainers_endpoints = fleet.worker_endpoints() + self._strategy.enable_backward_optimizer_op_deps = True + + self._compiled_program = compiler.CompiledProgram(main_program) + + self._compiled_program.with_data_parallel( + loss_name=self._loss.name, + build_strategy=self._strategy, + exec_strategy=self._strategy.exec_strategy, + share_vars_from=None) + + return self._compiled_program + def minimize(self, loss, startup_program=None, @@ -260,24 +291,20 @@ class CollectiveOptimizer(DistributedOptimizer): process, but currently the optimization part is written into Fleet(). A user does not need to care about how to startup a pserver node. """ - optimize_ops, param_grads = self._optimizer.minimize( - loss, startup_program, parameter_list, no_grad_set) + main_program = loss.block.program + if startup_program is None: + startup_program = fluid.default_startup_program() + fleet.startup_program = startup_program - worker_endpoints = fleet.worker_endpoints() - trainer_id = fleet.worker_index() - current_endpoint = fleet.worker_endpoints()[trainer_id] + self._loss = loss - startup_program = startup_program if startup_program else \ - fluid.framework.default_startup_program + self._check_collective_mode(main_program, self._optimizer, + self._strategy) - # call transpiler - config = dist_transpiler.DistributeTranspilerConfig() - config.mode = "nccl2" - t = dist_transpiler.DistributeTranspiler(config=config) - t.transpile( - trainer_id, - trainers=','.join(worker_endpoints), - startup_program=startup_program, - current_endpoint=current_endpoint) + optimize_ops, param_grads = self._optimizer.minimize( + loss, startup_program, parameter_list, no_grad_set) + + fleet._origin_program = main_program + fleet.main_program = self._try_to_compile(startup_program, main_program) return optimize_ops, param_grads diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py b/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py index 8c230c58e32..a13512d130d 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py @@ -239,6 +239,14 @@ class DistributedTranspiler(Fleet): self.main_program, self.startup_program = \ self._transpiler.get_pserver_programs(self.server_endpoints()[self.server_index()]) + def node_num(self): + logging.warn( + "You should not call 'node_num' method for collective mode.") + + def node_id(self): + logging.warn( + "You should not call 'node_id' method for collective mode.") + fleet = DistributedTranspiler() diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 2eb68b82f03..ad4eecb07f6 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -2176,6 +2176,7 @@ class LambOptimizer(AdamOptimizer): def _append_optimize_op(self, block, param_and_grad): assert isinstance(block, framework.Block) + block.program._use_lamb = True moment1 = self._get_accumulator(self._moment1_acc_str, param_and_grad[0]) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index bc98523d85e..410c853cda0 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -8,6 +8,7 @@ if(NOT WITH_DISTRIBUTE) list(REMOVE_ITEM TEST_OPS test_simple_dist_transpiler) list(REMOVE_ITEM TEST_OPS test_listen_and_serv_op) LIST(REMOVE_ITEM TEST_OPS test_dist_mnist) + LIST(REMOVE_ITEM TEST_OPS test_dist_mnist_fleetapi) LIST(REMOVE_ITEM TEST_OPS test_dist_mnist_dgc_nccl) LIST(REMOVE_ITEM TEST_OPS test_dist_mnist_hallreduce) LIST(REMOVE_ITEM TEST_OPS test_dist_mnist_multi_comm) @@ -236,29 +237,32 @@ if(WITH_DISTRIBUTE) if(NOT APPLE) set_tests_properties(test_dist_mnist PROPERTIES TIMEOUT 350 LABELS "RUN_TYPE=EXCLUSIVE") set_tests_properties(test_dist_mnist_dgc_nccl PROPERTIES TIMEOUT 350 LABELS "RUN_TYPE=EXCLUSIVE") - set_tests_properties(test_dist_mnist_hallreduce PROPERTIES TIMEOUT 350 LABELS "RUN_TYPE=EXCLUSIVE") - set_tests_properties(test_dist_mnist_multi_comm PROPERTIES TIMEOUT 350 LABELS "RUN_TYPE=EXCLUSIVE") - set_tests_properties(test_dist_mnist_ring_allreduce PROPERTIES TIMEOUT 350 LABELS "RUN_TYPE=EXCLUSIVE") - set_tests_properties(test_dist_mnist_backward_deps PROPERTIES TIMEOUT 350 LABELS "RUN_TYPE=EXCLUSIVE") + set_tests_properties(test_dist_mnist_hallreduce PROPERTIES TIMEOUT 350 LABELS "RUN_TYPE=EXCLUSIVE") + set_tests_properties(test_dist_mnist_multi_comm PROPERTIES TIMEOUT 350 LABELS "RUN_TYPE=EXCLUSIVE") + set_tests_properties(test_dist_mnist_ring_allreduce PROPERTIES TIMEOUT 350 LABELS "RUN_TYPE=EXCLUSIVE") + set_tests_properties(test_dist_mnist_backward_deps PROPERTIES TIMEOUT 350 LABELS "RUN_TYPE=EXCLUSIVE") + set_tests_properties(test_dist_mnist_fleetapi PROPERTIES TIMEOUT 350 LABELS "RUN_TYPE=EXCLUSIVE") set_tests_properties(test_dist_mnist_lars PROPERTIES TIMEOUT 350 LABELS "RUN_TYPE=EXCLUSIVE") set_tests_properties(test_dist_word2vec PROPERTIES TIMEOUT 350 LABELS "RUN_TYPE=EXCLUSIVE") - set_tests_properties(test_dist_simnet_bow PROPERTIES TIMEOUT 350 LABELS "RUN_TYPE=EXCLUSIVE") - set_tests_properties(test_dist_text_classification PROPERTIES TIMEOUT 350 LABELS "RUN_TYPE=EXCLUSIVE") + set_tests_properties(test_dist_simnet_bow PROPERTIES TIMEOUT 350 LABELS "RUN_TYPE=EXCLUSIVE") + set_tests_properties(test_dist_text_classification PROPERTIES TIMEOUT 350 LABELS "RUN_TYPE=EXCLUSIVE") - list(REMOVE_ITEM TEST_OPS test_dist_se_resnext_dgc) + list(REMOVE_ITEM TEST_OPS test_dist_se_resnext_dgc) list(REMOVE_ITEM TEST_OPS test_dist_se_resnext_sync) list(REMOVE_ITEM TEST_OPS test_dist_se_resnext_async) - list(REMOVE_ITEM TEST_OPS test_dist_se_resnext_sync_with_memopt) + list(REMOVE_ITEM TEST_OPS test_dist_se_resnext_sync_with_memopt) + py_test_modules(test_dist_se_resnext_dgc MODULES test_dist_se_resnext_dgc) - py_test_modules(test_dist_se_resnext_sync MODULES test_dist_se_resnext_sync) + py_test_modules(test_dist_se_resnext_sync MODULES test_dist_se_resnext_sync) py_test_modules(test_dist_se_resnext_nccl MODULES test_dist_se_resnext_nccl) bash_test_modules(test_launch MODULES test_launch.sh) + # FIXME(typhoonzero): add these tests back # py_test_modules(test_dist_transformer MODULES test_dist_transformer) # set_tests_properties(test_dist_transformer PROPERTIES TIMEOUT 1000) - set_tests_properties(test_dist_se_resnext_dgc PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE") - set_tests_properties(test_dist_se_resnext_sync PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE") - set_tests_properties(test_dist_se_resnext_nccl PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE") + set_tests_properties(test_dist_se_resnext_dgc PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE") + set_tests_properties(test_dist_se_resnext_sync PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE") + set_tests_properties(test_dist_se_resnext_nccl PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE") endif(NOT APPLE) # py_test_modules(test_dist_transpiler MODULES test_dist_transpiler) endif() diff --git a/python/paddle/fluid/tests/unittests/dist_mnist.py b/python/paddle/fluid/tests/unittests/dist_mnist.py index c598260e13c..25616155b10 100644 --- a/python/paddle/fluid/tests/unittests/dist_mnist.py +++ b/python/paddle/fluid/tests/unittests/dist_mnist.py @@ -29,6 +29,7 @@ import os import signal from functools import reduce from test_dist_base import TestDistRunnerBase, runtime_main +from paddle.fluid.incubate.fleet.collective import fleet, DistributedStrategy DTYPE = "float32" paddle.dataset.mnist.fetch() @@ -73,7 +74,7 @@ def cnn_model(data): class TestDistMnist2x2(TestDistRunnerBase): - def get_model(self, batch_size=2, use_dgc=False): + def get_model(self, batch_size=2, use_dgc=False, dist_strategy=None): # Input data images = fluid.layers.data(name='pixel', shape=[1, 28, 28], dtype=DTYPE) label = fluid.layers.data(name='label', shape=[1], dtype='int64') @@ -104,7 +105,14 @@ class TestDistMnist2x2(TestDistRunnerBase): paddle.dataset.mnist.test(), batch_size=batch_size) test_reader = paddle.batch( paddle.dataset.mnist.test(), batch_size=batch_size) - opt.minimize(avg_cost) + + if dist_strategy: + dist_opt = fleet.distributed_optimizer( + optimizer=opt, strategy=dist_strategy) + _, param_grads = dist_opt.minimize(avg_cost) + else: + opt.minimize(avg_cost) + return inference_program, avg_cost, train_reader, test_reader, batch_acc, predict diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index b42cea3114c..36646023052 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -31,6 +31,9 @@ import paddle.fluid.dygraph as dygraph from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.parallel import DataParallel +from paddle.fluid.incubate.fleet.collective import fleet, DistributedStrategy +import paddle.fluid.incubate.fleet.base.role_maker as role_maker + RUN_STEP = 5 DEFAULT_BATCH_SIZE = 2 @@ -44,6 +47,10 @@ def my_print(class_name, log_str): sys.stderr.buffer.write(pickle.dumps(print_str)) +def eprint(*args, **kwargs): + print(*args, file=sys.stderr, **kwargs) + + class TestDistRunnerBase(object): def get_model(self, batch_size=DEFAULT_BATCH_SIZE, @@ -96,6 +103,72 @@ class TestDistRunnerBase(object): exe.run(pserver_prog) my_print(type(self).__name__, "run pserver main program done.") + def run_gpu_fleet_api_trainer(self, args): + assert args.update_method == "nccl2" + + self.lr = args.lr + + exec_strategy = fluid.ExecutionStrategy() + exec_strategy.num_threads = 1 + + dist_strategy = DistributedStrategy() + dist_strategy.exec_strategy = exec_strategy + dist_strategy.fuse_memory_size = 1 #MB + dist_strategy.fuse_laryer_size = 1 + + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + my_print("gpu_fleet", "fleet.node_num:") + #"fleet.node_id:", fleet.node_id(), + #"fleet.trainer_num:", fleet.worker_num()) + + test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \ + self.get_model(batch_size=args.batch_size, dist_strategy=dist_strategy) + + trainer_prog = fleet._origin_program + dist_prog = fleet.main_program + + device_id = int(os.getenv("FLAGS_selected_gpus", "0")) + place = fluid.CUDAPlace(device_id) + + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + eprint(type(self).__name__, "run worker startup program done.") + + feed_var_list = [ + var for var in trainer_prog.global_block().vars.values() + if var.is_data + ] + + feeder = fluid.DataFeeder(feed_var_list, place) + reader_generator = train_reader() + + def get_data(): + origin_batch = next(reader_generator) + if args.update_method != "local" and args.use_reader_alloc: + new_batch = [] + for offset, item in enumerate(origin_batch): + if offset % 2 == args.trainer_id: + new_batch.append(item) + return new_batch + else: + return origin_batch + + my_print(type(self).__name__, "begin to train on trainer") + out_losses = [] + for i in six.moves.xrange(RUN_STEP): + loss, = exe.run(dist_prog, + fetch_list=[avg_cost.name], + feed=feeder.feed(get_data())) + out_losses.append(loss[0]) + my_print(type(self).__name__, "run step %d finished" % i) + my_print(type(self).__name__, "trainer run finished") + + if six.PY2: + print(pickle.dumps(out_losses)) + else: + sys.stdout.buffer.write(pickle.dumps(out_losses)) + def run_trainer(self, args): self.lr = args.lr if args.nccl2_reduce_layer_local_run: @@ -318,6 +391,7 @@ def runtime_main(test_class): parser.add_argument('--nccl_comm_num', type=int, required=False, default=1) parser.add_argument('--enable_backward_deps', action='store_true') parser.add_argument('--use_hallreduce', action='store_true') + parser.add_argument('--gpu_fleet_api', action='store_true') parser.add_argument( '--hallreduce_inter_nranks', type=int, required=False, default=2) parser.add_argument( @@ -344,6 +418,8 @@ def runtime_main(test_class): model = test_class() if args.role == "pserver" and args.update_method == "pserver": model.run_pserver(args) + elif args.gpu_fleet_api: + model.run_gpu_fleet_api_trainer(args) else: model.run_trainer(args) @@ -397,6 +473,7 @@ class TestDistBase(unittest.TestCase): self._dygraph = False self._nccl_comm_num = 1 self._enable_backward_deps = False + self._gpu_fleet_api = False self._use_hallreduce = False self._setup_config() self._after_setup_config() @@ -600,7 +677,9 @@ class TestDistBase(unittest.TestCase): env.update({ "CUDA_VISIBLE_DEVICES": "{}".format(trainer_id), "PADDLE_TRAINERS_NUM": "{}".format(trainer_num), - "PADDLE_TRAINER_ID": "{}".format(trainer_id) + "PADDLE_TRAINER_ID": "{}".format(trainer_id), + "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints, + "PADDLE_CURRENT_ENDPOINT": ep, }) else: env.update({'CPU_NUM': '1'}) @@ -620,6 +699,9 @@ class TestDistBase(unittest.TestCase): if self._enable_backward_deps: tr_cmd += " --enable_backward_deps" + if self._gpu_fleet_api: + tr_cmd += " --gpu_fleet_api" + return tr_cmd, env def _run_cluster_nccl2(self, model, envs, nccl2_reduce_layer, @@ -669,6 +751,9 @@ class TestDistBase(unittest.TestCase): pipes[i].close() sys.stderr.write('trainer {} stderr: {}\n'.format(i, tr_err)) + if check_error_log: + print("outs[0]:", outs[0]) + print("outs[1]:", outs[1]) return pickle.loads(outs[0]), pickle.loads(outs[1]) def check_with_place(self, diff --git a/python/paddle/fluid/tests/unittests/test_dist_mnist_fleetapi.py b/python/paddle/fluid/tests/unittests/test_dist_mnist_fleetapi.py new file mode 100644 index 00000000000..30f8592e1da --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dist_mnist_fleetapi.py @@ -0,0 +1,35 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +from test_dist_base import TestDistBase + + +class TestDistMnistNCCL2FleetApi(TestDistBase): + def _setup_config(self): + self._sync_mode = True + self._use_reduce = False + self._use_reader_alloc = False + self._nccl2_mode = True + self._gpu_fleet_api = True + + def test_dist_train(self): + import paddle.fluid as fluid + if fluid.core.is_compiled_with_cuda(): + self.check_with_place("dist_mnist.py", delta=1e-5) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 722531abe4b..5251b2be14b 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -174,7 +174,7 @@ class DistributeTranspilerConfig(object): hierarchical_allreduce_inter_nranks = 0 # if mode is collective - # supported modes: sgd, local_sgd + # supported modes: grad_allreduce, local_sgd collective_mode = None @@ -431,7 +431,7 @@ class DistributeTranspiler(object): trainers_num = len(self.origin_program._trainers_endpoints) # selected automaticly if self.config.hierarchical_allreduce_inter_nranks <= 1: - self.config.hierarchical_allreduce_inter_nranks = fluid.core.get_cuda_device_count( + self.config.hierarchical_allreduce_inter_nranks = core.get_cuda_device_count( ) assert trainers_num > self.config.hierarchical_allreduce_inter_nranks, \ -- GitLab