From 6c5c547ef8716e783ea1162c8f274ce36b3e2c69 Mon Sep 17 00:00:00 2001 From: mapingshuo Date: Wed, 23 Sep 2020 11:44:34 +0800 Subject: [PATCH] correct role_maker usage --- .../fleet/meta_optimizers/zero_optimizer.py | 39 +++++++++---------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/zero_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/zero_optimizer.py index a5c865a3a5..11adf124a2 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/zero_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/zero_optimizer.py @@ -243,7 +243,7 @@ class ZeroOptimizer(MetaOptimizerBase): param2mem.append((param.name, mem)) # print(param.name, mem) # print("total_param_mem: ", total_param_mem) - device_num = self.role_maker.worker_num() + device_num = self.role_maker._worker_num() # print("device_num: ", device_num) device2params = {x: [] for x in range(device_num)} device_idx = 0 @@ -327,7 +327,7 @@ class ZeroOptimizer(MetaOptimizerBase): if input_name != broadcast_name: op._rename_input(input_name, broadcast_name) continue - if root_device == self.role_maker.worker_index(): + if root_device == self.role_maker._worker_index(): broadcast_var_name = input_name else: broadcast_var_name = unique_name.generate(input_name + @@ -357,7 +357,7 @@ class ZeroOptimizer(MetaOptimizerBase): fp32_param = op.desc.input_arg_names()[0] fp16_param = op.desc.output_arg_names()[0] if self._param2device[ - fp32_param] == self.role_maker.worker_index(): + fp32_param] == self.role_maker._worker_index(): sub_prog._cast_ops[fp16_param] = fp32_param if sub_prog._param_mem > 0: @@ -406,7 +406,7 @@ class ZeroOptimizer(MetaOptimizerBase): params = [] for var_name, _ in block.vars.items(): if self._is_opti_var(var_name) and \ - self._var_device_id(var_name) != self.role_maker.worker_index(): + self._var_device_id(var_name) != self.role_maker._worker_index(): params.append(var_name) program_deps = ProgramDeps(block, reduced_grads, params) @@ -428,7 +428,7 @@ class ZeroOptimizer(MetaOptimizerBase): reduce_var = var_to_reduce_var[input_name] param_name = self._reduced_grads_to_param[reduce_var] if self._param2device[ - param_name] != self.role_maker.worker_index(): + param_name] != self.role_maker._worker_index(): program_deps.crop_input_var_from_op(idx, input_name) else: reversed_input_vars.append(input_name) @@ -726,20 +726,20 @@ class ZeroOptimizer(MetaOptimizerBase): for idx, op in reversed(list(enumerate(block.ops))): for output_name in op.desc.output_arg_names(): var_device_id = self._var_device_id(output_name) - if var_device_id == -1 or var_device_id == self.role_maker.worker_index( + if var_device_id == -1 or var_device_id == self.role_maker._worker_index( ): continue print("%d: startup_block remove op %s" % - (self.role_maker.worker_index(), op.type)) + (self.role_maker._worker_index(), op.type)) block._remove_op(idx) break for var_name, _ in block.vars.items(): var_device_id = self._var_device_id(var_name) - if var_device_id == -1 or var_device_id == self.role_maker.worker_index( + if var_device_id == -1 or var_device_id == self.role_maker._worker_index( ): continue print("%d: startup_block remove var %s" % - (self.role_maker.worker_index(), var_name)) + (self.role_maker._worker_index(), var_name)) block._remove_var(var_name) block._sync_with_cpp() @@ -775,15 +775,14 @@ class ZeroOptimizer(MetaOptimizerBase): def _set_up(self, params_grads): # step 1: initialize nccl - # TODO(mapingshuo) fix get_trainer_endpoints - print("work idx: ", self.role_maker.worker_index()) - endpoints = self.role_maker.get_trainer_endpoints() - current_endpoint = endpoints[self.role_maker.worker_index()] + print("work idx: ", self.role_maker._worker_index()) + endpoints = self.role_maker._get_trainer_endpoints() + current_endpoint = endpoints[self.role_maker._worker_index()] collective_helper = CollectiveHelper(self.role_maker, self._nrings) for ring_id in range(self._nrings): collective_helper._init_communicator( self._startup_program, current_endpoint, endpoints, - self.role_maker.worker_index(), ring_id, '6174') + self.role_maker._worker_index(), ring_id, '6174') startup_block = self._startup_program.global_block() startup_block._sync_with_cpp() @@ -846,7 +845,7 @@ class ZeroOptimizer(MetaOptimizerBase): # step4: insert reduce_sum for grad self._insert_scale_loss_grad_ops( - main_block, scale=1.0 / self.role_maker.worker_num()) + main_block, scale=1.0 / self.role_maker._worker_num()) main_block._sync_with_cpp() # step5: remove unneeded ops and vars from block @@ -1194,21 +1193,21 @@ class ZeroOptimizer(MetaOptimizerBase): if startup_program is None: startup_program = default_startup_program() - print("work idx: ", self.role_maker.worker_index()) - endpoints = self.role_maker.get_trainer_endpoints() - current_endpoint = endpoints[self.role_maker.worker_index()] + print("work idx: ", self.role_maker._worker_index()) + endpoints = self.role_maker._get_trainer_endpoints() + current_endpoint = endpoints[self.role_maker._worker_index()] collective_helper = CollectiveHelper(self.role_maker, self._nrings) for ring_id in range(self._nrings): collective_helper._init_communicator( startup_program, current_endpoint, endpoints, - self.role_maker.worker_index(), ring_id, '6174') + self.role_maker._worker_index(), ring_id, '6174') main_block = loss.block startup_block = startup_program.global_block() self._broadcast_params(startup_block) self._insert_scale_loss_grad_ops( - main_block, scale=1.0 / self.role_maker.worker_num()) + main_block, scale=1.0 / self.role_maker._worker_num()) self._insert_allreduce_ops_tmp(main_block) print("insert allreduce done") return optimize_ops, params_grads -- GitLab