提交 b54d1ba9 编写于 作者: Y Yancey1989

fix pserver sub-blocks

上级 59729902
...@@ -101,13 +101,16 @@ void ListenAndServOp::RunSyncLoop( ...@@ -101,13 +101,16 @@ void ListenAndServOp::RunSyncLoop(
framework::Scope *recv_scope, framework::Scope *recv_scope,
const std::vector<int> &prefetch_block_id_list) const { const std::vector<int> &prefetch_block_id_list) const {
size_t num_blocks = program->Size(); size_t num_blocks = program->Size();
auto skip_sub_blks = Attr<std::vector<int>>("skip_sub_blks");
PADDLE_ENFORCE_GE(num_blocks, 2, PADDLE_ENFORCE_GE(num_blocks, 2,
"server program should have at least 2 blocks"); "server program should have at least 2 blocks");
std::vector<int> optimize_block_id_list; std::vector<int> optimize_block_id_list;
for (int blkid = 1; blkid < num_blocks; ++blkid) { for (int blkid = 1; blkid < num_blocks; ++blkid) {
if (std::find(prefetch_block_id_list.begin(), prefetch_block_id_list.end(), if (std::find(prefetch_block_id_list.begin(), prefetch_block_id_list.end(),
blkid) == prefetch_block_id_list.end()) { blkid) == prefetch_block_id_list.end() &&
std::find(skip_sub_blks.begin(), skip_sub_blks.end(), blkid) ==
skip_sub_blks.end()) {
optimize_block_id_list.push_back(blkid); optimize_block_id_list.push_back(blkid);
} }
} }
...@@ -344,6 +347,11 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -344,6 +347,11 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault({}); .SetDefault({});
AddAttr<int>("Fanin", "How many clients send to this server.") AddAttr<int>("Fanin", "How many clients send to this server.")
.SetDefault(1); .SetDefault(1);
AddAttr<std::vector<int>>("skip_sub_blks",
"do not parallel execute the specify sub blocks, "
"it's used for the op which has"
"condition blocks")
.SetDefault({});
} }
}; };
......
...@@ -250,19 +250,14 @@ class DistributeTranspiler: ...@@ -250,19 +250,14 @@ class DistributeTranspiler:
split_method=RoundRobin, split_method=RoundRobin,
sync_mode=True): sync_mode=True):
""" """
:param trainer_id: one unique id for each trainer in a job. Args:
:type trainer_id: int trainer_id(int): one unique id for each trainer in a job.
:param program: program to transpile, default is default_main_program program(Program): program to transpile, default is default_main_program
:type program: Program pservers(string): parameter server endpoints like "m1:6174,m2:6174"
:param pservers: parameter server endpoints like "m1:6174,m2:6174" trainers(int): total number of workers/trainers in the job
:type pservers: string split_method(PSDispatcher): A function to determin how to split variables
:param trainers: total number of workers/trainers in the job
:type trainers: int
:param split_method: A function to determin how to split variables
to different servers equally. to different servers equally.
:type split_method: function sync_mode(boolean): if sync_mode is set True, it means that dist transpiler
:type sync_mode: boolean default True
:param sync_mode: if sync_mode is set True, it means that dist transpiler
will transpile the program into sync_mode pserver and trainer program. will transpile the program into sync_mode pserver and trainer program.
""" """
assert (split_method.__bases__[0] == PSDispatcher) assert (split_method.__bases__[0] == PSDispatcher)
...@@ -403,6 +398,11 @@ class DistributeTranspiler: ...@@ -403,6 +398,11 @@ class DistributeTranspiler:
NOTE: assume blocks of the same variable is not distributed NOTE: assume blocks of the same variable is not distributed
on the same pserver, only change param/grad varnames for on the same pserver, only change param/grad varnames for
trainers to fetch. trainers to fetch.
Args:
endpoint(string): the endpoint for the current pserver instance.
Returns(Program): the pserver program
""" """
# step1 # step1
pserver_program = Program() pserver_program = Program()
...@@ -479,9 +479,9 @@ class DistributeTranspiler: ...@@ -479,9 +479,9 @@ class DistributeTranspiler:
return varname return varname
return "" return ""
def __clone_lr_op_sub_block__(op, program, new_block): def __clone_lr_op_sub_block__(op, program, new_block, skip_sub_blks):
if not op.has_attr('sub_block'): if not op.has_attr('sub_block'):
return return -1
origin_block_desc = op.attr('sub_block') origin_block_desc = op.attr('sub_block')
origin_block = self.origin_program.block(origin_block_desc.id) origin_block = self.origin_program.block(origin_block_desc.id)
...@@ -489,6 +489,7 @@ class DistributeTranspiler: ...@@ -489,6 +489,7 @@ class DistributeTranspiler:
# we put the new sub block to new block to follow the block # we put the new sub block to new block to follow the block
# hierarchy of the original blocks # hierarchy of the original blocks
new_sub_block = program.create_block(new_block.idx) new_sub_block = program.create_block(new_block.idx)
skip_sub_blks(new_sub_block.idx)
# clone vars # clone vars
for var in origin_block.vars: for var in origin_block.vars:
...@@ -498,20 +499,24 @@ class DistributeTranspiler: ...@@ -498,20 +499,24 @@ class DistributeTranspiler:
for op in origin_block.ops: for op in origin_block.ops:
self._clone_lr_op(program, new_sub_block, op) self._clone_lr_op(program, new_sub_block, op)
# clone sub_block of op # clone sub_block of op
__clone_lr_op_sub_block__(op, program, new_sub_block) __clone_lr_op_sub_block__(op, program, new_sub_block,
skip_sub_blks)
# reset the block of op # reset the block of op
op.set_attr('sub_block', new_sub_block) op.set_attr('sub_block', new_sub_block)
return new_sub_block.idx
# append lr decay ops to the child block if exists # append lr decay ops to the child block if exists
lr_ops = self._get_lr_ops() lr_ops = self._get_lr_ops()
skip_sub_blks = []
if len(lr_ops) > 0: if len(lr_ops) > 0:
lr_decay_block = pserver_program.create_block( lr_decay_block = pserver_program.create_block(
pserver_program.num_blocks - 1) pserver_program.num_blocks - 1)
for _, op in enumerate(lr_ops): for _, op in enumerate(lr_ops):
self._append_pserver_non_opt_ops(lr_decay_block, op) self._append_pserver_non_opt_ops(lr_decay_block, op)
# append sub blocks to pserver_program in lr_decay_op # append sub blocks to pserver_program in lr_decay_op
__clone_lr_op_sub_block__(op, pserver_program, lr_decay_block) __clone_lr_op_sub_block__(op, pserver_program, lr_decay_block,
skip_sub_blks)
# append op to the current block # append op to the current block
grad_to_block_id = [] grad_to_block_id = []
...@@ -561,7 +566,8 @@ class DistributeTranspiler: ...@@ -561,7 +566,8 @@ class DistributeTranspiler:
"endpoint": endpoint, "endpoint": endpoint,
"Fanin": self.trainer_num, "Fanin": self.trainer_num,
"sync_mode": self.sync_mode, "sync_mode": self.sync_mode,
"grad_to_block_id": grad_to_block_id "grad_to_block_id": grad_to_block_id,
"skip_sub_blks": skip_sub_blks
} }
if len(prefetch_var_name_to_block_id) > 0: if len(prefetch_var_name_to_block_id) > 0:
attrs['prefetch_var_name_to_block_id'] \ attrs['prefetch_var_name_to_block_id'] \
...@@ -582,6 +588,11 @@ class DistributeTranspiler: ...@@ -582,6 +588,11 @@ class DistributeTranspiler:
Get startup program for current parameter server. Get startup program for current parameter server.
Modify operator input variables if there are variables that Modify operator input variables if there are variables that
were split to several blocks. were split to several blocks.
Args:
endpoint(string): the endpoint for the current pserver instance.
pserver_program(Program): the program for pserver to execute.
Returns(Program): the startup program for pserver
""" """
s_prog = Program() s_prog = Program()
orig_s_prog = default_startup_program() orig_s_prog = default_startup_program()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册