未验证 提交 6a1db204 编写于 作者: T tangwei12 提交者: GitHub

fix sync_with_distributed_lookup_table, test=develop (#19737)

fix wrong place with distributed_lookup_table
上级 38f1c2fe
...@@ -388,49 +388,84 @@ class DistributeTranspiler(object): ...@@ -388,49 +388,84 @@ class DistributeTranspiler(object):
sparse_update_ops.append(op) sparse_update_ops.append(op)
return sparse_update_ops return sparse_update_ops
def _update_remote_sparse_update_op(self, program, param_varname, def _update_remote_sparse_update_op(self, program,
height_sections, endpoints, need_sparse_update_params):
table_names):
for param_varname, attrs in need_sparse_update_params.items():
ops = [] height_sections = self.sparse_param_to_height_sections[
op_type = "" param_varname]
endpoints = attrs[0]
for op in self.sparse_update_ops: table_names = attrs[1]
if param_varname in op.input_arg_names and op_type == "":
op_type = op.type ops = []
ops.append(op) op_type = ""
used_ops = []
elif param_varname in op.input_arg_names and op_type == op.type:
ops.append(op) for idx, op in enumerate(self.sparse_update_ops):
if param_varname in op.input_arg_names and op_type == "":
if op_type == "lookup_table": op_type = op.type
all_ops = program.global_block().ops ops.append(op)
op_idxs = [all_ops.index(op) for op in ops] used_ops.append(idx)
inputs = [
program.global_block().vars[op.input("Ids")[0]] for op in ops elif param_varname in op.input_arg_names and op_type == op.type:
] ops.append(op)
w = program.global_block().vars[ops[0].input("W")[0]] used_ops.append(idx)
padding_idx = ops[0].attr("padding_idx")
outputs = [ if op_type == "lookup_table":
program.global_block().vars[op.output("Out")[0]] for op in ops all_ops = program.global_block().ops
] op_idxs = [all_ops.index(op) for op in ops]
inputs = [
program.global_block().vars[op.input("Ids")[0]]
for op in ops
]
w = program.global_block().vars[ops[0].input("W")[0]]
padding_idx = ops[0].attr("padding_idx")
outputs = [
program.global_block().vars[op.output("Out")[0]]
for op in ops
]
for idx in op_idxs[::-1]: for idx in op_idxs[::-1]:
program.global_block()._remove_op(idx) program.global_block()._remove_op(idx)
inputs_idxs = [-1] * len(inputs)
outputs_idxs = [-1] * len(outputs)
for idx, op in enumerate(program.global_block().ops):
for i in range(0, len(op.output_names)):
outs = op.output(op.output_names[i])
for in_id, in_var in enumerate(inputs):
if in_var.name in outs:
inputs_idxs[in_id] = idx
for i in range(0, len(op.input_names)):
ins = op.input(op.input_names[i])
for out_id, out_var in enumerate(outputs):
if out_var.name in ins:
outputs_idxs[out_id] = idx
if min(outputs_idxs) - max(inputs_idxs) >= 1:
distributed_idx = max(inputs_idxs) + 1
program.global_block()._insert_op(
index=distributed_idx,
type="distributed_lookup_table",
inputs={"Ids": inputs,
'W': w},
outputs={"Outputs": outputs},
attrs={
"table_names": table_names,
"height_sections": height_sections,
"endpoints": endpoints,
"padding_idx": padding_idx,
"trainer_id": self.trainer_id
})
else:
raise ValueError(
"something wrong with distribute_transpiler, submit a issue is recommended"
)
program.global_block()._insert_op( for idx in used_ops[::-1]:
index=op_idxs[0], self.sparse_update_ops.pop(idx)
type="distributed_lookup_table",
inputs={"Ids": inputs,
'W': w},
outputs={"Outputs": outputs},
attrs={
"table_names": table_names,
"height_sections": height_sections,
"endpoints": endpoints,
"padding_idx": padding_idx,
"trainer_id": self.trainer_id
})
def _is_input_of_remote_sparse_update_op(self, param_name): def _is_input_of_remote_sparse_update_op(self, param_name):
for op in self.sparse_update_ops: for op in self.sparse_update_ops:
...@@ -681,6 +716,8 @@ class DistributeTranspiler(object): ...@@ -681,6 +716,8 @@ class DistributeTranspiler(object):
recv_vars[i].name) recv_vars[i].name)
distributed_var.endpoint = ep distributed_var.endpoint = ep
need_sparse_update_params = {}
# step4: Concat the parameters splits together after recv. # step4: Concat the parameters splits together after recv.
all_recv_outputs = [] all_recv_outputs = []
for param_varname, splited_var in six.iteritems(self.param_var_mapping): for param_varname, splited_var in six.iteritems(self.param_var_mapping):
...@@ -712,10 +749,7 @@ class DistributeTranspiler(object): ...@@ -712,10 +749,7 @@ class DistributeTranspiler(object):
table_name) table_name)
distributed_var.vtype = "RemotePrefetch" distributed_var.vtype = "RemotePrefetch"
height_sections = self.sparse_param_to_height_sections[ need_sparse_update_params[param_varname] = (eps, table_names)
param_varname]
self._update_remote_sparse_update_op(
program, param_varname, height_sections, eps, table_names)
else: else:
recv_varnames = [] recv_varnames = []
if self.config.runtime_split_send_recv: if self.config.runtime_split_send_recv:
...@@ -764,6 +798,9 @@ class DistributeTranspiler(object): ...@@ -764,6 +798,9 @@ class DistributeTranspiler(object):
RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE
}) })
self._update_remote_sparse_update_op(program,
need_sparse_update_params)
self._get_trainer_startup_program(recv_vars=recv_vars, eplist=eplist) self._get_trainer_startup_program(recv_vars=recv_vars, eplist=eplist)
if self.has_distributed_lookup_table: if self.has_distributed_lookup_table:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册