diff --git a/python/paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py index 98d131822fe3639f48acbff53907f937a49d8605..ce0bad4193dea16a993643babf556c3d1ba5b3dd 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py @@ -74,6 +74,7 @@ class DGCMomentumOptimizer(Optimizer): self._global_step_var = None self._dgc_clip_norm = None + self._num_trainers = num_trainers if grad_clip is not None: if not isinstance(grad_clip, ClipGradByNorm): raise TypeError( @@ -87,7 +88,6 @@ class DGCMomentumOptimizer(Optimizer): num_trainers > 0 ), "The value of num_trainers should be greater than 0!" - self._num_trainers = num_trainers self._dgc_clip_norm = grad_clip.clip_norm * (num_trainers**-0.5) self.regular_type, self.regular_coeff = self._get_regularization_param( @@ -212,7 +212,7 @@ class DGCMomentumOptimizer(Optimizer): ) self._nranks_var = self._add_nranks_var( - name=core.dgc.kDGCNRanksName(), value=-1 + name=core.dgc.kDGCNRanksName(), value=self._num_trainers ) # rampup begin step var for all_reduce_op_handle diff --git a/python/paddle/fluid/compiler.py b/python/paddle/fluid/compiler.py index 32020f3901e09f39811fac55daf482c9b21f8a53..bff15df5f9fcc0c289bf9dcfab2daf1ce029fac2 100644 --- a/python/paddle/fluid/compiler.py +++ b/python/paddle/fluid/compiler.py @@ -417,12 +417,6 @@ class CompiledProgram: ) self._exec_strategy.num_threads = 1 - if self._build_strategy.num_trainers > 1: - assert self._is_data_parallel, ( - "If you use multi-trainer to train the model, you should use " - "the data parallel model, i.e. calling with_data_parallel function." - ) - # TODO(wuyi): trainer endpoings should be passed in through # build_strategy, not program.xxx. # TODO(gongwb): let user to set them once. diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/dist_mnist_gradient_merge.py b/python/paddle/fluid/tests/unittests/collective/fleet/dist_mnist_gradient_merge.py index 35525d03ca7a6ba099a149acff36cce467ff42f1..efd67c5b0e86e0eb38fc044224078087ab83dd6d 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/dist_mnist_gradient_merge.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/dist_mnist_gradient_merge.py @@ -13,7 +13,7 @@ # limitations under the License. from dist_mnist import cnn_model -from test_dist_base import TestDistRunnerBase, runtime_main +from test_dist_base import TestDistRunnerBase, _insert_comm_op, runtime_main import paddle import paddle.fluid as fluid @@ -27,7 +27,7 @@ fluid.default_main_program().random_seed = 1 class TestDistMnist2x2(TestDistRunnerBase): - def get_model(self, batch_size=2): + def get_model(self, batch_size=2, single_device=False): # Input data images = paddle.static.data( name='pixel', shape=[-1, 1, 28, 28], dtype=DTYPE @@ -53,7 +53,12 @@ class TestDistMnist2x2(TestDistRunnerBase): learning_rate=0.001, momentum=0.9 ) opt = fluid.optimizer.GradientMergeOptimizer(opt, 2) - + if single_device: + opt.minimize(avg_cost) + else: + opt._learning_rate = 0.001 + opt._learning_rate_map = {} + _insert_comm_op(opt, avg_cost) # Reader train_reader = paddle.batch( paddle.dataset.mnist.test(), batch_size=batch_size @@ -61,7 +66,7 @@ class TestDistMnist2x2(TestDistRunnerBase): test_reader = paddle.batch( paddle.dataset.mnist.test(), batch_size=batch_size ) - opt.minimize(avg_cost) + return ( inference_program, avg_cost, diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_dist_mnist_dgc_nccl.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_dist_mnist_dgc_nccl.py index 4774733f057b288aa2504c3229b9b4d4f9d14b11..dfd2e39be0139a0688d8350672c464e4f7731df9 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_dist_mnist_dgc_nccl.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_dist_mnist_dgc_nccl.py @@ -54,7 +54,7 @@ class TestDistMnistNCCL2DGC(TestDistBase): if fluid.core.is_compiled_with_cuda(): self.check_with_place( - os.path.abspath("../../dist_mnist.py"), + os.path.abspath("../../dist_mnist_dgc.py"), delta=1e-5, check_error_log=True, log_name=flag_name, @@ -76,38 +76,5 @@ class TestDistMnistNCCL2DGC(TestDistBase): self.temp_dir.cleanup() -class TestDistMnistNCCL2DGCMultiCards(TestDistBase): - def _setup_config(self): - self._sync_mode = True - self._use_reduce = False - self._use_reader_alloc = False - self._nccl2_mode = True - self._use_dgc = True - - def test_dist_train(self): - import paddle.fluid as fluid - - if fluid.core.is_compiled_with_cuda(): - self.check_with_place_multi_cards( - os.path.abspath("../../dist_mnist.py"), - delta=1e-5, - check_error_log=True, - log_name=flag_name, - ) - - def tearDown(self): - import paddle.fluid as fluid - - if fluid.core.is_compiled_with_cuda(): - log_file = os.path.join( - self.temp_dir.name, - 'test_dist_mnist_dgc_nccl_dgc_2cards_local.log', - ) - result = count_of_sparse_all_reduce_calls(log_file) - # same as above, but use two cards - self.assertEqual(result, 6) - self.temp_dir.cleanup() - - if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_dist_mnist_gradient_merge.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_dist_mnist_gradient_merge.py index 3618422345b286e7d5b0cfa4b68d6eb6a60ced1e..fcaaa023ce96c22a51a4237c845e527f3112b5d4 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_dist_mnist_gradient_merge.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_dist_mnist_gradient_merge.py @@ -27,6 +27,7 @@ class TestDistMnistGradMerge(TestDistBase): self._sync_mode = True self._use_reduce = False self._nccl2_mode = True + self._nccl2_reduce_layer = True def test_dist_train(self): if fluid.core.is_compiled_with_cuda(): diff --git a/python/paddle/fluid/tests/unittests/dist_allreduce_op.py b/python/paddle/fluid/tests/unittests/dist_allreduce_op.py index 044c6d78cac10fc15271cf3c7b34dec338ccbf50..efedce9140b7d686ffee4c627de4ad62f0dfe229 100644 --- a/python/paddle/fluid/tests/unittests/dist_allreduce_op.py +++ b/python/paddle/fluid/tests/unittests/dist_allreduce_op.py @@ -17,6 +17,7 @@ from functools import reduce from test_dist_base import TestDistRunnerBase, runtime_main import paddle +import paddle.distributed.fleet as fleet import paddle.fluid as fluid paddle.enable_static() @@ -109,13 +110,11 @@ class TestDistMnist2x2(TestDistRunnerBase): opt.minimize(avg_cost) else: # multi device or distributed multi device - params_grads = opt.backward(avg_cost) - data_parallel_param_grads = [] - for p, g in params_grads: - # NOTE: scale will be done on loss scale in multi_devices_graph_pass using nranks. - grad_reduce = fluid.layers.collective._allreduce(g) - data_parallel_param_grads.append([p, grad_reduce]) - opt.apply_gradients(data_parallel_param_grads) + strategy = fleet.DistributedStrategy() + strategy.without_graph_optimization = True + fleet.init(strategy=strategy) + optimizer = fleet.distributed_optimizer(opt) + optimizer.minimize(avg_cost) return ( inference_program, @@ -128,4 +127,5 @@ class TestDistMnist2x2(TestDistRunnerBase): if __name__ == "__main__": + runtime_main(TestDistMnist2x2) diff --git a/python/paddle/fluid/tests/unittests/dist_mnist_dgc.py b/python/paddle/fluid/tests/unittests/dist_mnist_dgc.py new file mode 100644 index 0000000000000000000000000000000000000000..702d23685cd7de60e592f54ff0e4e91363c11335 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dist_mnist_dgc.py @@ -0,0 +1,132 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import reduce + +from test_dist_base import TestDistRunnerBase, _insert_comm_op, runtime_main + +import paddle +import paddle.fluid as fluid + +paddle.enable_static() + +DTYPE = "float32" +paddle.dataset.mnist.fetch() + +# Fix seed for test +fluid.default_startup_program().random_seed = 1 +fluid.default_main_program().random_seed = 1 + + +def cnn_model(data): + conv_pool_1 = fluid.nets.simple_img_conv_pool( + input=data, + filter_size=5, + num_filters=20, + pool_size=2, + pool_stride=2, + act="relu", + param_attr=fluid.ParamAttr( + initializer=paddle.nn.initializer.Constant(value=0.01) + ), + ) + conv_pool_2 = fluid.nets.simple_img_conv_pool( + input=conv_pool_1, + filter_size=5, + num_filters=50, + pool_size=2, + pool_stride=2, + act="relu", + param_attr=fluid.ParamAttr( + initializer=paddle.nn.initializer.Constant(value=0.01) + ), + ) + + SIZE = 10 + input_shape = conv_pool_2.shape + param_shape = [reduce(lambda a, b: a * b, input_shape[1:], 1)] + [SIZE] + scale = (2.0 / (param_shape[0] ** 2 * SIZE)) ** 0.5 + + predict = paddle.static.nn.fc( + x=conv_pool_2, + size=SIZE, + activation="softmax", + weight_attr=fluid.param_attr.ParamAttr( + initializer=paddle.nn.initializer.Constant(value=0.01) + ), + ) + return predict + + +class TestDistMnistDGC(TestDistRunnerBase): + def get_model(self, batch_size=2, use_dgc=False, build_strategy=None): + # Input data + images = paddle.static.data( + name='pixel', shape=[-1, 1, 28, 28], dtype=DTYPE + ) + label = paddle.static.data(name='label', shape=[-1, 1], dtype='int64') + + # Train program + predict = cnn_model(images) + cost = paddle.nn.functional.cross_entropy( + input=predict, label=label, reduction='none', use_softmax=False + ) + avg_cost = paddle.mean(x=cost) + + # Evaluator + batch_size_tensor = paddle.tensor.create_tensor(dtype='int64') + batch_acc = paddle.static.accuracy( + input=predict, label=label, total=batch_size_tensor + ) + + inference_program = fluid.default_main_program().clone() + if not use_dgc: + opt = fluid.optimizer.Momentum(learning_rate=self.lr, momentum=0.9) + else: + opt = paddle.distributed.fleet.meta_optimizers.DGCMomentumOptimizer( + learning_rate=self.lr, + momentum=0.9, + rampup_begin_step=2, + num_trainers=build_strategy.num_trainers + if build_strategy + else None, + ) + if use_dgc: + assert ( + build_strategy is not None + ), "build_strategy can be None with dgc" + _insert_comm_op(opt, avg_cost, build_strategy) + else: + opt.minimize(avg_cost) + + # Reader + train_reader = paddle.batch( + paddle.dataset.mnist.test(), batch_size=batch_size + ) + test_reader = paddle.batch( + paddle.dataset.mnist.test(), batch_size=batch_size + ) + + return ( + inference_program, + avg_cost, + train_reader, + test_reader, + batch_acc, + predict, + ) + + +if __name__ == "__main__": + runtime_main(TestDistMnistDGC) diff --git a/python/paddle/fluid/tests/unittests/dist_mnist_fp16_allreduce.py b/python/paddle/fluid/tests/unittests/dist_mnist_fp16_allreduce.py index 9aa662854274f9b473142fbc0f85c5baa78156c1..4a4771cc0b6abd3c5b7759d50bcedbe091ed6bba 100644 --- a/python/paddle/fluid/tests/unittests/dist_mnist_fp16_allreduce.py +++ b/python/paddle/fluid/tests/unittests/dist_mnist_fp16_allreduce.py @@ -13,9 +13,10 @@ # limitations under the License. from dist_mnist import cnn_model -from test_dist_base import TestDistRunnerBase, runtime_main +from test_dist_base import TestDistRunnerBase, _insert_comm_op, runtime_main import paddle +import paddle.distributed.fleet as fleet import paddle.fluid as fluid from paddle.distributed.fleet.meta_optimizers import ( FP16AllReduceOptimizer as FP16AllReduce, @@ -30,7 +31,7 @@ fluid.default_main_program().random_seed = 1 class TestDistMnist2x2(TestDistRunnerBase): - def get_model(self, batch_size=2): + def get_model(self, batch_size=2, single_device=False): # Input data images = paddle.static.data( name='pixel', shape=[-1, 1, 28, 28], dtype=DTYPE @@ -55,8 +56,15 @@ class TestDistMnist2x2(TestDistRunnerBase): opt = fluid.optimizer.MomentumOptimizer( learning_rate=0.001, momentum=0.9 ) + opt = FP16AllReduce(opt) + if not single_device: + fleet.init() + _insert_comm_op(opt, avg_cost) + else: + opt.minimize(avg_cost) + # Reader train_reader = paddle.batch( paddle.dataset.mnist.test(), batch_size=batch_size @@ -64,7 +72,6 @@ class TestDistMnist2x2(TestDistRunnerBase): test_reader = paddle.batch( paddle.dataset.mnist.test(), batch_size=batch_size ) - opt.minimize(avg_cost) return ( inference_program, avg_cost, diff --git a/python/paddle/fluid/tests/unittests/test_dist_allreduce_op.py b/python/paddle/fluid/tests/unittests/test_dist_allreduce_op.py index 62f598ee27f9438ed6316782bd4fb17c1531cdd6..6c3a56bfa6fcc99e5b31c06675983513da8b2a6c 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_allreduce_op.py +++ b/python/paddle/fluid/tests/unittests/test_dist_allreduce_op.py @@ -33,7 +33,9 @@ class TestDistMnistNCCL2(TestDistBase): import paddle.fluid as fluid if fluid.core.is_compiled_with_cuda(): - self.check_with_place("dist_allreduce_op.py", delta=1e-5) + self.check_with_place( + "dist_allreduce_op.py", delta=1e-5, check_error_log=True + ) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index b0e4b11d1d9bcbae6f9dce443bd5fb5d0a5d5fe3..7c649bbdc98585a9220f001e1cf62d9836ca4347 100755 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -28,6 +28,9 @@ import numpy as np import paddle import paddle.fluid as fluid import paddle.incubate.distributed.fleet.role_maker as role_maker +from paddle.distributed.fleet.meta_optimizers import ( + RawProgramOptimizer as RawProgram, +) from paddle.fluid import compiler from paddle.incubate.distributed.fleet.collective import ( DistributedStrategy, @@ -53,6 +56,35 @@ def eprint(*args, **kwargs): print(*args, file=sys.stderr, **kwargs) +def _insert_comm_op(opt, loss, build_strategy=None): + opt = RawProgram(opt) + role = paddle.distributed.fleet.base.role_maker.PaddleCloudRoleMaker( + is_collective=True + ) + strategy = paddle.distributed.fleet.DistributedStrategy() + if build_strategy is not None: + strategy.build_strategy = build_strategy + opt._set_basic_info(loss, role, opt, strategy) + + # following code is a copy of RawProgramOptimizer.minimize except init_comm_group + opt.endpoints = opt.role_maker._get_trainer_endpoints() + opt.current_endpoint = opt.endpoints[opt.role_maker._worker_index()] + opt.rank = opt.role_maker._worker_index() + opt.nranks = opt.role_maker._worker_num() + startup_program = paddle.static.default_startup_program() + opt.startup_program = startup_program + + block = loss.block + program = block.program + opt.main_program = program + + optimize_ops, params_grads = opt.inner_opt.minimize(loss, startup_program) + + opt.main_program = program + if opt.nranks > 1: + opt._transpile_main_program(loss) + + class TestDistRunnerBase: def get_model( self, @@ -399,6 +431,51 @@ class TestDistRunnerBase: ) def run_trainer(self, args): + from io import StringIO + + old_stdout = sys.stdout + sys.stdout = StringIO() + + build_stra = fluid.BuildStrategy() + # FIXME force disable enable_inplace and memory_optimize + build_stra.enable_inplace = False + build_stra.memory_optimize = False + + if args.fuse_all_reduce is not None: + sys.stderr.write('fuse_all_reduce={}'.format(args.fuse_all_reduce)) + build_stra.fuse_all_reduce_ops = args.fuse_all_reduce + + if args.hogwild: + build_stra.async_mode = True + + if args.enable_backward_deps: + build_stra.enable_backward_optimizer_op_deps = True + + if args.use_reduce: + build_stra.reduce_strategy = ( + fluid.BuildStrategy.ReduceStrategy.Reduce + ) + else: + build_stra.reduce_strategy = ( + fluid.BuildStrategy.ReduceStrategy.AllReduce + ) + pass_builder = None + if args.batch_merge_repeat > 1: + pass_builder = build_stra._finalize_strategy_and_create_passes() + mypass = pass_builder.insert_pass(0, "multi_batch_merge_pass") + mypass.set("num_repeats", args.batch_merge_repeat) + + if ( + args.update_method == "nccl2" + or args.update_method == "nccl2_reduce_layer" + ): + build_stra.num_trainers = len(args.endpoints.split(",")) + build_stra.trainer_id = args.trainer_id + else: + # case args.update_method == "nccl2_reduce_layer": + build_stra.num_trainers = 1 + build_stra.trainer_id = 0 + self.lr = args.lr if args.nccl2_reduce_layer_local_run: ( @@ -417,7 +494,11 @@ class TestDistRunnerBase: test_reader, batch_acc, predict, - ) = self.get_model(batch_size=args.batch_size, use_dgc=args.use_dgc) + ) = self.get_model( + batch_size=args.batch_size, + use_dgc=args.use_dgc, + build_strategy=build_stra, + ) else: ( test_program, @@ -503,52 +584,9 @@ class TestDistRunnerBase: exec_strategy = fluid.ExecutionStrategy() exec_strategy.num_threads = 1 - build_stra = fluid.BuildStrategy() - # FIXME force disable enable_inplace and memory_optimize - build_stra.enable_inplace = False - build_stra.memory_optimize = False - - if args.fuse_all_reduce is not None: - sys.stderr.write('fuse_all_reduce={}'.format(args.fuse_all_reduce)) - build_stra.fuse_all_reduce_ops = args.fuse_all_reduce - - if args.hogwild: - build_stra.async_mode = True - - if args.enable_backward_deps: - build_stra.enable_backward_optimizer_op_deps = True - - if args.use_reduce: - build_stra.reduce_strategy = ( - fluid.BuildStrategy.ReduceStrategy.Reduce - ) - else: - build_stra.reduce_strategy = ( - fluid.BuildStrategy.ReduceStrategy.AllReduce - ) - - pass_builder = None - if args.batch_merge_repeat > 1: - pass_builder = build_stra._finalize_strategy_and_create_passes() - mypass = pass_builder.insert_pass(0, "multi_batch_merge_pass") - mypass.set("num_repeats", args.batch_merge_repeat) - - if ( - args.update_method == "nccl2" - or args.update_method == "nccl2_reduce_layer" - ): - build_stra.num_trainers = len(args.endpoints.split(",")) - build_stra.trainer_id = args.trainer_id - else: - # case args.update_method == "nccl2_reduce_layer": - build_stra.num_trainers = 1 - build_stra.trainer_id = 0 - print_to_err(type(self).__name__, "begin to compile with data parallel") - binary = compiler.CompiledProgram(trainer_prog).with_data_parallel( - loss_name=avg_cost.name, - build_strategy=build_stra, - exec_strategy=exec_strategy, + binary = compiler.CompiledProgram( + trainer_prog, build_strategy=build_stra ) print_to_err(type(self).__name__, "program compiled with data parallel") @@ -584,8 +622,10 @@ class TestDistRunnerBase: if lr_scheduler is not None: lr_scheduler.step() - print_to_err(type(self).__name__, "trainer run finished") + print_to_err(type(self).__name__, "trainer run finished\n") + # print_to_err(type(self).__name__, "out_losses") + sys.stdout = old_stdout print_to_out(out_losses) diff --git a/python/paddle/fluid/tests/unittests/test_dist_mnist_fp16_allreduce.py b/python/paddle/fluid/tests/unittests/test_dist_mnist_fp16_allreduce.py index 37d0130d2f5aef70713fbec769af50dc6983f102..7b184719e4922c2f1489bbf0424814944f5a75f6 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_mnist_fp16_allreduce.py +++ b/python/paddle/fluid/tests/unittests/test_dist_mnist_fp16_allreduce.py @@ -22,12 +22,15 @@ class TestDistMnist2x2FP16AllReduce(TestDistBase): self._sync_mode = True self._use_reduce = False self._nccl2_mode = True + self._nccl2_reduce_layer = True def test_dist_train(self): import paddle.fluid as fluid if fluid.core.is_compiled_with_cuda(): - self.check_with_place("dist_mnist_fp16_allreduce.py", delta=1e-5) + self.check_with_place( + "dist_mnist_fp16_allreduce.py", delta=1e-5, check_error_log=True + ) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/test_dist_mnist_pg.py b/python/paddle/fluid/tests/unittests/test_dist_mnist_pg.py index f7ae4755d689b84f2ed6f2382de770ec72c6a06a..fc5fd3255002a8a4f98b1c2ecdc32abffef211d8 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_mnist_pg.py +++ b/python/paddle/fluid/tests/unittests/test_dist_mnist_pg.py @@ -36,7 +36,6 @@ class TestDistMnistNCCL2(TestDistBase): "dist_mnist.py", delta=1, need_envs={ - "FLAGS_enable_parallel_graph": "1", "FLAGS_sync_nccl_allreduce": "1", }, )