提交 0d3d4ae7 编写于 作者: Q qiaolongfei

refine prefetch logic

上级 17b42fc2
...@@ -248,7 +248,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -248,7 +248,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
request_prefetch_handler_.get()); request_prefetch_handler_.get());
auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock); auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto *prefetch_block = Attr<framework::BlockDesc *>(kPrefetchBlock); auto grad_to_block_id_str = Attr<std::vector<std::string>>(kPrefetchBlock);
framework::BlockDesc *prefetch_block = nullptr;
auto *program = optimize_block->Program(); auto *program = optimize_block->Program();
framework::Executor executor(dev_place); framework::Executor executor(dev_place);
...@@ -302,7 +303,7 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -302,7 +303,7 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<bool>("sync_mode", "if works at sync_mode or not").SetDefault(true); AddAttr<bool>("sync_mode", "if works at sync_mode or not").SetDefault(true);
AddAttr<framework::BlockDesc *>(kOptimizeBlock, AddAttr<framework::BlockDesc *>(kOptimizeBlock,
"BlockID to run on server side."); "BlockID to run on server side.");
AddAttr<framework::BlockDesc *>(kPrefetchBlock, AddAttr<std::vector<std::string>>(kPrefetchBlock,
"prefetch block to run on server side."); "prefetch block to run on server side.");
AddAttr<int>("Fanin", "How many clients send to this server.") AddAttr<int>("Fanin", "How many clients send to this server.")
.SetDefault(1); .SetDefault(1);
......
...@@ -30,7 +30,7 @@ namespace paddle { ...@@ -30,7 +30,7 @@ namespace paddle {
namespace operators { namespace operators {
constexpr char kOptimizeBlock[] = "OptimizeBlock"; constexpr char kOptimizeBlock[] = "OptimizeBlock";
constexpr char kPrefetchBlock[] = "PrefetchBlock"; constexpr char kPrefetchBlock[] = "prefetch_var_name_to_block_id";
void RunServer(std::shared_ptr<detail::RPCServer> service); void RunServer(std::shared_ptr<detail::RPCServer> service);
......
...@@ -515,21 +515,20 @@ class DistributeTranspiler: ...@@ -515,21 +515,20 @@ class DistributeTranspiler:
grad_to_block_id, None) grad_to_block_id, None)
# process distributed lookup_table # process distributed lookup_table
prefetch_block = None prefetch_var_name_to_block_id = []
if self.has_distributed_lookup_table: if self.has_distributed_lookup_table:
pserver_index = self.pserver_endpoints.index(endpoint) pserver_index = self.pserver_endpoints.index(endpoint)
table_opt_block = self._create_table_optimize_block( table_opt_block = self._create_table_optimize_block(
pserver_index, pserver_program, pre_block_idx, grad_to_block_id) pserver_index, pserver_program, pre_block_idx, grad_to_block_id)
prefetch_block = self._create_prefetch_block( prefetch_var_name_to_block_id = self._create_prefetch_block(
pserver_index, pserver_program, table_opt_block) pserver_index, pserver_program, table_opt_block)
# NOTE: if has_distributed_lookup_table is False, then prefetch_block will # NOTE: if has_distributed_lookup_table is False, then prefetch_block will
# not be executed, so it's safe to use optimize_block to hold the place # not be executed, so it's safe to use optimize_block to hold the place
if self.has_distributed_lookup_table: if self.has_distributed_lookup_table:
assert prefetch_block is not None assert len(prefetch_var_name_to_block_id) > 0
else: else:
assert prefetch_block is None assert len(prefetch_var_name_to_block_id) == 0
prefetch_block = pserver_program.global_block()
# step5 append the listen_and_serv op # step5 append the listen_and_serv op
pserver_program.global_block().append_op( pserver_program.global_block().append_op(
...@@ -540,7 +539,7 @@ class DistributeTranspiler: ...@@ -540,7 +539,7 @@ class DistributeTranspiler:
"OptimizeBlock": pserver_program.block(1), "OptimizeBlock": pserver_program.block(1),
"endpoint": endpoint, "endpoint": endpoint,
"Fanin": self.trainer_num, "Fanin": self.trainer_num,
"PrefetchBlock": prefetch_block, "prefetch_var_name_to_block_id": prefetch_var_name_to_block_id,
"sync_mode": self.sync_mode, "sync_mode": self.sync_mode,
"grad_to_block_id": grad_to_block_id "grad_to_block_id": grad_to_block_id
}) })
...@@ -608,8 +607,15 @@ class DistributeTranspiler: ...@@ -608,8 +607,15 @@ class DistributeTranspiler:
def _replace_lookup_table_op_with_prefetch(self, program, def _replace_lookup_table_op_with_prefetch(self, program,
pserver_endpoints): pserver_endpoints):
# 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op # 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op
self.prefetch_input_vars = None # self.all_prefetch_input_vars =
self.prefetch_output_vars = None # [[var0_prefetch_in_pserver0, var0_prefetch_in_pserver1]
# [var1_prefetch_in_pserver0, var1_prefetch_in_pserver1]]
self.all_prefetch_input_vars = []
# self.all_prefetch_input_vars =
# [[var0_prefetch_in_pserver0, var0_prefetch_in_pserver1]
# [var1_prefetch_in_pserver0, var1_prefetch_in_pserver1]]
self.all_prefetch_output_vars = []
continue_search_lookup_table_op = True continue_search_lookup_table_op = True
while continue_search_lookup_table_op: while continue_search_lookup_table_op:
...@@ -623,18 +629,19 @@ class DistributeTranspiler: ...@@ -623,18 +629,19 @@ class DistributeTranspiler:
ids_name = op.input("Ids") ids_name = op.input("Ids")
out_name = op.output("Out") out_name = op.output("Out")
if self.prefetch_input_vars is None:
ids_var = program.global_block().vars[ids_name[0]] ids_var = program.global_block().vars[ids_name[0]]
self.prefetch_input_vars = self.create_splited_vars( prefetch_input_vars = self.create_splited_vars(
source_var=ids_var, source_var=ids_var,
block=program.global_block(), block=program.global_block(),
tag="_prefetch_in_") tag="_prefetch_in_")
if self.prefetch_output_vars is None: self.all_prefetch_input_vars.append(prefetch_input_vars)
out_var = program.global_block().vars[out_name[0]] out_var = program.global_block().vars[out_name[0]]
self.prefetch_output_vars = self.create_splited_vars( prefetch_output_vars = self.create_splited_vars(
source_var=out_var, source_var=out_var,
block=program.global_block(), block=program.global_block(),
tag="_prefetch_out_") tag="_prefetch_out_")
self.all_prefetch_output_vars.append(prefetch_output_vars)
# insert split_ids_op # insert split_ids_op
program.global_block().insert_op( program.global_block().insert_op(
...@@ -646,14 +653,14 @@ class DistributeTranspiler: ...@@ -646,14 +653,14 @@ class DistributeTranspiler:
for varname in ids_name for varname in ids_name
] ]
}, },
outputs={"Out": self.prefetch_input_vars}) outputs={"Out": prefetch_input_vars})
# insert prefetch_op # insert prefetch_op
program.global_block().insert_op( program.global_block().insert_op(
index=op_index + 1, index=op_index + 1,
type="prefetch", type="prefetch",
inputs={'X': self.prefetch_input_vars}, inputs={'X': prefetch_input_vars},
outputs={"Out": self.prefetch_output_vars}, outputs={"Out": prefetch_output_vars},
attrs={ attrs={
"epmap": pserver_endpoints, "epmap": pserver_endpoints,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
...@@ -663,7 +670,7 @@ class DistributeTranspiler: ...@@ -663,7 +670,7 @@ class DistributeTranspiler:
program.global_block().insert_op( program.global_block().insert_op(
index=op_index + 2, index=op_index + 2,
type="concat", type="concat",
inputs={'X': self.prefetch_output_vars}, inputs={'X': prefetch_output_vars},
outputs={ outputs={
"Out": [ "Out": [
program.global_block().vars[varname] program.global_block().vars[varname]
...@@ -709,14 +716,16 @@ class DistributeTranspiler: ...@@ -709,14 +716,16 @@ class DistributeTranspiler:
optimize_block): optimize_block):
# STEP: create prefetch block # STEP: create prefetch block
table_var = pserver_program.global_block().vars[self.table_name] table_var = pserver_program.global_block().vars[self.table_name]
prefetch_var_name_to_block_id = []
for index in range(len(self.all_prefetch_input_vars)):
prefetch_block = pserver_program.create_block(optimize_block.idx) prefetch_block = pserver_program.create_block(optimize_block.idx)
trainer_ids = self.prefetch_input_vars[pserver_index] trainer_ids = self.all_prefetch_input_vars[index][pserver_index]
pserver_ids = pserver_program.global_block().create_var( pserver_ids = pserver_program.global_block().create_var(
name=trainer_ids.name, name=trainer_ids.name,
type=trainer_ids.type, type=trainer_ids.type,
shape=trainer_ids.shape, shape=trainer_ids.shape,
dtype=trainer_ids.dtype) dtype=trainer_ids.dtype)
trainer_out = self.prefetch_output_vars[pserver_index] trainer_out = self.all_prefetch_output_vars[index][pserver_index]
pserver_out = pserver_program.global_block().create_var( pserver_out = pserver_program.global_block().create_var(
name=trainer_out.name, name=trainer_out.name,
type=trainer_out.type, type=trainer_out.type,
...@@ -732,7 +741,9 @@ class DistributeTranspiler: ...@@ -732,7 +741,9 @@ class DistributeTranspiler:
"is_distributed": True, "is_distributed": True,
"padding_idx": -1 "padding_idx": -1
}) })
return prefetch_block prefetch_var_name_to_block_id.append(trainer_ids.name + ":" + str(
prefetch_block.idx))
return prefetch_var_name_to_block_id
def _create_table_optimize_block(self, pserver_index, pserver_program, def _create_table_optimize_block(self, pserver_index, pserver_program,
pre_block_idx, grad_to_block_id): pre_block_idx, grad_to_block_id):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册