未验证 提交 d6997e5b 编写于 作者: F fengjiayi 提交者: GitHub

Merge pull request #11083 from JiayiFeng/dev_refine_programdesc_copy

Refine ProgramDesc copy
...@@ -200,7 +200,7 @@ BlockDesc::BlockDesc(ProgramDesc *prog, proto::BlockDesc *desc) ...@@ -200,7 +200,7 @@ BlockDesc::BlockDesc(ProgramDesc *prog, proto::BlockDesc *desc)
vars_[var_desc.name()].reset(new VarDesc(var_desc)); vars_[var_desc.name()].reset(new VarDesc(var_desc));
} }
for (const proto::OpDesc &op_desc : desc_->ops()) { for (const proto::OpDesc &op_desc : desc_->ops()) {
ops_.emplace_back(new OpDesc(op_desc, prog, this)); ops_.emplace_back(new OpDesc(op_desc, this));
} }
} }
...@@ -209,7 +209,7 @@ BlockDesc::BlockDesc(const BlockDesc &other, proto::BlockDesc *desc, ...@@ -209,7 +209,7 @@ BlockDesc::BlockDesc(const BlockDesc &other, proto::BlockDesc *desc,
: prog_(prog), desc_(desc) { : prog_(prog), desc_(desc) {
need_update_ = true; need_update_ = true;
for (auto &op : other.ops_) { for (auto &op : other.ops_) {
ops_.emplace_back(new OpDesc(*op->Proto(), prog, this)); ops_.emplace_back(new OpDesc(*op, this));
} }
for (auto &it : other.vars_) { for (auto &it : other.vars_) {
auto *var = new VarDesc(*it.second); auto *var = new VarDesc(*it.second);
......
...@@ -105,7 +105,7 @@ class BlockDesc { ...@@ -105,7 +105,7 @@ class BlockDesc {
size_t OpSize() const { return ops_.size(); } size_t OpSize() const { return ops_.size(); }
OpDesc *Op(int idx) { return ops_.at(idx).get(); } OpDesc *Op(int idx) const { return ops_.at(idx).get(); }
void Flush(); void Flush();
......
...@@ -103,7 +103,7 @@ void OpDesc::CopyFrom(const OpDesc &op_desc) { ...@@ -103,7 +103,7 @@ void OpDesc::CopyFrom(const OpDesc &op_desc) {
need_update_ = true; need_update_ = true;
} }
OpDesc::OpDesc(const proto::OpDesc &desc, ProgramDesc *prog, BlockDesc *block) OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block)
: desc_(desc), need_update_(false) { : desc_(desc), need_update_(false) {
// restore inputs_ // restore inputs_
int input_size = desc_.inputs_size(); int input_size = desc_.inputs_size();
......
...@@ -33,13 +33,14 @@ class OpDesc { ...@@ -33,13 +33,14 @@ class OpDesc {
OpDesc(const std::string &type, const VariableNameMap &inputs, OpDesc(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs); const VariableNameMap &outputs, const AttributeMap &attrs);
OpDesc(const proto::OpDesc &desc, ProgramDesc *prog, BlockDesc *block); OpDesc(const proto::OpDesc &desc, BlockDesc *block);
explicit OpDesc(BlockDesc *block) : block_(block) {} explicit OpDesc(BlockDesc *block) : block_(block) {}
OpDesc(const OpDesc &other, BlockDesc *block) { OpDesc(const OpDesc &other, BlockDesc *block) {
*this = other; *this = other;
block_ = block; block_ = block;
need_update_ = true;
} }
void CopyFrom(const OpDesc &op_desc); void CopyFrom(const OpDesc &op_desc);
......
...@@ -51,12 +51,15 @@ ProgramDesc::ProgramDesc(const ProgramDesc &o) { ...@@ -51,12 +51,15 @@ ProgramDesc::ProgramDesc(const ProgramDesc &o) {
auto *block = desc_.mutable_blocks(i); auto *block = desc_.mutable_blocks(i);
blocks_.emplace_back(new BlockDesc(*o.blocks_[i], block, this)); blocks_.emplace_back(new BlockDesc(*o.blocks_[i], block, this));
} }
for (auto &block : blocks_) { for (size_t block_id = 0; block_id < blocks_.size(); ++block_id) {
for (auto *op : block->AllOps()) { auto all_ops = blocks_[block_id]->AllOps();
for (const auto &attr : op->Proto()->attrs()) { for (size_t op_id = 0; op_id < all_ops.size(); ++op_id) {
if (attr.type() == proto::AttrType::BLOCK) { auto &op = all_ops[op_id];
size_t blk_idx = attr.block_idx(); for (const std::string &attr_name : op->AttrNames()) {
op->SetBlockAttr(attr.name(), this->MutableBlock(blk_idx)); if (op->GetAttrType(attr_name) == proto::AttrType::BLOCK) {
int sub_block_id =
o.Block(block_id).Op(op_id)->GetBlockAttr(attr_name);
op->SetBlockAttr(attr_name, MutableBlock(sub_block_id));
} }
} }
} }
...@@ -86,6 +89,16 @@ ProgramDesc::ProgramDesc(const std::string &binary_str) { ...@@ -86,6 +89,16 @@ ProgramDesc::ProgramDesc(const std::string &binary_str) {
for (auto &block_desc : *desc_.mutable_blocks()) { for (auto &block_desc : *desc_.mutable_blocks()) {
blocks_.emplace_back(new BlockDesc(this, &block_desc)); blocks_.emplace_back(new BlockDesc(this, &block_desc));
} }
for (auto &block : blocks_) {
for (auto *op : block->AllOps()) {
for (const auto &attr : op->Proto()->attrs()) {
if (attr.type() == proto::AttrType::BLOCK) {
size_t blk_idx = attr.block_idx();
op->SetBlockAttr(attr.name(), this->MutableBlock(blk_idx));
}
}
}
}
} }
const std::vector<std::string> ProgramDesc::GetFeedTargetNames() { const std::vector<std::string> ProgramDesc::GetFeedTargetNames() {
......
...@@ -24,7 +24,7 @@ class ReluOpConverter : public OpConverter { ...@@ -24,7 +24,7 @@ class ReluOpConverter : public OpConverter {
void operator()(const framework::proto::OpDesc& op) override { void operator()(const framework::proto::OpDesc& op) override {
// Here the two nullptr looks strange, that's because the // Here the two nullptr looks strange, that's because the
// framework::OpDesc's constructor is strange. // framework::OpDesc's constructor is strange.
framework::OpDesc op_desc(op, nullptr, nullptr); framework::OpDesc op_desc(op, nullptr);
LOG(INFO) << "convert a fluid relu op to tensorrt activation layer whose " LOG(INFO) << "convert a fluid relu op to tensorrt activation layer whose "
"type is Relu"; "type is Relu";
const nvinfer1::ITensor* input_tensor = const nvinfer1::ITensor* input_tensor =
......
...@@ -27,7 +27,7 @@ class MulOpConverter : public OpConverter { ...@@ -27,7 +27,7 @@ class MulOpConverter : public OpConverter {
void operator()(const framework::proto::OpDesc& op) override { void operator()(const framework::proto::OpDesc& op) override {
VLOG(4) << "convert a fluid mul op to tensorrt fc layer without bias"; VLOG(4) << "convert a fluid mul op to tensorrt fc layer without bias";
framework::OpDesc op_desc(op, nullptr, nullptr); framework::OpDesc op_desc(op, nullptr);
// Declare inputs // Declare inputs
auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]); auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]);
auto* input2 = engine_->GetITensor(op_desc.Input("Y")[0]); auto* input2 = engine_->GetITensor(op_desc.Input("Y")[0]);
......
...@@ -104,7 +104,7 @@ class TRTConvertValidation { ...@@ -104,7 +104,7 @@ class TRTConvertValidation {
engine_->FreezeNetwork(); engine_->FreezeNetwork();
// Declare outputs. // Declare outputs.
op_desc_.reset(new framework::OpDesc(desc, nullptr, nullptr)); op_desc_.reset(new framework::OpDesc(desc, nullptr));
// Set Inputs. // Set Inputs.
for (const auto& input : op_desc_->InputArgumentNames()) { for (const auto& input : op_desc_->InputArgumentNames()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册