program_desc.h 2.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <memory>
18 19
#include <utility>
#include <vector>
20
#include "lite/model_parser/base/program_desc.h"
21
#include "lite/model_parser/flatbuffers/block_desc.h"
22 23 24 25 26 27 28
#include "lite/model_parser/flatbuffers/framework_generated.h"
#include "lite/utils/all.h"

namespace paddle {
namespace lite {
namespace fbs {

29
class ProgramDesc : public ProgramDescAPI {
30
 public:
31
  ProgramDesc() = default;
32 33 34
  explicit ProgramDesc(std::unique_ptr<const char[]> buf) {
    Init(std::move(buf));
  }
35 36 37

  size_t BlocksSize() const override { return desc_->blocks()->size(); }

38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
  void Init(std::unique_ptr<const char[]> buf) {
    CHECK(buf.get() != nullptr);
    buf_ = std::move(buf);
    desc_ = proto::GetProgramDesc(buf_.get());
    blocks_.reserve(BlocksSize());
    for (size_t idx = 0; idx < BlocksSize(); ++idx) {
      blocks_.push_back(BlockDesc(desc_->blocks()->Get(idx)));
    }
  }

  void CopyFrom(const ProgramDesc& other) {
    size_t length = strlen(static_cast<const char*>(other.raw_buf()));
    std::unique_ptr<char[]> buf(new char[length]);
    memcpy(buf.get(), other.raw_buf(), length);
    Init(std::move(buf));
  }

55
  template <typename T>
56
  T const* GetBlock(int32_t idx) const;
57 58

  template <typename T>
59 60 61
  T* GetBlock(int32_t idx) {
    NotImplemented();
    return nullptr;
62 63
  }

64 65
  const std::vector<BlockDesc>& GetBlocks() const { return blocks_; }

66 67 68 69 70 71 72
  bool HasVersion() const override { return desc_->version() != nullptr; }

  int64_t Version() const override {
    CHECK(HasVersion());
    return desc_->version()->version();
  }

73 74 75 76
  proto::ProgramDesc const* raw_desc() const { return desc_; }

  const void* raw_buf() const { return buf_.get(); }

77
 private:
78 79 80 81 82 83 84 85 86 87 88
  proto::ProgramDesc const* desc_;
  std::unique_ptr<const char[]> buf_;
  std::vector<BlockDesc> blocks_;

 private:
  ProgramDesc& operator=(const ProgramDesc&) = delete;
  ProgramDesc(const ProgramDesc&) = delete;
  void NotImplemented() const {
    LOG(FATAL) << "The additional interfaces of ProgramDesc is temporarily "
                  "unavailable in read-only mode.";
  }
89 90 91 92 93
};

}  // namespace fbs
}  // namespace lite
}  // namespace paddle