未验证 提交 4c35f515 编写于 作者: S seemingwang 提交者: GitHub

fix distributed ops combining problems (#35942)

* graph engine demo

* upload unsaved changes

* fix dependency error

* fix shard_num problem

* py client

* remove lock and graph-type

* add load direct graph

* add load direct graph

* add load direct graph

* batch random_sample

* batch_sample_k

* fix num_nodes size

* batch brpc

* batch brpc

* add test

* add test

* add load_nodes; change add_node function

* change sample return type to pair

* resolve conflict

* resolved conflict

* resolved conflict

* separate server and client

* merge pair type

* fix

* resolved conflict

* fixed segment fault; high-level VLOG for load edges and load nodes

* random_sample return 0

* rm useless loop

* test:load edge

* fix ret -1

* test: rm sample

* rm sample

* random_sample return future

* random_sample return int

* test fake node

* fixed here

* memory leak

* remove test code

* fix return problem

* add common_graph_table

* random sample node &test & change data-structure from linkedList to vector

* add common_graph_table

* sample with srand

* add node_types

* optimize nodes sample

* recover test

* random sample

* destruct weighted sampler

* GraphEdgeBlob

* WeightedGraphEdgeBlob to GraphEdgeBlob

* WeightedGraphEdgeBlob to GraphEdgeBlob

* pybind sample nodes api

* pull nodes with step

* fixed pull_graph_list bug; add test for pull_graph_list by step

* add graph table;name

* add graph table;name

* add pybind

* add pybind

* add FeatureNode

* add FeatureNode

* add FeatureNode Serialize

* add FeatureNode Serialize

* get_feat_node

* avoid local rpc

* fix get_node_feat

* fix get_node_feat

* remove log

* get_node_feat return  py:bytes

* merge develop with graph_engine

* fix threadpool.h head

* fix

* fix typo

* resolve conflict

* fix conflict

* recover lost content

* fix pybind of FeatureNode

* recover cmake

* recover tools

* resolve conflict

* resolve linking problem

* code style

* change test_server port

* fix code problems

* remove shard_num config

* remove redundent threads

* optimize start server

* remove logs

* fix code problems by reviewers' suggestions

* move graph files into a folder

* code style change

* remove graph operations from base table

* optimize get_feat function of graph engine

* fix long long count problem

* remove redandunt graph files

* remove unused shell

* recover dropout_op_pass.h

* fix potential stack overflow when request number is too large & node add & node clear & node remove

* when sample k is larger than neigbor num, return directly

* using random seed generator of paddle to speed up

* fix bug of random sample k

* fix code style

* fix code style

* add remove graph to fleet_py.cc

* fix blocking_queue problem

* fix style

* fix

* recover capacity check

* add remove graph node; add set_feature

* add remove graph node; add set_feature

* add remove graph node; add set_feature

* add remove graph node; add set_feature

* fix distributed op combining problems

* optimize

* remove logs
Co-authored-by: NHuang Zhengjie <270018958@qq.com>
Co-authored-by: NWeiyue Su <weiyue.su@gmail.com>
Co-authored-by: Nsuweiyue <suweiyue@baidu.com>
Co-authored-by: Nluobin06 <luobin06@baidu.com>
Co-authored-by: Nliweibin02 <liweibin02@baidu.com>
Co-authored-by: Ntangwei12 <tangwei12@baidu.com>
上级 7c3567ea
...@@ -128,9 +128,113 @@ class DistributedInfer: ...@@ -128,9 +128,113 @@ class DistributedInfer:
return pull_sparse_ops return pull_sparse_ops
def _pull_sparse_fuse(_program, pull_sparse_ops): def _pull_sparse_fuse(_program, pull_sparse_ops):
def dag_check_up_and_reorder(program, inputs, outputs):
global_block = program.global_block()
min_output_index = len(global_block.ops)
max_input_index = -1
input_indexes = [0] * len(global_block.ops)
output_indexes = [0] * len(global_block.ops)
for idx, op in enumerate(global_block.ops):
for i in range(0, len(op.output_names)):
if input_indexes[idx] == 1:
break
outs = op.output(op.output_names[i])
for in_id, in_var in enumerate(inputs):
if in_var.name in outs:
input_indexes[idx] = 1
max_input_index = max(max_input_index, idx)
break
for i in range(0, len(op.input_names)):
if output_indexes[idx] == 1:
break
ins = op.input(op.input_names[i])
for out_id, out_var in enumerate(outputs):
if out_var.name in ins:
output_indexes[idx] = 1
min_output_index = min(min_output_index,
idx)
for i in range(len(global_block.ops)):
if input_indexes[i] == 1 and output_indexes[i] == 1:
warnings.warn(
"unable to re-arrange dags order to combine distributed embedding ops because a op both needs embedding table's output as input and produces ids as the same embedding table's input"
)
return
if min_output_index < max_input_index:
move_ops = []
for i in range(min_output_index + 1,
len(input_indexes)):
if input_indexes[i] == 1:
move_ops.append((global_block.ops[i], i))
for i, op in enumerate(move_ops):
queue = list()
visited = set()
queue.append(op[1])
visited.add(op[0])
start = 0
while start < len(queue):
pos = queue[start]
op = global_block.ops[pos]
op_inputs = []
for k in range(0, len(op.input_names)):
ins = op.input(op.input_names[k])
op_inputs.append(ins)
for j in range(pos - 1, min_output_index - 1,
-1):
op1 = global_block.ops[j]
if op1 in visited:
continue
found = False
for k in range(0, len(op1.output_names)):
outs = op1.output(op1.output_names[k])
for t in range(len(op_inputs)):
for y in op_inputs[t]:
if y in outs:
found = True
break
if found:
break
if found:
break
if found:
if output_indexes[j] == True:
warnings.warn(
"unable to re-arrange dags order to combine distributed embedding ops"
)
return
queue.append(j)
visited.add(global_block.ops[j])
start = start + 1
queue.sort()
for index in queue:
desc = global_block.desc._insert_op(
min_output_index)
desc.copy_from(global_block.ops[index].desc)
global_block.desc._remove_op(index + 1,
index + 2)
global_block.ops[index].desc = desc
insert_op = global_block.ops.pop(index)
input_state = input_indexes.pop(index)
output_state = output_indexes.pop(index)
global_block.ops.insert(min_output_index,
insert_op)
input_indexes.insert(min_output_index,
input_state)
output_indexes.insert(min_output_index,
output_state)
min_output_index = min_output_index + 1
assert global_block.desc.op_size() == len(
global_block.ops)
for i in range(len(global_block.ops)):
assert global_block.desc.op(i) == global_block.ops[
i].desc
for param, ops in pull_sparse_ops.items(): for param, ops in pull_sparse_ops.items():
all_ops = program.global_block().ops all_ops = program.global_block().ops
op_idxs = [all_ops.index(op) for op in ops]
inputs = [ inputs = [
program.global_block().vars[op.input("Ids")[0]] program.global_block().vars[op.input("Ids")[0]]
...@@ -155,23 +259,29 @@ class DistributedInfer: ...@@ -155,23 +259,29 @@ class DistributedInfer:
for op in ops for op in ops
] ]
dag_check_up_and_reorder(program, inputs, outputs)
op_idxs = [all_ops.index(op) for op in ops]
for idx in op_idxs[::-1]: for idx in op_idxs[::-1]:
program.global_block()._remove_op(idx) program.global_block()._remove_op(idx)
inputs_idxs = [-1] * len(inputs) inputs_idxs = [-1] * len(inputs)
outputs_idxs = [-1] * len(outputs) outputs_idxs = [len(program.global_block().ops) + 1] * len(
outputs)
for idx, op in enumerate(program.global_block().ops): for idx, op in enumerate(program.global_block().ops):
for i in range(0, len(op.output_names)): for i in range(0, len(op.output_names)):
outs = op.output(op.output_names[i]) outs = op.output(op.output_names[i])
for in_id, in_var in enumerate(inputs): for in_id, in_var in enumerate(inputs):
if in_var.name in outs: if in_var.name in outs:
inputs_idxs[in_id] = idx inputs_idxs[in_id] = max(idx,
inputs_idxs[in_id])
for i in range(0, len(op.input_names)): for i in range(0, len(op.input_names)):
ins = op.input(op.input_names[i]) ins = op.input(op.input_names[i])
for out_id, out_var in enumerate(outputs): for out_id, out_var in enumerate(outputs):
if out_var.name in ins: if out_var.name in ins:
outputs_idxs[out_id] = idx outputs_idxs[out_id] = min(
idx, outputs_idxs[out_id])
if min(outputs_idxs) - max(inputs_idxs) >= 1: if min(outputs_idxs) - max(inputs_idxs) >= 1:
distributed_idx = max(inputs_idxs) + 1 distributed_idx = max(inputs_idxs) + 1
......
...@@ -111,9 +111,104 @@ def distributed_ops_pass(program, config, use_ps_gpu=False): ...@@ -111,9 +111,104 @@ def distributed_ops_pass(program, config, use_ps_gpu=False):
return pull_sparse_ops return pull_sparse_ops
def _pull_sparse_fuse(_program, pull_sparse_ops, use_ps_gpu): def _pull_sparse_fuse(_program, pull_sparse_ops, use_ps_gpu):
def dag_check_up_and_reorder(program, inputs, outputs):
global_block = program.global_block()
min_output_index = len(global_block.ops)
max_input_index = -1
input_indexes = [0] * len(global_block.ops)
output_indexes = [0] * len(global_block.ops)
for idx, op in enumerate(global_block.ops):
for i in range(0, len(op.output_names)):
if input_indexes[idx] == 1:
break
outs = op.output(op.output_names[i])
for in_id, in_var in enumerate(inputs):
if in_var.name in outs:
input_indexes[idx] = 1
max_input_index = max(max_input_index, idx)
break
for i in range(0, len(op.input_names)):
if output_indexes[idx] == 1:
break
ins = op.input(op.input_names[i])
for out_id, out_var in enumerate(outputs):
if out_var.name in ins:
output_indexes[idx] = 1
min_output_index = min(min_output_index, idx)
for i in range(len(global_block.ops)):
if input_indexes[i] == 1 and output_indexes[i] == 1:
warnings.warn(
"unable to re-arrange dags order to combine distributed embedding ops because a op both needs embedding table's output as input and produces ids as the same embedding table's input"
)
return
if min_output_index < max_input_index:
move_ops = []
for i in range(min_output_index + 1, len(input_indexes)):
if input_indexes[i] == 1:
move_ops.append((global_block.ops[i], i))
for i, op in enumerate(move_ops):
queue = list()
visited = set()
queue.append(op[1])
visited.add(op[0])
start = 0
while start < len(queue):
pos = queue[start]
op = global_block.ops[pos]
op_inputs = []
for k in range(0, len(op.input_names)):
ins = op.input(op.input_names[k])
op_inputs.append(ins)
for j in range(pos - 1, min_output_index - 1, -1):
op1 = global_block.ops[j]
if op1 in visited:
continue
found = False
for k in range(0, len(op1.output_names)):
outs = op1.output(op1.output_names[k])
for t in range(len(op_inputs)):
for y in op_inputs[t]:
if y in outs:
found = True
break
if found:
break
if found:
break
if found:
if output_indexes[j] == True:
warnings.warn(
"unable to re-arrange dags order to combine distributed embedding ops"
)
return
queue.append(j)
visited.add(global_block.ops[j])
start = start + 1
queue.sort()
for index in queue:
desc = global_block.desc._insert_op(min_output_index)
desc.copy_from(global_block.ops[index].desc)
global_block.desc._remove_op(index + 1, index + 2)
global_block.ops[index].desc = desc
insert_op = global_block.ops.pop(index)
input_state = input_indexes.pop(index)
output_state = output_indexes.pop(index)
global_block.ops.insert(min_output_index, insert_op)
input_indexes.insert(min_output_index, input_state)
output_indexes.insert(min_output_index, output_state)
min_output_index = min_output_index + 1
assert global_block.desc.op_size() == len(global_block.ops)
for i in range(len(global_block.ops)):
assert global_block.desc.op(i) == global_block.ops[i].desc
for param, ops in pull_sparse_ops.items(): for param, ops in pull_sparse_ops.items():
all_ops = program.global_block().ops all_ops = program.global_block().ops
op_idxs = [all_ops.index(op) for op in ops]
inputs = [ inputs = [
program.global_block().vars[op.input("Ids")[0]] for op in ops program.global_block().vars[op.input("Ids")[0]] for op in ops
] ]
...@@ -139,23 +234,28 @@ def distributed_ops_pass(program, config, use_ps_gpu=False): ...@@ -139,23 +234,28 @@ def distributed_ops_pass(program, config, use_ps_gpu=False):
program.global_block().vars[op.output("Out")[0]] for op in ops program.global_block().vars[op.output("Out")[0]] for op in ops
] ]
dag_check_up_and_reorder(program, inputs, outputs)
op_idxs = [all_ops.index(op) for op in ops]
for idx in op_idxs[::-1]: for idx in op_idxs[::-1]:
program.global_block()._remove_op(idx) program.global_block()._remove_op(idx)
inputs_idxs = [-1] * len(inputs) inputs_idxs = [-1] * len(inputs)
outputs_idxs = [-1] * len(outputs) outputs_idxs = [len(program.global_block().ops) + 1] * len(outputs)
for idx, op in enumerate(program.global_block().ops): for idx, op in enumerate(program.global_block().ops):
for i in range(0, len(op.output_names)): for i in range(0, len(op.output_names)):
outs = op.output(op.output_names[i]) outs = op.output(op.output_names[i])
for in_id, in_var in enumerate(inputs): for in_id, in_var in enumerate(inputs):
if in_var.name in outs: if in_var.name in outs:
inputs_idxs[in_id] = idx inputs_idxs[in_id] = max(idx, inputs_idxs[in_id])
for i in range(0, len(op.input_names)): for i in range(0, len(op.input_names)):
ins = op.input(op.input_names[i]) ins = op.input(op.input_names[i])
for out_id, out_var in enumerate(outputs): for out_id, out_var in enumerate(outputs):
if out_var.name in ins: if out_var.name in ins:
outputs_idxs[out_id] = idx outputs_idxs[out_id] = min(idx,
outputs_idxs[out_id])
if min(outputs_idxs) - max(inputs_idxs) >= 1: if min(outputs_idxs) - max(inputs_idxs) >= 1:
distributed_idx = max(inputs_idxs) + 1 distributed_idx = max(inputs_idxs) + 1
...@@ -187,7 +287,7 @@ def distributed_ops_pass(program, config, use_ps_gpu=False): ...@@ -187,7 +287,7 @@ def distributed_ops_pass(program, config, use_ps_gpu=False):
}) })
else: else:
for i in range(len(inputs_idxs)): for i in range(len(inputs_idxs)):
distributed_idx = op_idxs[i] + 1 distributed_idx = op_idxs[i]
program.global_block()._insert_op( program.global_block()._insert_op(
index=distributed_idx, index=distributed_idx,
...@@ -557,7 +657,6 @@ def find_heter_ops(program, default_device="cpu"): ...@@ -557,7 +657,6 @@ def find_heter_ops(program, default_device="cpu"):
def create_heter_program(program, config, heter_program, heter_ops, def create_heter_program(program, config, heter_program, heter_ops,
block_var_detail, current_device): block_var_detail, current_device):
# This function mainly includes the following contents: # This function mainly includes the following contents:
# 1. For every heter block: # 1. For every heter block:
# a) copy heter device op from origin program # a) copy heter device op from origin program
...@@ -1029,7 +1128,6 @@ def insert_send_concat_op(program, block, index, var_name_list, new_var_name, ...@@ -1029,7 +1128,6 @@ def insert_send_concat_op(program, block, index, var_name_list, new_var_name,
def insert_recv_slice_op(program, block, index, var_name, var_shape, dtype, def insert_recv_slice_op(program, block, index, var_name, var_shape, dtype,
type, new_var_name_list, new_var_shape_list): type, new_var_name_list, new_var_shape_list):
if var_name not in program.global_block().vars: if var_name not in program.global_block().vars:
input_var = program.global_block().create_var( input_var = program.global_block().create_var(
name=var_name, shape=var_shape, dtype=dtype, type=type) name=var_name, shape=var_shape, dtype=dtype, type=type)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册