未验证 提交 6a1db204 编写于 作者: T tangwei12 提交者: GitHub

fix sync_with_distributed_lookup_table, test=develop (#19737)

fix wrong place with distributed_lookup_table
上级 38f1c2fe
......@@ -388,49 +388,84 @@ class DistributeTranspiler(object):
sparse_update_ops.append(op)
return sparse_update_ops
def _update_remote_sparse_update_op(self, program, param_varname,
height_sections, endpoints,
table_names):
ops = []
op_type = ""
for op in self.sparse_update_ops:
if param_varname in op.input_arg_names and op_type == "":
op_type = op.type
ops.append(op)
elif param_varname in op.input_arg_names and op_type == op.type:
ops.append(op)
if op_type == "lookup_table":
all_ops = program.global_block().ops
op_idxs = [all_ops.index(op) for op in ops]
inputs = [
program.global_block().vars[op.input("Ids")[0]] for op in ops
]
w = program.global_block().vars[ops[0].input("W")[0]]
padding_idx = ops[0].attr("padding_idx")
outputs = [
program.global_block().vars[op.output("Out")[0]] for op in ops
]
def _update_remote_sparse_update_op(self, program,
need_sparse_update_params):
for param_varname, attrs in need_sparse_update_params.items():
height_sections = self.sparse_param_to_height_sections[
param_varname]
endpoints = attrs[0]
table_names = attrs[1]
ops = []
op_type = ""
used_ops = []
for idx, op in enumerate(self.sparse_update_ops):
if param_varname in op.input_arg_names and op_type == "":
op_type = op.type
ops.append(op)
used_ops.append(idx)
elif param_varname in op.input_arg_names and op_type == op.type:
ops.append(op)
used_ops.append(idx)
if op_type == "lookup_table":
all_ops = program.global_block().ops
op_idxs = [all_ops.index(op) for op in ops]
inputs = [
program.global_block().vars[op.input("Ids")[0]]
for op in ops
]
w = program.global_block().vars[ops[0].input("W")[0]]
padding_idx = ops[0].attr("padding_idx")
outputs = [
program.global_block().vars[op.output("Out")[0]]
for op in ops
]
for idx in op_idxs[::-1]:
program.global_block()._remove_op(idx)
for idx in op_idxs[::-1]:
program.global_block()._remove_op(idx)
inputs_idxs = [-1] * len(inputs)
outputs_idxs = [-1] * len(outputs)
for idx, op in enumerate(program.global_block().ops):
for i in range(0, len(op.output_names)):
outs = op.output(op.output_names[i])
for in_id, in_var in enumerate(inputs):
if in_var.name in outs:
inputs_idxs[in_id] = idx
for i in range(0, len(op.input_names)):
ins = op.input(op.input_names[i])
for out_id, out_var in enumerate(outputs):
if out_var.name in ins:
outputs_idxs[out_id] = idx
if min(outputs_idxs) - max(inputs_idxs) >= 1:
distributed_idx = max(inputs_idxs) + 1
program.global_block()._insert_op(
index=distributed_idx,
type="distributed_lookup_table",
inputs={"Ids": inputs,
'W': w},
outputs={"Outputs": outputs},
attrs={
"table_names": table_names,
"height_sections": height_sections,
"endpoints": endpoints,
"padding_idx": padding_idx,
"trainer_id": self.trainer_id
})
else:
raise ValueError(
"something wrong with distribute_transpiler, submit a issue is recommended"
)
program.global_block()._insert_op(
index=op_idxs[0],
type="distributed_lookup_table",
inputs={"Ids": inputs,
'W': w},
outputs={"Outputs": outputs},
attrs={
"table_names": table_names,
"height_sections": height_sections,
"endpoints": endpoints,
"padding_idx": padding_idx,
"trainer_id": self.trainer_id
})
for idx in used_ops[::-1]:
self.sparse_update_ops.pop(idx)
def _is_input_of_remote_sparse_update_op(self, param_name):
for op in self.sparse_update_ops:
......@@ -681,6 +716,8 @@ class DistributeTranspiler(object):
recv_vars[i].name)
distributed_var.endpoint = ep
need_sparse_update_params = {}
# step4: Concat the parameters splits together after recv.
all_recv_outputs = []
for param_varname, splited_var in six.iteritems(self.param_var_mapping):
......@@ -712,10 +749,7 @@ class DistributeTranspiler(object):
table_name)
distributed_var.vtype = "RemotePrefetch"
height_sections = self.sparse_param_to_height_sections[
param_varname]
self._update_remote_sparse_update_op(
program, param_varname, height_sections, eps, table_names)
need_sparse_update_params[param_varname] = (eps, table_names)
else:
recv_varnames = []
if self.config.runtime_split_send_recv:
......@@ -764,6 +798,9 @@ class DistributeTranspiler(object):
RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE
})
self._update_remote_sparse_update_op(program,
need_sparse_update_params)
self._get_trainer_startup_program(recv_vars=recv_vars, eplist=eplist)
if self.has_distributed_lookup_table:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册