提交 0970bd9e 编写于 作者: Y Yancey1989

use optimize blocks attr to record optimize block id

上级 e02cbf35
...@@ -106,13 +106,8 @@ void ListenAndServOp::RunSyncLoop( ...@@ -106,13 +106,8 @@ void ListenAndServOp::RunSyncLoop(
"server program should have at least 2 blocks"); "server program should have at least 2 blocks");
std::vector<int> optimize_block_id_list; std::vector<int> optimize_block_id_list;
for (int blkid = 1; blkid < num_blocks; ++blkid) { for (auto *block : optimize_blocks) {
if (std::find(prefetch_block_id_list.begin(), prefetch_block_id_list.end(), optimize_block_id_list.push_back(block->ID());
blkid) == prefetch_block_id_list.end() &&
std::find(skip_sub_blks.begin(), skip_sub_blks.end(), blkid) ==
skip_sub_blks.end()) {
optimize_block_id_list.push_back(blkid);
}
} }
auto optimize_prepared = executor->Prepare(*program, optimize_block_id_list); auto optimize_prepared = executor->Prepare(*program, optimize_block_id_list);
// Insert placeholder for block0 which holds current op itself. // Insert placeholder for block0 which holds current op itself.
...@@ -137,9 +132,9 @@ void ListenAndServOp::RunSyncLoop( ...@@ -137,9 +132,9 @@ void ListenAndServOp::RunSyncLoop(
// and this will still work. // and this will still work.
// The optimize blocks which have the same parent ID would run parallel // The optimize blocks which have the same parent ID would run parallel
// TODO(Yancey1989): need to use ParallelExecutor for future // TODO(Yancey1989): need to use ParallelExecutor for future
int32_t last_parent_blkid = program->Block(1).Parent(); int32_t last_parent_blkid = optimize_blocks[0]->Parent();
std::vector<size_t> parallel_blkids; std::vector<size_t> parallel_blkids;
parallel_blkids.push_back(1); parallel_blkids.push_back(optimize_blocks[0]->ID());
double ts = GetTimestamp(); double ts = GetTimestamp();
for (size_t i = 1; i < optimize_block_id_list.size(); ++i) { for (size_t i = 1; i < optimize_block_id_list.size(); ++i) {
// skip the first optimize block because it is already in the // skip the first optimize block because it is already in the
...@@ -262,8 +257,11 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -262,8 +257,11 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
rpc_service_->RegisterRPC(detail::kRequestPrefetch, rpc_service_->RegisterRPC(detail::kRequestPrefetch,
request_prefetch_handler_.get()); request_prefetch_handler_.get());
auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock); auto optimize_blocks =
auto *program = optimize_block->Program(); Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks);
PADDLE_ENFORCE(optimize_blocks.size() > 1,
"optimize blocks should be 1 at least on the pserver side.");
auto *program = optimize_block[0]->Program();
framework::Executor executor(dev_place); framework::Executor executor(dev_place);
// prepare for prefetch // prepare for prefetch
...@@ -340,18 +338,13 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -340,18 +338,13 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
"a map from grad name to it's optimize block id") "a map from grad name to it's optimize block id")
.SetDefault({}); .SetDefault({});
AddAttr<bool>("sync_mode", "if works at sync_mode or not").SetDefault(true); AddAttr<bool>("sync_mode", "if works at sync_mode or not").SetDefault(true);
AddAttr<framework::BlockDesc *>(kOptimizeBlock, AddAttr<framework::BlockDesc *>(kOptimizeBlocks,
"BlockID to run on server side."); "Optimize blocks to run on server side.");
AddAttr<std::vector<std::string>>(kPrefetchVarNameToBlockId, AddAttr<std::vector<std::string>>(kPrefetchVarNameToBlockId,
"prefetch blocks to run on server side.") "prefetch blocks to run on server side.")
.SetDefault({}); .SetDefault({});
AddAttr<int>("Fanin", "How many clients send to this server.") AddAttr<int>("Fanin", "How many clients send to this server.")
.SetDefault(1); .SetDefault(1);
AddAttr<std::vector<int>>("skip_sub_blks",
"do not parallel execute the specify sub blocks, "
"it's used for the op which has"
"condition blocks")
.SetDefault({});
} }
}; };
......
...@@ -30,7 +30,7 @@ limitations under the License. */ ...@@ -30,7 +30,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
constexpr char kOptimizeBlock[] = "OptimizeBlock"; constexpr char kOptimizeBlocks[] = "optimize_blocks";
constexpr char kPrefetchVarNameToBlockId[] = "prefetch_var_name_to_block_id"; constexpr char kPrefetchVarNameToBlockId[] = "prefetch_var_name_to_block_id";
void RunServer(std::shared_ptr<detail::RPCServer> service); void RunServer(std::shared_ptr<detail::RPCServer> service);
......
...@@ -424,10 +424,12 @@ class DistributeTranspiler(object): ...@@ -424,10 +424,12 @@ class DistributeTranspiler(object):
# append lr decay ops to the child block if exists # append lr decay ops to the child block if exists
lr_ops = self._get_lr_ops() lr_ops = self._get_lr_ops()
skip_sub_blks = [] # record optimize blocks and we can run them on pserver parallel
optimize_blocks = []
if len(lr_ops) > 0: if len(lr_ops) > 0:
lr_decay_block = pserver_program.create_block( lr_decay_block = pserver_program.create_block(
pserver_program.num_blocks - 1) pserver_program.num_blocks - 1)
optimize_blocks.append(lr_decay_block)
for _, op in enumerate(lr_ops): for _, op in enumerate(lr_ops):
self._append_pserver_non_opt_ops(lr_decay_block, op) self._append_pserver_non_opt_ops(lr_decay_block, op)
# append sub blocks to pserver_program in lr_decay_op # append sub blocks to pserver_program in lr_decay_op
...@@ -439,6 +441,7 @@ class DistributeTranspiler(object): ...@@ -439,6 +441,7 @@ class DistributeTranspiler(object):
pre_block_idx = pserver_program.num_blocks - 1 pre_block_idx = pserver_program.num_blocks - 1
for idx, opt_op in enumerate(opt_op_on_pserver): for idx, opt_op in enumerate(opt_op_on_pserver):
per_opt_block = pserver_program.create_block(pre_block_idx) per_opt_block = pserver_program.create_block(pre_block_idx)
optimize_blocks.append(per_opt_block)
# append grad merging ops before clip and weight decay # append grad merging ops before clip and weight decay
for _, op in enumerate(self.optimize_ops): for _, op in enumerate(self.optimize_ops):
# find the origin @GRAD var before clipping # find the origin @GRAD var before clipping
...@@ -457,6 +460,7 @@ class DistributeTranspiler(object): ...@@ -457,6 +460,7 @@ class DistributeTranspiler(object):
if global_ops: if global_ops:
opt_state_block = pserver_program.create_block( opt_state_block = pserver_program.create_block(
pserver_program.num_blocks - 1) pserver_program.num_blocks - 1)
optimize_blocks.append(opt_state_block)
for glb_op in global_ops: for glb_op in global_ops:
__append_optimize_op__(glb_op, opt_state_block, __append_optimize_op__(glb_op, opt_state_block,
grad_to_block_id, None) grad_to_block_id, None)
...@@ -478,12 +482,11 @@ class DistributeTranspiler(object): ...@@ -478,12 +482,11 @@ class DistributeTranspiler(object):
assert len(prefetch_var_name_to_block_id) == 0 assert len(prefetch_var_name_to_block_id) == 0
attrs = { attrs = {
"OptimizeBlock": pserver_program.block(1), "optimize_blocks": optimize_blocks,
"endpoint": endpoint, "endpoint": endpoint,
"Fanin": self.trainer_num, "Fanin": self.trainer_num,
"sync_mode": self.sync_mode, "sync_mode": self.sync_mode,
"grad_to_block_id": grad_to_block_id, "grad_to_block_id": grad_to_block_id,
"skip_sub_blks": skip_sub_blks
} }
if len(prefetch_var_name_to_block_id) > 0: if len(prefetch_var_name_to_block_id) > 0:
attrs['prefetch_var_name_to_block_id'] \ attrs['prefetch_var_name_to_block_id'] \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册