提交 0d3d4ae7 编写于 作者: Q qiaolongfei

refine prefetch logic

上级 17b42fc2
......@@ -248,7 +248,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
request_prefetch_handler_.get());
auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto *prefetch_block = Attr<framework::BlockDesc *>(kPrefetchBlock);
auto grad_to_block_id_str = Attr<std::vector<std::string>>(kPrefetchBlock);
framework::BlockDesc *prefetch_block = nullptr;
auto *program = optimize_block->Program();
framework::Executor executor(dev_place);
......@@ -302,7 +303,7 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<bool>("sync_mode", "if works at sync_mode or not").SetDefault(true);
AddAttr<framework::BlockDesc *>(kOptimizeBlock,
"BlockID to run on server side.");
AddAttr<framework::BlockDesc *>(kPrefetchBlock,
AddAttr<std::vector<std::string>>(kPrefetchBlock,
"prefetch block to run on server side.");
AddAttr<int>("Fanin", "How many clients send to this server.")
.SetDefault(1);
......
......@@ -30,7 +30,7 @@ namespace paddle {
namespace operators {
constexpr char kOptimizeBlock[] = "OptimizeBlock";
constexpr char kPrefetchBlock[] = "PrefetchBlock";
constexpr char kPrefetchBlock[] = "prefetch_var_name_to_block_id";
void RunServer(std::shared_ptr<detail::RPCServer> service);
......
......@@ -515,21 +515,20 @@ class DistributeTranspiler:
grad_to_block_id, None)
# process distributed lookup_table
prefetch_block = None
prefetch_var_name_to_block_id = []
if self.has_distributed_lookup_table:
pserver_index = self.pserver_endpoints.index(endpoint)
table_opt_block = self._create_table_optimize_block(
pserver_index, pserver_program, pre_block_idx, grad_to_block_id)
prefetch_block = self._create_prefetch_block(
prefetch_var_name_to_block_id = self._create_prefetch_block(
pserver_index, pserver_program, table_opt_block)
# NOTE: if has_distributed_lookup_table is False, then prefetch_block will
# not be executed, so it's safe to use optimize_block to hold the place
if self.has_distributed_lookup_table:
assert prefetch_block is not None
assert len(prefetch_var_name_to_block_id) > 0
else:
assert prefetch_block is None
prefetch_block = pserver_program.global_block()
assert len(prefetch_var_name_to_block_id) == 0
# step5 append the listen_and_serv op
pserver_program.global_block().append_op(
......@@ -540,7 +539,7 @@ class DistributeTranspiler:
"OptimizeBlock": pserver_program.block(1),
"endpoint": endpoint,
"Fanin": self.trainer_num,
"PrefetchBlock": prefetch_block,
"prefetch_var_name_to_block_id": prefetch_var_name_to_block_id,
"sync_mode": self.sync_mode,
"grad_to_block_id": grad_to_block_id
})
......@@ -608,8 +607,15 @@ class DistributeTranspiler:
def _replace_lookup_table_op_with_prefetch(self, program,
pserver_endpoints):
# 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op
self.prefetch_input_vars = None
self.prefetch_output_vars = None
# self.all_prefetch_input_vars =
# [[var0_prefetch_in_pserver0, var0_prefetch_in_pserver1]
# [var1_prefetch_in_pserver0, var1_prefetch_in_pserver1]]
self.all_prefetch_input_vars = []
# self.all_prefetch_input_vars =
# [[var0_prefetch_in_pserver0, var0_prefetch_in_pserver1]
# [var1_prefetch_in_pserver0, var1_prefetch_in_pserver1]]
self.all_prefetch_output_vars = []
continue_search_lookup_table_op = True
while continue_search_lookup_table_op:
......@@ -623,18 +629,19 @@ class DistributeTranspiler:
ids_name = op.input("Ids")
out_name = op.output("Out")
if self.prefetch_input_vars is None:
ids_var = program.global_block().vars[ids_name[0]]
self.prefetch_input_vars = self.create_splited_vars(
prefetch_input_vars = self.create_splited_vars(
source_var=ids_var,
block=program.global_block(),
tag="_prefetch_in_")
if self.prefetch_output_vars is None:
self.all_prefetch_input_vars.append(prefetch_input_vars)
out_var = program.global_block().vars[out_name[0]]
self.prefetch_output_vars = self.create_splited_vars(
prefetch_output_vars = self.create_splited_vars(
source_var=out_var,
block=program.global_block(),
tag="_prefetch_out_")
self.all_prefetch_output_vars.append(prefetch_output_vars)
# insert split_ids_op
program.global_block().insert_op(
......@@ -646,14 +653,14 @@ class DistributeTranspiler:
for varname in ids_name
]
},
outputs={"Out": self.prefetch_input_vars})
outputs={"Out": prefetch_input_vars})
# insert prefetch_op
program.global_block().insert_op(
index=op_index + 1,
type="prefetch",
inputs={'X': self.prefetch_input_vars},
outputs={"Out": self.prefetch_output_vars},
inputs={'X': prefetch_input_vars},
outputs={"Out": prefetch_output_vars},
attrs={
"epmap": pserver_endpoints,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
......@@ -663,7 +670,7 @@ class DistributeTranspiler:
program.global_block().insert_op(
index=op_index + 2,
type="concat",
inputs={'X': self.prefetch_output_vars},
inputs={'X': prefetch_output_vars},
outputs={
"Out": [
program.global_block().vars[varname]
......@@ -709,14 +716,16 @@ class DistributeTranspiler:
optimize_block):
# STEP: create prefetch block
table_var = pserver_program.global_block().vars[self.table_name]
prefetch_var_name_to_block_id = []
for index in range(len(self.all_prefetch_input_vars)):
prefetch_block = pserver_program.create_block(optimize_block.idx)
trainer_ids = self.prefetch_input_vars[pserver_index]
trainer_ids = self.all_prefetch_input_vars[index][pserver_index]
pserver_ids = pserver_program.global_block().create_var(
name=trainer_ids.name,
type=trainer_ids.type,
shape=trainer_ids.shape,
dtype=trainer_ids.dtype)
trainer_out = self.prefetch_output_vars[pserver_index]
trainer_out = self.all_prefetch_output_vars[index][pserver_index]
pserver_out = pserver_program.global_block().create_var(
name=trainer_out.name,
type=trainer_out.type,
......@@ -732,7 +741,9 @@ class DistributeTranspiler:
"is_distributed": True,
"padding_idx": -1
})
return prefetch_block
prefetch_var_name_to_block_id.append(trainer_ids.name + ":" + str(
prefetch_block.idx))
return prefetch_var_name_to_block_id
def _create_table_optimize_block(self, pserver_index, pserver_program,
pre_block_idx, grad_to_block_id):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册