diff --git a/paddle/fluid/framework/block_desc.cc b/paddle/fluid/framework/block_desc.cc index fd409ed4c0f7a504686765909e9c71692aab8824..e7842e9b8130d35e511e02dfb1dc27f307d17f38 100644 --- a/paddle/fluid/framework/block_desc.cc +++ b/paddle/fluid/framework/block_desc.cc @@ -200,7 +200,7 @@ BlockDesc::BlockDesc(ProgramDesc *prog, proto::BlockDesc *desc) vars_[var_desc.name()].reset(new VarDesc(var_desc)); } for (const proto::OpDesc &op_desc : desc_->ops()) { - ops_.emplace_back(new OpDesc(op_desc, prog, this)); + ops_.emplace_back(new OpDesc(op_desc, this)); } } @@ -209,7 +209,7 @@ BlockDesc::BlockDesc(const BlockDesc &other, proto::BlockDesc *desc, : prog_(prog), desc_(desc) { need_update_ = true; for (auto &op : other.ops_) { - ops_.emplace_back(new OpDesc(*op->Proto(), prog, this)); + ops_.emplace_back(new OpDesc(*op, this)); } for (auto &it : other.vars_) { auto *var = new VarDesc(*it.second); diff --git a/paddle/fluid/framework/block_desc.h b/paddle/fluid/framework/block_desc.h index 600601669c5d56a3ffc2fb9c804ffad5fde58f0b..189dd6c52f85b5bf623b98c64c07c0c7269505d4 100644 --- a/paddle/fluid/framework/block_desc.h +++ b/paddle/fluid/framework/block_desc.h @@ -105,7 +105,7 @@ class BlockDesc { size_t OpSize() const { return ops_.size(); } - OpDesc *Op(int idx) { return ops_.at(idx).get(); } + OpDesc *Op(int idx) const { return ops_.at(idx).get(); } void Flush(); diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index 09b67e5a1741c68c5f5487340e8fc86ff31e00a4..f92769192c218eb7cdc2350ff6e4721b45005806 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -103,7 +103,7 @@ void OpDesc::CopyFrom(const OpDesc &op_desc) { need_update_ = true; } -OpDesc::OpDesc(const proto::OpDesc &desc, ProgramDesc *prog, BlockDesc *block) +OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block) : desc_(desc), need_update_(false) { // restore inputs_ int input_size = desc_.inputs_size(); diff --git a/paddle/fluid/framework/op_desc.h b/paddle/fluid/framework/op_desc.h index 1a330db7cc5555a939950043ac90a321573b292d..a02d3e269129596f65a2fb346e76c1af7fbead95 100644 --- a/paddle/fluid/framework/op_desc.h +++ b/paddle/fluid/framework/op_desc.h @@ -33,13 +33,14 @@ class OpDesc { OpDesc(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const AttributeMap &attrs); - OpDesc(const proto::OpDesc &desc, ProgramDesc *prog, BlockDesc *block); + OpDesc(const proto::OpDesc &desc, BlockDesc *block); explicit OpDesc(BlockDesc *block) : block_(block) {} OpDesc(const OpDesc &other, BlockDesc *block) { *this = other; block_ = block; + need_update_ = true; } void CopyFrom(const OpDesc &op_desc); diff --git a/paddle/fluid/framework/program_desc.cc b/paddle/fluid/framework/program_desc.cc index 64fb028f83a539d17885186d5d8ee6ef26f095e9..1e01a6e900404990e16674755367d2fc6d832725 100644 --- a/paddle/fluid/framework/program_desc.cc +++ b/paddle/fluid/framework/program_desc.cc @@ -51,12 +51,15 @@ ProgramDesc::ProgramDesc(const ProgramDesc &o) { auto *block = desc_.mutable_blocks(i); blocks_.emplace_back(new BlockDesc(*o.blocks_[i], block, this)); } - for (auto &block : blocks_) { - for (auto *op : block->AllOps()) { - for (const auto &attr : op->Proto()->attrs()) { - if (attr.type() == proto::AttrType::BLOCK) { - size_t blk_idx = attr.block_idx(); - op->SetBlockAttr(attr.name(), this->MutableBlock(blk_idx)); + for (size_t block_id = 0; block_id < blocks_.size(); ++block_id) { + auto all_ops = blocks_[block_id]->AllOps(); + for (size_t op_id = 0; op_id < all_ops.size(); ++op_id) { + auto &op = all_ops[op_id]; + for (const std::string &attr_name : op->AttrNames()) { + if (op->GetAttrType(attr_name) == proto::AttrType::BLOCK) { + int sub_block_id = + o.Block(block_id).Op(op_id)->GetBlockAttr(attr_name); + op->SetBlockAttr(attr_name, MutableBlock(sub_block_id)); } } } @@ -86,6 +89,16 @@ ProgramDesc::ProgramDesc(const std::string &binary_str) { for (auto &block_desc : *desc_.mutable_blocks()) { blocks_.emplace_back(new BlockDesc(this, &block_desc)); } + for (auto &block : blocks_) { + for (auto *op : block->AllOps()) { + for (const auto &attr : op->Proto()->attrs()) { + if (attr.type() == proto::AttrType::BLOCK) { + size_t blk_idx = attr.block_idx(); + op->SetBlockAttr(attr.name(), this->MutableBlock(blk_idx)); + } + } + } + } } const std::vector ProgramDesc::GetFeedTargetNames() { diff --git a/paddle/fluid/inference/tensorrt/convert/activation_op.cc b/paddle/fluid/inference/tensorrt/convert/activation_op.cc index 6297051e5a30f1daa512d25d5aa3ab3b2f79f1d1..79d01b640a214ed5eb86173a36d5e85a6626066f 100644 --- a/paddle/fluid/inference/tensorrt/convert/activation_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/activation_op.cc @@ -24,7 +24,7 @@ class ReluOpConverter : public OpConverter { void operator()(const framework::proto::OpDesc& op) override { // Here the two nullptr looks strange, that's because the // framework::OpDesc's constructor is strange. - framework::OpDesc op_desc(op, nullptr, nullptr); + framework::OpDesc op_desc(op, nullptr); LOG(INFO) << "convert a fluid relu op to tensorrt activation layer whose " "type is Relu"; const nvinfer1::ITensor* input_tensor = diff --git a/paddle/fluid/inference/tensorrt/convert/mul_op.cc b/paddle/fluid/inference/tensorrt/convert/mul_op.cc index ed09f54bde00d12aaec829ba90cc08ebfef57e92..aa8e66490f7e40038b0de4da32655f1b168ca332 100644 --- a/paddle/fluid/inference/tensorrt/convert/mul_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/mul_op.cc @@ -27,7 +27,7 @@ class MulOpConverter : public OpConverter { void operator()(const framework::proto::OpDesc& op) override { VLOG(4) << "convert a fluid mul op to tensorrt fc layer without bias"; - framework::OpDesc op_desc(op, nullptr, nullptr); + framework::OpDesc op_desc(op, nullptr); // Declare inputs auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]); auto* input2 = engine_->GetITensor(op_desc.Input("Y")[0]); diff --git a/paddle/fluid/inference/tensorrt/convert/ut_helper.h b/paddle/fluid/inference/tensorrt/convert/ut_helper.h index e46c577cdae145c0d4ceb6bfa307f03d313514ce..684bbc208fc1cb02d2a36b4de720309ea6bed173 100644 --- a/paddle/fluid/inference/tensorrt/convert/ut_helper.h +++ b/paddle/fluid/inference/tensorrt/convert/ut_helper.h @@ -104,7 +104,7 @@ class TRTConvertValidation { engine_->FreezeNetwork(); // Declare outputs. - op_desc_.reset(new framework::OpDesc(desc, nullptr, nullptr)); + op_desc_.reset(new framework::OpDesc(desc, nullptr)); // Set Inputs. for (const auto& input : op_desc_->InputArgumentNames()) {