未验证 提交 42ab4d55 编写于 作者: 石晓伟 提交者: GitHub

update desc interfaces, test=develop (#3926)

* update desc interfaces, test=develop

* update desc interfaces, test=develop

* update compatible_pb.cc, test=develop

* fix build errors, test=develop

* remove the fstream to shrink the size of library, test=develop
上级 3d0a45c3
......@@ -97,7 +97,7 @@ function(compile_flatbuffers_schema_to_cpp_opt TARGET SRC_FBS OPT)
OUTPUT ${GEN_HEADER}
COMMAND "${FLATBUFFERS_FLATC_EXECUTABLE}"
--cpp --gen-mutable --gen-object-api --reflect-names
--cpp-ptr-type flatbuffers::unique_ptr # Used to test with C++98 STLs
--force-empty --force-empty-vectors
${OPT}
-I "${CMAKE_CURRENT_SOURCE_DIR}/tests/include_test"
-o "${CMAKE_CURRENT_SOURCE_DIR}/${SRC_FBS_DIR}"
......
......@@ -192,7 +192,8 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
} else if (is_weight_quantization) {
std::string scale_name = conv_weight_name + "_quant_scale";
if (conv_op_desc->HasAttr(scale_name)) {
auto scale = conv_op_desc->GetAttr<std::vector<float>>(scale_name);
std::vector<float> scale =
conv_op_desc->GetAttr<std::vector<float>>(scale_name);
CHECK_EQ(scale.size(), alpha_tensor.numel());
for (size_t i = 0; i < scale.size(); i++) {
scale[i] *= alpha_data[i];
......
......@@ -84,11 +84,12 @@ cpp::OpDesc TransposeSoftmaxTransposeFuser::GenOpDesc(
op_desc.SetInput("X", {matched.at("x1")->arg()->name});
op_desc.SetOutput("Out", {matched.at("out")->arg()->name});
op_desc.SetAttr("axis",
matched.at("transpose1")
->stmt()
->op_info()
->GetAttr<std::vector<int>>("axis")
.back());
*(matched.at("transpose1")
->stmt()
->op_info()
->GetAttr<std::vector<int>>("axis")
.end() -
1));
return op_desc;
}
......
......@@ -62,15 +62,17 @@ std::string Visualize(mir::SSAGraph* graph) {
<< string_trunc(op_info->GetAttr<std::string>(attr_name)) << "\"";
break;
case AttrType::FLOATS: {
auto vals = op_info->GetAttr<std::vector<float>>(attr_name);
std::vector<float> vals =
op_info->GetAttr<std::vector<float>>(attr_name);
os << ":floats: {" + Join(vals, ",") << "}";
} break;
case AttrType::INTS: {
auto vals = op_info->GetAttr<std::vector<int>>(attr_name);
std::vector<int> vals = op_info->GetAttr<std::vector<int>>(attr_name);
os << ":ints: {" + Join(vals, ",") + "}";
} break;
case AttrType::STRINGS: {
auto vals = op_info->GetAttr<std::vector<std::string>>(attr_name);
std::vector<std::string> vals =
op_info->GetAttr<std::vector<std::string>>(attr_name);
os << ":strings: {" + string_trunc(Join(vals, ",")) << "}";
} break;
default:
......
......@@ -195,7 +195,7 @@ void Program::Build(const cpp::ProgramDesc& prog) {
CHECK(ops_.empty()) << "Executor duplicate Build found";
// Create operators.
auto program = prog;
auto& program = prog;
CHECK(program.BlocksSize());
auto& main_block = *program.GetBlock<cpp::BlockDesc>(0);
for (size_t i = 0; i < main_block.OpsSize(); ++i) {
......@@ -262,7 +262,7 @@ void Program::PrepareWorkspace(const cpp::ProgramDesc& prog,
}
};
auto program = prog;
auto& program = prog;
CHECK(program.BlocksSize());
for (size_t b = 0; b < program.BlocksSize(); ++b) {
auto& main_block = *program.GetBlock<cpp::BlockDesc>(b);
......
......@@ -46,7 +46,8 @@ struct Program {
const std::shared_ptr<Scope>& root,
const std::vector<Place>& valid_places,
const std::vector<std::string>& var_names = {})
: scope_(root), valid_places_(valid_places), desc_(desc) {
: scope_(root), valid_places_(valid_places) {
desc_.CopyFrom(desc);
CHECK(scope_) << "scope should be init first";
VLOG(4) << "prepare work";
PrepareWorkspace(desc, var_names);
......
......@@ -54,10 +54,16 @@ class BlockDescWriteAPI {
virtual void SetForwardBlockIdx(int32_t idx) { NotImplemented(); }
template <typename T>
T* AddVar();
T* AddVar() {
NotImplemented();
return nullptr;
}
template <typename T>
T* AddOp();
T* AddOp() {
NotImplemented();
return nullptr;
}
virtual ~BlockDescWriteAPI() = default;
......
......@@ -73,7 +73,9 @@ class OpDescWriteAPI {
}
template <typename T>
void SetAttr(const std::string& name, const T& v);
void SetAttr(const std::string& name, const T& v) {
NotImplemented();
}
virtual ~OpDescWriteAPI() = default;
......
......@@ -40,7 +40,10 @@ class ProgramDescWriteAPI {
virtual void SetVersion(int64_t version) { NotImplemented(); }
template <typename T>
T* AddBlock();
T* AddBlock() {
NotImplemented();
return nullptr;
}
virtual ~ProgramDescWriteAPI() = default;
......
......@@ -57,6 +57,7 @@ class VectorView {
public:
typedef vector_view::VectorTraits<T, U> Traits;
explicit VectorView(typename Traits::vector_type const* cvec) {
CHECK(cvec);
cvec_ = cvec;
}
typename Traits::subscript_return_type operator[](size_t i) const {
......
......@@ -277,7 +277,7 @@ void OpAttrsCppToAny(const cpp::OpDesc &cpp_desc, OpDescType *any_desc) {
template <> \
void TransformProgramDescCppToAny<NT::T>(const cpp::T &cpp_desc, \
NT::T *any_desc) { \
auto desc = cpp_desc; \
auto &desc = cpp_desc; \
if (desc.HasVersion()) { \
any_desc->SetVersion(desc.Version()); \
} \
......
......@@ -8,9 +8,6 @@ endfunction()
lite_fbs_library(fbs_op_desc SRCS op_desc.cc FBS_DEPS framework_fbs_header)
lite_fbs_library(fbs_var_desc SRCS var_desc.cc FBS_DEPS framework_fbs_header)
lite_fbs_library(fbs_block_desc SRCS block_desc.cc FBS_DEPS framework_fbs_header)
lite_fbs_library(fbs_program_desc SRCS program_desc.cc FBS_DEPS framework_fbs_header)
lite_cc_test(test_vector_view SRCS vector_view_test.cc)
if (TARGET test_vector_view)
add_dependencies(test_vector_view framework_fbs_header)
endif()
lite_cc_library(fbs_program_desc SRCS program_desc.cc DEPS fbs_op_desc fbs_var_desc fbs_block_desc)
lite_cc_library(fbs_io SRCS io.cc DEPS fbs_program_desc)
lite_cc_test(test_vector_view SRCS vector_view_test.cc DEPS fbs_program_desc)
......@@ -19,15 +19,27 @@ namespace lite {
namespace fbs {
template <>
proto::VarDesc* BlockDesc::GetVar<proto::VarDesc>(int32_t idx) {
proto::VarDesc const* BlockDesc::GetVar<proto::VarDesc>(int32_t idx) const {
CHECK_LT(idx, VarsSize()) << "idx >= vars.size()";
return const_cast<proto::VarDesc*>(desc_->vars()->Get(idx));
return desc_->vars()->Get(idx);
}
template <>
proto::OpDesc* BlockDesc::GetOp<proto::OpDesc>(int32_t idx) {
proto::OpDesc const* BlockDesc::GetOp<proto::OpDesc>(int32_t idx) const {
CHECK_LT(idx, OpsSize()) << "idx >= ops.size()";
return const_cast<proto::OpDesc*>(desc_->ops()->Get(idx));
return desc_->ops()->Get(idx);
}
template <>
VarDesc const* BlockDesc::GetVar<VarDesc>(int32_t idx) const {
CHECK_LT(idx, VarsSize()) << "idx >= vars.size()";
return &vars_[idx];
}
template <>
OpDesc const* BlockDesc::GetOp<OpDesc>(int32_t idx) const {
CHECK_LT(idx, OpsSize()) << "idx >= ops.size()";
return &ops_[idx];
}
} // namespace fbs
......
......@@ -14,8 +14,11 @@
#pragma once
#include <vector>
#include "lite/model_parser/base/block_desc.h"
#include "lite/model_parser/flatbuffers/framework_generated.h"
#include "lite/model_parser/flatbuffers/op_desc.h"
#include "lite/model_parser/flatbuffers/var_desc.h"
#include "lite/utils/all.h"
namespace paddle {
......@@ -24,7 +27,17 @@ namespace fbs {
class BlockDesc : public BlockDescAPI {
public:
explicit BlockDesc(proto::BlockDesc* desc) : desc_(desc) { CHECK(desc_); }
explicit BlockDesc(proto::BlockDesc const* desc) : desc_(desc) {
CHECK(desc_);
vars_.reserve(VarsSize());
ops_.reserve(OpsSize());
for (size_t idx = 0; idx < VarsSize(); ++idx) {
vars_.push_back(VarDesc(desc_->vars()->Get(idx)));
}
for (size_t idx = 0; idx < OpsSize(); ++idx) {
ops_.push_back(OpDesc(desc_->ops()->Get(idx)));
}
}
int32_t Idx() const override { return desc_->idx(); }
......@@ -33,11 +46,12 @@ class BlockDesc : public BlockDescAPI {
size_t VarsSize() const override { return desc_->vars()->size(); }
template <typename T>
T* GetVar(int32_t idx);
T const* GetVar(int32_t idx) const;
template <typename T>
T const* GetVar(int32_t idx) const {
return GetVar<T>(idx);
T* GetVar(int32_t idx) {
NotImplemented();
return nullptr;
}
size_t OpsSize() const override {
......@@ -47,21 +61,32 @@ class BlockDesc : public BlockDescAPI {
}
template <typename T>
T* GetOp(int32_t idx);
T const* GetOp(int32_t idx) const;
template <typename T>
T const* GetOp(int32_t idx) const {
return GetOp<T>(idx);
T* GetOp(int32_t idx) {
NotImplemented();
return nullptr;
}
const std::vector<VarDesc>& GetVars() const { return vars_; }
int32_t ForwardBlockIdx() const override {
return desc_->forward_block_idx();
}
BlockDesc() = delete;
BlockDesc() { NotImplemented(); }
private:
proto::BlockDesc* desc_; // not_own
proto::BlockDesc const* desc_; // not_own
std::vector<VarDesc> vars_;
std::vector<OpDesc> ops_;
private:
void NotImplemented() const {
LOG(FATAL) << "The additional interfaces of BlockDesc is temporarily "
"unavailable in read-only mode.";
}
};
} // namespace fbs
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/model_parser/flatbuffers/io.h"
#include <memory>
#include <utility>
namespace paddle {
namespace lite {
namespace fbs {
void LoadModel(const std::string& path, ProgramDesc* prog) {
FILE* file = fopen(path.c_str(), "rb");
fseek(file, 0, SEEK_END);
int64_t size = ftell(file);
rewind(file);
char* data = new char[size];
size = fread(data, 1, size, file);
fclose(file);
std::unique_ptr<char[]> buf(data);
prog->Init(std::move(buf));
}
} // namespace fbs
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "lite/model_parser/flatbuffers/program_desc.h"
namespace paddle {
namespace lite {
namespace fbs {
void LoadModel(const std::string& path, ProgramDesc* prog);
} // namespace fbs
} // namespace lite
} // namespace paddle
......@@ -30,7 +30,7 @@ namespace fbs {
class OpDesc : public OpDescAPI {
public:
explicit OpDesc(proto::OpDesc* desc) : desc_(desc) { CHECK(desc_); }
explicit OpDesc(proto::OpDesc const* desc) : desc_(desc) { CHECK(desc_); }
std::string Type() const override { return desc_->type()->str(); }
......@@ -95,7 +95,7 @@ class OpDesc : public OpDescAPI {
OpDescAPI::AttrType GetAttrType(const std::string& name) const override {
const auto& attr = desc_->attrs()->LookupByKey(name.c_str());
CHECK(attr);
CHECK(attr) << "Can not find attr: " << name;
return static_cast<OpDescAPI::AttrType>(attr->type());
}
......@@ -124,10 +124,8 @@ class OpDesc : public OpDescAPI {
template <typename T>
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT GetAttr(size_t idx) const;
OpDesc() = delete;
private:
proto::OpDesc* desc_;
proto::OpDesc const* desc_;
// To reduce overhead, we expect to use namespace aliasing to make cpp::Desc
// and flatbuffers::Desc replace each other. However, there is no direct
......@@ -138,6 +136,7 @@ class OpDesc : public OpDescAPI {
// caused by different building options.
public:
OpDesc() { NotImplemented(); }
bool HasInput(const std::string& param) const {
return desc_->inputs()->LookupByKey(param.c_str()) != nullptr;
}
......
......@@ -19,9 +19,16 @@ namespace lite {
namespace fbs {
template <>
proto::BlockDesc* ProgramDesc::GetBlock<proto::BlockDesc>(int32_t idx) {
proto::BlockDesc const* ProgramDesc::GetBlock<proto::BlockDesc>(
int32_t idx) const {
CHECK_LT(idx, BlocksSize()) << "idx >= blocks.size()";
return const_cast<proto::BlockDesc*>(desc_->blocks()->Get(idx));
return desc_->blocks()->Get(idx);
}
template <>
BlockDesc const* ProgramDesc::GetBlock<BlockDesc>(int32_t idx) const {
CHECK_LT(idx, BlocksSize()) << "idx >= blocks.size()";
return &blocks_[idx];
}
} // namespace fbs
......
......@@ -15,7 +15,10 @@
#pragma once
#include <memory>
#include <utility>
#include <vector>
#include "lite/model_parser/base/program_desc.h"
#include "lite/model_parser/flatbuffers/block_desc.h"
#include "lite/model_parser/flatbuffers/framework_generated.h"
#include "lite/utils/all.h"
......@@ -26,18 +29,40 @@ namespace fbs {
class ProgramDesc : public ProgramDescAPI {
public:
ProgramDesc() = default;
explicit ProgramDesc(proto::ProgramDesc *desc) : desc_(desc) { CHECK(desc); }
explicit ProgramDesc(std::unique_ptr<const char[]> buf) {
Init(std::move(buf));
}
size_t BlocksSize() const override { return desc_->blocks()->size(); }
void Init(std::unique_ptr<const char[]> buf) {
CHECK(buf.get() != nullptr);
buf_ = std::move(buf);
desc_ = proto::GetProgramDesc(buf_.get());
blocks_.reserve(BlocksSize());
for (size_t idx = 0; idx < BlocksSize(); ++idx) {
blocks_.push_back(BlockDesc(desc_->blocks()->Get(idx)));
}
}
void CopyFrom(const ProgramDesc& other) {
size_t length = strlen(static_cast<const char*>(other.raw_buf()));
std::unique_ptr<char[]> buf(new char[length]);
memcpy(buf.get(), other.raw_buf(), length);
Init(std::move(buf));
}
template <typename T>
T *GetBlock(int32_t idx);
T const* GetBlock(int32_t idx) const;
template <typename T>
T const *GetBlock(int32_t idx) const {
return GetBlock<T>(idx);
T* GetBlock(int32_t idx) {
NotImplemented();
return nullptr;
}
const std::vector<BlockDesc>& GetBlocks() const { return blocks_; }
bool HasVersion() const override { return desc_->version() != nullptr; }
int64_t Version() const override {
......@@ -45,8 +70,22 @@ class ProgramDesc : public ProgramDescAPI {
return desc_->version()->version();
}
proto::ProgramDesc const* raw_desc() const { return desc_; }
const void* raw_buf() const { return buf_.get(); }
private:
proto::ProgramDesc *desc_; // not_own
proto::ProgramDesc const* desc_;
std::unique_ptr<const char[]> buf_;
std::vector<BlockDesc> blocks_;
private:
ProgramDesc& operator=(const ProgramDesc&) = delete;
ProgramDesc(const ProgramDesc&) = delete;
void NotImplemented() const {
LOG(FATAL) << "The additional interfaces of ProgramDesc is temporarily "
"unavailable in read-only mode.";
}
};
} // namespace fbs
......
......@@ -27,7 +27,7 @@ namespace fbs {
class VarDesc : public VarDescAPI {
public:
explicit VarDesc(proto::VarDesc* desc) : desc_(desc) {}
explicit VarDesc(proto::VarDesc const* desc) : desc_(desc) {}
std::string Name() const override { return desc_->name()->str(); }
......@@ -48,10 +48,14 @@ class VarDesc : public VarDescAPI {
return dims_vec;
}
VarDesc() = delete;
VarDescAPI::Type GetDataType() const {
CHECK(GetType() == VarDescAPI::Type::LOD_TENSOR);
return static_cast<VarDescAPI::Type>(
desc_->type()->lod_tensor()->tensor()->data_type());
}
private:
proto::VarDesc* desc_;
proto::VarDesc const* desc_;
// To reduce overhead, we expect to use namespace aliasing to make cpp::Desc
// and flatbuffers::Desc replace each other. However, there is no direct
......@@ -62,10 +66,7 @@ class VarDesc : public VarDescAPI {
// caused by different building options.
public:
VarDescAPI::Type GetDataType() const {
NotImplemented();
return data_type_;
}
VarDesc() { NotImplemented(); }
void SetDataType(Type data_type) { NotImplemented(); }
void SetShape(const std::vector<int64_t>& dims) { NotImplemented(); }
......@@ -74,7 +75,6 @@ class VarDesc : public VarDescAPI {
LOG(FATAL) << "The additional interfaces of VarDesc is temporarily "
"unavailable in read-only mode.";
}
Type data_type_;
std::vector<int64_t> shape_;
};
......
......@@ -104,20 +104,32 @@ class VectorView<std::string, Flatbuffers> {
explicit VectorView(typename Traits::vector_type const* cvec) {
cvec_ = cvec;
}
std::string operator[](size_t i) const { return cvec_->operator[](i)->str(); }
std::string operator[](size_t i) const {
CHECK(cvec_);
return cvec_->operator[](i)->str();
}
vector_view::FBSStrIterator begin() const {
CHECK(cvec_);
return vector_view::FBSStrIterator(cvec_->begin());
}
vector_view::FBSStrIterator end() const {
CHECK(cvec_);
return vector_view::FBSStrIterator(cvec_->end());
}
size_t size() const { return cvec_->size(); }
size_t size() const {
if (cvec_ == nullptr) {
return 0;
}
return cvec_->size();
}
operator std::vector<std::string>() const {
VLOG(5) << "Copying elements out of VectorView will damage performance.";
std::vector<std::string> tmp;
tmp.reserve(cvec_->size());
for (auto val : *cvec_) {
tmp.push_back(val->str());
tmp.reserve(size());
if (cvec_ != nullptr) {
for (auto val : *cvec_) {
tmp.push_back(val->str());
}
}
return tmp;
}
......
......@@ -24,6 +24,12 @@ VarDesc* BlockDesc::GetVar<VarDesc>(int32_t idx) {
return &vars_[idx];
}
template <>
VarDesc const* BlockDesc::GetVar<VarDesc>(int32_t idx) const {
CHECK_LT(idx, VarsSize()) << "idx >= vars.size()";
return &vars_[idx];
}
template <>
VarDesc* BlockDesc::AddVar<VarDesc>() {
vars_.emplace_back();
......@@ -36,6 +42,12 @@ OpDesc* BlockDesc::GetOp<OpDesc>(int32_t idx) {
return &ops_[idx];
}
template <>
OpDesc const* BlockDesc::GetOp<OpDesc>(int32_t idx) const {
CHECK_LT(idx, OpsSize()) << "idx >= ops.size()";
return &ops_[idx];
}
template <>
OpDesc* BlockDesc::AddOp<OpDesc>() {
ops_.emplace_back();
......
......@@ -46,12 +46,10 @@ class BlockDesc : public BlockDescAPI {
template <typename T>
T* GetVar(int32_t idx);
std::vector<VarDesc>& GetVars() { return vars_; }
template <typename T>
T const* GetVar(int32_t idx) const {
return GetVar<T>(idx);
}
T const* GetVar(int32_t idx) const;
std::vector<VarDesc>& GetVars() { return vars_; }
template <typename T>
T* AddVar();
......@@ -64,9 +62,7 @@ class BlockDesc : public BlockDescAPI {
T* GetOp(int32_t idx);
template <typename T>
T const* GetOp(int32_t idx) const {
return GetOp<T>(idx);
}
T const* GetOp(int32_t idx) const;
template <typename T>
T* AddOp();
......
......@@ -24,6 +24,12 @@ BlockDesc* ProgramDesc::GetBlock<BlockDesc>(int32_t idx) {
return &blocks_[idx];
}
template <>
BlockDesc const* ProgramDesc::GetBlock<BlockDesc>(int32_t idx) const {
CHECK_LT(idx, BlocksSize()) << "idx >= blocks.size()";
return &blocks_[idx];
}
template <>
BlockDesc* ProgramDesc::AddBlock<BlockDesc>() {
blocks_.emplace_back();
......
......@@ -30,6 +30,13 @@ class ProgramDesc : public ProgramDescAPI {
public:
ProgramDesc() = default;
void CopyFrom(const ProgramDesc& other) {
version_ = other.Version();
blocks_ = other.blocks();
}
const std::vector<BlockDesc>& blocks() const { return blocks_; }
size_t BlocksSize() const override { return blocks_.size(); }
void ClearBlocks() override { blocks_.clear(); }
......@@ -37,12 +44,10 @@ class ProgramDesc : public ProgramDescAPI {
template <typename T>
T* GetBlock(int32_t idx);
std::vector<BlockDesc>& GetBlocks() { return blocks_; }
template <typename T>
T const* GetBlock(int32_t idx) const {
return GetBlock<T>(idx);
}
T const* GetBlock(int32_t idx) const;
std::vector<BlockDesc>& GetBlocks() { return blocks_; }
template <typename T>
T* AddBlock();
......
......@@ -176,7 +176,7 @@ void LoadCombinedParamsPb(const std::string &path,
const cpp::ProgramDesc &cpp_prog,
bool params_from_memory) {
CHECK(scope);
auto prog = cpp_prog;
auto &prog = cpp_prog;
auto &main_block_desc = *prog.GetBlock<cpp::BlockDesc>(0);
// Get vars
......@@ -310,7 +310,7 @@ void SaveModelPb(const std::string &model_dir,
void SaveCombinedParamsPb(const std::string &path,
const lite::Scope &exec_scope,
const cpp::ProgramDesc &cpp_prog) {
auto prog = cpp_prog;
auto &prog = cpp_prog;
auto &main_block_desc = *prog.GetBlock<cpp::BlockDesc>(0);
// Get vars
......@@ -526,7 +526,7 @@ void SaveCombinedParamsNaive(const std::string &path,
naive_buffer::proto::CombinedParamsDesc pt_desc(&table);
naive_buffer::CombinedParamsDesc desc(&pt_desc);
auto prog = cpp_prog;
auto &prog = cpp_prog;
auto &main_block_desc = *prog.GetBlock<cpp::BlockDesc>(0);
// set unique_var_names to avoid saving shared params repeatedly
std::set<std::string> unique_var_names;
......@@ -681,7 +681,7 @@ void LoadCombinedParamsNaive(const std::string &path,
}
// Check all params loaded
auto prog = cpp_prog;
auto &prog = cpp_prog;
auto &main_block_desc = *prog.GetBlock<cpp::BlockDesc>(0);
for (size_t i = 0; i < main_block_desc.VarsSize(); ++i) {
auto &var = *main_block_desc.GetVar<cpp::VarDesc>(i);
......
......@@ -55,11 +55,6 @@ class BlockDesc : public BlockDescAPI {
template <typename T>
T* GetVar(int32_t idx);
template <typename T>
T const* GetVar(int32_t idx) const {
return GetVar<T>(idx);
}
template <typename T>
T* AddVar();
......@@ -70,11 +65,6 @@ class BlockDesc : public BlockDescAPI {
template <typename T>
T* GetOp(int32_t idx);
template <typename T>
T const* GetOp(int32_t idx) const {
return GetOp<T>(idx);
}
template <typename T>
T* AddOp();
......
......@@ -45,11 +45,6 @@ class ProgramDesc : public ProgramDescAPI {
template <typename T>
T *GetBlock(int32_t idx);
template <typename T>
T const *GetBlock(int32_t idx) const {
return GetBlock<T>(idx);
}
template <typename T>
T *AddBlock();
......
......@@ -83,7 +83,7 @@ class DeformableConvOpLite : public OpLite {
param_.conv_param.filter =
scope->FindVar(Filter)->GetMutable<lite::Tensor>();
param_.conv_param.strides = op_desc.GetAttr<std::vector<int>>("strides");
auto paddings = op_desc.GetAttr<std::vector<int>>("paddings");
std::vector<int> paddings = op_desc.GetAttr<std::vector<int>>("paddings");
auto dilations = op_desc.GetAttr<std::vector<int>>("dilations");
param_.conv_param.groups = op_desc.GetAttr<int>("groups");
param_.conv_param.dilations = std::make_shared<std::vector<int>>(dilations);
......
......@@ -54,7 +54,7 @@ class MaxPoolWithIndexOpLite : public OpLite {
param_.ksize = op_desc.GetAttr<std::vector<int>>("ksize");
param_.global_pooling = op_desc.GetAttr<bool>("global_pooling");
param_.strides = op_desc.GetAttr<std::vector<int>>("strides");
auto paddings = op_desc.GetAttr<std::vector<int>>("paddings");
std::vector<int> paddings = op_desc.GetAttr<std::vector<int>>("paddings");
if (op_desc.HasAttr("adaptive")) {
param_.adaptive = op_desc.GetAttr<bool>("adaptive");
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册