提交 6c5c547e 编写于 作者: M mapingshuo

correct role_maker usage

上级 e3334f3e
......@@ -243,7 +243,7 @@ class ZeroOptimizer(MetaOptimizerBase):
param2mem.append((param.name, mem))
# print(param.name, mem)
# print("total_param_mem: ", total_param_mem)
device_num = self.role_maker.worker_num()
device_num = self.role_maker._worker_num()
# print("device_num: ", device_num)
device2params = {x: [] for x in range(device_num)}
device_idx = 0
......@@ -327,7 +327,7 @@ class ZeroOptimizer(MetaOptimizerBase):
if input_name != broadcast_name:
op._rename_input(input_name, broadcast_name)
continue
if root_device == self.role_maker.worker_index():
if root_device == self.role_maker._worker_index():
broadcast_var_name = input_name
else:
broadcast_var_name = unique_name.generate(input_name +
......@@ -357,7 +357,7 @@ class ZeroOptimizer(MetaOptimizerBase):
fp32_param = op.desc.input_arg_names()[0]
fp16_param = op.desc.output_arg_names()[0]
if self._param2device[
fp32_param] == self.role_maker.worker_index():
fp32_param] == self.role_maker._worker_index():
sub_prog._cast_ops[fp16_param] = fp32_param
if sub_prog._param_mem > 0:
......@@ -406,7 +406,7 @@ class ZeroOptimizer(MetaOptimizerBase):
params = []
for var_name, _ in block.vars.items():
if self._is_opti_var(var_name) and \
self._var_device_id(var_name) != self.role_maker.worker_index():
self._var_device_id(var_name) != self.role_maker._worker_index():
params.append(var_name)
program_deps = ProgramDeps(block, reduced_grads, params)
......@@ -428,7 +428,7 @@ class ZeroOptimizer(MetaOptimizerBase):
reduce_var = var_to_reduce_var[input_name]
param_name = self._reduced_grads_to_param[reduce_var]
if self._param2device[
param_name] != self.role_maker.worker_index():
param_name] != self.role_maker._worker_index():
program_deps.crop_input_var_from_op(idx, input_name)
else:
reversed_input_vars.append(input_name)
......@@ -726,20 +726,20 @@ class ZeroOptimizer(MetaOptimizerBase):
for idx, op in reversed(list(enumerate(block.ops))):
for output_name in op.desc.output_arg_names():
var_device_id = self._var_device_id(output_name)
if var_device_id == -1 or var_device_id == self.role_maker.worker_index(
if var_device_id == -1 or var_device_id == self.role_maker._worker_index(
):
continue
print("%d: startup_block remove op %s" %
(self.role_maker.worker_index(), op.type))
(self.role_maker._worker_index(), op.type))
block._remove_op(idx)
break
for var_name, _ in block.vars.items():
var_device_id = self._var_device_id(var_name)
if var_device_id == -1 or var_device_id == self.role_maker.worker_index(
if var_device_id == -1 or var_device_id == self.role_maker._worker_index(
):
continue
print("%d: startup_block remove var %s" %
(self.role_maker.worker_index(), var_name))
(self.role_maker._worker_index(), var_name))
block._remove_var(var_name)
block._sync_with_cpp()
......@@ -775,15 +775,14 @@ class ZeroOptimizer(MetaOptimizerBase):
def _set_up(self, params_grads):
# step 1: initialize nccl
# TODO(mapingshuo) fix get_trainer_endpoints
print("work idx: ", self.role_maker.worker_index())
endpoints = self.role_maker.get_trainer_endpoints()
current_endpoint = endpoints[self.role_maker.worker_index()]
print("work idx: ", self.role_maker._worker_index())
endpoints = self.role_maker._get_trainer_endpoints()
current_endpoint = endpoints[self.role_maker._worker_index()]
collective_helper = CollectiveHelper(self.role_maker, self._nrings)
for ring_id in range(self._nrings):
collective_helper._init_communicator(
self._startup_program, current_endpoint, endpoints,
self.role_maker.worker_index(), ring_id, '6174')
self.role_maker._worker_index(), ring_id, '6174')
startup_block = self._startup_program.global_block()
startup_block._sync_with_cpp()
......@@ -846,7 +845,7 @@ class ZeroOptimizer(MetaOptimizerBase):
# step4: insert reduce_sum for grad
self._insert_scale_loss_grad_ops(
main_block, scale=1.0 / self.role_maker.worker_num())
main_block, scale=1.0 / self.role_maker._worker_num())
main_block._sync_with_cpp()
# step5: remove unneeded ops and vars from block
......@@ -1194,21 +1193,21 @@ class ZeroOptimizer(MetaOptimizerBase):
if startup_program is None:
startup_program = default_startup_program()
print("work idx: ", self.role_maker.worker_index())
endpoints = self.role_maker.get_trainer_endpoints()
current_endpoint = endpoints[self.role_maker.worker_index()]
print("work idx: ", self.role_maker._worker_index())
endpoints = self.role_maker._get_trainer_endpoints()
current_endpoint = endpoints[self.role_maker._worker_index()]
collective_helper = CollectiveHelper(self.role_maker, self._nrings)
for ring_id in range(self._nrings):
collective_helper._init_communicator(
startup_program, current_endpoint, endpoints,
self.role_maker.worker_index(), ring_id, '6174')
self.role_maker._worker_index(), ring_id, '6174')
main_block = loss.block
startup_block = startup_program.global_block()
self._broadcast_params(startup_block)
self._insert_scale_loss_grad_ops(
main_block, scale=1.0 / self.role_maker.worker_num())
main_block, scale=1.0 / self.role_maker._worker_num())
self._insert_allreduce_ops_tmp(main_block)
print("insert allreduce done")
return optimize_ops, params_grads
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册