提交 4663cb2d 编写于 作者: Y Yancey 提交者: Yancey1989

Merge pull request #11585 from Yancey1989/fix_pserver_sub_blocks

fix pserver sub-blocks
上级 fac1d477
...@@ -27,6 +27,7 @@ enum AttrType { ...@@ -27,6 +27,7 @@ enum AttrType {
BOOLEANS = 7; BOOLEANS = 7;
BLOCK = 8; BLOCK = 8;
LONG = 9; LONG = 9;
BLOCKS = 10;
} }
// OpDesc describes an instance of a C++ framework::OperatorBase // OpDesc describes an instance of a C++ framework::OperatorBase
...@@ -46,6 +47,7 @@ message OpDesc { ...@@ -46,6 +47,7 @@ message OpDesc {
repeated bool bools = 11; repeated bool bools = 11;
optional int32 block_idx = 12; optional int32 block_idx = 12;
optional int64 l = 13; optional int64 l = 13;
repeated int32 blocks_idx = 14;
}; };
message Var { message Var {
......
...@@ -211,6 +211,12 @@ void OpDesc::SetBlockAttr(const std::string &name, BlockDesc *block) { ...@@ -211,6 +211,12 @@ void OpDesc::SetBlockAttr(const std::string &name, BlockDesc *block) {
need_update_ = true; need_update_ = true;
} }
void OpDesc::SetBlocksAttr(const std::string &name,
std::vector<BlockDesc *> blocks) {
this->attrs_[name] = blocks;
need_update_ = true;
}
void OpDesc::SetAttrMap( void OpDesc::SetAttrMap(
const std::unordered_map<std::string, Attribute> &attr_map) { const std::unordered_map<std::string, Attribute> &attr_map) {
attrs_ = attr_map; attrs_ = attr_map;
...@@ -305,6 +311,13 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> { ...@@ -305,6 +311,13 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> {
void operator()(const std::vector<bool> &v) const { void operator()(const std::vector<bool> &v) const {
VectorToRepeated(v, attr_->mutable_bools()); VectorToRepeated(v, attr_->mutable_bools());
} }
void operator()(const std::vector<BlockDesc *> &v) const {
std::vector<int> blocks_idx;
for (auto blk : v) {
blocks_idx.push_back(blk->ID());
}
VectorToRepeated(blocks_idx, attr_->mutable_blocks_idx());
}
void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->ID()); } void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->ID()); }
void operator()(int64_t v) const { attr_->set_l(v); } void operator()(int64_t v) const { attr_->set_l(v); }
void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); } void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); }
......
...@@ -77,6 +77,8 @@ class OpDesc { ...@@ -77,6 +77,8 @@ class OpDesc {
void SetBlockAttr(const std::string &name, BlockDesc *block); void SetBlockAttr(const std::string &name, BlockDesc *block);
void SetBlocksAttr(const std::string &name, std::vector<BlockDesc *> blocks);
Attribute GetAttr(const std::string &name) const; Attribute GetAttr(const std::string &name) const;
Attribute GetNullableAttr(const std::string &name) const; Attribute GetNullableAttr(const std::string &name) const;
......
...@@ -35,7 +35,8 @@ using VariableNameMap = std::map<std::string, std::vector<std::string>>; ...@@ -35,7 +35,8 @@ using VariableNameMap = std::map<std::string, std::vector<std::string>>;
using Attribute = using Attribute =
boost::variant<boost::blank, int, float, std::string, std::vector<int>, boost::variant<boost::blank, int, float, std::string, std::vector<int>,
std::vector<float>, std::vector<std::string>, bool, std::vector<float>, std::vector<std::string>, bool,
std::vector<bool>, BlockDesc*, int64_t>; std::vector<bool>, BlockDesc*, int64_t,
std::vector<BlockDesc*>>;
using AttributeMap = std::unordered_map<std::string, Attribute>; using AttributeMap = std::unordered_map<std::string, Attribute>;
......
...@@ -101,17 +101,16 @@ void ListenAndServOp::RunSyncLoop( ...@@ -101,17 +101,16 @@ void ListenAndServOp::RunSyncLoop(
framework::Scope *recv_scope, framework::Scope *recv_scope,
const std::vector<int> &prefetch_block_id_list) const { const std::vector<int> &prefetch_block_id_list) const {
size_t num_blocks = program->Size(); size_t num_blocks = program->Size();
auto optimize_blocks =
Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks);
PADDLE_ENFORCE_GE(num_blocks, 2, PADDLE_ENFORCE_GE(num_blocks, 2,
"server program should have at least 2 blocks"); "server program should have at least 2 blocks");
std::vector<int> optimize_block_id_list; std::vector<int> optimize_blocks_idx;
for (int blkid = 1; blkid < num_blocks; ++blkid) { for (auto blk : optimize_blocks) {
if (std::find(prefetch_block_id_list.begin(), prefetch_block_id_list.end(), optimize_blocks_idx.push_back(blk->ID());
blkid) == prefetch_block_id_list.end()) {
optimize_block_id_list.push_back(blkid);
} }
} auto optimize_prepared = executor->Prepare(*program, optimize_blocks_idx);
auto optimize_prepared = executor->Prepare(*program, optimize_block_id_list);
// Insert placeholder for block0 which holds current op itself. // Insert placeholder for block0 which holds current op itself.
optimize_prepared.insert( optimize_prepared.insert(
optimize_prepared.begin(), optimize_prepared.begin(),
...@@ -134,14 +133,14 @@ void ListenAndServOp::RunSyncLoop( ...@@ -134,14 +133,14 @@ void ListenAndServOp::RunSyncLoop(
// and this will still work. // and this will still work.
// The optimize blocks which have the same parent ID would run parallel // The optimize blocks which have the same parent ID would run parallel
// TODO(Yancey1989): need to use ParallelExecutor for future // TODO(Yancey1989): need to use ParallelExecutor for future
int32_t last_parent_blkid = program->Block(1).Parent(); int32_t last_parent_blkid = optimize_blocks[0]->Parent();
std::vector<size_t> parallel_blkids; std::vector<size_t> parallel_blkids;
parallel_blkids.push_back(1); parallel_blkids.push_back(optimize_blocks[0]->ID());
double ts = GetTimestamp(); double ts = GetTimestamp();
for (size_t i = 1; i < optimize_block_id_list.size(); ++i) { for (size_t i = 1; i < optimize_blocks.size(); ++i) {
// skip the first optimize block because it is already in the // skip the first optimize block because it is already in the
// parallel_blkids. // parallel_blkids.
int blkid = optimize_block_id_list[i]; int blkid = optimize_blocks[i]->ID();
if (program->Block(blkid).Parent() != last_parent_blkid) { if (program->Block(blkid).Parent() != last_parent_blkid) {
ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared, ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared,
program, recv_scope); program, recv_scope);
...@@ -259,8 +258,11 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -259,8 +258,11 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
rpc_service_->RegisterRPC(distributed::kRequestPrefetch, rpc_service_->RegisterRPC(distributed::kRequestPrefetch,
request_prefetch_handler_.get()); request_prefetch_handler_.get());
auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock); auto optimize_blocks =
auto *program = optimize_block->Program(); Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks);
PADDLE_ENFORCE(optimize_blocks.size() >= 1,
"optimize blocks should be 1 at least on the pserver side.");
auto *program = optimize_blocks[0]->Program();
framework::Executor executor(dev_place); framework::Executor executor(dev_place);
// prepare for prefetch // prepare for prefetch
...@@ -337,8 +339,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -337,8 +339,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
"a map from grad name to it's optimize block id") "a map from grad name to it's optimize block id")
.SetDefault({}); .SetDefault({});
AddAttr<bool>("sync_mode", "if works at sync_mode or not").SetDefault(true); AddAttr<bool>("sync_mode", "if works at sync_mode or not").SetDefault(true);
AddAttr<framework::BlockDesc *>(kOptimizeBlock, AddAttr<std::vector<framework::BlockDesc *>>(
"BlockID to run on server side."); kOptimizeBlocks, "Optimize blocks to run on server side.")
.SetDefault({});
AddAttr<std::vector<std::string>>(kPrefetchVarNameToBlockId, AddAttr<std::vector<std::string>>(kPrefetchVarNameToBlockId,
"prefetch blocks to run on server side.") "prefetch blocks to run on server side.")
.SetDefault({}); .SetDefault({});
......
...@@ -30,7 +30,7 @@ limitations under the License. */ ...@@ -30,7 +30,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
constexpr char kOptimizeBlock[] = "OptimizeBlock"; constexpr char kOptimizeBlocks[] = "optimize_blocks";
constexpr char kPrefetchVarNameToBlockId[] = "prefetch_var_name_to_block_id"; constexpr char kPrefetchVarNameToBlockId[] = "prefetch_var_name_to_block_id";
void RunServer(std::shared_ptr<distributed::RPCServer> service); void RunServer(std::shared_ptr<distributed::RPCServer> service);
......
...@@ -129,7 +129,10 @@ void StartServerNet(bool is_sparse, std::atomic<bool> *initialized) { ...@@ -129,7 +129,10 @@ void StartServerNet(bool is_sparse, std::atomic<bool> *initialized) {
// sub program run in listen_and_serv_op, for simple test we use sum // sub program run in listen_and_serv_op, for simple test we use sum
f::ProgramDesc program; f::ProgramDesc program;
const auto &root_block = program.Block(0); const auto &root_block = program.Block(0);
std::vector<framework::BlockDesc *> optimize_blocks;
auto *optimize_block = program.AppendBlock(root_block); auto *optimize_block = program.AppendBlock(root_block);
optimize_blocks.push_back(optimize_block);
auto *prefetch_block = program.AppendBlock(root_block); auto *prefetch_block = program.AppendBlock(root_block);
// X for server side tensors, RX for received tensors, must be of same shape. // X for server side tensors, RX for received tensors, must be of same shape.
AddOp("sum", {{"X", {"x0", "x1"}}}, {{"Out", {"Out"}}}, {}, optimize_block, AddOp("sum", {{"X", {"x0", "x1"}}}, {{"Out", {"Out"}}}, {}, optimize_block,
...@@ -139,7 +142,7 @@ void StartServerNet(bool is_sparse, std::atomic<bool> *initialized) { ...@@ -139,7 +142,7 @@ void StartServerNet(bool is_sparse, std::atomic<bool> *initialized) {
attrs.insert({"Fanin", 1}); attrs.insert({"Fanin", 1});
attrs.insert({"ParamList", std::vector<std::string>({"Out"})}); attrs.insert({"ParamList", std::vector<std::string>({"Out"})});
attrs.insert({"GradList", std::vector<std::string>({"x1"})}); attrs.insert({"GradList", std::vector<std::string>({"x1"})});
attrs.insert({"OptimizeBlock", optimize_block}); attrs.insert({"optimize_blocks", optimize_blocks});
attrs.insert({"PrefetchBlock", prefetch_block}); attrs.insert({"PrefetchBlock", prefetch_block});
attrs.insert({"grad_to_block_id", std::vector<std::string>({""})}); attrs.insert({"grad_to_block_id", std::vector<std::string>({""})});
attrs.insert({"sync_mode", true}); attrs.insert({"sync_mode", true});
......
...@@ -268,7 +268,8 @@ void BindOpDesc(pybind11::module *m) { ...@@ -268,7 +268,8 @@ void BindOpDesc(pybind11::module *m) {
.value("STRINGS", pd::proto::AttrType::STRINGS) .value("STRINGS", pd::proto::AttrType::STRINGS)
.value("BOOL", pd::proto::AttrType::BOOLEAN) .value("BOOL", pd::proto::AttrType::BOOLEAN)
.value("BOOLS", pd::proto::AttrType::BOOLEANS) .value("BOOLS", pd::proto::AttrType::BOOLEANS)
.value("BLOCK", pd::proto::AttrType::BLOCK); .value("BLOCK", pd::proto::AttrType::BLOCK)
.value("BLOCKS", pd::proto::AttrType::BLOCKS);
pybind11::class_<pd::OpDesc> op_desc(*m, "OpDesc", ""); pybind11::class_<pd::OpDesc> op_desc(*m, "OpDesc", "");
op_desc op_desc
...@@ -293,6 +294,7 @@ void BindOpDesc(pybind11::module *m) { ...@@ -293,6 +294,7 @@ void BindOpDesc(pybind11::module *m) {
.def("set_attr", &pd::OpDesc::SetAttr) .def("set_attr", &pd::OpDesc::SetAttr)
.def("attr", &pd::OpDesc::GetAttr) .def("attr", &pd::OpDesc::GetAttr)
.def("set_block_attr", &pd::OpDesc::SetBlockAttr) .def("set_block_attr", &pd::OpDesc::SetBlockAttr)
.def("set_blocks_attr", &pd::OpDesc::SetBlocksAttr)
.def("set_serialized_attr", .def("set_serialized_attr",
[](pd::OpDesc &self, const std::string &name, [](pd::OpDesc &self, const std::string &name,
const pybind11::bytes &seriralized) { const pybind11::bytes &seriralized) {
......
...@@ -558,15 +558,20 @@ class Operator(object): ...@@ -558,15 +558,20 @@ class Operator(object):
if (attr_name not in self.attrs) or ( if (attr_name not in self.attrs) or (
self.attrs[attr_name] is None): self.attrs[attr_name] is None):
continue continue
if isinstance(self.attrs[attr_name], Block): attr_val = self.attrs[attr_name]
if isinstance(attr_val, Block):
self.desc.set_block_attr(attr_name, self.desc.set_block_attr(attr_name,
self.attrs[attr_name].desc) self.attrs[attr_name].desc)
elif isinstance(self.attrs[attr_name], core.BlockDesc) or \ elif isinstance(attr_val, list) and attr_val and \
isinstance(self.attrs[attr_name], core.ProgramDesc): all(isinstance(v, Block) for v in attr_val):
self.desc.set_blocks_attr(attr_name,
[v.desc for v in attr_val])
elif isinstance(attr_val, core.BlockDesc) or \
isinstance(attr_val, core.ProgramDesc):
self.desc.set_serialized_attr( self.desc.set_serialized_attr(
attr_name, self.attrs[attr_name].serialize_to_string()) attr_name, attr_val.serialize_to_string())
else: else:
self.desc.set_attr(attr_name, self.attrs[attr_name]) self.desc.set_attr(attr_name, attr_val)
self.desc.check_attrs() self.desc.check_attrs()
if self.has_kernel(type): if self.has_kernel(type):
self.desc.infer_var_type(self.block.desc) self.desc.infer_var_type(self.block.desc)
...@@ -715,6 +720,9 @@ class Operator(object): ...@@ -715,6 +720,9 @@ class Operator(object):
self.attrs[name] = val self.attrs[name] = val
if isinstance(val, Block): if isinstance(val, Block):
self.desc.set_block_attr(name, val.desc) self.desc.set_block_attr(name, val.desc)
elif isinstance(val, list) and val and all(
isinstance(v, Block) for v in val):
self.desc.set_blocks_attr(name, [v.desc for v in val])
elif isinstance(val, core.BlockDesc) or \ elif isinstance(val, core.BlockDesc) or \
isinstance(val, core.ProgramDesc): isinstance(val, core.ProgramDesc):
self.desc.set_serialized_attr(name, val.serialize_to_string()) self.desc.set_serialized_attr(name, val.serialize_to_string())
......
...@@ -186,7 +186,6 @@ class ListenAndServ(object): ...@@ -186,7 +186,6 @@ class ListenAndServ(object):
main_program = self.helper.main_program main_program = self.helper.main_program
current_block = main_program.current_block() current_block = main_program.current_block()
parent_block = self.parent_block() parent_block = self.parent_block()
empty_block = Program().global_block()
parent_block.append_op( parent_block.append_op(
type='listen_and_serv', type='listen_and_serv',
...@@ -195,8 +194,9 @@ class ListenAndServ(object): ...@@ -195,8 +194,9 @@ class ListenAndServ(object):
attrs={ attrs={
'endpoint': self.endpoint, 'endpoint': self.endpoint,
'Fanin': self.fan_in, 'Fanin': self.fan_in,
'OptimizeBlock': current_block, 'optimize_blocks': [
'PrefetchBlock': empty_block, current_block
], # did not support multiple optimize blocks in layers
'sync_mode': True, # did not support async now in layers 'sync_mode': True, # did not support async now in layers
'grad_to_block_id': [""] 'grad_to_block_id': [""]
}) })
......
...@@ -396,7 +396,7 @@ class DistributeTranspiler(object): ...@@ -396,7 +396,7 @@ class DistributeTranspiler(object):
return varname return varname
return "" return ""
def __clone_lr_op_sub_block__(op, program, new_block): def __clone_lr_op_sub_block__(op, program, lr_block):
if not op.has_attr('sub_block'): if not op.has_attr('sub_block'):
return return
...@@ -405,36 +405,41 @@ class DistributeTranspiler(object): ...@@ -405,36 +405,41 @@ class DistributeTranspiler(object):
assert isinstance(origin_block, Block) assert isinstance(origin_block, Block)
# we put the new sub block to new block to follow the block # we put the new sub block to new block to follow the block
# hierarchy of the original blocks # hierarchy of the original blocks
new_sub_block = program.create_block(new_block.idx) new_sub_block = program.create_block(lr_block.idx)
# clone vars # clone vars
for var in origin_block.vars: for var in origin_block.vars:
new_sub_block.clone_variable(var) new_sub_block.clone_variable(var)
# clone ops # clone ops
for op in origin_block.ops: for origin_op in origin_block.ops:
self._clone_lr_op(program, new_sub_block, op) cloned_op = self._clone_lr_op(program, new_sub_block, origin_op)
# clone sub_block of op # clone sub_block of op
__clone_lr_op_sub_block__(op, program, new_sub_block) __clone_lr_op_sub_block__(cloned_op, program, new_sub_block)
# reset the block of op # reset the block of op
op.set_attr('sub_block', new_sub_block) op.set_attr('sub_block', new_sub_block)
# append lr decay ops to the child block if exists # append lr decay ops to the child block if exists
lr_ops = self._get_lr_ops() lr_ops = self._get_lr_ops()
# record optimize blocks and we can run them on pserver parallel
optimize_blocks = []
if len(lr_ops) > 0: if len(lr_ops) > 0:
lr_decay_block = pserver_program.create_block( lr_decay_block = pserver_program.create_block(
pserver_program.num_blocks - 1) pserver_program.num_blocks - 1)
optimize_blocks.append(lr_decay_block)
for _, op in enumerate(lr_ops): for _, op in enumerate(lr_ops):
self._append_pserver_non_opt_ops(lr_decay_block, op) cloned_op = self._append_pserver_non_opt_ops(lr_decay_block, op)
# append sub blocks to pserver_program in lr_decay_op # append sub blocks to pserver_program in lr_decay_op
__clone_lr_op_sub_block__(op, pserver_program, lr_decay_block) __clone_lr_op_sub_block__(cloned_op, pserver_program,
lr_decay_block)
# append op to the current block # append op to the current block
grad_to_block_id = [] grad_to_block_id = []
pre_block_idx = pserver_program.num_blocks - 1 pre_block_idx = pserver_program.num_blocks - 1
for idx, opt_op in enumerate(opt_op_on_pserver): for idx, opt_op in enumerate(opt_op_on_pserver):
per_opt_block = pserver_program.create_block(pre_block_idx) per_opt_block = pserver_program.create_block(pre_block_idx)
optimize_blocks.append(per_opt_block)
# append grad merging ops before clip and weight decay # append grad merging ops before clip and weight decay
for _, op in enumerate(self.optimize_ops): for _, op in enumerate(self.optimize_ops):
# find the origin @GRAD var before clipping # find the origin @GRAD var before clipping
...@@ -453,6 +458,7 @@ class DistributeTranspiler(object): ...@@ -453,6 +458,7 @@ class DistributeTranspiler(object):
if global_ops: if global_ops:
opt_state_block = pserver_program.create_block( opt_state_block = pserver_program.create_block(
pserver_program.num_blocks - 1) pserver_program.num_blocks - 1)
optimize_blocks.append(opt_state_block)
for glb_op in global_ops: for glb_op in global_ops:
__append_optimize_op__(glb_op, opt_state_block, __append_optimize_op__(glb_op, opt_state_block,
grad_to_block_id, None) grad_to_block_id, None)
...@@ -474,11 +480,11 @@ class DistributeTranspiler(object): ...@@ -474,11 +480,11 @@ class DistributeTranspiler(object):
assert len(prefetch_var_name_to_block_id) == 0 assert len(prefetch_var_name_to_block_id) == 0
attrs = { attrs = {
"OptimizeBlock": pserver_program.block(1), "optimize_blocks": optimize_blocks,
"endpoint": endpoint, "endpoint": endpoint,
"Fanin": self.trainer_num, "Fanin": self.trainer_num,
"sync_mode": self.sync_mode, "sync_mode": self.sync_mode,
"grad_to_block_id": grad_to_block_id "grad_to_block_id": grad_to_block_id,
} }
if len(prefetch_var_name_to_block_id) > 0: if len(prefetch_var_name_to_block_id) > 0:
attrs['prefetch_var_name_to_block_id'] \ attrs['prefetch_var_name_to_block_id'] \
...@@ -1211,7 +1217,7 @@ class DistributeTranspiler(object): ...@@ -1211,7 +1217,7 @@ class DistributeTranspiler(object):
if var not in program.global_block().vars: if var not in program.global_block().vars:
block.clone_variable(var) block.clone_variable(var)
block.append_op( return block.append_op(
type=op.type, inputs=inputs, outputs=outputs, attrs=op.attrs) type=op.type, inputs=inputs, outputs=outputs, attrs=op.attrs)
def _append_pserver_non_opt_ops(self, optimize_block, opt_op): def _append_pserver_non_opt_ops(self, optimize_block, opt_op):
...@@ -1249,7 +1255,7 @@ class DistributeTranspiler(object): ...@@ -1249,7 +1255,7 @@ class DistributeTranspiler(object):
elif not program.global_block().vars.has_key(var.name): elif not program.global_block().vars.has_key(var.name):
program.global_block().clone_variable(var) program.global_block().clone_variable(var)
optimize_block.append_op( return optimize_block.append_op(
type=opt_op.type, type=opt_op.type,
inputs=inputs, inputs=inputs,
outputs=outputs, outputs=outputs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册