提交 6c5c547e 编写于 作者: M mapingshuo

correct role_maker usage

上级 e3334f3e
...@@ -243,7 +243,7 @@ class ZeroOptimizer(MetaOptimizerBase): ...@@ -243,7 +243,7 @@ class ZeroOptimizer(MetaOptimizerBase):
param2mem.append((param.name, mem)) param2mem.append((param.name, mem))
# print(param.name, mem) # print(param.name, mem)
# print("total_param_mem: ", total_param_mem) # print("total_param_mem: ", total_param_mem)
device_num = self.role_maker.worker_num() device_num = self.role_maker._worker_num()
# print("device_num: ", device_num) # print("device_num: ", device_num)
device2params = {x: [] for x in range(device_num)} device2params = {x: [] for x in range(device_num)}
device_idx = 0 device_idx = 0
...@@ -327,7 +327,7 @@ class ZeroOptimizer(MetaOptimizerBase): ...@@ -327,7 +327,7 @@ class ZeroOptimizer(MetaOptimizerBase):
if input_name != broadcast_name: if input_name != broadcast_name:
op._rename_input(input_name, broadcast_name) op._rename_input(input_name, broadcast_name)
continue continue
if root_device == self.role_maker.worker_index(): if root_device == self.role_maker._worker_index():
broadcast_var_name = input_name broadcast_var_name = input_name
else: else:
broadcast_var_name = unique_name.generate(input_name + broadcast_var_name = unique_name.generate(input_name +
...@@ -357,7 +357,7 @@ class ZeroOptimizer(MetaOptimizerBase): ...@@ -357,7 +357,7 @@ class ZeroOptimizer(MetaOptimizerBase):
fp32_param = op.desc.input_arg_names()[0] fp32_param = op.desc.input_arg_names()[0]
fp16_param = op.desc.output_arg_names()[0] fp16_param = op.desc.output_arg_names()[0]
if self._param2device[ if self._param2device[
fp32_param] == self.role_maker.worker_index(): fp32_param] == self.role_maker._worker_index():
sub_prog._cast_ops[fp16_param] = fp32_param sub_prog._cast_ops[fp16_param] = fp32_param
if sub_prog._param_mem > 0: if sub_prog._param_mem > 0:
...@@ -406,7 +406,7 @@ class ZeroOptimizer(MetaOptimizerBase): ...@@ -406,7 +406,7 @@ class ZeroOptimizer(MetaOptimizerBase):
params = [] params = []
for var_name, _ in block.vars.items(): for var_name, _ in block.vars.items():
if self._is_opti_var(var_name) and \ if self._is_opti_var(var_name) and \
self._var_device_id(var_name) != self.role_maker.worker_index(): self._var_device_id(var_name) != self.role_maker._worker_index():
params.append(var_name) params.append(var_name)
program_deps = ProgramDeps(block, reduced_grads, params) program_deps = ProgramDeps(block, reduced_grads, params)
...@@ -428,7 +428,7 @@ class ZeroOptimizer(MetaOptimizerBase): ...@@ -428,7 +428,7 @@ class ZeroOptimizer(MetaOptimizerBase):
reduce_var = var_to_reduce_var[input_name] reduce_var = var_to_reduce_var[input_name]
param_name = self._reduced_grads_to_param[reduce_var] param_name = self._reduced_grads_to_param[reduce_var]
if self._param2device[ if self._param2device[
param_name] != self.role_maker.worker_index(): param_name] != self.role_maker._worker_index():
program_deps.crop_input_var_from_op(idx, input_name) program_deps.crop_input_var_from_op(idx, input_name)
else: else:
reversed_input_vars.append(input_name) reversed_input_vars.append(input_name)
...@@ -726,20 +726,20 @@ class ZeroOptimizer(MetaOptimizerBase): ...@@ -726,20 +726,20 @@ class ZeroOptimizer(MetaOptimizerBase):
for idx, op in reversed(list(enumerate(block.ops))): for idx, op in reversed(list(enumerate(block.ops))):
for output_name in op.desc.output_arg_names(): for output_name in op.desc.output_arg_names():
var_device_id = self._var_device_id(output_name) var_device_id = self._var_device_id(output_name)
if var_device_id == -1 or var_device_id == self.role_maker.worker_index( if var_device_id == -1 or var_device_id == self.role_maker._worker_index(
): ):
continue continue
print("%d: startup_block remove op %s" % print("%d: startup_block remove op %s" %
(self.role_maker.worker_index(), op.type)) (self.role_maker._worker_index(), op.type))
block._remove_op(idx) block._remove_op(idx)
break break
for var_name, _ in block.vars.items(): for var_name, _ in block.vars.items():
var_device_id = self._var_device_id(var_name) var_device_id = self._var_device_id(var_name)
if var_device_id == -1 or var_device_id == self.role_maker.worker_index( if var_device_id == -1 or var_device_id == self.role_maker._worker_index(
): ):
continue continue
print("%d: startup_block remove var %s" % print("%d: startup_block remove var %s" %
(self.role_maker.worker_index(), var_name)) (self.role_maker._worker_index(), var_name))
block._remove_var(var_name) block._remove_var(var_name)
block._sync_with_cpp() block._sync_with_cpp()
...@@ -775,15 +775,14 @@ class ZeroOptimizer(MetaOptimizerBase): ...@@ -775,15 +775,14 @@ class ZeroOptimizer(MetaOptimizerBase):
def _set_up(self, params_grads): def _set_up(self, params_grads):
# step 1: initialize nccl # step 1: initialize nccl
# TODO(mapingshuo) fix get_trainer_endpoints print("work idx: ", self.role_maker._worker_index())
print("work idx: ", self.role_maker.worker_index()) endpoints = self.role_maker._get_trainer_endpoints()
endpoints = self.role_maker.get_trainer_endpoints() current_endpoint = endpoints[self.role_maker._worker_index()]
current_endpoint = endpoints[self.role_maker.worker_index()]
collective_helper = CollectiveHelper(self.role_maker, self._nrings) collective_helper = CollectiveHelper(self.role_maker, self._nrings)
for ring_id in range(self._nrings): for ring_id in range(self._nrings):
collective_helper._init_communicator( collective_helper._init_communicator(
self._startup_program, current_endpoint, endpoints, self._startup_program, current_endpoint, endpoints,
self.role_maker.worker_index(), ring_id, '6174') self.role_maker._worker_index(), ring_id, '6174')
startup_block = self._startup_program.global_block() startup_block = self._startup_program.global_block()
startup_block._sync_with_cpp() startup_block._sync_with_cpp()
...@@ -846,7 +845,7 @@ class ZeroOptimizer(MetaOptimizerBase): ...@@ -846,7 +845,7 @@ class ZeroOptimizer(MetaOptimizerBase):
# step4: insert reduce_sum for grad # step4: insert reduce_sum for grad
self._insert_scale_loss_grad_ops( self._insert_scale_loss_grad_ops(
main_block, scale=1.0 / self.role_maker.worker_num()) main_block, scale=1.0 / self.role_maker._worker_num())
main_block._sync_with_cpp() main_block._sync_with_cpp()
# step5: remove unneeded ops and vars from block # step5: remove unneeded ops and vars from block
...@@ -1194,21 +1193,21 @@ class ZeroOptimizer(MetaOptimizerBase): ...@@ -1194,21 +1193,21 @@ class ZeroOptimizer(MetaOptimizerBase):
if startup_program is None: if startup_program is None:
startup_program = default_startup_program() startup_program = default_startup_program()
print("work idx: ", self.role_maker.worker_index()) print("work idx: ", self.role_maker._worker_index())
endpoints = self.role_maker.get_trainer_endpoints() endpoints = self.role_maker._get_trainer_endpoints()
current_endpoint = endpoints[self.role_maker.worker_index()] current_endpoint = endpoints[self.role_maker._worker_index()]
collective_helper = CollectiveHelper(self.role_maker, self._nrings) collective_helper = CollectiveHelper(self.role_maker, self._nrings)
for ring_id in range(self._nrings): for ring_id in range(self._nrings):
collective_helper._init_communicator( collective_helper._init_communicator(
startup_program, current_endpoint, endpoints, startup_program, current_endpoint, endpoints,
self.role_maker.worker_index(), ring_id, '6174') self.role_maker._worker_index(), ring_id, '6174')
main_block = loss.block main_block = loss.block
startup_block = startup_program.global_block() startup_block = startup_program.global_block()
self._broadcast_params(startup_block) self._broadcast_params(startup_block)
self._insert_scale_loss_grad_ops( self._insert_scale_loss_grad_ops(
main_block, scale=1.0 / self.role_maker.worker_num()) main_block, scale=1.0 / self.role_maker._worker_num())
self._insert_allreduce_ops_tmp(main_block) self._insert_allreduce_ops_tmp(main_block)
print("insert allreduce done") print("insert allreduce done")
return optimize_ops, params_grads return optimize_ops, params_grads
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册