未验证 提交 d51324bf 编写于 作者: 石晓伟 提交者: GitHub

new class: ParamDesc, test=develop (#4009)

* add ParamDesc, test=develop

* serialize tensor funcs, test=develop
上级 473db814
......@@ -100,7 +100,7 @@ function(compile_flatbuffers_schema_to_cpp_opt TARGET SRC_FBS OPT)
${OPT}
-o "${CMAKE_CURRENT_SOURCE_DIR}/${SRC_FBS_DIR}"
"${CMAKE_CURRENT_SOURCE_DIR}/${SRC_FBS}"
DEPENDS flatbuffers
DEPENDS flatbuffers ${SRC_FBS}
COMMENT "Run generation: '${GEN_HEADER}'")
register_generated_output(${GEN_HEADER})
add_custom_target(${TARGET} ALL DEPENDS ${GEN_HEADER})
......@@ -108,7 +108,10 @@ endfunction()
set(FRAMEWORK_FBS_DIR "lite/model_parser/flatbuffers")
set(FRAMEWORK_SCHEMA_PATH "${FRAMEWORK_FBS_DIR}/framework.fbs")
set(PARAM_SCHEMA_PATH "${FRAMEWORK_FBS_DIR}/param.fbs")
compile_flatbuffers_schema_to_cpp_opt(framework_fbs_header ${FRAMEWORK_SCHEMA_PATH} "--no-includes;--gen-compare;--force-empty")
compile_flatbuffers_schema_to_cpp_opt(param_fbs_header ${PARAM_SCHEMA_PATH} "--no-includes;--gen-compare;--force-empty")
include_directories(${FLATBUFFERS_INCLUDE_DIR})
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/${SRC_FBS_DIR})
add_custom_target(fbs_headers ALL DEPENDS framework_fbs_header param_fbs_header)
......@@ -16,7 +16,7 @@ if ((NOT LITE_ON_TINY_PUBLISH) AND (LITE_WITH_CUDA OR LITE_WITH_X86 OR LITE_WITH
lite_cc_library(paddle_full_api_shared SHARED SRCS paddle_api.cc light_api.cc cxx_api.cc cxx_api_impl.cc light_api_impl.cc
DEPS paddle_api paddle_api_light paddle_api_full)
target_sources(paddle_full_api_shared PUBLIC ${__lite_cc_files})
add_dependencies(paddle_full_api_shared op_list_h kernel_list_h framework_proto op_registry framework_fbs_header)
add_dependencies(paddle_full_api_shared op_list_h kernel_list_h framework_proto op_registry fbs_headers)
target_link_libraries(paddle_full_api_shared framework_proto op_registry)
if(LITE_WITH_X86)
add_dependencies(paddle_full_api_shared xxhash)
......@@ -72,7 +72,7 @@ else()
set(TARGET_COMIPILE_FLAGS "${TARGET_COMIPILE_FLAGS} -flto")
endif()
set_target_properties(paddle_light_api_shared PROPERTIES COMPILE_FLAGS "${TARGET_COMIPILE_FLAGS}")
add_dependencies(paddle_light_api_shared op_list_h kernel_list_h framework_fbs_header)
add_dependencies(paddle_light_api_shared op_list_h kernel_list_h fbs_headers)
if (LITE_WITH_NPU)
# Need to add HIAI runtime libs (libhiai.so) dependency
target_link_libraries(paddle_light_api_shared ${npu_builder_libs} ${npu_runtime_libs})
......
......@@ -17,7 +17,7 @@ if (NOT LITE_ON_TINY_PUBLISH)
# Unlike static library, module library has to link target to be able to work
# as a single .so lib.
target_link_libraries(paddle_lite_jni ${lib_DEPS} ${arm_kernels} ${npu_kernels})
add_dependencies(paddle_lite_jni framework_fbs_header)
add_dependencies(paddle_lite_jni fbs_headers)
if (LITE_WITH_NPU)
# Strips the symbols of our protobuf functions to fix the conflicts during
# loading HIAI builder libs (libhiai_ir.so and libhiai_ir_build.so)
......@@ -32,7 +32,7 @@ else()
endif()
set_target_properties(paddle_lite_jni PROPERTIES COMPILE_FLAGS ${TARGET_COMIPILE_FLAGS})
target_sources(paddle_lite_jni PUBLIC ${__lite_cc_files} paddle_lite_jni.cc tensor_jni.cc)
add_dependencies(paddle_lite_jni op_list_h kernel_list_h framework_fbs_header)
add_dependencies(paddle_lite_jni op_list_h kernel_list_h fbs_headers)
if (LITE_WITH_NPU)
# Need to add HIAI runtime libs (libhiai.so) dependency
target_link_libraries(paddle_lite_jni ${npu_builder_libs} ${npu_runtime_libs})
......
......@@ -16,6 +16,7 @@
#include "lite/model_parser/base/block_desc.h"
#include "lite/model_parser/base/op_desc.h"
#include "lite/model_parser/base/param_desc.h"
#include "lite/model_parser/base/program_desc.h"
#include "lite/model_parser/base/proto_desc.h"
#include "lite/model_parser/base/traits.h"
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/model_parser/base/traits.h"
#include "lite/utils/cp_logging.h"
namespace paddle {
namespace lite {
class ParamDescReadAPI {
public:
virtual std::string Name() const = 0;
virtual std::vector<int64_t> Dim() const = 0;
virtual VarDataType GetDataType() const = 0;
virtual const void *GetData() const = 0;
virtual size_t byte_size() const = 0;
virtual ~ParamDescReadAPI() = default;
};
class ParamDescWriteAPI {
public:
virtual void SetName(const std::string &name) { NotImplemented(); }
virtual void SetDim(const std::vector<int64_t> &dim) { NotImplemented(); }
virtual void SetDataType(VarDataType data_type) { NotImplemented(); }
virtual void SetData(const void *data, size_t byte_size) { NotImplemented(); }
virtual ~ParamDescWriteAPI() = default;
private:
void NotImplemented() const {
LOG(FATAL) << "ParamDescWriteAPI is not available in model read-only mode.";
}
};
class CombinedParamsDescReadAPI {
public:
virtual const ParamDescReadAPI *GetParamDesc(size_t idx) const = 0;
virtual size_t GetParamsSize() const = 0;
virtual ~CombinedParamsDescReadAPI() = default;
};
class CombinedParamsDescWriteAPI {
public:
virtual ParamDescWriteAPI *AddParamDesc() {
NotImplemented();
return nullptr;
}
virtual ~CombinedParamsDescWriteAPI() = default;
private:
void NotImplemented() const {
LOG(FATAL) << "CombinedParamsDescWriteAPI is not available in model "
"read-only mode.";
}
};
// The reading and writing of the model are one-time and separate.
// This interface is a combination of reading and writing interfaces,
// which is used to support legacy interfaces.
class ParamDescAPI : public ParamDescReadAPI, public ParamDescWriteAPI {
public:
virtual ~ParamDescAPI() = default;
};
class CombinedParamsDescAPI : public CombinedParamsDescReadAPI,
public CombinedParamsDescWriteAPI {
public:
virtual ~CombinedParamsDescAPI() = default;
};
} // namespace lite
} // namespace paddle
......@@ -16,6 +16,8 @@
#include <string>
#include <vector>
#include "lite/api/paddle_place.h"
#include "lite/utils/cp_logging.h"
namespace paddle {
namespace lite {
......@@ -37,6 +39,77 @@ enum class OpAttrType {
UNK,
};
enum class VarDataType {
// Pod Types
BOOL = 0,
INT16,
INT32,
INT64,
FP16,
FP32,
FP64,
// Tensor<size_t> is used in C++.
SIZE_T,
UINT8,
INT8,
// Other types that may need additional descriptions
LOD_TENSOR,
SELECTED_ROWS,
FEED_MINIBATCH,
FETCH_LIST,
STEP_SCOPES,
LOD_RANK_TABLE,
LOD_TENSOR_ARRAY,
PLACE_LIST,
READER,
// Any runtime decided variable type is raw
// raw variables should manage their own allocations
// in operators like nccl_op
RAW,
TUPLE
};
inline VarDataType ConvertPrecisionType(lite_api::PrecisionType type) {
#define CASE(ptype, vtype) \
case lite_api::PrecisionType::k##ptype: \
return lite::VarDataType::vtype; \
break
switch (type) {
CASE(Float, FP32);
CASE(Int8, INT8);
CASE(Int32, INT32);
CASE(FP16, FP16);
CASE(Bool, BOOL);
CASE(Int64, INT64);
CASE(Int16, INT16);
default:
LOG(FATAL) << "Illegal flatbuffer VarType.";
return lite::VarDataType();
}
#undef CASE
}
inline lite_api::PrecisionType ConvertPrecisionType(VarDataType type) {
#define CASE(ptype, vtype) \
case lite::VarDataType::vtype: \
return lite_api::PrecisionType::k##ptype; \
break
switch (type) {
CASE(Float, FP32);
CASE(Int8, INT8);
CASE(Int32, INT32);
CASE(FP16, FP16);
CASE(Bool, BOOL);
CASE(Int64, INT64);
CASE(Int16, INT16);
default:
LOG(FATAL) << "Illegal flatbuffer VarType.";
return lite_api::PrecisionType();
}
#undef CASE
}
struct Standard {};
struct Flatbuffers {};
......
......@@ -16,42 +16,12 @@
#include <string>
#include <vector>
#include "lite/model_parser/base/traits.h"
#include "lite/utils/cp_logging.h"
namespace paddle {
namespace lite {
enum class VarDataType {
// Pod Types
BOOL = 0,
INT16,
INT32,
INT64,
FP16,
FP32,
FP64,
// Tensor<size_t> is used in C++.
SIZE_T,
UINT8,
INT8,
// Other types that may need additional descriptions
LOD_TENSOR,
SELECTED_ROWS,
FEED_MINIBATCH,
FETCH_LIST,
STEP_SCOPES,
LOD_RANK_TABLE,
LOD_TENSOR_ARRAY,
PLACE_LIST,
READER,
// Any runtime decided variable type is raw
// raw variables should manage their own allocations
// in operators like nccl_op
RAW,
TUPLE
};
class VarDescReadAPI {
public:
virtual std::string Name() const = 0;
......
......@@ -5,9 +5,10 @@ function(lite_fbs_library TARGET)
add_dependencies(${TARGET} ${args_FBS_DEPS})
endfunction()
lite_fbs_library(fbs_op_desc SRCS op_desc.cc FBS_DEPS framework_fbs_header)
lite_fbs_library(fbs_var_desc SRCS var_desc.cc FBS_DEPS framework_fbs_header)
lite_fbs_library(fbs_block_desc SRCS block_desc.cc FBS_DEPS framework_fbs_header)
lite_fbs_library(fbs_op_desc SRCS op_desc.cc FBS_DEPS fbs_headers)
lite_fbs_library(fbs_var_desc SRCS var_desc.cc FBS_DEPS fbs_headers)
lite_fbs_library(fbs_block_desc SRCS block_desc.cc FBS_DEPS fbs_headers)
lite_cc_library(fbs_program_desc SRCS program_desc.cc DEPS fbs_op_desc fbs_var_desc fbs_block_desc)
lite_cc_library(fbs_io SRCS io.cc DEPS fbs_program_desc)
lite_fbs_library(fbs_param_desc SRCS param_desc.cc FBS_DEPS fbs_headers)
lite_cc_library(fbs_io SRCS io.cc DEPS fbs_program_desc fbs_param_desc)
lite_cc_test(test_vector_view SRCS vector_view_test.cc DEPS fbs_program_desc)
......@@ -13,9 +13,11 @@
// limitations under the License.
#include "lite/model_parser/flatbuffers/io.h"
#include <cstring>
#include <memory>
#include <utility>
#include <vector>
#include "lite/model_parser/flatbuffers/traits.h"
namespace paddle {
namespace lite {
......@@ -33,6 +35,43 @@ void LoadModel(const std::string& path, ProgramDesc* prog) {
prog->Init(std::move(buf));
}
void SetParamWithTensor(const std::string& name,
const lite::Tensor& tensor,
ParamDescWriteAPI* prog) {
CHECK(prog);
prog->SetName(name);
prog->SetDim(tensor.dims().Vectorize());
prog->SetDataType(lite::ConvertPrecisionType(tensor.precision()));
prog->SetData(tensor.raw_data(), tensor.memory_size());
}
void SetTensorWithParam(lite::Tensor* tensor, const ParamDescReadAPI& param) {
tensor->Resize(param.Dim());
tensor->set_precision(lite::ConvertPrecisionType(param.GetDataType()));
std::memcpy(tensor->mutable_data(param.byte_size()),
param.GetData(),
param.byte_size());
}
void SetCombinedParamsWithScope(const lite::Scope& scope,
const std::vector<std::string>& params_name,
CombinedParamsDescWriteAPI* params) {
for (const auto& name : params_name) {
auto* param = params->AddParamDesc();
auto& tensor = scope.FindVar(name)->Get<lite::Tensor>();
SetParamWithTensor(name, tensor, param);
}
}
void SetScopeWithCombinedParams(lite::Scope* scope,
const CombinedParamsDescReadAPI& params) {
CHECK(scope);
for (size_t i = 0; i < params.GetParamsSize(); ++i) {
const auto& param = *params.GetParamDesc(i);
auto* tensor = scope->Var(param.Name())->GetMutable<lite::Tensor>();
SetTensorWithParam(tensor, param);
}
}
} // namespace fbs
} // namespace lite
} // namespace paddle
......@@ -15,6 +15,10 @@
#pragma once
#include <string>
#include <vector>
#include "lite/core/scope.h"
#include "lite/core/tensor.h"
#include "lite/model_parser/flatbuffers/param_desc.h"
#include "lite/model_parser/flatbuffers/program_desc.h"
namespace paddle {
......@@ -23,6 +27,17 @@ namespace fbs {
void LoadModel(const std::string& path, ProgramDesc* prog);
void SetParamWithTensor(const std::string& name,
const lite::Tensor& tensor,
ParamDescWriteAPI* prog);
void SetTensorWithParam(const lite::Tensor& tensor, ParamDescReadAPI* prog);
void SetCombinedParamsWithScope(const lite::Scope& scope,
const std::vector<std::string>& params_name,
CombinedParamsDescWriteAPI* params);
void SetScopeWithCombinedParams(lite::Scope* scope,
const CombinedParamsDescReadAPI& params);
} // namespace fbs
} // namespace lite
} // namespace paddle
......@@ -21,6 +21,7 @@
#include "lite/model_parser/base/op_desc.h"
#include "lite/model_parser/flatbuffers/framework_generated.h"
#include "lite/model_parser/flatbuffers/traits.h"
#include "lite/model_parser/flatbuffers/vector_view.h"
#include "lite/utils/all.h"
......@@ -96,13 +97,13 @@ class OpDesc : public OpDescAPI {
OpDescAPI::AttrType GetAttrType(const std::string& name) const override {
const auto& attr = desc_->attrs()->LookupByKey(name.c_str());
CHECK(attr) << "Can not find attr: " << name;
return static_cast<OpDescAPI::AttrType>(attr->type());
return ConvertAttrType(attr->type());
}
OpDescAPI::AttrType GetAttrType(size_t idx) const {
const auto& attr = desc_->attrs()->Get(idx);
CHECK(attr);
return static_cast<OpDescAPI::AttrType>(attr->type());
return ConvertAttrType(attr->type());
}
std::vector<std::string> AttrNames() const override {
......
include "framework.fbs";
namespace paddle.lite.fbs.proto;
table CombinedParamsDesc {
params:[paddle.lite.fbs.proto.ParamDesc];
}
namespace paddle.lite.fbs.proto.ParamDesc_;
table LoDTensorDesc {
lod_level:int;
lod:[long];
dim:[long];
data_type:paddle.lite.fbs.proto.VarType_.Type;
data:[byte];
}
table VersionDesc {
version:int;
model_version:int;
}
union VariableDesc {
LoDTensorDesc
}
namespace paddle.lite.fbs.proto;
table ParamDesc {
version:paddle.lite.fbs.proto.ParamDesc_.VersionDesc;
name:string (required, key);
variable:paddle.lite.fbs.proto.ParamDesc_.VariableDesc;
}
root_type paddle.lite.fbs.proto.ParamDesc;
root_type paddle.lite.fbs.proto.CombinedParamsDesc;
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/model_parser/flatbuffers/param_desc.h"
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <cstring>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "lite/model_parser/base/param_desc.h"
#include "lite/model_parser/flatbuffers/framework_generated.h"
#include "lite/model_parser/flatbuffers/param_generated.h"
#include "lite/model_parser/flatbuffers/traits.h"
namespace paddle {
namespace lite {
namespace fbs {
class ParamDescView : public ParamDescReadAPI {
public:
explicit ParamDescView(proto::ParamDesc const* desc) : desc_(desc) {
CHECK(desc_);
CHECK(desc_->variable_type() ==
proto::ParamDesc_::VariableDesc_LoDTensorDesc);
tensor_desc_ = desc_->variable_as<proto::ParamDesc_::LoDTensorDesc>();
}
std::string Name() const override { return desc_->name()->c_str(); }
std::vector<int64_t> Dim() const override {
const auto& dims = tensor_desc_->dim();
std::vector<int64_t> dims_vec;
dims_vec.reserve(dims->size());
for (const auto& dim : *dims) {
dims_vec.push_back(dim);
}
return dims_vec;
}
VarDataType GetDataType() const override {
return ConvertVarType(tensor_desc_->data_type());
}
const void* GetData() const override { return tensor_desc_->data()->Data(); }
size_t byte_size() const override { return tensor_desc_->data()->size(); }
ParamDescView() = delete;
private:
proto::ParamDesc const* desc_;
proto::ParamDesc_::LoDTensorDesc const* tensor_desc_;
};
class CombinedParamsDescView : public CombinedParamsDescReadAPI {
public:
CombinedParamsDescView() = default;
explicit CombinedParamsDescView(const std::vector<char>& buf) { Init(buf); }
explicit CombinedParamsDescView(std::vector<char>&& buf) {
Init(std::forward<std::vector<char>>(buf));
}
void Init(const std::vector<char>& buf) {
CHECK(buf.data());
buf_ = buf;
InitParams();
}
void Init(std::vector<char>&& buf) {
CHECK(buf.data());
buf_ = std::move(buf);
InitParams();
}
void InitParams() {
desc_ = proto::GetCombinedParamsDesc(buf_.data());
params_.reserve(GetParamsSize());
for (size_t idx = 0; idx < GetParamsSize(); ++idx) {
params_.push_back(ParamDescView(desc_->params()->Get(idx)));
}
}
const ParamDescReadAPI* GetParamDesc(size_t idx) const override {
CHECK(idx < GetParamsSize());
return &params_[idx];
}
size_t GetParamsSize() const override { return params_.size(); }
private:
std::vector<ParamDescView> params_;
std::vector<char> buf_;
proto::CombinedParamsDesc const* desc_;
};
class ParamDesc : public ParamDescAPI {
public:
ParamDesc() : owned_(true), desc_(new proto::ParamDescT()) {
desc_->variable.Set(proto::ParamDesc_::LoDTensorDescT());
lod_tensor_ = desc_->variable.AsLoDTensorDesc();
CHECK(lod_tensor_);
}
explicit ParamDesc(proto::ParamDescT* desc) : desc_(desc) {
lod_tensor_ = desc_->variable.AsLoDTensorDesc();
CHECK(lod_tensor_);
}
std::string Name() const override { return desc_->name; }
void SetName(const std::string& name) override { desc_->name = name; }
std::vector<int64_t> Dim() const override { return lod_tensor_->dim; }
void SetDim(const std::vector<int64_t>& dim) override {
lod_tensor_->dim = dim;
}
VarDataType GetDataType() const override {
return ConvertVarType(lod_tensor_->data_type);
}
void SetDataType(VarDataType data_type) override {
lod_tensor_->data_type = ConvertVarType(data_type);
}
const void* GetData() const override { return lod_tensor_->data.data(); }
size_t byte_size() const override { return lod_tensor_->data.size(); }
void SetData(const void* data, size_t byte_size) {
lod_tensor_->data.resize(byte_size);
std::memcpy(lod_tensor_->data.data(), data, byte_size);
}
const proto::ParamDescT* raw_desc() const { return desc_; }
~ParamDesc() {
if (owned_) {
delete desc_;
}
}
private:
bool owned_{false};
proto::ParamDescT* desc_{nullptr};
proto::ParamDesc_::LoDTensorDescT* lod_tensor_{nullptr};
};
class CombinedParamsDesc : public CombinedParamsDescAPI {
public:
CombinedParamsDesc() = default;
explicit CombinedParamsDesc(const std::vector<char>& buf) {
const auto* raw_buf = proto::GetCombinedParamsDesc(buf.data());
raw_buf->UnPackTo(&desc_);
SyncParams();
}
const ParamDescReadAPI* GetParamDesc(size_t idx) const override {
return &params_[idx];
}
size_t GetParamsSize() const override { return desc_.params.size(); }
ParamDescWriteAPI* AddParamDesc() override {
desc_.params.push_back(std::unique_ptr<proto::ParamDescT>());
SyncParams();
return &params_[params_.size() - 1];
}
const void* data() {
SyncBuffer();
return buf_.data();
}
size_t buf_size() {
SyncBuffer();
return buf_.size();
}
private:
void SyncParams() {
params_.resize(GetParamsSize());
for (size_t i = 0; i < GetParamsSize(); ++i) {
if (params_[i].raw_desc() != desc_.params[i].get()) {
params_[i] = ParamDesc(desc_.params[i].get());
}
}
}
void SyncBuffer() {
fbb_.Reset();
flatbuffers::Offset<proto::CombinedParamsDesc> desc =
proto::CombinedParamsDesc::Pack(fbb_, &desc_);
fbb_.Finish(desc);
buf_ = fbb_.Release();
}
flatbuffers::DetachedBuffer buf_;
flatbuffers::FlatBufferBuilder fbb_;
proto::CombinedParamsDescT desc_;
std::vector<ParamDesc> params_;
};
} // namespace fbs
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/model_parser/base/traits.h"
#include "lite/model_parser/flatbuffers/framework_generated.h"
namespace paddle {
namespace lite {
namespace fbs {
inline lite::VarDataType ConvertVarType(proto::VarType_::Type type) {
#define CASE(type) \
case proto::VarType_::Type_##type: \
return lite::VarDataType::type; \
break
switch (type) {
CASE(BOOL);
CASE(INT16);
CASE(INT32);
CASE(INT64);
CASE(FP16);
CASE(FP32);
CASE(FP64);
CASE(LOD_TENSOR);
CASE(SELECTED_ROWS);
CASE(FEED_MINIBATCH);
CASE(FETCH_LIST);
CASE(STEP_SCOPES);
CASE(LOD_RANK_TABLE);
CASE(LOD_TENSOR_ARRAY);
CASE(PLACE_LIST);
CASE(READER);
CASE(RAW);
CASE(TUPLE);
CASE(SIZE_T);
CASE(UINT8);
CASE(INT8);
default:
LOG(FATAL) << "Illegal flatbuffer VarType.";
return lite::VarDataType();
}
#undef CASE
}
inline proto::VarType_::Type ConvertVarType(lite::VarDataType type) {
#define CASE(type) \
case lite::VarDataType::type: \
return proto::VarType_::Type_##type; \
break
switch (type) {
CASE(BOOL);
CASE(INT16);
CASE(INT32);
CASE(INT64);
CASE(FP16);
CASE(FP32);
CASE(FP64);
CASE(LOD_TENSOR);
CASE(SELECTED_ROWS);
CASE(FEED_MINIBATCH);
CASE(FETCH_LIST);
CASE(STEP_SCOPES);
CASE(LOD_RANK_TABLE);
CASE(LOD_TENSOR_ARRAY);
CASE(PLACE_LIST);
CASE(READER);
CASE(RAW);
CASE(TUPLE);
CASE(SIZE_T);
CASE(UINT8);
CASE(INT8);
default:
LOG(FATAL) << "Illegal flatbuffer VarType.";
return proto::VarType_::Type();
}
#undef CASE
}
inline lite::OpAttrType ConvertAttrType(proto::AttrType type) {
#define CASE(type) \
case proto::AttrType_##type: \
return lite::OpAttrType::type; \
break
switch (type) {
CASE(INT);
CASE(FLOAT);
CASE(STRING);
CASE(INTS);
CASE(FLOATS);
CASE(STRINGS);
CASE(BOOLEAN);
CASE(BOOLEANS);
CASE(BLOCK);
CASE(LONG);
CASE(BLOCKS);
CASE(LONGS);
default:
LOG(FATAL) << "Illegal flatbuffer AttrType.";
return lite::OpAttrType();
}
#undef CASE
}
inline proto::AttrType ConvertAttrType(lite::OpAttrType type) {
#define CASE(type) \
case lite::OpAttrType::type: \
return proto::AttrType_##type; \
break
switch (type) {
CASE(INT);
CASE(FLOAT);
CASE(STRING);
CASE(INTS);
CASE(FLOATS);
CASE(STRINGS);
CASE(BOOLEAN);
CASE(BOOLEANS);
CASE(BLOCK);
CASE(LONG);
CASE(BLOCKS);
CASE(LONGS);
default:
LOG(FATAL) << "Illegal flatbuffer AttrType.";
return proto::AttrType();
}
#undef CASE
}
} // namespace fbs
} // namespace lite
} // namespace paddle
......@@ -19,6 +19,7 @@
#include <vector>
#include "lite/model_parser/base/var_desc.h"
#include "lite/model_parser/flatbuffers/framework_generated.h"
#include "lite/model_parser/flatbuffers/traits.h"
#include "lite/utils/all.h"
namespace paddle {
......@@ -32,7 +33,7 @@ class VarDesc : public VarDescAPI {
std::string Name() const override { return desc_->name()->str(); }
VarDescAPI::Type GetType() const override {
return static_cast<VarDescAPI::Type>(desc_->type()->type());
return ConvertVarType(desc_->type()->type());
}
bool Persistable() const override { return desc_->persistable(); }
......@@ -50,8 +51,7 @@ class VarDesc : public VarDescAPI {
VarDescAPI::Type GetDataType() const {
CHECK(GetType() == VarDescAPI::Type::LOD_TENSOR);
return static_cast<VarDescAPI::Type>(
desc_->type()->lod_tensor()->tensor()->data_type());
return ConvertVarType(desc_->type()->lod_tensor()->tensor()->data_type());
}
private:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册