// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include "lite/model_parser/base/block_desc.h" #include "lite/model_parser/flatbuffers/framework_generated.h" #include "lite/model_parser/flatbuffers/op_desc.h" #include "lite/model_parser/flatbuffers/var_desc.h" #include "lite/utils/all.h" namespace paddle { namespace lite { namespace fbs { class BlockDescView : public BlockDescAPI { public: explicit BlockDescView(proto::BlockDesc const* desc) : desc_(desc) { CHECK(desc_); vars_.reserve(VarsSize()); ops_.reserve(OpsSize()); for (size_t idx = 0; idx < VarsSize(); ++idx) { vars_.push_back(VarDescView(desc_->vars()->Get(idx))); } for (size_t idx = 0; idx < OpsSize(); ++idx) { ops_.push_back(OpDescView(desc_->ops()->Get(idx))); } } int32_t Idx() const override { return desc_->idx(); } int32_t ParentIdx() const override { return desc_->parent_idx(); } size_t VarsSize() const override { return desc_->vars()->size(); } template T const* GetVar(int32_t idx) const; template T* GetVar(int32_t idx) { NotImplemented(); return nullptr; } size_t OpsSize() const override { CHECK(desc_); CHECK(desc_->ops()); return desc_->ops()->size(); } template T const* GetOp(int32_t idx) const; template T* GetOp(int32_t idx) { NotImplemented(); return nullptr; } const std::vector& GetVars() const { return vars_; } int32_t ForwardBlockIdx() const override { return desc_->forward_block_idx(); } BlockDescView() { NotImplemented(); } private: proto::BlockDesc const* desc_; // not_own std::vector vars_; std::vector ops_; private: void NotImplemented() const { LOG(FATAL) << "The additional interfaces of BlockDescView is temporarily " "unavailable in read-only mode."; } }; class BlockDesc : public BlockDescAPI { public: BlockDesc() : owned_(true), desc_(new proto::BlockDescT()) {} explicit BlockDesc(proto::BlockDescT* desc) : desc_(desc) { CHECK(desc_); } int32_t Idx() const override { return desc_->idx; } void SetIdx(int32_t idx) override { desc_->idx = idx; } int32_t ParentIdx() const override { return desc_->parent_idx; } void SetParentIdx(int32_t idx) override { desc_->parent_idx = idx; } size_t VarsSize() const override { return desc_->vars.size(); } void ClearVars() override { desc_->vars.clear(); SyncVars(); } size_t OpsSize() const override { return desc_->ops.size(); } void ClearOps() override { desc_->ops.clear(); SyncOps(); } int32_t ForwardBlockIdx() const override { return desc_->forward_block_idx; } void SetForwardBlockIdx(int32_t idx_in) override { desc_->forward_block_idx = idx_in; } proto::BlockDescT* raw_desc() { return desc_; } template T* GetVar(int32_t idx); template T* AddVar(); template T* GetOp(int32_t idx); template T* AddOp(); ~BlockDesc() { if (owned_) { delete desc_; } } private: void SyncVars() { vars_.resize(desc_->vars.size()); for (size_t i = 0; i < desc_->vars.size(); ++i) { if (vars_[i].raw_desc() != desc_->vars[i].get()) { vars_[i] = VarDesc(desc_->vars[i].get()); } } } void SyncOps() { ops_.resize(desc_->ops.size()); for (size_t i = 0; i < desc_->ops.size(); ++i) { if (ops_[i].raw_desc() != desc_->ops[i].get()) { ops_[i] = OpDesc(desc_->ops[i].get()); } } } bool owned_{false}; proto::BlockDescT* desc_{nullptr}; std::vector vars_; std::vector ops_; }; } // namespace fbs } // namespace lite } // namespace paddle