From 60c3ef3ab8143979905af9e0c30600c0a67743ed Mon Sep 17 00:00:00 2001 From: 123malin Date: Thu, 10 Sep 2020 10:58:58 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90paddle.fleet=E3=80=91parameter=5Fserve?= =?UTF-8?q?r=5Foptimizer=20support=20auto=5Fstrategy=20(#27181)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * parameter_server_optimizer support auto_strategy --- .../distributed/fleet/base/fleet_base.py | 8 +- .../fleet/meta_optimizers/amp_optimizer.py | 5 +- .../fleet/meta_optimizers/dgc_optimizer.py | 7 +- .../gradient_merge_optimizer.py | 7 +- .../graph_execution_optimizer.py | 30 +++---- .../fleet/meta_optimizers/lamb_optimizer.py | 7 +- .../fleet/meta_optimizers/lars_optimizer.py | 7 +- .../meta_optimizers/localsgd_optimizer.py | 11 ++- .../meta_optimizers/meta_optimizer_base.py | 2 +- .../parameter_server_graph_optimizer.py | 10 ++- .../parameter_server_optimizer.py | 78 +++++++++--------- .../meta_optimizers/pipeline_optimizer.py | 7 +- .../meta_optimizers/recompute_optimizer.py | 5 +- .../fluid/tests/unittests/CMakeLists.txt | 2 - .../test_dist_fleet_a_sync_optimizer_auto.py | 76 ------------------ ..._dist_fleet_a_sync_optimizer_auto_async.py | 79 +++++++++++++++++++ ...st_dist_fleet_a_sync_optimizer_auto_geo.py | 67 ++++++++++++++++ 17 files changed, 251 insertions(+), 157 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_dist_fleet_a_sync_optimizer_auto_async.py create mode 100644 python/paddle/fluid/tests/unittests/test_dist_fleet_a_sync_optimizer_auto_geo.py diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index b918949269..0dfcd5f325 100644 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -231,7 +231,7 @@ class Fleet(object): Returns: int: worker numbers - + Examples: .. code-block:: python @@ -737,7 +737,7 @@ class Fleet(object): """ Set the value of the learning rate manually in the optimizer. Only work in dygraph mode - + Args: value (float|Tensor): the value of learning rate @@ -877,7 +877,7 @@ class Fleet(object): """ Execute the optimizer once. Only work in dygraph mode - + Returns: None Examples: @@ -1019,7 +1019,7 @@ class Fleet(object): if self.user_defined_strategy._is_strict_auto(): # turn on all the strategy for each optimizer for opt in distributed_optimizer_list: - opt._enable_strategy(self.user_defined_strategy) + opt._enable_strategy(self.user_defined_strategy, context) valid_optimizer_list = [] valid_graph_optimizer_list = [] diff --git a/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py index 938bd25884..31a9913701 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py @@ -34,6 +34,9 @@ class AMPOptimizer(MetaOptimizerBase): loss, role_maker, user_defined_optimizer, user_defined_strategy) def _can_apply(self): + if not self.role_maker._is_collective: + return False + if self.user_defined_strategy.amp: return True return False @@ -42,7 +45,7 @@ class AMPOptimizer(MetaOptimizerBase): dist_strategy.amp = False dist_strategy.amp_configs = {} - def _enable_strategy(self, dist_strategy): + def _enable_strategy(self, dist_strategy, context): dist_strategy.amp = True dist_strategy.amp_configs = { "init_loss_scaling": 32768.0, diff --git a/python/paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py index d292f58456..3f6ed1ed2f 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py @@ -53,6 +53,9 @@ class DGCOptimizer(MetaOptimizerBase): name=opt._name) def _can_apply(self): + if not self.role_maker._is_collective: + return False + if self.user_defined_strategy.dgc: if not isinstance(self.inner_opt, Momentum): logging.warn("dgc only works on Momentum optimizer") @@ -69,7 +72,7 @@ class DGCOptimizer(MetaOptimizerBase): dist_strategy.dgc = False dist_strategy.dgc_configs = {} - def _enable_strategy(self, dist_strategy): + def _enable_strategy(self, dist_strategy, context): dist_strategy.dgc = True dist_strategy.dgc_configs = {"rampup_begin_step": 0, "rampup_step": 1} @@ -89,5 +92,5 @@ class DGCOptimizer(MetaOptimizerBase): no_grad_set=None): optimize_ops, params_grads = \ self.dgc_opt.minimize(loss, startup_program, - parameter_list, no_grad_set) + parameter_list, no_grad_set) return optimize_ops, params_grads diff --git a/python/paddle/distributed/fleet/meta_optimizers/gradient_merge_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/gradient_merge_optimizer.py index bb0c631e08..f1b3680976 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/gradient_merge_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/gradient_merge_optimizer.py @@ -37,15 +37,18 @@ class GradientMergeOptimizer(MetaOptimizerBase): self.user_defined_strategy.gradient_merge_configs["avg"]) def _can_apply(self): + if not self.role_maker._is_collective: + return False + can_apply = (self.user_defined_strategy.gradient_merge == True) and \ - self.user_defined_strategy.gradient_merge_configs["k_steps"] > 1 + self.user_defined_strategy.gradient_merge_configs["k_steps"] > 1 return can_apply def _disable_strategy(self, dist_strategy): dist_strategy.gradient_merge = False dist_strategy.gradient_merge_configs = {} - def _enable_strategy(self, dist_strategy): + def _enable_strategy(self, dist_strategy, context): # we currently do not support auto-enable gradient merge return diff --git a/python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py index 03304f1b68..6c1cc3d7a9 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py @@ -48,7 +48,7 @@ class GraphExecutionOptimizer(MetaOptimizerBase): callbacks=None): pass - # should fix the variable + # should fix the variable def _setup_nccl_op(self, startup_program, main_program, build_strategy): trainer_endpoints = self.role_maker.get_trainer_endpoints() trainers = trainer_endpoints @@ -94,31 +94,31 @@ class GraphExecutionOptimizer(MetaOptimizerBase): dist_strategy = self.user_defined_strategy local_build_strategy = paddle.fluid.BuildStrategy() local_build_strategy.enable_sequential_execution = \ - dist_strategy.build_strategy.enable_sequential_execution + dist_strategy.build_strategy.enable_sequential_execution local_build_strategy.fuse_elewise_add_act_ops = \ - dist_strategy.build_strategy.fuse_elewise_add_act_ops + dist_strategy.build_strategy.fuse_elewise_add_act_ops local_build_strategy.fuse_bn_act_ops = \ - dist_strategy.build_strategy.fuse_bn_act_ops + dist_strategy.build_strategy.fuse_bn_act_ops local_build_strategy.enable_auto_fusion = \ - dist_strategy.build_strategy.enable_auto_fusion + dist_strategy.build_strategy.enable_auto_fusion local_build_strategy.fuse_relu_depthwise_conv = \ - dist_strategy.build_strategy.fuse_relu_depthwise_conv + dist_strategy.build_strategy.fuse_relu_depthwise_conv local_build_strategy.fuse_broadcast_ops = \ - dist_strategy.build_strategy.fuse_broadcast_ops + dist_strategy.build_strategy.fuse_broadcast_ops local_build_strategy.fuse_all_optimizer_ops = \ - dist_strategy.build_strategy.fuse_all_optimizer_ops + dist_strategy.build_strategy.fuse_all_optimizer_ops local_build_strategy.enable_inplace = \ - dist_strategy.build_strategy.enable_inplace + dist_strategy.build_strategy.enable_inplace local_build_strategy.use_hierarchical_allreduce = \ - dist_strategy.use_hierarchical_allreduce + dist_strategy.use_hierarchical_allreduce local_build_strategy.hierarchical_allreduce_inter_nranks = \ - dist_strategy.hierarchical_allreduce_inter_nranks + dist_strategy.hierarchical_allreduce_inter_nranks local_build_strategy.sync_batch_norm = \ - dist_strategy.sync_batch_norm + dist_strategy.sync_batch_norm local_build_strategy.fuse_all_reduce_ops = \ - dist_strategy.fuse_all_reduce_ops + dist_strategy.fuse_all_reduce_ops local_build_strategy.nccl_comm_num = \ - dist_strategy.nccl_comm_num + dist_strategy.nccl_comm_num if self.user_defined_strategy.recompute == True: logging.warn( @@ -190,7 +190,7 @@ class GraphExecutionOptimizer(MetaOptimizerBase): # TODO(guru4elephant): should close all PE related flags here return - def _enable_strategy(self, dist_strategy): + def _enable_strategy(self, dist_strategy, context): # by default, graph execution strategy is enabled return diff --git a/python/paddle/distributed/fleet/meta_optimizers/lamb_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/lamb_optimizer.py index bfa186a1e7..df9887759e 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/lamb_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/lamb_optimizer.py @@ -62,6 +62,9 @@ class LambOptimizer(MetaOptimizerBase): name=opt._name) def _can_apply(self): + if not self.role_maker._is_collective: + return False + if self.user_defined_strategy.lamb: if not isinstance(self.inner_opt, AdamOptimizer): logging.warn( @@ -75,7 +78,7 @@ class LambOptimizer(MetaOptimizerBase): dist_strategy.lamb = False dist_strategy.lamb_configs = {} - def _enable_strategy(self, dist_strategy): + def _enable_strategy(self, dist_strategy, context): dist_strategy.lamb = True dist_strategy.lamb_configs = { "lamb_weight_decay": 0.01, @@ -102,5 +105,5 @@ class LambOptimizer(MetaOptimizerBase): no_grad_set=None): optimize_ops, params_grads = \ self.lamb_opt.minimize(loss, startup_program, - parameter_list, no_grad_set) + parameter_list, no_grad_set) return optimize_ops, params_grads diff --git a/python/paddle/distributed/fleet/meta_optimizers/lars_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/lars_optimizer.py index ec7a7eb18b..609d8b85e7 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/lars_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/lars_optimizer.py @@ -49,6 +49,9 @@ class LarsOptimizer(MetaOptimizerBase): epsilon=configs['epsilon']) def _can_apply(self): + if not self.role_maker._is_collective: + return False + if self.user_defined_strategy.lars: if not isinstance(self.inner_opt, Momentum): logging.warn( @@ -62,7 +65,7 @@ class LarsOptimizer(MetaOptimizerBase): dist_strategy.lars = False dist_strategy.lars_configs = {} - def _enable_strategy(self, dist_strategy): + def _enable_strategy(self, dist_strategy, context): dist_strategy.lars = True dist_strategy.lars_configs = { "lars_coeff": 0.01, @@ -89,5 +92,5 @@ class LarsOptimizer(MetaOptimizerBase): no_grad_set=None): optimize_ops, params_grads = \ self.lars_opt.minimize(loss, startup_program, - parameter_list, no_grad_set) + parameter_list, no_grad_set) return optimize_ops, params_grads diff --git a/python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py index 3c1318301b..4d33dfe745 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py @@ -29,6 +29,9 @@ class LocalSGDOptimizer(MetaOptimizerBase): self.snapshot_key = '@SNAPSHOT' def _can_apply(self): + if not self.role_maker._is_collective: + return False + if not self.user_defined_strategy.localsgd: return False @@ -36,15 +39,15 @@ class LocalSGDOptimizer(MetaOptimizerBase): return False return isinstance(self.inner_opt, paddle.optimizer.momentum.Momentum) \ - or isinstance(self.inner_opt, paddle.fluid.optimizer.Momentum) \ - or isinstance(self.inner_opt, paddle.optimizer.sgd.SGD) \ - or isinstance(self.inner_opt, paddle.fluid.optimizer.SGD) + or isinstance(self.inner_opt, paddle.fluid.optimizer.Momentum) \ + or isinstance(self.inner_opt, paddle.optimizer.sgd.SGD) \ + or isinstance(self.inner_opt, paddle.fluid.optimizer.SGD) def _disable_strategy(self, dist_strategy): dist_strategy.localsgd = False dist_strategy.localsgd_configs = {} - def _enable_strategy(self, dist_strategy): + def _enable_strategy(self, dist_strategy, context): dist_strategy.localsgd = True dist_strategy.localsgd_configs = {"k_steps": 1} diff --git a/python/paddle/distributed/fleet/meta_optimizers/meta_optimizer_base.py b/python/paddle/distributed/fleet/meta_optimizers/meta_optimizer_base.py index b105c25b3a..a12ca50442 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/meta_optimizer_base.py +++ b/python/paddle/distributed/fleet/meta_optimizers/meta_optimizer_base.py @@ -48,7 +48,7 @@ class MetaOptimizerBase(Optimizer): raise NotImplementedError("you should implement disable strategy in {}". format(type(self).__name__)) - def _enable_strategy(self, dist_strategy): + def _enable_strategy(self, dist_strategy, context=None): raise NotImplementedError("you should implement enable strategy in {}". format(type(self).__name__)) diff --git a/python/paddle/distributed/fleet/meta_optimizers/parameter_server_graph_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/parameter_server_graph_optimizer.py index c9260dd2f8..7dc532c86e 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/parameter_server_graph_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/parameter_server_graph_optimizer.py @@ -24,6 +24,9 @@ class ParameterServerGraphOptimizer(ParameterServerOptimizer): self.meta_optimizers_white_list = [] def _can_apply(self): + if self.role_maker._is_collective: + return False + k_steps = self.user_defined_strategy.a_sync_configs["k_steps"] if k_steps < 0: return False @@ -37,12 +40,11 @@ class ParameterServerGraphOptimizer(ParameterServerOptimizer): return True def _disable_strategy(self, dist_strategy): - dist_strategy.a_sync_configs = {} + return - def _enable_strategy(self, dist_strategy): + def _enable_strategy(self, dist_strategy, context): # only open up the async mode for auto-parallel - dist_strategy.a_sync = True - dist_strategy.a_sync_configs = {} + return def _is_graph_out(self): return True diff --git a/python/paddle/distributed/fleet/meta_optimizers/parameter_server_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/parameter_server_optimizer.py index 7dca7b9cb8..51d4d34316 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/parameter_server_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/parameter_server_optimizer.py @@ -32,8 +32,6 @@ class ParameterServerOptimizer(MetaOptimizerBase): def _can_apply(self): if self.role_maker._is_collective: return False - if self.user_defined_strategy.auto == True: - return True k_steps = self.user_defined_strategy.a_sync_configs["k_steps"] return True if k_steps >= 0 else False @@ -134,7 +132,7 @@ class ParameterServerOptimizer(MetaOptimizerBase): return _main, _startup - def _try_auto_apply_geo(self, program, compiled_config): + def _can_apply_geo(self, dist_strategy, program): def get_sys_free_mem(): plat = platform.system() if platform.system() == "Darwin": @@ -163,36 +161,28 @@ class ParameterServerOptimizer(MetaOptimizerBase): "%s platform is unsupported is parameter server optimizer" % (platform.system())) - if self.user_defined_strategy.auto == False: - return - - a_sync_configs = self.user_defined_strategy.a_sync_configs - if a_sync_configs["k_steps"] >= 0: - return - - self.user_defined_strategy.a_sync = True if not isinstance(self.inner_opt, fluid.optimizer.SGDOptimizer): - # auto async - a_sync_configs["k_steps"] = 0 - self.user_defined_strategy.a_sync_configs = a_sync_configs - return + return False - from paddle.fluid.incubate.fleet.parameter_server.ir.vars_metatools import dtype_to_size free = get_sys_free_mem() - param_grad_pairs = compiled_config.origin_sparse_pairs + compiled_config.origin_dense_pairs - processed_var_names = set(["@EMPTY@"]) + from paddle.fluid.incubate.fleet.parameter_server.ir import vars_metatools + processed_var_names = set(["@EMPTY@"]) param_memory_size = 0 - for param_grad_pair in param_grad_pairs: - param, grad = param_grad_pair + for varname in program.global_block().vars: + var = program.global_block().vars[varname] + if not var.persistable or var.desc.type( + ) != core.VarDesc.VarType.LOD_TENSOR: + continue + param = vars_metatools.create_var_struct(var) param_memory_size += param.m_size - processed_var_names.add(param.name) + processed_var_names.add(varname) upper_mem_use = param_memory_size * 5.0 program_tmp_vars = dict() - batch_size = 1024 + eval_batch_size = 1024 for op in program.global_block().ops: for var_name in op.output_arg_names: if var_name in processed_var_names: @@ -215,23 +205,21 @@ class ParameterServerOptimizer(MetaOptimizerBase): data_count *= (-x) else: data_count *= x - program_tmp_vars[var_name] = (data_count, neg_dim_count, - dtype_to_size[var.dtype]) + program_tmp_vars[var_name] = ( + data_count, neg_dim_count, + vars_metatools.dtype_to_size[var.dtype]) for varname in program_tmp_vars: data_count, neg_dim_count, type_size = program_tmp_vars[varname] if neg_dim_count == 1: - data_count *= batch_size + data_count *= eval_batch_size var_memory = data_count * type_size upper_mem_use += var_memory if upper_mem_use < free: - # auto geo - a_sync_configs["k_steps"] = 800 + return True else: - # auto async - a_sync_configs["k_steps"] = 0 - self.user_defined_strategy.a_sync_configs = a_sync_configs + return False def minimize_impl(self, loss, @@ -240,6 +228,7 @@ class ParameterServerOptimizer(MetaOptimizerBase): no_grad_set=None): self.inner_opt.minimize(loss, startup_program, parameter_list, no_grad_set) + strategy = self._get_distributed_strategy() _origin_main_program = loss.block.program _origin_startup_program = startup_program @@ -247,11 +236,7 @@ class ParameterServerOptimizer(MetaOptimizerBase): compiled_config = public.CompileTimeStrategy(_origin_main_program, _origin_startup_program, - None, self.role_maker) - - self._try_auto_apply_geo(_origin_main_program, compiled_config) - - strategy = self._get_distributed_strategy() + strategy, self.role_maker) compiled_config.strategy = strategy if self.role_maker.is_worker() or self.role_maker._is_heter_worker(): @@ -267,9 +252,24 @@ class ParameterServerOptimizer(MetaOptimizerBase): return None, None def _disable_strategy(self, dist_strategy): - dist_strategy.a_sync_configs = {} - self.user_defined_strategy.a_sync_configs = {} + dist_strategy.a_sync = False + a_sync_configs = dist_strategy.a_sync_configs + a_sync_configs["k_steps"] = -1 + dist_strategy.a_sync_configs = a_sync_configs + + def _enable_strategy(self, dist_strategy, context): + a_sync_configs = dist_strategy.a_sync_configs + if a_sync_configs["k_steps"] >= 0: + return - def _enable_strategy(self, dist_strategy): dist_strategy.a_sync = True - dist_strategy.a_sync_configs = {} + a_sync_configs = dist_strategy.a_sync_configs + + is_geo = self._can_apply_geo(dist_strategy, + context["origin_main_program"]) + + if is_geo: + a_sync_configs["k_steps"] = 800 + else: + a_sync_configs["k_steps"] = 0 + dist_strategy.a_sync_configs = a_sync_configs diff --git a/python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py index 32c54d4486..87fa707791 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py @@ -103,6 +103,9 @@ class PipelineOptimizer(MetaOptimizerBase): self.wrapped_opt = PO(self.inner_opt, num_microbatches=num_microbatches) def _can_apply(self): + if not self.role_maker._is_collective: + return False + if self.user_defined_strategy.pipeline == True: return True return False @@ -111,7 +114,7 @@ class PipelineOptimizer(MetaOptimizerBase): dist_strategy.pipeline = False dist_strategy.pipeline_configs = {} - def _enable_strategy(self, dist_strategy): + def _enable_strategy(self, dist_strategy, context): # we do not support enable pipeline automatically right now return @@ -180,7 +183,7 @@ class PipelineOptimizer(MetaOptimizerBase): grad = None for idx, op in reversed(list(enumerate(block.ops))): if is_backward_op(op) and \ - OP_ROLE_VAR_KEY in op.attr_names: + OP_ROLE_VAR_KEY in op.attr_names: op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY] if len(op_role_var) == 0: continue diff --git a/python/paddle/distributed/fleet/meta_optimizers/recompute_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/recompute_optimizer.py index 267656824c..8f95954869 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/recompute_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/recompute_optimizer.py @@ -38,6 +38,9 @@ class RecomputeOptimizer(MetaOptimizerBase): list(user_defined_strategy.recompute_configs["checkpoints"])) def _can_apply(self): + if self.role_maker._is_collective: + return False + if self.user_defined_strategy.recompute == True: if len(self.user_defined_strategy.recompute_configs[ "checkpoints"]) == 0: @@ -49,7 +52,7 @@ class RecomputeOptimizer(MetaOptimizerBase): dist_strategy.recompute = False dist_strategy.recompute_configs = {} - def _enable_strategy(self, dist_strategy): + def _enable_strategy(self, dist_strategy, context): # we do not support automatically recompute checkpoints currently return diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 8c9dbba2d0..b496b7953a 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -441,8 +441,6 @@ if(WITH_DISTRIBUTE) # FIXME(seiriosX) will fix this list(REMOVE_ITEM DIST_TEST_OPS "test_dist_fleet_sparse_embedding_ctr") list(REMOVE_ITEM DIST_TEST_OPS "test_dist_fleet_gloo") - list(REMOVE_ITEM DIST_TEST_OPS "test_dist_fleet_a_sync_optimizer_auto") - list(REMOVE_ITEM DIST_TEST_OPS "test_dist_fleet_ctr") py_test_modules(test_recv_save_op MODULES test_recv_save_op ENVS ${dist_ENVS}) py_test_modules(test_transpiler_ops MODULES test_transpiler_ops ENVS ${dist_ENVS}) diff --git a/python/paddle/fluid/tests/unittests/test_dist_fleet_a_sync_optimizer_auto.py b/python/paddle/fluid/tests/unittests/test_dist_fleet_a_sync_optimizer_auto.py index ab47659a88..5a5d8afc55 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_fleet_a_sync_optimizer_auto.py +++ b/python/paddle/fluid/tests/unittests/test_dist_fleet_a_sync_optimizer_auto.py @@ -62,82 +62,6 @@ class TestFleetGradientMergeMetaOptimizer(unittest.TestCase): a_sync_configs = optimizer.user_defined_strategy.a_sync_configs self.assertTrue(a_sync_configs['k_steps'] == 0) - def test_a_sync_optimizer2(self): - os.environ["TRAINING_ROLE"] = "TRAINER" - import paddle.distributed.fleet as fleet - - main_program = paddle.fluid.Program() - startup_program = paddle.fluid.Program() - - paddle.fluid.framework.switch_main_program(main_program) - paddle.fluid.framework.switch_startup_program(startup_program) - - fleet.init(role_maker.PaddleCloudRoleMaker()) - input_x = paddle.fluid.layers.data( - name="x", shape=[32], dtype='float32') - input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64') - - fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh') - fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh') - prediction = paddle.fluid.layers.fc(input=[fc_2], size=2, act='softmax') - cost = paddle.fluid.layers.cross_entropy( - input=prediction, label=input_y) - avg_cost = paddle.fluid.layers.mean(x=cost) - - strategy = paddle.distributed.fleet.DistributedStrategy() - strategy.auto = True - optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.01) - optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) - optimizer.minimize(avg_cost) - - self.assertTrue(optimizer.user_defined_strategy.a_sync) - a_sync_configs = optimizer.user_defined_strategy.a_sync_configs - self.assertTrue(a_sync_configs['k_steps'] == 800) - - def test_a_sync_optimizer3(self): - os.environ["TRAINING_ROLE"] = "TRAINER" - import paddle.distributed.fleet as fleet - - main_program = paddle.fluid.Program() - startup_program = paddle.fluid.Program() - - paddle.fluid.framework.switch_main_program(main_program) - paddle.fluid.framework.switch_startup_program(startup_program) - - fleet.init(role_maker.PaddleCloudRoleMaker()) - input_x = paddle.fluid.layers.data( - name="x", - shape=[-1, 1], - dtype="int64", - lod_level=1, - append_batch_size=False) - x_embedding = paddle.fluid.layers.embedding( - is_distributed=False, - input=input_x, - size=[1000000000, 100000], - param_attr=paddle.fluid.ParamAttr( - name="embedding", - initializer=paddle.fluid.initializer.Constant(value=0.01)), - is_sparse=True) - input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64') - - fc_1 = paddle.fluid.layers.fc(input=x_embedding, size=64, act='tanh') - fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh') - prediction = paddle.fluid.layers.fc(input=[fc_2], size=2, act='softmax') - cost = paddle.fluid.layers.cross_entropy( - input=prediction, label=input_y) - avg_cost = paddle.fluid.layers.mean(x=cost) - - strategy = paddle.distributed.fleet.DistributedStrategy() - strategy.auto = True - optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.01) - optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) - optimizer.minimize(avg_cost) - - self.assertTrue(optimizer.user_defined_strategy.a_sync) - a_sync_configs = optimizer.user_defined_strategy.a_sync_configs - self.assertTrue(a_sync_configs['k_steps'] == 0) - if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_dist_fleet_a_sync_optimizer_auto_async.py b/python/paddle/fluid/tests/unittests/test_dist_fleet_a_sync_optimizer_auto_async.py new file mode 100644 index 0000000000..9085556c04 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dist_fleet_a_sync_optimizer_auto_async.py @@ -0,0 +1,79 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import os +import paddle.distributed.fleet.base.role_maker as role_maker +import time + + +class TestFleetGradientMergeMetaOptimizer(unittest.TestCase): + def setUp(self): + os.environ["PADDLE_PSERVER_NUMS"] = "2" + os.environ["PADDLE_TRAINERS_NUM"] = "2" + os.environ["POD_IP"] = "127.0.0.1" + os.environ["PADDLE_PORT"] = "36001" + os.environ["PADDLE_TRAINER_ID"] = "0" + os.environ["PADDLE_TRAINERS_NUM"] = "2" + os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = \ + "127.0.0.1:36001,127.0.0.2:36001" + + def test_a_sync_optimizer3(self): + os.environ["TRAINING_ROLE"] = "TRAINER" + import paddle.distributed.fleet as fleet + + main_program = paddle.fluid.Program() + startup_program = paddle.fluid.Program() + + paddle.fluid.framework.switch_main_program(main_program) + paddle.fluid.framework.switch_startup_program(startup_program) + + fleet.init(role_maker.PaddleCloudRoleMaker()) + input_x = paddle.fluid.layers.data( + name="x", + shape=[-1, 1], + dtype="int64", + lod_level=1, + append_batch_size=False) + x_embedding = paddle.fluid.layers.embedding( + is_distributed=False, + input=input_x, + size=[1000000000, 100000], + param_attr=paddle.fluid.ParamAttr( + name="embedding", + initializer=paddle.fluid.initializer.Constant(value=0.01)), + is_sparse=True) + input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64') + + fc_1 = paddle.fluid.layers.fc(input=x_embedding, size=64, act='tanh') + fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh') + prediction = paddle.fluid.layers.fc(input=[fc_2], size=2, act='softmax') + cost = paddle.fluid.layers.cross_entropy( + input=prediction, label=input_y) + avg_cost = paddle.fluid.layers.mean(x=cost) + + strategy = paddle.distributed.fleet.DistributedStrategy() + strategy.auto = True + optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.01) + optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) + optimizer.minimize(avg_cost) + + self.assertTrue(optimizer.user_defined_strategy.a_sync) + a_sync_configs = optimizer.user_defined_strategy.a_sync_configs + self.assertTrue(a_sync_configs['k_steps'] == 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_dist_fleet_a_sync_optimizer_auto_geo.py b/python/paddle/fluid/tests/unittests/test_dist_fleet_a_sync_optimizer_auto_geo.py new file mode 100644 index 0000000000..4787d048bd --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dist_fleet_a_sync_optimizer_auto_geo.py @@ -0,0 +1,67 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import os +import paddle.distributed.fleet.base.role_maker as role_maker +import time + + +class TestFleetGradientMergeMetaOptimizer(unittest.TestCase): + def setUp(self): + os.environ["PADDLE_PSERVER_NUMS"] = "2" + os.environ["PADDLE_TRAINERS_NUM"] = "2" + os.environ["POD_IP"] = "127.0.0.1" + os.environ["PADDLE_PORT"] = "36001" + os.environ["PADDLE_TRAINER_ID"] = "0" + os.environ["PADDLE_TRAINERS_NUM"] = "2" + os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = \ + "127.0.0.1:36001,127.0.0.2:36001" + + def test_a_sync_optimizer2(self): + os.environ["TRAINING_ROLE"] = "TRAINER" + import paddle.distributed.fleet as fleet + + main_program = paddle.fluid.Program() + startup_program = paddle.fluid.Program() + + paddle.fluid.framework.switch_main_program(main_program) + paddle.fluid.framework.switch_startup_program(startup_program) + + fleet.init(role_maker.PaddleCloudRoleMaker()) + input_x = paddle.fluid.layers.data( + name="x", shape=[32], dtype='float32') + input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64') + + fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh') + fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh') + prediction = paddle.fluid.layers.fc(input=[fc_2], size=2, act='softmax') + cost = paddle.fluid.layers.cross_entropy( + input=prediction, label=input_y) + avg_cost = paddle.fluid.layers.mean(x=cost) + + strategy = paddle.distributed.fleet.DistributedStrategy() + strategy.auto = True + optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.01) + optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) + optimizer.minimize(avg_cost) + + self.assertTrue(optimizer.user_defined_strategy.a_sync) + a_sync_configs = optimizer.user_defined_strategy.a_sync_configs + self.assertTrue(a_sync_configs['k_steps'] == 800) + + +if __name__ == "__main__": + unittest.main() -- GitLab