未验证 提交 42ab4d55 编写于 作者: 石晓伟 提交者: GitHub

update desc interfaces, test=develop (#3926)

* update desc interfaces, test=develop

* update desc interfaces, test=develop

* update compatible_pb.cc, test=develop

* fix build errors, test=develop

* remove the fstream to shrink the size of library, test=develop
上级 3d0a45c3
...@@ -97,7 +97,7 @@ function(compile_flatbuffers_schema_to_cpp_opt TARGET SRC_FBS OPT) ...@@ -97,7 +97,7 @@ function(compile_flatbuffers_schema_to_cpp_opt TARGET SRC_FBS OPT)
OUTPUT ${GEN_HEADER} OUTPUT ${GEN_HEADER}
COMMAND "${FLATBUFFERS_FLATC_EXECUTABLE}" COMMAND "${FLATBUFFERS_FLATC_EXECUTABLE}"
--cpp --gen-mutable --gen-object-api --reflect-names --cpp --gen-mutable --gen-object-api --reflect-names
--cpp-ptr-type flatbuffers::unique_ptr # Used to test with C++98 STLs --force-empty --force-empty-vectors
${OPT} ${OPT}
-I "${CMAKE_CURRENT_SOURCE_DIR}/tests/include_test" -I "${CMAKE_CURRENT_SOURCE_DIR}/tests/include_test"
-o "${CMAKE_CURRENT_SOURCE_DIR}/${SRC_FBS_DIR}" -o "${CMAKE_CURRENT_SOURCE_DIR}/${SRC_FBS_DIR}"
......
...@@ -192,7 +192,8 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { ...@@ -192,7 +192,8 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
} else if (is_weight_quantization) { } else if (is_weight_quantization) {
std::string scale_name = conv_weight_name + "_quant_scale"; std::string scale_name = conv_weight_name + "_quant_scale";
if (conv_op_desc->HasAttr(scale_name)) { if (conv_op_desc->HasAttr(scale_name)) {
auto scale = conv_op_desc->GetAttr<std::vector<float>>(scale_name); std::vector<float> scale =
conv_op_desc->GetAttr<std::vector<float>>(scale_name);
CHECK_EQ(scale.size(), alpha_tensor.numel()); CHECK_EQ(scale.size(), alpha_tensor.numel());
for (size_t i = 0; i < scale.size(); i++) { for (size_t i = 0; i < scale.size(); i++) {
scale[i] *= alpha_data[i]; scale[i] *= alpha_data[i];
......
...@@ -84,11 +84,12 @@ cpp::OpDesc TransposeSoftmaxTransposeFuser::GenOpDesc( ...@@ -84,11 +84,12 @@ cpp::OpDesc TransposeSoftmaxTransposeFuser::GenOpDesc(
op_desc.SetInput("X", {matched.at("x1")->arg()->name}); op_desc.SetInput("X", {matched.at("x1")->arg()->name});
op_desc.SetOutput("Out", {matched.at("out")->arg()->name}); op_desc.SetOutput("Out", {matched.at("out")->arg()->name});
op_desc.SetAttr("axis", op_desc.SetAttr("axis",
matched.at("transpose1") *(matched.at("transpose1")
->stmt() ->stmt()
->op_info() ->op_info()
->GetAttr<std::vector<int>>("axis") ->GetAttr<std::vector<int>>("axis")
.back()); .end() -
1));
return op_desc; return op_desc;
} }
......
...@@ -62,15 +62,17 @@ std::string Visualize(mir::SSAGraph* graph) { ...@@ -62,15 +62,17 @@ std::string Visualize(mir::SSAGraph* graph) {
<< string_trunc(op_info->GetAttr<std::string>(attr_name)) << "\""; << string_trunc(op_info->GetAttr<std::string>(attr_name)) << "\"";
break; break;
case AttrType::FLOATS: { case AttrType::FLOATS: {
auto vals = op_info->GetAttr<std::vector<float>>(attr_name); std::vector<float> vals =
op_info->GetAttr<std::vector<float>>(attr_name);
os << ":floats: {" + Join(vals, ",") << "}"; os << ":floats: {" + Join(vals, ",") << "}";
} break; } break;
case AttrType::INTS: { case AttrType::INTS: {
auto vals = op_info->GetAttr<std::vector<int>>(attr_name); std::vector<int> vals = op_info->GetAttr<std::vector<int>>(attr_name);
os << ":ints: {" + Join(vals, ",") + "}"; os << ":ints: {" + Join(vals, ",") + "}";
} break; } break;
case AttrType::STRINGS: { case AttrType::STRINGS: {
auto vals = op_info->GetAttr<std::vector<std::string>>(attr_name); std::vector<std::string> vals =
op_info->GetAttr<std::vector<std::string>>(attr_name);
os << ":strings: {" + string_trunc(Join(vals, ",")) << "}"; os << ":strings: {" + string_trunc(Join(vals, ",")) << "}";
} break; } break;
default: default:
......
...@@ -195,7 +195,7 @@ void Program::Build(const cpp::ProgramDesc& prog) { ...@@ -195,7 +195,7 @@ void Program::Build(const cpp::ProgramDesc& prog) {
CHECK(ops_.empty()) << "Executor duplicate Build found"; CHECK(ops_.empty()) << "Executor duplicate Build found";
// Create operators. // Create operators.
auto program = prog; auto& program = prog;
CHECK(program.BlocksSize()); CHECK(program.BlocksSize());
auto& main_block = *program.GetBlock<cpp::BlockDesc>(0); auto& main_block = *program.GetBlock<cpp::BlockDesc>(0);
for (size_t i = 0; i < main_block.OpsSize(); ++i) { for (size_t i = 0; i < main_block.OpsSize(); ++i) {
...@@ -262,7 +262,7 @@ void Program::PrepareWorkspace(const cpp::ProgramDesc& prog, ...@@ -262,7 +262,7 @@ void Program::PrepareWorkspace(const cpp::ProgramDesc& prog,
} }
}; };
auto program = prog; auto& program = prog;
CHECK(program.BlocksSize()); CHECK(program.BlocksSize());
for (size_t b = 0; b < program.BlocksSize(); ++b) { for (size_t b = 0; b < program.BlocksSize(); ++b) {
auto& main_block = *program.GetBlock<cpp::BlockDesc>(b); auto& main_block = *program.GetBlock<cpp::BlockDesc>(b);
......
...@@ -46,7 +46,8 @@ struct Program { ...@@ -46,7 +46,8 @@ struct Program {
const std::shared_ptr<Scope>& root, const std::shared_ptr<Scope>& root,
const std::vector<Place>& valid_places, const std::vector<Place>& valid_places,
const std::vector<std::string>& var_names = {}) const std::vector<std::string>& var_names = {})
: scope_(root), valid_places_(valid_places), desc_(desc) { : scope_(root), valid_places_(valid_places) {
desc_.CopyFrom(desc);
CHECK(scope_) << "scope should be init first"; CHECK(scope_) << "scope should be init first";
VLOG(4) << "prepare work"; VLOG(4) << "prepare work";
PrepareWorkspace(desc, var_names); PrepareWorkspace(desc, var_names);
......
...@@ -54,10 +54,16 @@ class BlockDescWriteAPI { ...@@ -54,10 +54,16 @@ class BlockDescWriteAPI {
virtual void SetForwardBlockIdx(int32_t idx) { NotImplemented(); } virtual void SetForwardBlockIdx(int32_t idx) { NotImplemented(); }
template <typename T> template <typename T>
T* AddVar(); T* AddVar() {
NotImplemented();
return nullptr;
}
template <typename T> template <typename T>
T* AddOp(); T* AddOp() {
NotImplemented();
return nullptr;
}
virtual ~BlockDescWriteAPI() = default; virtual ~BlockDescWriteAPI() = default;
......
...@@ -73,7 +73,9 @@ class OpDescWriteAPI { ...@@ -73,7 +73,9 @@ class OpDescWriteAPI {
} }
template <typename T> template <typename T>
void SetAttr(const std::string& name, const T& v); void SetAttr(const std::string& name, const T& v) {
NotImplemented();
}
virtual ~OpDescWriteAPI() = default; virtual ~OpDescWriteAPI() = default;
......
...@@ -40,7 +40,10 @@ class ProgramDescWriteAPI { ...@@ -40,7 +40,10 @@ class ProgramDescWriteAPI {
virtual void SetVersion(int64_t version) { NotImplemented(); } virtual void SetVersion(int64_t version) { NotImplemented(); }
template <typename T> template <typename T>
T* AddBlock(); T* AddBlock() {
NotImplemented();
return nullptr;
}
virtual ~ProgramDescWriteAPI() = default; virtual ~ProgramDescWriteAPI() = default;
......
...@@ -57,6 +57,7 @@ class VectorView { ...@@ -57,6 +57,7 @@ class VectorView {
public: public:
typedef vector_view::VectorTraits<T, U> Traits; typedef vector_view::VectorTraits<T, U> Traits;
explicit VectorView(typename Traits::vector_type const* cvec) { explicit VectorView(typename Traits::vector_type const* cvec) {
CHECK(cvec);
cvec_ = cvec; cvec_ = cvec;
} }
typename Traits::subscript_return_type operator[](size_t i) const { typename Traits::subscript_return_type operator[](size_t i) const {
......
...@@ -277,7 +277,7 @@ void OpAttrsCppToAny(const cpp::OpDesc &cpp_desc, OpDescType *any_desc) { ...@@ -277,7 +277,7 @@ void OpAttrsCppToAny(const cpp::OpDesc &cpp_desc, OpDescType *any_desc) {
template <> \ template <> \
void TransformProgramDescCppToAny<NT::T>(const cpp::T &cpp_desc, \ void TransformProgramDescCppToAny<NT::T>(const cpp::T &cpp_desc, \
NT::T *any_desc) { \ NT::T *any_desc) { \
auto desc = cpp_desc; \ auto &desc = cpp_desc; \
if (desc.HasVersion()) { \ if (desc.HasVersion()) { \
any_desc->SetVersion(desc.Version()); \ any_desc->SetVersion(desc.Version()); \
} \ } \
......
...@@ -8,9 +8,6 @@ endfunction() ...@@ -8,9 +8,6 @@ endfunction()
lite_fbs_library(fbs_op_desc SRCS op_desc.cc FBS_DEPS framework_fbs_header) lite_fbs_library(fbs_op_desc SRCS op_desc.cc FBS_DEPS framework_fbs_header)
lite_fbs_library(fbs_var_desc SRCS var_desc.cc FBS_DEPS framework_fbs_header) lite_fbs_library(fbs_var_desc SRCS var_desc.cc FBS_DEPS framework_fbs_header)
lite_fbs_library(fbs_block_desc SRCS block_desc.cc FBS_DEPS framework_fbs_header) lite_fbs_library(fbs_block_desc SRCS block_desc.cc FBS_DEPS framework_fbs_header)
lite_fbs_library(fbs_program_desc SRCS program_desc.cc FBS_DEPS framework_fbs_header) lite_cc_library(fbs_program_desc SRCS program_desc.cc DEPS fbs_op_desc fbs_var_desc fbs_block_desc)
lite_cc_library(fbs_io SRCS io.cc DEPS fbs_program_desc)
lite_cc_test(test_vector_view SRCS vector_view_test.cc) lite_cc_test(test_vector_view SRCS vector_view_test.cc DEPS fbs_program_desc)
if (TARGET test_vector_view)
add_dependencies(test_vector_view framework_fbs_header)
endif()
...@@ -19,15 +19,27 @@ namespace lite { ...@@ -19,15 +19,27 @@ namespace lite {
namespace fbs { namespace fbs {
template <> template <>
proto::VarDesc* BlockDesc::GetVar<proto::VarDesc>(int32_t idx) { proto::VarDesc const* BlockDesc::GetVar<proto::VarDesc>(int32_t idx) const {
CHECK_LT(idx, VarsSize()) << "idx >= vars.size()"; CHECK_LT(idx, VarsSize()) << "idx >= vars.size()";
return const_cast<proto::VarDesc*>(desc_->vars()->Get(idx)); return desc_->vars()->Get(idx);
} }
template <> template <>
proto::OpDesc* BlockDesc::GetOp<proto::OpDesc>(int32_t idx) { proto::OpDesc const* BlockDesc::GetOp<proto::OpDesc>(int32_t idx) const {
CHECK_LT(idx, OpsSize()) << "idx >= ops.size()"; CHECK_LT(idx, OpsSize()) << "idx >= ops.size()";
return const_cast<proto::OpDesc*>(desc_->ops()->Get(idx)); return desc_->ops()->Get(idx);
}
template <>
VarDesc const* BlockDesc::GetVar<VarDesc>(int32_t idx) const {
CHECK_LT(idx, VarsSize()) << "idx >= vars.size()";
return &vars_[idx];
}
template <>
OpDesc const* BlockDesc::GetOp<OpDesc>(int32_t idx) const {
CHECK_LT(idx, OpsSize()) << "idx >= ops.size()";
return &ops_[idx];
} }
} // namespace fbs } // namespace fbs
......
...@@ -14,8 +14,11 @@ ...@@ -14,8 +14,11 @@
#pragma once #pragma once
#include <vector>
#include "lite/model_parser/base/block_desc.h" #include "lite/model_parser/base/block_desc.h"
#include "lite/model_parser/flatbuffers/framework_generated.h" #include "lite/model_parser/flatbuffers/framework_generated.h"
#include "lite/model_parser/flatbuffers/op_desc.h"
#include "lite/model_parser/flatbuffers/var_desc.h"
#include "lite/utils/all.h" #include "lite/utils/all.h"
namespace paddle { namespace paddle {
...@@ -24,7 +27,17 @@ namespace fbs { ...@@ -24,7 +27,17 @@ namespace fbs {
class BlockDesc : public BlockDescAPI { class BlockDesc : public BlockDescAPI {
public: public:
explicit BlockDesc(proto::BlockDesc* desc) : desc_(desc) { CHECK(desc_); } explicit BlockDesc(proto::BlockDesc const* desc) : desc_(desc) {
CHECK(desc_);
vars_.reserve(VarsSize());
ops_.reserve(OpsSize());
for (size_t idx = 0; idx < VarsSize(); ++idx) {
vars_.push_back(VarDesc(desc_->vars()->Get(idx)));
}
for (size_t idx = 0; idx < OpsSize(); ++idx) {
ops_.push_back(OpDesc(desc_->ops()->Get(idx)));
}
}
int32_t Idx() const override { return desc_->idx(); } int32_t Idx() const override { return desc_->idx(); }
...@@ -33,11 +46,12 @@ class BlockDesc : public BlockDescAPI { ...@@ -33,11 +46,12 @@ class BlockDesc : public BlockDescAPI {
size_t VarsSize() const override { return desc_->vars()->size(); } size_t VarsSize() const override { return desc_->vars()->size(); }
template <typename T> template <typename T>
T* GetVar(int32_t idx); T const* GetVar(int32_t idx) const;
template <typename T> template <typename T>
T const* GetVar(int32_t idx) const { T* GetVar(int32_t idx) {
return GetVar<T>(idx); NotImplemented();
return nullptr;
} }
size_t OpsSize() const override { size_t OpsSize() const override {
...@@ -47,21 +61,32 @@ class BlockDesc : public BlockDescAPI { ...@@ -47,21 +61,32 @@ class BlockDesc : public BlockDescAPI {
} }
template <typename T> template <typename T>
T* GetOp(int32_t idx); T const* GetOp(int32_t idx) const;
template <typename T> template <typename T>
T const* GetOp(int32_t idx) const { T* GetOp(int32_t idx) {
return GetOp<T>(idx); NotImplemented();
return nullptr;
} }
const std::vector<VarDesc>& GetVars() const { return vars_; }
int32_t ForwardBlockIdx() const override { int32_t ForwardBlockIdx() const override {
return desc_->forward_block_idx(); return desc_->forward_block_idx();
} }
BlockDesc() = delete; BlockDesc() { NotImplemented(); }
private: private:
proto::BlockDesc* desc_; // not_own proto::BlockDesc const* desc_; // not_own
std::vector<VarDesc> vars_;
std::vector<OpDesc> ops_;
private:
void NotImplemented() const {
LOG(FATAL) << "The additional interfaces of BlockDesc is temporarily "
"unavailable in read-only mode.";
}
}; };
} // namespace fbs } // namespace fbs
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/model_parser/flatbuffers/io.h"
#include <memory>
#include <utility>
namespace paddle {
namespace lite {
namespace fbs {
void LoadModel(const std::string& path, ProgramDesc* prog) {
FILE* file = fopen(path.c_str(), "rb");
fseek(file, 0, SEEK_END);
int64_t size = ftell(file);
rewind(file);
char* data = new char[size];
size = fread(data, 1, size, file);
fclose(file);
std::unique_ptr<char[]> buf(data);
prog->Init(std::move(buf));
}
} // namespace fbs
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "lite/model_parser/flatbuffers/program_desc.h"
namespace paddle {
namespace lite {
namespace fbs {
void LoadModel(const std::string& path, ProgramDesc* prog);
} // namespace fbs
} // namespace lite
} // namespace paddle
...@@ -30,7 +30,7 @@ namespace fbs { ...@@ -30,7 +30,7 @@ namespace fbs {
class OpDesc : public OpDescAPI { class OpDesc : public OpDescAPI {
public: public:
explicit OpDesc(proto::OpDesc* desc) : desc_(desc) { CHECK(desc_); } explicit OpDesc(proto::OpDesc const* desc) : desc_(desc) { CHECK(desc_); }
std::string Type() const override { return desc_->type()->str(); } std::string Type() const override { return desc_->type()->str(); }
...@@ -95,7 +95,7 @@ class OpDesc : public OpDescAPI { ...@@ -95,7 +95,7 @@ class OpDesc : public OpDescAPI {
OpDescAPI::AttrType GetAttrType(const std::string& name) const override { OpDescAPI::AttrType GetAttrType(const std::string& name) const override {
const auto& attr = desc_->attrs()->LookupByKey(name.c_str()); const auto& attr = desc_->attrs()->LookupByKey(name.c_str());
CHECK(attr); CHECK(attr) << "Can not find attr: " << name;
return static_cast<OpDescAPI::AttrType>(attr->type()); return static_cast<OpDescAPI::AttrType>(attr->type());
} }
...@@ -124,10 +124,8 @@ class OpDesc : public OpDescAPI { ...@@ -124,10 +124,8 @@ class OpDesc : public OpDescAPI {
template <typename T> template <typename T>
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT GetAttr(size_t idx) const; typename lite::OpDataTypeTrait<T, Flatbuffers>::RT GetAttr(size_t idx) const;
OpDesc() = delete;
private: private:
proto::OpDesc* desc_; proto::OpDesc const* desc_;
// To reduce overhead, we expect to use namespace aliasing to make cpp::Desc // To reduce overhead, we expect to use namespace aliasing to make cpp::Desc
// and flatbuffers::Desc replace each other. However, there is no direct // and flatbuffers::Desc replace each other. However, there is no direct
...@@ -138,6 +136,7 @@ class OpDesc : public OpDescAPI { ...@@ -138,6 +136,7 @@ class OpDesc : public OpDescAPI {
// caused by different building options. // caused by different building options.
public: public:
OpDesc() { NotImplemented(); }
bool HasInput(const std::string& param) const { bool HasInput(const std::string& param) const {
return desc_->inputs()->LookupByKey(param.c_str()) != nullptr; return desc_->inputs()->LookupByKey(param.c_str()) != nullptr;
} }
......
...@@ -19,9 +19,16 @@ namespace lite { ...@@ -19,9 +19,16 @@ namespace lite {
namespace fbs { namespace fbs {
template <> template <>
proto::BlockDesc* ProgramDesc::GetBlock<proto::BlockDesc>(int32_t idx) { proto::BlockDesc const* ProgramDesc::GetBlock<proto::BlockDesc>(
int32_t idx) const {
CHECK_LT(idx, BlocksSize()) << "idx >= blocks.size()"; CHECK_LT(idx, BlocksSize()) << "idx >= blocks.size()";
return const_cast<proto::BlockDesc*>(desc_->blocks()->Get(idx)); return desc_->blocks()->Get(idx);
}
template <>
BlockDesc const* ProgramDesc::GetBlock<BlockDesc>(int32_t idx) const {
CHECK_LT(idx, BlocksSize()) << "idx >= blocks.size()";
return &blocks_[idx];
} }
} // namespace fbs } // namespace fbs
......
...@@ -15,7 +15,10 @@ ...@@ -15,7 +15,10 @@
#pragma once #pragma once
#include <memory> #include <memory>
#include <utility>
#include <vector>
#include "lite/model_parser/base/program_desc.h" #include "lite/model_parser/base/program_desc.h"
#include "lite/model_parser/flatbuffers/block_desc.h"
#include "lite/model_parser/flatbuffers/framework_generated.h" #include "lite/model_parser/flatbuffers/framework_generated.h"
#include "lite/utils/all.h" #include "lite/utils/all.h"
...@@ -26,18 +29,40 @@ namespace fbs { ...@@ -26,18 +29,40 @@ namespace fbs {
class ProgramDesc : public ProgramDescAPI { class ProgramDesc : public ProgramDescAPI {
public: public:
ProgramDesc() = default; ProgramDesc() = default;
explicit ProgramDesc(proto::ProgramDesc *desc) : desc_(desc) { CHECK(desc); } explicit ProgramDesc(std::unique_ptr<const char[]> buf) {
Init(std::move(buf));
}
size_t BlocksSize() const override { return desc_->blocks()->size(); } size_t BlocksSize() const override { return desc_->blocks()->size(); }
void Init(std::unique_ptr<const char[]> buf) {
CHECK(buf.get() != nullptr);
buf_ = std::move(buf);
desc_ = proto::GetProgramDesc(buf_.get());
blocks_.reserve(BlocksSize());
for (size_t idx = 0; idx < BlocksSize(); ++idx) {
blocks_.push_back(BlockDesc(desc_->blocks()->Get(idx)));
}
}
void CopyFrom(const ProgramDesc& other) {
size_t length = strlen(static_cast<const char*>(other.raw_buf()));
std::unique_ptr<char[]> buf(new char[length]);
memcpy(buf.get(), other.raw_buf(), length);
Init(std::move(buf));
}
template <typename T> template <typename T>
T *GetBlock(int32_t idx); T const* GetBlock(int32_t idx) const;
template <typename T> template <typename T>
T const *GetBlock(int32_t idx) const { T* GetBlock(int32_t idx) {
return GetBlock<T>(idx); NotImplemented();
return nullptr;
} }
const std::vector<BlockDesc>& GetBlocks() const { return blocks_; }
bool HasVersion() const override { return desc_->version() != nullptr; } bool HasVersion() const override { return desc_->version() != nullptr; }
int64_t Version() const override { int64_t Version() const override {
...@@ -45,8 +70,22 @@ class ProgramDesc : public ProgramDescAPI { ...@@ -45,8 +70,22 @@ class ProgramDesc : public ProgramDescAPI {
return desc_->version()->version(); return desc_->version()->version();
} }
proto::ProgramDesc const* raw_desc() const { return desc_; }
const void* raw_buf() const { return buf_.get(); }
private: private:
proto::ProgramDesc *desc_; // not_own proto::ProgramDesc const* desc_;
std::unique_ptr<const char[]> buf_;
std::vector<BlockDesc> blocks_;
private:
ProgramDesc& operator=(const ProgramDesc&) = delete;
ProgramDesc(const ProgramDesc&) = delete;
void NotImplemented() const {
LOG(FATAL) << "The additional interfaces of ProgramDesc is temporarily "
"unavailable in read-only mode.";
}
}; };
} // namespace fbs } // namespace fbs
......
...@@ -27,7 +27,7 @@ namespace fbs { ...@@ -27,7 +27,7 @@ namespace fbs {
class VarDesc : public VarDescAPI { class VarDesc : public VarDescAPI {
public: public:
explicit VarDesc(proto::VarDesc* desc) : desc_(desc) {} explicit VarDesc(proto::VarDesc const* desc) : desc_(desc) {}
std::string Name() const override { return desc_->name()->str(); } std::string Name() const override { return desc_->name()->str(); }
...@@ -48,10 +48,14 @@ class VarDesc : public VarDescAPI { ...@@ -48,10 +48,14 @@ class VarDesc : public VarDescAPI {
return dims_vec; return dims_vec;
} }
VarDesc() = delete; VarDescAPI::Type GetDataType() const {
CHECK(GetType() == VarDescAPI::Type::LOD_TENSOR);
return static_cast<VarDescAPI::Type>(
desc_->type()->lod_tensor()->tensor()->data_type());
}
private: private:
proto::VarDesc* desc_; proto::VarDesc const* desc_;
// To reduce overhead, we expect to use namespace aliasing to make cpp::Desc // To reduce overhead, we expect to use namespace aliasing to make cpp::Desc
// and flatbuffers::Desc replace each other. However, there is no direct // and flatbuffers::Desc replace each other. However, there is no direct
...@@ -62,10 +66,7 @@ class VarDesc : public VarDescAPI { ...@@ -62,10 +66,7 @@ class VarDesc : public VarDescAPI {
// caused by different building options. // caused by different building options.
public: public:
VarDescAPI::Type GetDataType() const { VarDesc() { NotImplemented(); }
NotImplemented();
return data_type_;
}
void SetDataType(Type data_type) { NotImplemented(); } void SetDataType(Type data_type) { NotImplemented(); }
void SetShape(const std::vector<int64_t>& dims) { NotImplemented(); } void SetShape(const std::vector<int64_t>& dims) { NotImplemented(); }
...@@ -74,7 +75,6 @@ class VarDesc : public VarDescAPI { ...@@ -74,7 +75,6 @@ class VarDesc : public VarDescAPI {
LOG(FATAL) << "The additional interfaces of VarDesc is temporarily " LOG(FATAL) << "The additional interfaces of VarDesc is temporarily "
"unavailable in read-only mode."; "unavailable in read-only mode.";
} }
Type data_type_;
std::vector<int64_t> shape_; std::vector<int64_t> shape_;
}; };
......
...@@ -104,20 +104,32 @@ class VectorView<std::string, Flatbuffers> { ...@@ -104,20 +104,32 @@ class VectorView<std::string, Flatbuffers> {
explicit VectorView(typename Traits::vector_type const* cvec) { explicit VectorView(typename Traits::vector_type const* cvec) {
cvec_ = cvec; cvec_ = cvec;
} }
std::string operator[](size_t i) const { return cvec_->operator[](i)->str(); } std::string operator[](size_t i) const {
CHECK(cvec_);
return cvec_->operator[](i)->str();
}
vector_view::FBSStrIterator begin() const { vector_view::FBSStrIterator begin() const {
CHECK(cvec_);
return vector_view::FBSStrIterator(cvec_->begin()); return vector_view::FBSStrIterator(cvec_->begin());
} }
vector_view::FBSStrIterator end() const { vector_view::FBSStrIterator end() const {
CHECK(cvec_);
return vector_view::FBSStrIterator(cvec_->end()); return vector_view::FBSStrIterator(cvec_->end());
} }
size_t size() const { return cvec_->size(); } size_t size() const {
if (cvec_ == nullptr) {
return 0;
}
return cvec_->size();
}
operator std::vector<std::string>() const { operator std::vector<std::string>() const {
VLOG(5) << "Copying elements out of VectorView will damage performance."; VLOG(5) << "Copying elements out of VectorView will damage performance.";
std::vector<std::string> tmp; std::vector<std::string> tmp;
tmp.reserve(cvec_->size()); tmp.reserve(size());
for (auto val : *cvec_) { if (cvec_ != nullptr) {
tmp.push_back(val->str()); for (auto val : *cvec_) {
tmp.push_back(val->str());
}
} }
return tmp; return tmp;
} }
......
...@@ -24,6 +24,12 @@ VarDesc* BlockDesc::GetVar<VarDesc>(int32_t idx) { ...@@ -24,6 +24,12 @@ VarDesc* BlockDesc::GetVar<VarDesc>(int32_t idx) {
return &vars_[idx]; return &vars_[idx];
} }
template <>
VarDesc const* BlockDesc::GetVar<VarDesc>(int32_t idx) const {
CHECK_LT(idx, VarsSize()) << "idx >= vars.size()";
return &vars_[idx];
}
template <> template <>
VarDesc* BlockDesc::AddVar<VarDesc>() { VarDesc* BlockDesc::AddVar<VarDesc>() {
vars_.emplace_back(); vars_.emplace_back();
...@@ -36,6 +42,12 @@ OpDesc* BlockDesc::GetOp<OpDesc>(int32_t idx) { ...@@ -36,6 +42,12 @@ OpDesc* BlockDesc::GetOp<OpDesc>(int32_t idx) {
return &ops_[idx]; return &ops_[idx];
} }
template <>
OpDesc const* BlockDesc::GetOp<OpDesc>(int32_t idx) const {
CHECK_LT(idx, OpsSize()) << "idx >= ops.size()";
return &ops_[idx];
}
template <> template <>
OpDesc* BlockDesc::AddOp<OpDesc>() { OpDesc* BlockDesc::AddOp<OpDesc>() {
ops_.emplace_back(); ops_.emplace_back();
......
...@@ -46,12 +46,10 @@ class BlockDesc : public BlockDescAPI { ...@@ -46,12 +46,10 @@ class BlockDesc : public BlockDescAPI {
template <typename T> template <typename T>
T* GetVar(int32_t idx); T* GetVar(int32_t idx);
std::vector<VarDesc>& GetVars() { return vars_; }
template <typename T> template <typename T>
T const* GetVar(int32_t idx) const { T const* GetVar(int32_t idx) const;
return GetVar<T>(idx);
} std::vector<VarDesc>& GetVars() { return vars_; }
template <typename T> template <typename T>
T* AddVar(); T* AddVar();
...@@ -64,9 +62,7 @@ class BlockDesc : public BlockDescAPI { ...@@ -64,9 +62,7 @@ class BlockDesc : public BlockDescAPI {
T* GetOp(int32_t idx); T* GetOp(int32_t idx);
template <typename T> template <typename T>
T const* GetOp(int32_t idx) const { T const* GetOp(int32_t idx) const;
return GetOp<T>(idx);
}
template <typename T> template <typename T>
T* AddOp(); T* AddOp();
......
...@@ -24,6 +24,12 @@ BlockDesc* ProgramDesc::GetBlock<BlockDesc>(int32_t idx) { ...@@ -24,6 +24,12 @@ BlockDesc* ProgramDesc::GetBlock<BlockDesc>(int32_t idx) {
return &blocks_[idx]; return &blocks_[idx];
} }
template <>
BlockDesc const* ProgramDesc::GetBlock<BlockDesc>(int32_t idx) const {
CHECK_LT(idx, BlocksSize()) << "idx >= blocks.size()";
return &blocks_[idx];
}
template <> template <>
BlockDesc* ProgramDesc::AddBlock<BlockDesc>() { BlockDesc* ProgramDesc::AddBlock<BlockDesc>() {
blocks_.emplace_back(); blocks_.emplace_back();
......
...@@ -30,6 +30,13 @@ class ProgramDesc : public ProgramDescAPI { ...@@ -30,6 +30,13 @@ class ProgramDesc : public ProgramDescAPI {
public: public:
ProgramDesc() = default; ProgramDesc() = default;
void CopyFrom(const ProgramDesc& other) {
version_ = other.Version();
blocks_ = other.blocks();
}
const std::vector<BlockDesc>& blocks() const { return blocks_; }
size_t BlocksSize() const override { return blocks_.size(); } size_t BlocksSize() const override { return blocks_.size(); }
void ClearBlocks() override { blocks_.clear(); } void ClearBlocks() override { blocks_.clear(); }
...@@ -37,12 +44,10 @@ class ProgramDesc : public ProgramDescAPI { ...@@ -37,12 +44,10 @@ class ProgramDesc : public ProgramDescAPI {
template <typename T> template <typename T>
T* GetBlock(int32_t idx); T* GetBlock(int32_t idx);
std::vector<BlockDesc>& GetBlocks() { return blocks_; }
template <typename T> template <typename T>
T const* GetBlock(int32_t idx) const { T const* GetBlock(int32_t idx) const;
return GetBlock<T>(idx);
} std::vector<BlockDesc>& GetBlocks() { return blocks_; }
template <typename T> template <typename T>
T* AddBlock(); T* AddBlock();
......
...@@ -176,7 +176,7 @@ void LoadCombinedParamsPb(const std::string &path, ...@@ -176,7 +176,7 @@ void LoadCombinedParamsPb(const std::string &path,
const cpp::ProgramDesc &cpp_prog, const cpp::ProgramDesc &cpp_prog,
bool params_from_memory) { bool params_from_memory) {
CHECK(scope); CHECK(scope);
auto prog = cpp_prog; auto &prog = cpp_prog;
auto &main_block_desc = *prog.GetBlock<cpp::BlockDesc>(0); auto &main_block_desc = *prog.GetBlock<cpp::BlockDesc>(0);
// Get vars // Get vars
...@@ -310,7 +310,7 @@ void SaveModelPb(const std::string &model_dir, ...@@ -310,7 +310,7 @@ void SaveModelPb(const std::string &model_dir,
void SaveCombinedParamsPb(const std::string &path, void SaveCombinedParamsPb(const std::string &path,
const lite::Scope &exec_scope, const lite::Scope &exec_scope,
const cpp::ProgramDesc &cpp_prog) { const cpp::ProgramDesc &cpp_prog) {
auto prog = cpp_prog; auto &prog = cpp_prog;
auto &main_block_desc = *prog.GetBlock<cpp::BlockDesc>(0); auto &main_block_desc = *prog.GetBlock<cpp::BlockDesc>(0);
// Get vars // Get vars
...@@ -526,7 +526,7 @@ void SaveCombinedParamsNaive(const std::string &path, ...@@ -526,7 +526,7 @@ void SaveCombinedParamsNaive(const std::string &path,
naive_buffer::proto::CombinedParamsDesc pt_desc(&table); naive_buffer::proto::CombinedParamsDesc pt_desc(&table);
naive_buffer::CombinedParamsDesc desc(&pt_desc); naive_buffer::CombinedParamsDesc desc(&pt_desc);
auto prog = cpp_prog; auto &prog = cpp_prog;
auto &main_block_desc = *prog.GetBlock<cpp::BlockDesc>(0); auto &main_block_desc = *prog.GetBlock<cpp::BlockDesc>(0);
// set unique_var_names to avoid saving shared params repeatedly // set unique_var_names to avoid saving shared params repeatedly
std::set<std::string> unique_var_names; std::set<std::string> unique_var_names;
...@@ -681,7 +681,7 @@ void LoadCombinedParamsNaive(const std::string &path, ...@@ -681,7 +681,7 @@ void LoadCombinedParamsNaive(const std::string &path,
} }
// Check all params loaded // Check all params loaded
auto prog = cpp_prog; auto &prog = cpp_prog;
auto &main_block_desc = *prog.GetBlock<cpp::BlockDesc>(0); auto &main_block_desc = *prog.GetBlock<cpp::BlockDesc>(0);
for (size_t i = 0; i < main_block_desc.VarsSize(); ++i) { for (size_t i = 0; i < main_block_desc.VarsSize(); ++i) {
auto &var = *main_block_desc.GetVar<cpp::VarDesc>(i); auto &var = *main_block_desc.GetVar<cpp::VarDesc>(i);
......
...@@ -55,11 +55,6 @@ class BlockDesc : public BlockDescAPI { ...@@ -55,11 +55,6 @@ class BlockDesc : public BlockDescAPI {
template <typename T> template <typename T>
T* GetVar(int32_t idx); T* GetVar(int32_t idx);
template <typename T>
T const* GetVar(int32_t idx) const {
return GetVar<T>(idx);
}
template <typename T> template <typename T>
T* AddVar(); T* AddVar();
...@@ -70,11 +65,6 @@ class BlockDesc : public BlockDescAPI { ...@@ -70,11 +65,6 @@ class BlockDesc : public BlockDescAPI {
template <typename T> template <typename T>
T* GetOp(int32_t idx); T* GetOp(int32_t idx);
template <typename T>
T const* GetOp(int32_t idx) const {
return GetOp<T>(idx);
}
template <typename T> template <typename T>
T* AddOp(); T* AddOp();
......
...@@ -45,11 +45,6 @@ class ProgramDesc : public ProgramDescAPI { ...@@ -45,11 +45,6 @@ class ProgramDesc : public ProgramDescAPI {
template <typename T> template <typename T>
T *GetBlock(int32_t idx); T *GetBlock(int32_t idx);
template <typename T>
T const *GetBlock(int32_t idx) const {
return GetBlock<T>(idx);
}
template <typename T> template <typename T>
T *AddBlock(); T *AddBlock();
......
...@@ -83,7 +83,7 @@ class DeformableConvOpLite : public OpLite { ...@@ -83,7 +83,7 @@ class DeformableConvOpLite : public OpLite {
param_.conv_param.filter = param_.conv_param.filter =
scope->FindVar(Filter)->GetMutable<lite::Tensor>(); scope->FindVar(Filter)->GetMutable<lite::Tensor>();
param_.conv_param.strides = op_desc.GetAttr<std::vector<int>>("strides"); param_.conv_param.strides = op_desc.GetAttr<std::vector<int>>("strides");
auto paddings = op_desc.GetAttr<std::vector<int>>("paddings"); std::vector<int> paddings = op_desc.GetAttr<std::vector<int>>("paddings");
auto dilations = op_desc.GetAttr<std::vector<int>>("dilations"); auto dilations = op_desc.GetAttr<std::vector<int>>("dilations");
param_.conv_param.groups = op_desc.GetAttr<int>("groups"); param_.conv_param.groups = op_desc.GetAttr<int>("groups");
param_.conv_param.dilations = std::make_shared<std::vector<int>>(dilations); param_.conv_param.dilations = std::make_shared<std::vector<int>>(dilations);
......
...@@ -54,7 +54,7 @@ class MaxPoolWithIndexOpLite : public OpLite { ...@@ -54,7 +54,7 @@ class MaxPoolWithIndexOpLite : public OpLite {
param_.ksize = op_desc.GetAttr<std::vector<int>>("ksize"); param_.ksize = op_desc.GetAttr<std::vector<int>>("ksize");
param_.global_pooling = op_desc.GetAttr<bool>("global_pooling"); param_.global_pooling = op_desc.GetAttr<bool>("global_pooling");
param_.strides = op_desc.GetAttr<std::vector<int>>("strides"); param_.strides = op_desc.GetAttr<std::vector<int>>("strides");
auto paddings = op_desc.GetAttr<std::vector<int>>("paddings"); std::vector<int> paddings = op_desc.GetAttr<std::vector<int>>("paddings");
if (op_desc.HasAttr("adaptive")) { if (op_desc.HasAttr("adaptive")) {
param_.adaptive = op_desc.GetAttr<bool>("adaptive"); param_.adaptive = op_desc.GetAttr<bool>("adaptive");
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册