From 62af59030c734058a4b0321509321793d35a55ad Mon Sep 17 00:00:00 2001 From: Haohongxiang <86215757+haohongxiang@users.noreply.github.com> Date: Mon, 28 Mar 2022 19:34:11 +0800 Subject: [PATCH] [Dygraph] Add unittests for DataParallel in eager mode (#40709) * add uts for EagerReducer * add more uts * fix bugs * fix bugs * modify * modify uts * fix bugs * update * update * update * solve conflicts and merge * add some other uts * modify time of uts * update * update * update * remove uts of resnet --- .../distributed/fleet/base/fleet_base.py | 3 +- python/paddle/fluid/dygraph/parallel.py | 28 ++- .../fluid/tests/unittests/CMakeLists.txt | 10 +- ...llel_dygraph_dataparallel_in_eager_mode.py | 34 ++- ...el_dygraph_gradient_check_in_eager_mode.py | 8 +- .../unittests/parallel_dygraph_no_sync.py | 98 ++++---- .../parallel_dygraph_no_sync_control_flow.py | 96 +------- .../parallel_dygraph_no_sync_unused_params.py | 100 +------- .../tests/unittests/spawn_runner_base.py | 7 + .../fluid/tests/unittests/test_dist_base.py | 217 +++++++++++++++++- ...llel_dygraph_control_flow_in_eager_mode.py | 84 +++++++ .../test_parallel_dygraph_dataparallel.py | 8 + ...t_parallel_dygraph_dataparallel_cpuonly.py | 6 + .../test_parallel_dygraph_no_sync.py | 6 + ..._parallel_dygraph_no_sync_in_eager_mode.py | 111 +++++++++ 15 files changed, 534 insertions(+), 282 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_parallel_dygraph_control_flow_in_eager_mode.py create mode 100644 python/paddle/fluid/tests/unittests/test_parallel_dygraph_no_sync_in_eager_mode.py diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index 17c5fbd4f79..217d4fd43be 100755 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -1053,8 +1053,7 @@ class Fleet(object): last_comm_buffer_size=self._user_defined_strategy. last_comm_group_size_MB, find_unused_parameters=self._user_defined_strategy. - find_unused_parameters, - static_graph=True if recompute_enable else False) + find_unused_parameters) elif self._hcg.get_parallel_mode() == ParallelMode.TENSOR_PARALLEL: model = TensorParallel( model, self._hcg, strategy=self._user_defined_strategy) diff --git a/python/paddle/fluid/dygraph/parallel.py b/python/paddle/fluid/dygraph/parallel.py index ee7d15bbe93..64388aadb2f 100644 --- a/python/paddle/fluid/dygraph/parallel.py +++ b/python/paddle/fluid/dygraph/parallel.py @@ -30,7 +30,7 @@ from paddle.fluid.dygraph import to_variable, no_grad from paddle.utils import deprecated from ..layers import collective from paddle.fluid.dygraph import base as imperative_base -from paddle.fluid.framework import ParamBase, EagerParamBase, _in_legacy_dygraph +from paddle.fluid.framework import ParamBase, _in_legacy_dygraph, _non_static_mode, in_dygraph_mode __all__ = ["prepare_context", "ParallelEnv", "DataParallel"] @@ -574,19 +574,18 @@ class DataParallel(layers.Layer): comm_buffer_size=25, last_comm_buffer_size=1, find_unused_parameters=False, - process_group=None, - gradient_as_buffer_view=False, - static_graph=False): + process_group=None): super(DataParallel, self).__init__(layers.full_name() + "_data_parallel") + assert _non_static_mode(), \ + "It's not supported to construct DataParallel in static mode." + self._layers = layers self.find_unused_parameters = find_unused_parameters self.grad_need_sync = True self.process_group = process_group - self.gradient_as_buffer_view = gradient_as_buffer_view - self.static_graph = static_graph - self.var_dtype = core.eager.Tensor if not _in_legacy_dygraph( + self.var_dtype = core.eager.Tensor if in_dygraph_mode( ) else core.VarBase # NOTE(chenweihang): The ParallelStrategy here is not strictly a strategy. @@ -604,19 +603,19 @@ class DataParallel(layers.Layer): "ParallelContext must be initialized before. You should use init_parallel_env() before" \ "constructing the DataParallel." - if self.process_group is None and (not _in_legacy_dygraph()): + if self.process_group is None and in_dygraph_mode(): raise RuntimeError( - "Process group should be built in DataParallel of eager mode." + "Process group should be built for DataParallel in eager mode." ) # sync buffer and params # TODO(liuyuhui) Currently not support xpu. xpu is # still broadcasting parameters when calling layer if not paddle.is_compiled_with_xpu(): - if not _in_legacy_dygraph(): + if in_dygraph_mode(): sync_eager_params( self._layers, comm_group=self.process_group) - else: + elif _in_legacy_dygraph(): sync_params_buffers(self._layers) self.comm_buffer_size = int(comm_buffer_size * 1024 * 1024) @@ -670,7 +669,7 @@ class DataParallel(layers.Layer): check_layer_sparse(sublayer) for sublayer, _ in layers_param ] - if not _in_legacy_dygraph(): + if in_dygraph_mode(): self.group_indices = core.eager_assign_group_by_size( trainable_parameters, is_sparse_gradient, [self.last_comm_buffer_size, self.comm_buffer_size]) @@ -681,7 +680,7 @@ class DataParallel(layers.Layer): self.process_group, [self.last_comm_buffer_size, self.comm_buffer_size], self.find_unused_parameters) - else: + elif _in_legacy_dygraph(): self.group_indices = core.assign_group_by_size( trainable_parameters, is_sparse_gradient, [self.last_comm_buffer_size, self.comm_buffer_size]) @@ -694,8 +693,7 @@ class DataParallel(layers.Layer): self.find_unused_parameters) def _find_varbase(self, obj): - var_type = core.eager.Tensor if not _in_legacy_dygraph( - ) else core.VarBase + var_type = core.eager.Tensor if in_dygraph_mode() else core.VarBase if isinstance(obj, var_type): return [obj] if isinstance(obj, (list, tuple)): diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 26a1076a64c..bc0f391c651 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -39,7 +39,9 @@ if (WITH_GPU OR WITH_XPU OR WITH_ASCEND OR WITH_ASCEND_CL) endif() list(APPEND DIST_TEST_OPS test_parallel_dygraph_unused_variables) list(APPEND DIST_TEST_OPS test_parallel_dygraph_control_flow) +list(APPEND DIST_TEST_OPS test_parallel_dygraph_control_flow_in_eager_mode) list(APPEND DIST_TEST_OPS test_parallel_dygraph_no_sync) +list(APPEND DIST_TEST_OPS test_parallel_dygraph_no_sync_in_eager_mode) list(APPEND DIST_TEST_OPS test_parallel_dygraph_no_sync_gradient_check) list(APPEND DIST_TEST_OPS test_parallel_dygraph_dataparallel) list(APPEND DIST_TEST_OPS test_parallel_dygraph_pipeline_parallel) @@ -275,7 +277,9 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM)) LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_transformer) LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_sync_batch_norm) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_control_flow) + list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_control_flow_in_eager_mode) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_no_sync) + list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_no_sync_in_eager_mode) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_no_sync_gradient_check) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_dataparallel) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_pipeline_parallel) @@ -748,7 +752,7 @@ if(WITH_DISTRIBUTE) set(dist_ut_port 20001) foreach(TEST_OP ${DIST_TEST_OPS}) bash_test_modules(${TEST_OP} START_BASH dist_test.sh SERIAL LABELS "RUN_TYPE=EXCLUSIVE" ENVS "PADDLE_DIST_UT_PORT=${dist_ut_port}") - MATH(EXPR dist_ut_port "${dist_ut_port}+30") + MATH(EXPR dist_ut_port "${dist_ut_port}+20") if(dist_ut_port GREATER_EQUAL 22998) message(FATAL_ERROR "available ports have been exhausted:${dist_ut_port}") endif() @@ -1121,9 +1125,12 @@ set_tests_properties(test_cumprod_op PROPERTIES TIMEOUT 120) set_tests_properties(test_split_program PROPERTIES TIMEOUT 120) if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL) set_tests_properties(test_parallel_dygraph_dataparallel PROPERTIES TIMEOUT 120) + set_tests_properties(test_parallel_dygraph_mnist PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_unused_variables PROPERTIES TIMEOUT 150) set_tests_properties(test_parallel_dygraph_control_flow PROPERTIES TIMEOUT 200) + set_tests_properties(test_parallel_dygraph_control_flow_in_eager_mode PROPERTIES TIMEOUT 150) set_tests_properties(test_parallel_dygraph_no_sync PROPERTIES TIMEOUT 150) + set_tests_properties(test_parallel_dygraph_no_sync_in_eager_mode PROPERTIES TIMEOUT 150) set_tests_properties(test_parallel_dygraph_no_sync_gradient_check PROPERTIES TIMEOUT 30) set_tests_properties(test_parallel_dygraph_pipeline_parallel PROPERTIES TIMEOUT 200) set_tests_properties(test_parallel_dygraph_tensor_parallel PROPERTIES TIMEOUT 200) @@ -1214,6 +1221,7 @@ set_tests_properties(test_eigvals_op PROPERTIES TIMEOUT 400) set_tests_properties(test_tensordot PROPERTIES TIMEOUT 1000) set_tests_properties(test_tensordot PROPERTIES LABELS "RUN_TYPE=NIGHTLY") if (WITH_GLOO) + set_tests_properties(test_parallel_dygraph_dataparallel_cpuonly PROPERTIES TIMEOUT 30) set_tests_properties(test_parallel_dygraph_unused_variables_gloo PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_sparse_embedding_gloo PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_sparse_embedding_over_height_gloo PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/parallel_dygraph_dataparallel_in_eager_mode.py b/python/paddle/fluid/tests/unittests/parallel_dygraph_dataparallel_in_eager_mode.py index 91c340c35d4..d48a7f09ce7 100644 --- a/python/paddle/fluid/tests/unittests/parallel_dygraph_dataparallel_in_eager_mode.py +++ b/python/paddle/fluid/tests/unittests/parallel_dygraph_dataparallel_in_eager_mode.py @@ -17,6 +17,7 @@ from __future__ import print_function import unittest import os +import copy import numpy as np import random import socket @@ -30,28 +31,22 @@ import paddle.distributed as dist from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.optimizer import SGD from paddle.fluid.initializer import NumpyArrayInitializer - - -def net_is_used(port, ip='127.0.0.1'): - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - try: - s.connect((ip, port)) - s.shutdown(2) - return True - except Exception as e: - return False +from test_parallel_dygraph_dataparallel import get_dist_port_from_flags def init_process_group(strategy=None): nranks = ParallelEnv().nranks rank = ParallelEnv().local_rank is_master = True if rank == 0 else False - for port in range(20000, 21000): - if not net_is_used(port): - store = paddle.fluid.core.TCPStore("127.0.0.1", port, is_master, - nranks) - group = core.ProcessGroupNCCL(store, rank, nranks) - return group + envs = copy.copy(os.environ.copy()) + port = get_dist_port_from_flags() + store = paddle.fluid.core.TCPStore("127.0.0.1", port, is_master, nranks) + if 'PADDLE_DISTRI_BACKEND' in envs.keys() and envs[ + 'PADDLE_DISTRI_BACKEND'] == 'gloo': + group = core.ProcessGroupGloo(store, rank, nranks) + else: + group = core.ProcessGroupNCCL(store, rank, nranks) + return group class LinearModel(nn.Layer): @@ -75,11 +70,12 @@ class TestDistTraning(unittest.TestCase): def test_multiple_gpus(self): process_group = init_process_group() self.generate_reducer("float32", process_group) - self.generate_reducer("float16", process_group) + if paddle.get_device() != "cpu": + self.generate_reducer("float16", process_group) def generate_reducer(self, dtype, process_group): - dev_id = ParallelEnv().dev_id - np.random.seed(2022 + dev_id) + local_rank = ParallelEnv().local_rank + np.random.seed(2022 + local_rank) paddle.set_default_dtype(dtype) w_1 = paddle.ParamAttr(initializer=NumpyArrayInitializer( diff --git a/python/paddle/fluid/tests/unittests/parallel_dygraph_gradient_check_in_eager_mode.py b/python/paddle/fluid/tests/unittests/parallel_dygraph_gradient_check_in_eager_mode.py index 214f41c78a3..bf337d48643 100644 --- a/python/paddle/fluid/tests/unittests/parallel_dygraph_gradient_check_in_eager_mode.py +++ b/python/paddle/fluid/tests/unittests/parallel_dygraph_gradient_check_in_eager_mode.py @@ -17,6 +17,7 @@ from __future__ import print_function import unittest import os +import copy import paddle import numpy as np @@ -39,7 +40,11 @@ def init_process_group(strategy=None): nranks = ParallelEnv().nranks rank = ParallelEnv().local_rank is_master = True if rank == 0 else False - store = paddle.fluid.core.TCPStore("127.0.0.1", 6174, is_master, nranks) + current_env = copy.copy(os.environ.copy()) + port = 6175 + if 'PADDLE_DIST_UT_PORT' in current_env.keys(): + port = int(current_env['PADDLE_DIST_UT_PORT']) + store = paddle.fluid.core.TCPStore("127.0.0.1", port, is_master, nranks) group = core.ProcessGroupNCCL(store, rank, nranks) return group @@ -107,7 +112,6 @@ class TestDistTraning(unittest.TestCase): w2_grad_sum = np.zeros((in_dim, out_dim), dtype='float32') for step_id in range(5): - print("==============", step_id) random_input = paddle.rand(shape=(batch, in_dim)) random_input.stop_gradient = True diff --git a/python/paddle/fluid/tests/unittests/parallel_dygraph_no_sync.py b/python/paddle/fluid/tests/unittests/parallel_dygraph_no_sync.py index 0e7e1a32cfa..f5af896f73e 100644 --- a/python/paddle/fluid/tests/unittests/parallel_dygraph_no_sync.py +++ b/python/paddle/fluid/tests/unittests/parallel_dygraph_no_sync.py @@ -26,8 +26,10 @@ import paddle import paddle.fluid as fluid import paddle.distributed as dist import paddle.fluid.dygraph as dygraph +from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid import core from paddle.fluid.dygraph.nn import Linear +from paddle.fluid.framework import _test_eager_guard from test_dist_base import print_to_err, print_to_out, runtime_main, TestParallelDyGraphRunnerBase seed = 90 @@ -68,6 +70,18 @@ class TestNoSync(TestParallelDyGraphRunnerBase): return loss def run_trainer(self, args): + if args.eager_mode: + self.run_trainer_in_eager_mode(args) + else: + self.run_trainer_func(args) + + def run_trainer_with_spawn(self, args): + if args.eager_mode: + return self.run_trainer_with_spawn_in_eager_mode(args) + else: + return self.run_trainer_with_spawn_func(args) + + def run_trainer_func(self, args): if fluid.core.is_compiled_with_cuda(): device_id = int(os.getenv("FLAGS_selected_gpus", "0")) place = fluid.CUDAPlace(device_id) @@ -86,56 +100,46 @@ class TestNoSync(TestParallelDyGraphRunnerBase): print_to_err( type(self).__name__, "begin to prepare context in dygraph with nccl2") - if not args.find_unused_parameters: - model = paddle.DataParallel( - model, find_unused_parameters=False) - else: - model = paddle.DataParallel( - model, find_unused_parameters=True) - print_to_err(type(self).__name__, "model built in dygraph") - out_losses = [] - print_to_err(type(self).__name__, "begin to run dygraph training") - for step_id, data in enumerate(train_reader()): - data = self._get_data(data, args) - if step_id == RUN_STEP: - break - if step_id % 3 != 0: - if args.update_method == "nccl2": - with model.no_sync(): - loss = self.run_one_loop(model, opt, data) - loss.backward() - else: - loss = self.run_one_loop(model, opt, data) - loss.backward() - else: - loss = self.run_one_loop(model, opt, data) - loss.backward() - opt.minimize(loss) - print_to_err( - type(self).__name__, - "loss at step %d: %f" % (step_id, loss.numpy())) - out_losses.append(loss.numpy()) + model = paddle.DataParallel( + model, find_unused_parameters=args.find_unused_parameters) + print_to_err(type(self).__name__, "model built in dygraph") + return self.model_train(args, model, opt, train_reader) - if not args.accumulate_gradient: - model.clear_gradients() - print_to_out(out_losses) + def run_trainer_in_eager_mode(self, args): + if fluid.core.is_compiled_with_cuda(): + device_id = int(os.getenv("FLAGS_selected_gpus", "0")) + place = fluid.CUDAPlace(device_id) + else: + assert ("Only support CUDAPlace for now.") - def run_trainer_with_spawn(self, args): - fluid.default_startup_program().random_seed = seed - fluid.default_main_program().random_seed = seed - np.random.seed(seed) - random.seed(seed) - args.trainer_id = dist.get_rank() - - if args.update_method == "nccl2": - dist.init_parallel_env() - model, train_reader, opt = self.get_model() - if args.update_method == "nccl2": - if args.find_unused_parameters: - model = paddle.DataParallel(model, find_unused_parameters=True) - else: - model = paddle.DataParallel(model, find_unused_parameters=False) + with fluid.dygraph.guard(place): + fluid.default_startup_program().random_seed = seed + fluid.default_main_program().random_seed = seed + np.random.seed(seed) + random.seed(seed) + + with _test_eager_guard(): + model, train_reader, opt = self.get_model() + if args.update_method == "nccl2": + dist.init_parallel_env() + print_to_err( + type(self).__name__, + "begin to prepare context in dygraph with nccl2") + + nranks = ParallelEnv().nranks + rank = ParallelEnv().local_rank + is_master = True if rank == 0 else False + store = paddle.fluid.core.TCPStore( + "127.0.0.1", args.dist_port, is_master, nranks) + group = core.ProcessGroupNCCL(store, rank, nranks) + model = paddle.DataParallel( + model, + process_group=group, + find_unused_parameters=args.find_unused_parameters) + print_to_err(type(self).__name__, "model built in dygraph") + return self.model_train(args, model, opt, train_reader) + def model_train(self, args, model, opt, train_reader): out_losses = [] for step_id, data in enumerate(train_reader()): data = self._get_data(data, args) diff --git a/python/paddle/fluid/tests/unittests/parallel_dygraph_no_sync_control_flow.py b/python/paddle/fluid/tests/unittests/parallel_dygraph_no_sync_control_flow.py index ebc0cd7d6f3..8b3e1b9aedd 100644 --- a/python/paddle/fluid/tests/unittests/parallel_dygraph_no_sync_control_flow.py +++ b/python/paddle/fluid/tests/unittests/parallel_dygraph_no_sync_control_flow.py @@ -28,7 +28,8 @@ import paddle.distributed as dist import paddle.fluid.dygraph as dygraph from paddle.fluid import core from paddle.fluid.dygraph.nn import Linear -from test_dist_base import print_to_err, print_to_out, runtime_main, TestParallelDyGraphRunnerBase +from test_dist_base import runtime_main +from parallel_dygraph_no_sync import TestNoSync seed = 90 RUN_STEP = 20 @@ -54,7 +55,7 @@ class SimpleNetControlFlow(fluid.Layer): return x -class TestNoSyncControlFlow(TestParallelDyGraphRunnerBase): +class TestNoSyncControlFlow(TestNoSync): def get_model(self): model = SimpleNetControlFlow() train_reader = paddle.batch( @@ -71,97 +72,6 @@ class TestNoSyncControlFlow(TestParallelDyGraphRunnerBase): loss = out.sum() / len(batch) return loss - def run_trainer(self, args): - if fluid.core.is_compiled_with_cuda(): - device_id = int(os.getenv("FLAGS_selected_gpus", "0")) - place = fluid.CUDAPlace(device_id) - else: - assert ("Only support CUDAPlace for now.") - - with fluid.dygraph.guard(place): - fluid.default_startup_program().random_seed = seed - fluid.default_main_program().random_seed = seed - np.random.seed(seed) - random.seed(seed) - model, train_reader, opt = self.get_model() - - if args.update_method == "nccl2": - dist.init_parallel_env() - print_to_err( - type(self).__name__, - "begin to prepare context in dygraph with nccl2") - if not args.find_unused_parameters: - model = paddle.DataParallel( - model, find_unused_parameters=False) - else: - model = paddle.DataParallel( - model, find_unused_parameters=True) - print_to_err(type(self).__name__, "model built in dygraph") - out_losses = [] - print_to_err(type(self).__name__, "begin to run dygraph training") - for step_id, data in enumerate(train_reader()): - data = self._get_data(data, args) - if step_id == RUN_STEP: - break - if step_id % 3 != 0: - if args.update_method == "nccl2": - with model.no_sync(): - loss = self.run_one_loop(model, opt, data) - loss.backward() - else: - loss = self.run_one_loop(model, opt, data) - loss.backward() - else: - loss = self.run_one_loop(model, opt, data) - loss.backward() - opt.minimize(loss) - print_to_err( - type(self).__name__, - "loss at step %d: %f" % (step_id, loss.numpy())) - out_losses.append(loss.numpy()) - - if not args.accumulate_gradient: - model.clear_gradients() - print_to_out(out_losses) - - def run_trainer_with_spawn(self, args): - fluid.default_startup_program().random_seed = seed - fluid.default_main_program().random_seed = seed - np.random.seed(seed) - random.seed(seed) - args.trainer_id = dist.get_rank() - - if args.update_method == "nccl2": - dist.init_parallel_env() - model, train_reader, opt = self.get_model() - if args.update_method == "nccl2": - model = paddle.DataParallel(model, find_unused_parameters=True) - - out_losses = [] - for step_id, data in enumerate(train_reader()): - data = self._get_data(data, args) - if step_id == RUN_STEP: - break - if step_id % 3 != 0: - if args.update_method == "nccl2": - with model.no_sync(): - loss = self.run_one_loop(model, opt, data) - loss.backward() - else: - loss = self.run_one_loop(model, opt, data) - loss.backward() - else: - loss = self.run_one_loop(model, opt, data) - loss.backward() - opt.minimize(loss) - print_to_err( - type(self).__name__, - "loss at step %d: %f" % (step_id, loss.numpy())) - out_losses.append(loss.numpy()) - model.clear_gradients() - print_to_out(out_losses) - return out_losses - def fake_sample_reader(): def __reader__(): diff --git a/python/paddle/fluid/tests/unittests/parallel_dygraph_no_sync_unused_params.py b/python/paddle/fluid/tests/unittests/parallel_dygraph_no_sync_unused_params.py index a5ab327b778..5aecf13bc15 100644 --- a/python/paddle/fluid/tests/unittests/parallel_dygraph_no_sync_unused_params.py +++ b/python/paddle/fluid/tests/unittests/parallel_dygraph_no_sync_unused_params.py @@ -28,7 +28,8 @@ import paddle.distributed as dist import paddle.fluid.dygraph as dygraph from paddle.fluid import core from paddle.fluid.dygraph.nn import Linear -from test_dist_base import print_to_err, print_to_out, runtime_main, TestParallelDyGraphRunnerBase +from test_dist_base import runtime_main +from parallel_dygraph_no_sync import TestNoSync seed = 90 RUN_STEP = 20 @@ -53,7 +54,7 @@ class SimpleNetUnusedParam(fluid.Layer): return x -class TestNoSyncUnusedParam(TestParallelDyGraphRunnerBase): +class TestNoSyncUnusedParam(TestNoSync): def get_model(self): model = SimpleNetUnusedParam() train_reader = paddle.batch( @@ -70,101 +71,6 @@ class TestNoSyncUnusedParam(TestParallelDyGraphRunnerBase): loss = out.sum() / len(batch) return loss - def run_trainer(self, args): - if fluid.core.is_compiled_with_cuda(): - device_id = int(os.getenv("FLAGS_selected_gpus", "0")) - place = fluid.CUDAPlace(device_id) - else: - assert ("Only support CUDAPlace for now.") - - with fluid.dygraph.guard(place): - fluid.default_startup_program().random_seed = seed - fluid.default_main_program().random_seed = seed - np.random.seed(seed) - random.seed(seed) - model, train_reader, opt = self.get_model() - - if args.update_method == "nccl2": - dist.init_parallel_env() - print_to_err( - type(self).__name__, - "begin to prepare context in dygraph with nccl2") - if not args.find_unused_parameters: - model = paddle.DataParallel( - model, find_unused_parameters=False) - else: - model = paddle.DataParallel( - model, find_unused_parameters=True) - print_to_err(type(self).__name__, "model built in dygraph") - out_losses = [] - print_to_err(type(self).__name__, "begin to run dygraph training") - for step_id, data in enumerate(train_reader()): - data = self._get_data(data, args) - if step_id == RUN_STEP: - break - if step_id % 3 != 0: - if args.update_method == "nccl2": - with model.no_sync(): - loss = self.run_one_loop(model, opt, data) - loss.backward() - else: - loss = self.run_one_loop(model, opt, data) - loss.backward() - else: - loss = self.run_one_loop(model, opt, data) - loss.backward() - opt.minimize(loss) - print_to_err( - type(self).__name__, - "loss at step %d: %f" % (step_id, loss.numpy())) - out_losses.append(loss.numpy()) - - if not args.accumulate_gradient: - model.clear_gradients() - print_to_out(out_losses) - - def run_trainer_with_spawn(self, args): - paddle.disable_static() - fluid.default_startup_program().random_seed = seed - fluid.default_main_program().random_seed = seed - np.random.seed(seed) - random.seed(seed) - args.trainer_id = dist.get_rank() - - if args.update_method == "nccl2": - dist.init_parallel_env() - model, train_reader, opt = self.get_model() - if args.update_method == "nccl2": - if args.find_unused_parameters: - model = paddle.DataParallel(model, find_unused_parameters=True) - else: - model = paddle.DataParallel(model, find_unused_parameters=False) - - out_losses = [] - for step_id, data in enumerate(train_reader()): - data = self._get_data(data, args) - if step_id == RUN_STEP: - break - if step_id % 3 != 0: - if args.update_method == "nccl2": - with model.no_sync(): - loss = self.run_one_loop(model, opt, data) - loss.backward() - else: - loss = self.run_one_loop(model, opt, data) - loss.backward() - else: - loss = self.run_one_loop(model, opt, data) - loss.backward() - opt.minimize(loss) - print_to_err( - type(self).__name__, - "loss at step %d: %f" % (step_id, loss.numpy())) - out_losses.append(loss.numpy()) - model.clear_gradients() - print_to_out(out_losses) - return out_losses - def fake_sample_reader(): def __reader__(): diff --git a/python/paddle/fluid/tests/unittests/spawn_runner_base.py b/python/paddle/fluid/tests/unittests/spawn_runner_base.py index 2719e28fea0..e7057f95d28 100644 --- a/python/paddle/fluid/tests/unittests/spawn_runner_base.py +++ b/python/paddle/fluid/tests/unittests/spawn_runner_base.py @@ -21,6 +21,7 @@ import paddle # used by model.run_trainer in test_dist_base from test_dist_base import RUN_STEP +from test_parallel_dygraph_dataparallel import get_dist_port_from_flags # NOTE: compatible TestParallelDyGraphRunnerBase args @@ -28,6 +29,8 @@ class SpawnAssistTestArgs(object): update_method = "local" trainer_id = 0 find_unused_parameters = False + eager_mode = False + dist_port = get_dist_port_from_flags() class TestDistSpawnRunner(unittest.TestCase): @@ -52,10 +55,14 @@ class TestDistSpawnRunner(unittest.TestCase): result_list.append(res_queue.get()) return result_list + def _args_config(self, args): + return + def check_dist_result_with_spawn(self, test_class, delta=1e-3): # 0. prepare model and args model = test_class() args = SpawnAssistTestArgs() + self._args_config(args) # 1. calc signal card loss losses = self._run(model, args) diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index b1a5a5a9540..a2faf1e395d 100755 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -31,9 +31,11 @@ import time import paddle import paddle.fluid as fluid from paddle.fluid import compiler +import paddle.fluid.core as core import paddle.fluid.dygraph as dygraph from paddle.fluid.dygraph.base import to_variable -from paddle.fluid.dygraph.parallel import DataParallel +from paddle.fluid.dygraph.parallel import DataParallel, ParallelEnv +from paddle.fluid.framework import _test_eager_guard from paddle.fluid.incubate.fleet.collective import fleet, DistributedStrategy import paddle.fluid.incubate.fleet.base.role_maker as role_maker @@ -541,7 +543,12 @@ class TestParallelDyGraphRunnerBase(object): return batch def run_trainer(self, args): + if args.eager_mode: + self.run_trainer_in_eager_mode(args) + else: + self.run_trainer_func(args) + def run_trainer_func(self, args): seed = 90 if args.update_method == 'gloo': place = fluid.CPUPlace() @@ -614,7 +621,82 @@ class TestParallelDyGraphRunnerBase(object): model.clear_gradients() print_to_out(out_losses) + def run_trainer_in_eager_mode(self, args): + seed = 90 + if args.update_method == 'gloo': + place = fluid.CPUPlace() + elif fluid.core.is_compiled_with_cuda(): + device_id = int(os.getenv("FLAGS_selected_gpus", "0")) + place = fluid.CUDAPlace(device_id) + elif fluid.core.is_compiled_with_xpu(): + device_id = int(os.getenv("FLAGS_selected_xpus", "0")) + place = fluid.XPUPlace(device_id) + elif fluid.core.is_compiled_with_npu(): + device_id = int(os.getenv("FLAGS_selected_npus", "0")) + place = fluid.NPUPlace(device_id) + else: + assert ("Only support CUDAPlace or XPUPlace or CPU(Gloo) for now.") + + with _test_eager_guard(): + with fluid.dygraph.guard(place): + fluid.default_startup_program().random_seed = seed + fluid.default_main_program().random_seed = seed + np.random.seed(seed) + import random + random.seed(seed) + + model, train_reader, opt = self.get_model() + + #if args.update_method == "nccl2": + if args.update_method in ["nccl2", "gloo"]: + paddle.distributed.init_parallel_env() + nranks = ParallelEnv().nranks + rank = ParallelEnv().local_rank + is_master = True if rank == 0 else False + store = paddle.fluid.core.TCPStore( + "127.0.0.1", args.dist_port, is_master, nranks) + if args.update_method == "nccl2": + group = core.ProcessGroupNCCL(store, rank, nranks) + elif args.update_method == "gloo": + group = core.ProcessGroupGloo(store, rank, nranks) + + print_to_err( + type(self).__name__, + "begin to prepare context in dygraph with nccl2") + model = dygraph.parallel.DataParallel( + model, + process_group=group, + find_unused_parameters=args.find_unused_parameters) + print_to_err(type(self).__name__, "model built in dygraph") + + out_losses = [] + print_to_err( + type(self).__name__, "begin to run dygraph training") + for step_id, data in enumerate(train_reader()): + data = self._get_data(data, args) + if step_id == RUN_STEP: + break + loss = self.run_one_loop(model, opt, data) + if step_id % 10 == 0: + print_to_err( + type(self).__name__, + "loss at step %d: %f" % (step_id, loss.numpy())) + out_losses.append(loss.numpy()) + + loss.backward() + + opt.minimize(loss) + if not args.accumulate_gradient: + model.clear_gradients() + print_to_out(out_losses) + def run_trainer_with_spawn(self, args): + if args.eager_mode: + return self.run_trainer_with_spawn_in_eager_mode(args) + else: + return self.run_trainer_with_spawn_func(args) + + def run_trainer_with_spawn_func(self, args): # 1. enable dygraph paddle.disable_static() @@ -634,10 +716,8 @@ class TestParallelDyGraphRunnerBase(object): # 4. train model model, train_reader, opt = self.get_model() if args.update_method in ["nccl2", "gloo"]: - if args.find_unused_parameters: - model = paddle.DataParallel(model, find_unused_parameters=True) - else: - model = paddle.DataParallel(model, find_unused_parameters=False) + model = paddle.DataParallel( + model, find_unused_parameters=args.find_unused_parameters) out_losses = [] for step_id, data in enumerate(train_reader()): @@ -653,7 +733,64 @@ class TestParallelDyGraphRunnerBase(object): model.clear_gradients() return out_losses + def run_trainer_with_spawn_in_eager_mode(self, args): + # 1. enable dygraph + paddle.disable_static() + + # 2. init seed + seed = 90 + paddle.static.default_startup_program().random_seed = seed + paddle.static.default_main_program().random_seed = seed + np.random.seed(seed) + random.seed(seed) + # get trainer id + args.trainer_id = paddle.distributed.get_rank() + + # 3. init parallel env + if args.update_method in ["nccl2", "gloo"]: + paddle.distributed.init_parallel_env() + + # 4. build process group + nranks = ParallelEnv().nranks + rank = ParallelEnv().local_rank + is_master = True if rank == 0 else False + store = paddle.fluid.core.TCPStore("127.0.0.1", args.dist_port, + is_master, nranks) + if args.update_method == "nccl2": + group = core.ProcessGroupNCCL(store, rank, nranks) + elif args.update_method == "gloo": + group = core.ProcessGroupGloo(store, rank, nranks) + + # 5. train model + with _test_eager_guard(): + model, train_reader, opt = self.get_model() + if args.update_method in ["nccl2", "gloo"]: + model = paddle.DataParallel( + model, + process_group=group, + find_unused_parameters=args.find_unused_parameters) + + out_losses = [] + for step_id, data in enumerate(train_reader()): + data = self._get_data(data, args) + if step_id == RUN_STEP: + break + loss = self.run_one_loop(model, opt, data) + out_losses.append(loss.numpy()) + + loss.backward() + + opt.minimize(loss) + model.clear_gradients() + return out_losses + def run_use_fleet_api_trainer(self, args): + if args.eager_mode: + self.run_use_fleet_api_trainer_in_eager_mode(args) + else: + self.run_use_fleet_api_trainer_func(args) + + def run_use_fleet_api_trainer_func(self, args): import paddle.distributed.fleet as fleet import paddle.distributed.fleet.base.role_maker as role_maker # 1. enable dygraph @@ -698,6 +835,52 @@ class TestParallelDyGraphRunnerBase(object): opt.clear_grad() print_to_out(out_losses) + def run_use_fleet_api_trainer_in_eager_mode(self, args): + import paddle.distributed.fleet as fleet + import paddle.distributed.fleet.base.role_maker as role_maker + # 1. enable dygraph + paddle.disable_static() + + # 2. init seed + seed = 90 + paddle.static.default_startup_program().random_seed = seed + paddle.static.default_main_program().random_seed = seed + np.random.seed(seed) + random.seed(seed) + # get trainer id + args.trainer_id = paddle.distributed.get_rank() + + # set strategy + strategy = fleet.DistributedStrategy() + if args.find_unused_parameters: + strategy.find_unused_parameters = True + + # 3. init parallel env + if args.update_method == "nccl2" or "bkcl" or "hccl": + fleet.init(is_collective=True, strategy=strategy) + + # 4. train model + with _test_eager_guard(): + model, train_reader, opt = self.get_model() + if args.update_method == "nccl2" or "bkcl" or "hccl": + opt = fleet.distributed_optimizer(opt) + model = fleet.distributed_model(model) + + out_losses = [] + for step_id, data in enumerate(train_reader()): + data = self._get_data(data, args) + if step_id == RUN_STEP: + break + loss = self.run_one_loop(model, opt, data) + out_losses.append(loss.numpy()) + + loss.backward() + + opt.step() + if not args.accumulate_gradient: + opt.clear_grad() + print_to_out(out_losses) + def runtime_main(test_class): parser = argparse.ArgumentParser(description='Run dist test.') @@ -728,6 +911,8 @@ def runtime_main(test_class): parser.add_argument( '--current_endpoint', type=str, required=False, default="") parser.add_argument('--sync_mode', action='store_true') + parser.add_argument('--eager_mode', action='store_true') + parser.add_argument('--dist_port', type=int, required=False, default=6175) parser.add_argument('--use_cuda', action='store_true') parser.add_argument('--use_cpu', action='store_true') parser.add_argument('--use_xpu', action='store_true') @@ -820,6 +1005,8 @@ class TestDistBase(unittest.TestCase): self._port_set = set() self._python_interp = sys.executable self._sync_mode = True + self._dist_port = 6175 + self._eager_mode = False self._hogwild_mode = False self._enforce_place = None self._use_reduce = False @@ -861,10 +1048,10 @@ class TestDistBase(unittest.TestCase): self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % ( self._find_free_port(), self._find_free_port()) else: - print("set begin_port:", DIST_UT_PORT) self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % ( DIST_UT_PORT, DIST_UT_PORT + 1) DIST_UT_PORT += 2 + self._dist_port = DIST_UT_PORT self._after_setup_config() @@ -981,6 +1168,10 @@ class TestDistBase(unittest.TestCase): if len(devices) > 1 and self._use_dgc: cmd += " --use_dgc" + if self._eager_mode: + cmd += " --eager_mode" + cmd += " --dist_port {}".format(self._dist_port) + if self._accumulate_gradient: cmd += " --accumulate_gradient" @@ -1054,6 +1245,11 @@ class TestDistBase(unittest.TestCase): if self._sync_mode: tr0_cmd += " --sync_mode" tr1_cmd += " --sync_mode" + if self._eager_mode: + tr0_cmd += " --eager_mode" + tr1_cmd += " --eager_mode" + tr0_cmd += " --dist_port {}".format(self._dist_port) + tr1_cmd += " --dist_port {}".format(self._dist_port) if self._hogwild_mode: tr0_cmd += " --hogwild" tr1_cmd += " --hogwild" @@ -1159,6 +1355,11 @@ class TestDistBase(unittest.TestCase): }) assert self._use_dgc == False, "gloo not support use dgc" + + if self._eager_mode: + tr_cmd += " --eager_mode" + tr_cmd += " --dist_port {}".format(self._dist_port) + if self._accumulate_gradient: tr_cmd += " --accumulate_gradient" @@ -1236,6 +1437,10 @@ class TestDistBase(unittest.TestCase): if self._use_dgc: tr_cmd += " --use_dgc" + if self._eager_mode: + tr_cmd += " --eager_mode" + tr_cmd += " --dist_port {}".format(self._dist_port) + if self._accumulate_gradient: tr_cmd += " --accumulate_gradient" diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_control_flow_in_eager_mode.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_control_flow_in_eager_mode.py new file mode 100644 index 00000000000..dde0c4b260c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_control_flow_in_eager_mode.py @@ -0,0 +1,84 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import os +import sys +import unittest + +import paddle.fluid as fluid +from test_dist_base import TestDistBase +from spawn_runner_base import TestDistSpawnRunner + +flag_name = os.path.splitext(__file__)[0] + + +class TestDygraphControlFlowSameEager(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._nccl2_mode = True + self._eager_mode = True + self._dygraph = True + self._find_unused_parameters = True + + def test_net(self): + if fluid.core.is_compiled_with_cuda(): + self.check_with_place( + "parallel_dygraph_control_flow_same.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + +class TestDygraphControlFlowSameAccGradEager(TestDygraphControlFlowSameEager): + def _setup_config(self): + self._sync_mode = False + self._nccl2_mode = True + self._eager_mode = True + self._dygraph = True + self._accumulate_gradient = True + self._find_unused_parameters = True + + +class TestDygraphControlFlowDiffEager(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._nccl2_mode = True + self._eager_mode = True + self._dygraph = True + self._find_unused_parameters = True + + def test_net(self): + if fluid.core.is_compiled_with_cuda(): + self.check_with_place( + "parallel_dygraph_control_flow_different.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + +class TestFleetDygraphControlFlowDiffAccGradEager( + TestDygraphControlFlowDiffEager): + def _setup_config(self): + self._sync_mode = False + self._nccl2_mode = True + self._eager_mode = True + self._dygraph = True + self._accumulate_gradient = True + self._find_unused_parameters = True + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel.py index 2530fc07753..cbf08856e7e 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel.py @@ -16,6 +16,7 @@ from __future__ import print_function import unittest import time +import paddle import paddle.fluid as fluid import copy import os @@ -143,6 +144,13 @@ def start_local_trainers(cluster, return procs +def get_dist_port_from_flags(): + DIST_UT_PORT = 6175 + if os.getenv("PADDLE_DIST_UT_PORT"): + DIST_UT_PORT = int(os.getenv("PADDLE_DIST_UT_PORT")) + return DIST_UT_PORT + + class TestMultipleGpus(unittest.TestCase): def run_mnist_2gpu(self, target_file_name): if not fluid.core.is_compiled_with_cuda( diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel_cpuonly.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel_cpuonly.py index 6caf0c54ae6..587824a1dc7 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel_cpuonly.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel_cpuonly.py @@ -16,6 +16,7 @@ from __future__ import print_function import unittest import time +import paddle import paddle.fluid as fluid import copy import os @@ -130,5 +131,10 @@ class TestDataParallelGradientCheck(TestMultipleGpus): self.run_mnist_2gpu('parallel_dygraph_gradient_check.py') +class TestDataParallelGradientCheckInEagerMode(TestMultipleGpus): + def test_multiple_gpus_dynamic(self): + self.run_mnist_2gpu('parallel_dygraph_dataparallel_in_eager_mode.py') + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_no_sync.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_no_sync.py index a1a8ae52eb7..2e364e5d4d9 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_no_sync.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_no_sync.py @@ -83,6 +83,9 @@ class TestParallelDygraphNoSyncSpawn(TestDistSpawnRunner): class TestParallelDygraphNoSyncUnusedParamSpawn(TestDistSpawnRunner): + def _args_config(self, args): + args.find_unused_parameters = True + def test_no_sync_with_spawn(self): if fluid.core.is_compiled_with_cuda() and sys.version_info >= (3, 4): self.check_dist_result_with_spawn( @@ -90,6 +93,9 @@ class TestParallelDygraphNoSyncUnusedParamSpawn(TestDistSpawnRunner): class TestParallelDygraphNoSyncControlFlowSpawn(TestDistSpawnRunner): + def _args_config(self, args): + args.find_unused_parameters = True + def test_no_sync_with_spawn(self): if fluid.core.is_compiled_with_cuda() and sys.version_info >= (3, 4): self.check_dist_result_with_spawn( diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_no_sync_in_eager_mode.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_no_sync_in_eager_mode.py new file mode 100644 index 00000000000..d0e7d413952 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_no_sync_in_eager_mode.py @@ -0,0 +1,111 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import os +import sys +import unittest + +import paddle.fluid as fluid +from test_dist_base import TestDistBase +from spawn_runner_base import TestDistSpawnRunner +from parallel_dygraph_no_sync import TestNoSync +from parallel_dygraph_no_sync_unused_params import TestNoSyncUnusedParam +from parallel_dygraph_no_sync_control_flow import TestNoSyncControlFlow + +flag_name = os.path.splitext(__file__)[0] + + +class TestParallelDygraphNoSync(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._eager_mode = True + self._nccl2_mode = True + self._dygraph = True + self._find_unused_parameters = False + + def test_no_sync(self): + if fluid.core.is_compiled_with_cuda(): + self.check_with_place( + "parallel_dygraph_no_sync.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + +class TestParallelDygraphNoSyncUnusedParam(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._eager_mode = True + self._nccl2_mode = True + self._dygraph = True + self._find_unused_parameters = True + + def test_no_sync_ununsed_param(self): + if fluid.core.is_compiled_with_cuda(): + self.check_with_place( + "parallel_dygraph_no_sync_unused_params.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + +class TestParallelDygraphNoSyncControlFlow(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._eager_mode = True + self._nccl2_mode = True + self._dygraph = True + self._find_unused_parameters = True + + def test_no_sync_control_flow(self): + if fluid.core.is_compiled_with_cuda(): + self.check_with_place( + "parallel_dygraph_no_sync_control_flow.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + +class TestParallelDygraphNoSyncSpawn(TestDistSpawnRunner): + def test_no_sync_with_spawn(self): + if fluid.core.is_compiled_with_cuda() and sys.version_info >= (3, 4): + self.check_dist_result_with_spawn(test_class=TestNoSync, delta=1e-5) + + +class TestParallelDygraphNoSyncUnusedParamSpawn(TestDistSpawnRunner): + def _args_config(self, args): + args.find_unused_parameters = True + args.eager_mode = True + + def test_no_sync_with_spawn(self): + if fluid.core.is_compiled_with_cuda() and sys.version_info >= (3, 4): + self.check_dist_result_with_spawn( + test_class=TestNoSyncUnusedParam, delta=1e-5) + + +class TestParallelDygraphNoSyncControlFlowSpawn(TestDistSpawnRunner): + def _args_config(self, args): + args.find_unused_parameters = True + args.eager_mode = True + + def test_no_sync_with_spawn(self): + if fluid.core.is_compiled_with_cuda() and sys.version_info >= (3, 4): + self.check_dist_result_with_spawn( + test_class=TestNoSyncControlFlow, delta=1e-5) + + +if __name__ == "__main__": + unittest.main() -- GitLab