提交 0970bd9e 编写于 作者: Y Yancey1989

use optimize blocks attr to record optimize block id

上级 e02cbf35
......@@ -106,13 +106,8 @@ void ListenAndServOp::RunSyncLoop(
"server program should have at least 2 blocks");
std::vector<int> optimize_block_id_list;
for (int blkid = 1; blkid < num_blocks; ++blkid) {
if (std::find(prefetch_block_id_list.begin(), prefetch_block_id_list.end(),
blkid) == prefetch_block_id_list.end() &&
std::find(skip_sub_blks.begin(), skip_sub_blks.end(), blkid) ==
skip_sub_blks.end()) {
optimize_block_id_list.push_back(blkid);
}
for (auto *block : optimize_blocks) {
optimize_block_id_list.push_back(block->ID());
}
auto optimize_prepared = executor->Prepare(*program, optimize_block_id_list);
// Insert placeholder for block0 which holds current op itself.
......@@ -137,9 +132,9 @@ void ListenAndServOp::RunSyncLoop(
// and this will still work.
// The optimize blocks which have the same parent ID would run parallel
// TODO(Yancey1989): need to use ParallelExecutor for future
int32_t last_parent_blkid = program->Block(1).Parent();
int32_t last_parent_blkid = optimize_blocks[0]->Parent();
std::vector<size_t> parallel_blkids;
parallel_blkids.push_back(1);
parallel_blkids.push_back(optimize_blocks[0]->ID());
double ts = GetTimestamp();
for (size_t i = 1; i < optimize_block_id_list.size(); ++i) {
// skip the first optimize block because it is already in the
......@@ -262,8 +257,11 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
rpc_service_->RegisterRPC(detail::kRequestPrefetch,
request_prefetch_handler_.get());
auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto *program = optimize_block->Program();
auto optimize_blocks =
Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks);
PADDLE_ENFORCE(optimize_blocks.size() > 1,
"optimize blocks should be 1 at least on the pserver side.");
auto *program = optimize_block[0]->Program();
framework::Executor executor(dev_place);
// prepare for prefetch
......@@ -340,18 +338,13 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
"a map from grad name to it's optimize block id")
.SetDefault({});
AddAttr<bool>("sync_mode", "if works at sync_mode or not").SetDefault(true);
AddAttr<framework::BlockDesc *>(kOptimizeBlock,
"BlockID to run on server side.");
AddAttr<framework::BlockDesc *>(kOptimizeBlocks,
"Optimize blocks to run on server side.");
AddAttr<std::vector<std::string>>(kPrefetchVarNameToBlockId,
"prefetch blocks to run on server side.")
.SetDefault({});
AddAttr<int>("Fanin", "How many clients send to this server.")
.SetDefault(1);
AddAttr<std::vector<int>>("skip_sub_blks",
"do not parallel execute the specify sub blocks, "
"it's used for the op which has"
"condition blocks")
.SetDefault({});
}
};
......
......@@ -30,7 +30,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
constexpr char kOptimizeBlock[] = "OptimizeBlock";
constexpr char kOptimizeBlocks[] = "optimize_blocks";
constexpr char kPrefetchVarNameToBlockId[] = "prefetch_var_name_to_block_id";
void RunServer(std::shared_ptr<detail::RPCServer> service);
......
......@@ -424,10 +424,12 @@ class DistributeTranspiler(object):
# append lr decay ops to the child block if exists
lr_ops = self._get_lr_ops()
skip_sub_blks = []
# record optimize blocks and we can run them on pserver parallel
optimize_blocks = []
if len(lr_ops) > 0:
lr_decay_block = pserver_program.create_block(
pserver_program.num_blocks - 1)
optimize_blocks.append(lr_decay_block)
for _, op in enumerate(lr_ops):
self._append_pserver_non_opt_ops(lr_decay_block, op)
# append sub blocks to pserver_program in lr_decay_op
......@@ -439,6 +441,7 @@ class DistributeTranspiler(object):
pre_block_idx = pserver_program.num_blocks - 1
for idx, opt_op in enumerate(opt_op_on_pserver):
per_opt_block = pserver_program.create_block(pre_block_idx)
optimize_blocks.append(per_opt_block)
# append grad merging ops before clip and weight decay
for _, op in enumerate(self.optimize_ops):
# find the origin @GRAD var before clipping
......@@ -457,6 +460,7 @@ class DistributeTranspiler(object):
if global_ops:
opt_state_block = pserver_program.create_block(
pserver_program.num_blocks - 1)
optimize_blocks.append(opt_state_block)
for glb_op in global_ops:
__append_optimize_op__(glb_op, opt_state_block,
grad_to_block_id, None)
......@@ -478,12 +482,11 @@ class DistributeTranspiler(object):
assert len(prefetch_var_name_to_block_id) == 0
attrs = {
"OptimizeBlock": pserver_program.block(1),
"optimize_blocks": optimize_blocks,
"endpoint": endpoint,
"Fanin": self.trainer_num,
"sync_mode": self.sync_mode,
"grad_to_block_id": grad_to_block_id,
"skip_sub_blks": skip_sub_blks
}
if len(prefetch_var_name_to_block_id) > 0:
attrs['prefetch_var_name_to_block_id'] \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册