提交 56a903d3 编写于 作者: Y Yancey1989

use optimize block list instead of first optimize block

上级 3a37e142
...@@ -46,6 +46,7 @@ message OpDesc { ...@@ -46,6 +46,7 @@ message OpDesc {
repeated bool bools = 11; repeated bool bools = 11;
optional int32 block_idx = 12; optional int32 block_idx = 12;
optional int64 l = 13; optional int64 l = 13;
repeated int32 blocks_idx = 14;
}; };
message Var { message Var {
......
...@@ -211,6 +211,12 @@ void OpDesc::SetBlockAttr(const std::string &name, BlockDesc *block) { ...@@ -211,6 +211,12 @@ void OpDesc::SetBlockAttr(const std::string &name, BlockDesc *block) {
need_update_ = true; need_update_ = true;
} }
void OpDesc::SetBlocksAttr(const std::string &name,
std::vector<BlockDesc *> blocks) {
this->attrs_[name] = blocks;
need_update_ = true;
}
void OpDesc::SetAttrMap( void OpDesc::SetAttrMap(
const std::unordered_map<std::string, Attribute> &attr_map) { const std::unordered_map<std::string, Attribute> &attr_map) {
attrs_ = attr_map; attrs_ = attr_map;
...@@ -305,6 +311,13 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> { ...@@ -305,6 +311,13 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> {
void operator()(const std::vector<bool> &v) const { void operator()(const std::vector<bool> &v) const {
VectorToRepeated(v, attr_->mutable_bools()); VectorToRepeated(v, attr_->mutable_bools());
} }
void operator()(const std::vector<BlockDesc *> &v) const {
std::vector<int> blocks_idx;
for (auto blk : v) {
blocks_idx.push_back(blk->ID());
}
VectorToRepeated(blocks_idx, attr_->mutable_blocks_idx());
}
void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->ID()); } void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->ID()); }
void operator()(int64_t v) const { attr_->set_l(v); } void operator()(int64_t v) const { attr_->set_l(v); }
void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); } void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); }
......
...@@ -77,6 +77,8 @@ class OpDesc { ...@@ -77,6 +77,8 @@ class OpDesc {
void SetBlockAttr(const std::string &name, BlockDesc *block); void SetBlockAttr(const std::string &name, BlockDesc *block);
void SetBlocksAttr(const std::string &name, std::vector<BlockDesc *> blocks);
Attribute GetAttr(const std::string &name) const; Attribute GetAttr(const std::string &name) const;
Attribute GetNullableAttr(const std::string &name) const; Attribute GetNullableAttr(const std::string &name) const;
......
...@@ -35,7 +35,8 @@ using VariableNameMap = std::map<std::string, std::vector<std::string>>; ...@@ -35,7 +35,8 @@ using VariableNameMap = std::map<std::string, std::vector<std::string>>;
using Attribute = using Attribute =
boost::variant<boost::blank, int, float, std::string, std::vector<int>, boost::variant<boost::blank, int, float, std::string, std::vector<int>,
std::vector<float>, std::vector<std::string>, bool, std::vector<float>, std::vector<std::string>, bool,
std::vector<bool>, BlockDesc*, int64_t>; std::vector<bool>, BlockDesc*, int64_t,
std::vector<BlockDesc*>>;
using AttributeMap = std::unordered_map<std::string, Attribute>; using AttributeMap = std::unordered_map<std::string, Attribute>;
......
...@@ -101,14 +101,11 @@ void ListenAndServOp::RunSyncLoop( ...@@ -101,14 +101,11 @@ void ListenAndServOp::RunSyncLoop(
framework::Scope *recv_scope, framework::Scope *recv_scope,
const std::vector<int> &prefetch_block_id_list) const { const std::vector<int> &prefetch_block_id_list) const {
size_t num_blocks = program->Size(); size_t num_blocks = program->Size();
auto skip_sub_blks = Attr<std::vector<int>>("skip_sub_blks"); auto optimize_blocks =
Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks);
PADDLE_ENFORCE_GE(num_blocks, 2, PADDLE_ENFORCE_GE(num_blocks, 2,
"server program should have at least 2 blocks"); "server program should have at least 2 blocks");
std::vector<int> optimize_block_id_list;
for (auto *block : optimize_blocks) {
optimize_block_id_list.push_back(block->ID());
}
auto optimize_prepared = executor->Prepare(*program, optimize_block_id_list); auto optimize_prepared = executor->Prepare(*program, optimize_block_id_list);
// Insert placeholder for block0 which holds current op itself. // Insert placeholder for block0 which holds current op itself.
optimize_prepared.insert( optimize_prepared.insert(
...@@ -136,10 +133,10 @@ void ListenAndServOp::RunSyncLoop( ...@@ -136,10 +133,10 @@ void ListenAndServOp::RunSyncLoop(
std::vector<size_t> parallel_blkids; std::vector<size_t> parallel_blkids;
parallel_blkids.push_back(optimize_blocks[0]->ID()); parallel_blkids.push_back(optimize_blocks[0]->ID());
double ts = GetTimestamp(); double ts = GetTimestamp();
for (size_t i = 1; i < optimize_block_id_list.size(); ++i) { for (size_t i = 1; i < optimize_blocks.size(); ++i) {
// skip the first optimize block because it is already in the // skip the first optimize block because it is already in the
// parallel_blkids. // parallel_blkids.
int blkid = optimize_block_id_list[i]; int blkid = optimize_blocks[i]->ID();
if (program->Block(blkid).Parent() != last_parent_blkid) { if (program->Block(blkid).Parent() != last_parent_blkid) {
ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared, ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared,
program, recv_scope); program, recv_scope);
...@@ -263,7 +260,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -263,7 +260,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks); Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks);
PADDLE_ENFORCE(optimize_blocks.size() > 1, PADDLE_ENFORCE(optimize_blocks.size() > 1,
"optimize blocks should be 1 at least on the pserver side."); "optimize blocks should be 1 at least on the pserver side.");
auto *program = optimize_block[0]->Program(); auto *program = optimize_blocks[0]->Program();
framework::Executor executor(dev_place); framework::Executor executor(dev_place);
// prepare for prefetch // prepare for prefetch
...@@ -340,8 +337,8 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -340,8 +337,8 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
"a map from grad name to it's optimize block id") "a map from grad name to it's optimize block id")
.SetDefault({}); .SetDefault({});
AddAttr<bool>("sync_mode", "if works at sync_mode or not").SetDefault(true); AddAttr<bool>("sync_mode", "if works at sync_mode or not").SetDefault(true);
AddAttr<framework::BlockDesc *>(kOptimizeBlocks, AddAttr<std::vector<framework::BlockDesc *>>(
"Optimize blocks to run on server side."); kOptimizeBlocks, "Optimize blocks to run on server side.");
AddAttr<std::vector<std::string>>(kPrefetchVarNameToBlockId, AddAttr<std::vector<std::string>>(kPrefetchVarNameToBlockId,
"prefetch blocks to run on server side.") "prefetch blocks to run on server side.")
.SetDefault({}); .SetDefault({});
......
...@@ -293,6 +293,7 @@ void BindOpDesc(pybind11::module *m) { ...@@ -293,6 +293,7 @@ void BindOpDesc(pybind11::module *m) {
.def("set_attr", &pd::OpDesc::SetAttr) .def("set_attr", &pd::OpDesc::SetAttr)
.def("attr", &pd::OpDesc::GetAttr) .def("attr", &pd::OpDesc::GetAttr)
.def("set_block_attr", &pd::OpDesc::SetBlockAttr) .def("set_block_attr", &pd::OpDesc::SetBlockAttr)
.def("set_blocks_attr", &pd::OpDesc::SetBlocksAttr)
.def("set_serialized_attr", .def("set_serialized_attr",
[](pd::OpDesc &self, const std::string &name, [](pd::OpDesc &self, const std::string &name,
const pybind11::bytes &seriralized) { const pybind11::bytes &seriralized) {
......
...@@ -561,6 +561,10 @@ class Operator(object): ...@@ -561,6 +561,10 @@ class Operator(object):
if isinstance(self.attrs[attr_name], Block): if isinstance(self.attrs[attr_name], Block):
self.desc.set_block_attr(attr_name, self.desc.set_block_attr(attr_name,
self.attrs[attr_name].desc) self.attrs[attr_name].desc)
elif isinstance(self.attrs[attr_name], list) and \
all(isinstance(v, Block) for v in self.attrs[attr_name]):
self.desc.set_blocks_attr(
attr_name, [v.desc for v in self.attrs[attr_name]])
elif isinstance(self.attrs[attr_name], core.BlockDesc) or \ elif isinstance(self.attrs[attr_name], core.BlockDesc) or \
isinstance(self.attrs[attr_name], core.ProgramDesc): isinstance(self.attrs[attr_name], core.ProgramDesc):
self.desc.set_serialized_attr( self.desc.set_serialized_attr(
...@@ -715,6 +719,8 @@ class Operator(object): ...@@ -715,6 +719,8 @@ class Operator(object):
self.attrs[name] = val self.attrs[name] = val
if isinstance(val, Block): if isinstance(val, Block):
self.desc.set_block_attr(name, val.desc) self.desc.set_block_attr(name, val.desc)
elif isinstance(val, list) and all(isinstance(v, Block) for v in val):
self.desc.set_blocks_attr(name, [v.desc for v in val])
elif isinstance(val, core.BlockDesc) or \ elif isinstance(val, core.BlockDesc) or \
isinstance(val, core.ProgramDesc): isinstance(val, core.ProgramDesc):
self.desc.set_serialized_attr(name, val.serialize_to_string()) self.desc.set_serialized_attr(name, val.serialize_to_string())
......
...@@ -396,7 +396,7 @@ class DistributeTranspiler(object): ...@@ -396,7 +396,7 @@ class DistributeTranspiler(object):
return varname return varname
return "" return ""
def __clone_lr_op_sub_block__(op, program, new_block, skip_sub_blks): def __clone_lr_op_sub_block__(op, program, new_block):
if not op.has_attr('sub_block'): if not op.has_attr('sub_block'):
return return
...@@ -406,7 +406,6 @@ class DistributeTranspiler(object): ...@@ -406,7 +406,6 @@ class DistributeTranspiler(object):
# we put the new sub block to new block to follow the block # we put the new sub block to new block to follow the block
# hierarchy of the original blocks # hierarchy of the original blocks
new_sub_block = program.create_block(new_block.idx) new_sub_block = program.create_block(new_block.idx)
skip_sub_blks.append(new_sub_block.idx)
# clone vars # clone vars
for var in origin_block.vars: for var in origin_block.vars:
...@@ -416,8 +415,7 @@ class DistributeTranspiler(object): ...@@ -416,8 +415,7 @@ class DistributeTranspiler(object):
for op in origin_block.ops: for op in origin_block.ops:
self._clone_lr_op(program, new_sub_block, op) self._clone_lr_op(program, new_sub_block, op)
# clone sub_block of op # clone sub_block of op
__clone_lr_op_sub_block__(op, program, new_sub_block, __clone_lr_op_sub_block__(op, program, new_sub_block)
skip_sub_blks)
# reset the block of op # reset the block of op
op.set_attr('sub_block', new_sub_block) op.set_attr('sub_block', new_sub_block)
...@@ -433,8 +431,7 @@ class DistributeTranspiler(object): ...@@ -433,8 +431,7 @@ class DistributeTranspiler(object):
for _, op in enumerate(lr_ops): for _, op in enumerate(lr_ops):
self._append_pserver_non_opt_ops(lr_decay_block, op) self._append_pserver_non_opt_ops(lr_decay_block, op)
# append sub blocks to pserver_program in lr_decay_op # append sub blocks to pserver_program in lr_decay_op
__clone_lr_op_sub_block__(op, pserver_program, lr_decay_block, __clone_lr_op_sub_block__(op, pserver_program, lr_decay_block)
skip_sub_blks)
# append op to the current block # append op to the current block
grad_to_block_id = [] grad_to_block_id = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册