提交 6f0e630c 编写于 作者: K Kexin Zhao

fix prune and program desc constructor

上级 c3d27b15
...@@ -155,6 +155,8 @@ BlockDesc::BlockDesc(ProgramDesc *prog, proto::BlockDesc *desc) ...@@ -155,6 +155,8 @@ BlockDesc::BlockDesc(ProgramDesc *prog, proto::BlockDesc *desc)
for (const proto::OpDesc &op_desc : desc_->ops()) { for (const proto::OpDesc &op_desc : desc_->ops()) {
ops_.emplace_back(new OpDesc(op_desc, prog, this)); ops_.emplace_back(new OpDesc(op_desc, prog, this));
} }
std::cout << "Constructed block idx " << desc->idx() << " from protobuf str"
<< std::endl;
} }
BlockDesc::BlockDesc(const BlockDesc &other, proto::BlockDesc *desc, BlockDesc::BlockDesc(const BlockDesc &other, proto::BlockDesc *desc,
......
...@@ -124,11 +124,24 @@ OpDesc::OpDesc(const proto::OpDesc &desc, ProgramDesc *prog, BlockDesc *block) ...@@ -124,11 +124,24 @@ OpDesc::OpDesc(const proto::OpDesc &desc, ProgramDesc *prog, BlockDesc *block)
// restore attrs_ // restore attrs_
for (const proto::OpDesc::Attr &attr : desc_.attrs()) { for (const proto::OpDesc::Attr &attr : desc_.attrs()) {
std::string attr_name = attr.name(); std::string attr_name = attr.name();
// we use a trick to handle attr.type() is BLOCK here, because at this
// moment the sub_block hasn't beed added to ProgramDesc's vector<Block>
// so we cast the block_idx to a dummy BlockDesc pointer
if (attr.type() != proto::AttrType::BLOCK) { if (attr.type() != proto::AttrType::BLOCK) {
attrs_[attr_name] = GetAttrValue(attr); attrs_[attr_name] = GetAttrValue(attr);
} else { } else {
auto bid = attr.block_idx(); size_t blk_idx = attr.block_idx();
attrs_[attr_name] = prog->MutableBlock(bid); if (blk_idx < prog->Size()) {
attrs_[attr_name] = prog->MutableBlock(blk_idx);
} else {
std::cout << "Setting blockdesc attribute for id " << blk_idx
<< std::endl;
attrs_[attr_name] = reinterpret_cast<BlockDesc *>(blk_idx);
std::cout << "Testing reinterpret_cast result is "
<< reinterpret_cast<size_t>(
boost::get<BlockDesc *>(attrs_[attr_name]))
<< std::endl;
}
} }
} }
this->block_ = block; this->block_ = block;
......
...@@ -52,9 +52,27 @@ ProgramDesc::ProgramDesc(const ProgramDesc &o) { ...@@ -52,9 +52,27 @@ ProgramDesc::ProgramDesc(const ProgramDesc &o) {
ProgramDesc::ProgramDesc(const proto::ProgramDesc &desc) { ProgramDesc::ProgramDesc(const proto::ProgramDesc &desc) {
desc_ = desc; desc_ = desc;
std::cout << std::endl << "starting in ProgDesc constructor" << std::endl;
for (auto &block_desc : *desc_.mutable_blocks()) { for (auto &block_desc : *desc_.mutable_blocks()) {
blocks_.emplace_back(new BlockDesc(this, &block_desc)); blocks_.emplace_back(new BlockDesc(this, &block_desc));
std::cout << "Done constructing block idx " << block_desc.idx()
<< " parent idx " << block_desc.parent_idx() << std::endl;
} }
for (auto &block : blocks_) {
for (auto *op : block->AllOps()) {
for (auto &name : op->AttrNames()) {
if (op->GetAttrType(name) == proto::AttrType::BLOCK) {
auto attr = op->GetAttr(name);
size_t blk_idx =
reinterpret_cast<size_t>(boost::get<BlockDesc *>(attr));
op->SetBlockAttr(name, *this->MutableBlock(blk_idx));
std::cout << "Update attr name " << name << " for block idx "
<< blk_idx << std::endl;
}
}
}
}
std::cout << "Done ProgDesc construction" << std::endl << std::endl;
} }
ProgramDesc::ProgramDesc(const std::string &binary_str) { ProgramDesc::ProgramDesc(const std::string &binary_str) {
......
...@@ -109,15 +109,14 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, ...@@ -109,15 +109,14 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
// we reverse the should_run vector // we reverse the should_run vector
std::reverse(should_run.begin(), should_run.end()); std::reverse(should_run.begin(), should_run.end());
//*output = input;
// copy the current block from input to output // copy the current block from input to output
auto* block_field = output->mutable_blocks(); auto* block_field = output->mutable_blocks();
*block_field->Add() = input.blocks(block_id); *block_field->Add() = input.blocks(block_id);
int output_block_id = output->blocks_size() - 1; int output_block_id = output->blocks_size() - 1;
auto* output_block = output->mutable_blocks(output_block_id); auto* output_block = output->mutable_blocks(output_block_id);
output_block->set_idx = output_block_id; output_block->set_idx(output_block_id);
output_block->set_parent_idx = parent_block_id; output_block->set_parent_idx(parent_block_id);
auto* op_field = output_block->mutable_ops(); auto* op_field = output_block->mutable_ops();
op_field->Clear(); op_field->Clear();
...@@ -128,17 +127,18 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, ...@@ -128,17 +127,18 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
if (HasSubBlock(*op)) { if (HasSubBlock(*op)) {
// create sub_block_dependent_vars here to help prune the sub block // create sub_block_dependent_vars here to help prune the sub block
std::set<std::string> sub_block_dependent_vars; std::set<std::string> sub_block_dependent_vars;
for (auto& var : op.inputs()) { for (auto& var : op->inputs()) {
for (auto& argu : var.arguments()) { for (auto& argu : var.arguments()) {
sub_block_dependent_vars.insert(argu); sub_block_dependent_vars.insert(argu);
} }
} }
for (auto& var : op.outputs()) { for (auto& var : op->outputs()) {
for (auto& argu : var.arguments()) { for (auto& argu : var.arguments()) {
sub_block_dependent_vars.insert(argu); sub_block_dependent_vars.insert(argu);
} }
} }
std::cout << "pruning the next block, the current output_block_id is "
<< output_block_id << std::endl;
// GetSubBlockIndex(*op) is the idx of the sub_block in the input desc // GetSubBlockIndex(*op) is the idx of the sub_block in the input desc
// output_block_id is the idx of the current block in the output desc // output_block_id is the idx of the current block in the output desc
prune_impl(input, output, GetSubBlockIndex(*op), output_block_id, prune_impl(input, output, GetSubBlockIndex(*op), output_block_id,
...@@ -147,6 +147,8 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, ...@@ -147,6 +147,8 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
} }
} }
std::cout << "Starting to remove unreferenced variables"
<< " for block idx " << output_block_id << std::endl;
// remove the VarDescs in BlockDesc that are not referenced in // remove the VarDescs in BlockDesc that are not referenced in
// the pruned OpDescs // the pruned OpDescs
std::unordered_map<std::string, proto::VarDesc> var_map; std::unordered_map<std::string, proto::VarDesc> var_map;
...@@ -155,28 +157,38 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, ...@@ -155,28 +157,38 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
var_map[var.name()] = var; var_map[var.name()] = var;
} }
var_field->Clear(); std::set<std::string> var_names;
for (const auto& op : *op_field) { for (const auto& op : *op_field) {
// add VarDescs of all input arguments for each OpDesc
auto& input_field = op.inputs(); auto& input_field = op.inputs();
for (auto& input_var : input_field) { for (auto& input_var : input_field) {
for (auto& arg : input_var.arguments()) { for (auto& arg : input_var.arguments()) {
*var_field->Add() = var_map.at(arg); if (var_map.count(arg) != 0) {
var_names.insert(arg);
}
} }
} }
// add VarDescs of all output arguments for each OpDesc
auto& output_field = op.outputs(); auto& output_field = op.outputs();
for (auto& output_var : output_field) { for (auto& output_var : output_field) {
for (auto& arg : output_var.arguments()) { for (auto& arg : output_var.arguments()) {
*var_field->Add() = var_map.at(arg); if (var_map.count(arg) != 0) {
var_names.insert(arg);
}
} }
} }
} }
var_field->Clear();
for (const auto& name : var_names) {
*var_field->Add() = var_map[name];
}
} }
// TODO(fengjiayi): Prune() could be inplaced to avoid unnecessary copies // TODO(fengjiayi): Prune() could be inplaced to avoid unnecessary copies
void Prune(const proto::ProgramDesc& input, proto::ProgramDesc* output) { void Prune(const proto::ProgramDesc& input, proto::ProgramDesc* output) {
prune_impl(input, output, 0, -1, {}); std::set<std::string> dependent_vars;
std::cout << std::endl << "Start C++ framework::prune" << std::endl;
prune_impl(input, output, 0, -1, dependent_vars);
std::cout << "Finished C++ framework::prune" << std::endl << std::endl;
} }
void inference_optimize_impl(const proto::ProgramDesc& input, void inference_optimize_impl(const proto::ProgramDesc& input,
......
...@@ -342,6 +342,12 @@ def save_inference_model(dirname, ...@@ -342,6 +342,12 @@ def save_inference_model(dirname,
prepend_feed_ops(inference_program, feeded_var_names) prepend_feed_ops(inference_program, feeded_var_names)
append_fetch_ops(inference_program, fetch_var_names) append_fetch_ops(inference_program, fetch_var_names)
# save for checking
curstr = inference_program.to_string(True)
f = open("save_inf_prog_after_feed_fetch.txt", 'w')
f.write(curstr)
f.close()
model_file_name = dirname + "/__model__" model_file_name = dirname + "/__model__"
with open(model_file_name, "wb") as f: with open(model_file_name, "wb") as f:
f.write(inference_program.desc.serialize_to_string()) f.write(inference_program.desc.serialize_to_string())
......
...@@ -197,14 +197,15 @@ def train(save_dirname=None): ...@@ -197,14 +197,15 @@ def train(save_dirname=None):
" avg_cost=" + str(avg_cost_val)) " avg_cost=" + str(avg_cost_val))
if batch_id > 3: if batch_id > 3:
if save_dirname is not None: if save_dirname is not None:
fluid.io.save_inference_model(save_dirname, [ fluid.io.save_inference_model(
'source_sequence', 'target_sequence', 'label_sequence' save_dirname, ['source_sequence',
], [prediction], exe) 'target_sequence'], [prediction], exe)
return
exit(0) exit(0)
batch_id += 1 batch_id += 1
def inference(save_dirname=None): def infer(save_dirname=None):
if save_dirname is None: if save_dirname is None:
return return
...@@ -221,24 +222,32 @@ def inference(save_dirname=None): ...@@ -221,24 +222,32 @@ def inference(save_dirname=None):
data = [[0, 1, 0, 1], [0, 1, 1, 0, 0, 1]] data = [[0, 1, 0, 1], [0, 1, 1, 0, 0, 1]]
word_data = to_lodtensor(data, place) word_data = to_lodtensor(data, place)
trg_word = to_lodtensor(data, place) trg_word = to_lodtensor(data, place)
trg_word_next = to_lodtensor(data, place)
# Construct feed as a dictionary of {feed_target_name: feed_target_data} # Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets. # and results will contain a list of data corresponding to fetch_targets.
print("Print feed fetch target names as follows")
print(feed_target_names) print(feed_target_names)
assert feed_target_names[0] == 'source_sequence' assert feed_target_names[0] == 'source_sequence'
assert feed_target_names[1] == 'target_sequence' assert feed_target_names[1] == 'target_sequence'
assert feed_target_names[2] == 'label_sequence' print([var.name for var in fetch_targets])
# save for checking
curstr = inference_program.to_string(True)
f = open("loaded_infer_prog.txt", 'w')
f.write(curstr)
f.close()
results = exe.run(inference_program, results = exe.run(inference_program,
feed={ feed={
feed_target_names[0]: word_data, feed_target_names[0]: word_data,
feed_target_names[1]: trg_word, feed_target_names[1]: trg_word,
feed_target_names[2]: trg_word_next
}, },
fetch_list=fetch_targets) fetch_list=fetch_targets,
return_numpy=False)
print("Inference Shape: ", results[0].shape) print(results[0].lod())
print("infer results: ", results[0]) np_data = np.array(results[0])
print("Inference shape: ", np_data.shape)
print("Inference results: ", np_data)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册