未验证 提交 7ef0e7fe 编写于 作者: W Wilber 提交者: GitHub

modify static_kernel_pass to support select the kernel according to input type (#2488)

修改了选kernel的逻辑,默认从模型文件中读取出lod_tensor的data type,在static_kernel_pick pass中如果kernel输入输出的类型与读取的data type完全一致,则选择该Kernel的概率增大。

- 增加 从模型文件__model__读取lod_tensor的data type到cpp::vardesc

- program中增加unordered_map<string, type>字段,并在 Program::PrepareWorkspace中对该字段赋值

- 修改了node.h文件,将const Type* 更改为Type*,并在SSAGraph::Build过程中为符合条件的type*赋值

- static_kernel_pick_pass中添加新规则,如果kernel的输入类型输出类型与__model__中存储的类型的一致,则score*=2。

- 支持模型中用到sequence_reverse_float kernel(输入输出均为float)和sequence_reverse_int64 kernel(输入输出均为int64),能够根据输入输出type选kernel
上级 9a3552db
...@@ -123,6 +123,9 @@ void SSAGraph::Build(const Program &program, ...@@ -123,6 +123,9 @@ void SSAGraph::Build(const Program &program,
return true; return true;
}; };
std::unordered_map<std::string, PrecisionType> var_types =
program.var_data_type();
std::unordered_map<std::string, mir::Node *> arg_update_node_map_; std::unordered_map<std::string, mir::Node *> arg_update_node_map_;
for (auto &op : program.ops()) { for (auto &op : program.ops()) {
VLOG(3) << op->op_info()->Type(); VLOG(3) << op->op_info()->Type();
...@@ -137,6 +140,10 @@ void SSAGraph::Build(const Program &program, ...@@ -137,6 +140,10 @@ void SSAGraph::Build(const Program &program,
arg_node->AsArg(name, node_storage_.size() - 1); arg_node->AsArg(name, node_storage_.size() - 1);
arg_update_node_map_[name] = arg_node; arg_update_node_map_[name] = arg_node;
} }
if (var_types.count(name) && !arg_node->arg()->type) {
arg_node->arg()->type = LiteType::GetTensorTy(
TARGET(kUnk), var_types[name], DATALAYOUT(kUnk));
}
if (is_weights(name)) arg_node->AsArg().is_weight = true; if (is_weights(name)) arg_node->AsArg().is_weight = true;
CHECK(arg_node->IsRoleSet()); CHECK(arg_node->IsRoleSet());
DirectedLink(arg_node, op_node); DirectedLink(arg_node, op_node);
...@@ -146,6 +153,10 @@ void SSAGraph::Build(const Program &program, ...@@ -146,6 +153,10 @@ void SSAGraph::Build(const Program &program,
auto *arg_node = &node_storage_.back(); auto *arg_node = &node_storage_.back();
arg_node->AsArg(name, node_storage_.size() - 1); arg_node->AsArg(name, node_storage_.size() - 1);
arg_update_node_map_[name] = arg_node; arg_update_node_map_[name] = arg_node;
if (var_types.count(name) && !arg_node->arg()->type) {
arg_node->arg()->type = LiteType::GetTensorTy(
TARGET(kUnk), var_types[name], DATALAYOUT(kUnk));
}
if (is_weights(name)) arg_node->AsArg().is_weight = true; if (is_weights(name)) arg_node->AsArg().is_weight = true;
CHECK(arg_node->IsRoleSet()); CHECK(arg_node->IsRoleSet());
......
...@@ -14,7 +14,10 @@ ...@@ -14,7 +14,10 @@
#include "lite/core/mir/static_kernel_pick_pass.h" #include "lite/core/mir/static_kernel_pick_pass.h"
#include <algorithm> #include <algorithm>
#include <list>
#include <memory> #include <memory>
#include <string>
#include <unordered_map>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/mir/graph_visualize_pass.h" #include "lite/core/mir/graph_visualize_pass.h"
...@@ -43,13 +46,33 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -43,13 +46,33 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
if (!node.IsStmt()) continue; if (!node.IsStmt()) continue;
auto& instruct = node.AsStmt(); auto& instruct = node.AsStmt();
std::unordered_map<std::string, PrecisionType> in_types;
std::unordered_map<std::string, PrecisionType> out_types;
for (std::list<Node*>::iterator i = node.inlinks.begin();
i != node.inlinks.end();
++i) {
if ((*i)->arg()->type)
in_types[(*i)->arg()->name] = (*i)->arg()->type->precision();
}
for (std::list<Node*>::iterator i = node.outlinks.begin();
i != node.outlinks.end();
++i) {
if ((*i)->arg()->type)
out_types[(*i)->arg()->name] = (*i)->arg()->type->precision();
}
// Get candidate kernels // Get candidate kernels
std::vector<std::pair<float, std::unique_ptr<KernelBase>>> scored; std::vector<std::pair<float, std::unique_ptr<KernelBase>>> scored;
CHECK(!instruct.kernels().empty()) << "No kernels found for " CHECK(!instruct.kernels().empty()) << "No kernels found for "
<< instruct.op_type(); << instruct.op_type();
VLOG(4) << "instruct.kernels().size():" << instruct.kernels().size(); VLOG(4) << "instruct.kernels().size():" << instruct.kernels().size();
for (auto&& kernel : instruct.kernels()) { for (auto&& kernel : instruct.kernels()) {
float score = KernelGrade(instruct, *kernel, graph->valid_places()); float score = KernelGrade(instruct,
*kernel,
graph->valid_places(),
in_types,
out_types,
instruct.op_info()->input_names(),
instruct.op_info()->output_names());
VLOG(4) << "kernel->summary():" << kernel->summary() VLOG(4) << "kernel->summary():" << kernel->summary()
<< " score:" << score; << " score:" << score;
scored.emplace_back(score, std::move(kernel)); scored.emplace_back(score, std::move(kernel));
...@@ -99,7 +122,13 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -99,7 +122,13 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
instruct.ResetOp(update_desc, graph->valid_places()); instruct.ResetOp(update_desc, graph->valid_places());
scored.clear(); scored.clear();
for (auto&& kernel : instruct.kernels()) { for (auto&& kernel : instruct.kernels()) {
float score = KernelGrade(instruct, *kernel, graph->valid_places()); float score = KernelGrade(instruct,
*kernel,
graph->valid_places(),
in_types,
out_types,
instruct.op_info()->input_names(),
instruct.op_info()->output_names());
scored.emplace_back(score, std::move(kernel)); scored.emplace_back(score, std::move(kernel));
} }
std::sort(scored.begin(), scored.end(), KernelScoreCmp); std::sort(scored.begin(), scored.end(), KernelScoreCmp);
......
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
#include <limits> #include <limits>
#include <memory> #include <memory>
#include <string>
#include <unordered_map>
#include <vector> #include <vector>
#include "lite/core/mir/pass.h" #include "lite/core/mir/pass.h"
#include "lite/core/types.h" #include "lite/core/types.h"
...@@ -48,9 +50,14 @@ class StaticKernelPickPass : public mir::StmtPass { ...@@ -48,9 +50,14 @@ class StaticKernelPickPass : public mir::StmtPass {
private: private:
// Score the kernel. // Score the kernel.
size_t KernelGrade(const lite::mir::Node::Stmt& instruct, size_t KernelGrade(
const lite::mir::Node::Stmt& instruct,
const lite::KernelBase& kernel, const lite::KernelBase& kernel,
const std::vector<Place>& places) { const std::vector<Place>& places,
const std::unordered_map<std::string, PrecisionType>& in_types,
const std::unordered_map<std::string, PrecisionType>& out_types,
const std::vector<std::string>& in_names,
const std::vector<std::string>& out_names) {
CHECK_GT(places.size(), 0) << "valid_places is empty."; CHECK_GT(places.size(), 0) << "valid_places is empty.";
float final_score{-1.}; float final_score{-1.};
Place winner_place{places[0]}; Place winner_place{places[0]};
...@@ -100,6 +107,37 @@ class StaticKernelPickPass : public mir::StmtPass { ...@@ -100,6 +107,37 @@ class StaticKernelPickPass : public mir::StmtPass {
core::KernelPickFactor::Factor::DataLayoutFirst); core::KernelPickFactor::Factor::DataLayoutFirst);
} }
VLOG(4) << "[score s3]:" << score; VLOG(4) << "[score s3]:" << score;
// add new rules for precision: When the input types are consistent with
// kernel's input types and the output types are consistent with kernel's
// output types. Select the kernel of the precision. Note that this
// strategy is not compatible with quantization, so skip quantization op.
if (!instruct.op_info()->HasAttr("enable_int8")) {
bool type_match = true;
for (size_t i = 0; i < in_names.size(); ++i) {
std::string tmp;
CHECK(instruct.op_info()->GetInputArgname(in_names[i], &tmp));
if (in_types.count(in_names[i]) &&
in_types.at(in_names[i]) !=
kernel.GetInputDeclType(tmp)->precision()) {
type_match = false;
}
}
for (size_t i = 0; i < out_names.size(); ++i) {
std::string tmp;
CHECK(instruct.op_info()->GetOutputArgname(out_names[i], &tmp));
if (out_types.count(out_names[i]) &&
out_types.at(out_names[i]) !=
kernel.GetOutputDeclType(tmp)->precision()) {
type_match = false;
}
}
if (type_match) {
score *= 2;
}
VLOG(4) << "[score s4]:" << score;
}
if (weight * score > final_score) { if (weight * score > final_score) {
final_score = weight * score; final_score = weight * score;
winner_place = place; winner_place = place;
......
...@@ -84,23 +84,22 @@ std::vector<std::string> AddFCDesc( ...@@ -84,23 +84,22 @@ std::vector<std::string> AddFCDesc(
static int id = 0; static int id = 0;
std::string prefix = "fc_" + std::to_string(id); std::string prefix = "fc_" + std::to_string(id);
auto* op_desc = block_desc->AddOp<cpp::OpDesc>(); auto* op_desc = block_desc->AddOp<cpp::OpDesc>();
auto* wgt = block_desc->AddVar<cpp::VarDesc>();
auto* bias = block_desc->AddVar<cpp::VarDesc>();
auto* out = block_desc->AddVar<cpp::VarDesc>();
auto* wgt = block_desc->AddVar<cpp::VarDesc>();
wgt->SetName(prefix + "_W"); wgt->SetName(prefix + "_W");
bias->SetName(prefix + "_Bias");
out->SetName(prefix + "_Out");
std::vector<std::string> out_var_names{prefix + "_Out"};
auto* wtensor = scope->Var(prefix + "_W")->GetMutable<lite::Tensor>(); auto* wtensor = scope->Var(prefix + "_W")->GetMutable<lite::Tensor>();
wtensor->Resize(wshape); wtensor->Resize(wshape);
wtensor->mutable_data<float>(); wtensor->mutable_data<float>();
auto* bias = block_desc->AddVar<cpp::VarDesc>();
bias->SetName(prefix + "_Bias");
auto* btensor = scope->Var(prefix + "_Bias")->GetMutable<lite::Tensor>(); auto* btensor = scope->Var(prefix + "_Bias")->GetMutable<lite::Tensor>();
btensor->Resize({wshape[1]}); btensor->Resize({wshape[1]});
btensor->mutable_data<float>(); btensor->mutable_data<float>();
auto* out = block_desc->AddVar<cpp::VarDesc>();
out->SetName(prefix + "_Out");
std::vector<std::string> out_var_names{prefix + "_Out"};
scope->Var(prefix + "_Out")->GetMutable<lite::Tensor>(); scope->Var(prefix + "_Out")->GetMutable<lite::Tensor>();
op_desc->SetType("fc"); op_desc->SetType("fc");
...@@ -192,7 +191,6 @@ std::unique_ptr<mir::SSAGraph> BuildSimpleNet( ...@@ -192,7 +191,6 @@ std::unique_ptr<mir::SSAGraph> BuildSimpleNet(
auto* block_desc = program_desc->AddBlock<cpp::BlockDesc>(); auto* block_desc = program_desc->AddBlock<cpp::BlockDesc>();
block_desc->ClearOps(); block_desc->ClearOps();
block_desc->ClearVars(); block_desc->ClearVars();
auto* var_desc = block_desc->AddVar<cpp::VarDesc>(); auto* var_desc = block_desc->AddVar<cpp::VarDesc>();
var_desc->SetName("feed_var"); var_desc->SetName("feed_var");
auto* feed_var = scope->Var("feed_var")->GetMutable<lite::Tensor>(); auto* feed_var = scope->Var("feed_var")->GetMutable<lite::Tensor>();
......
...@@ -129,6 +129,17 @@ class VariablePlaceInferencePass : public DebugPass { ...@@ -129,6 +129,17 @@ class VariablePlaceInferencePass : public DebugPass {
} else { } else {
x_in->AsArg().type = type; x_in->AsArg().type = type;
} }
} else if (x_in->AsArg().type->target() == TARGET(kUnk) &&
x_in->AsArg().type->precision() != PRECISION(kUnk) &&
x_in->AsArg().type->layout() == DATALAYOUT(kUnk)) {
// If is quantization, infer the Int8 type.
if (type->precision() == PRECISION(kInt8)) {
x_in->AsArg().type = type;
} else {
PrecisionType tmp_ptype = x_in->AsArg().type->precision();
x_in->AsArg().type = LiteType::GetTensorTy(
type->target(), tmp_ptype, type->layout());
}
} }
} }
...@@ -149,6 +160,17 @@ class VariablePlaceInferencePass : public DebugPass { ...@@ -149,6 +160,17 @@ class VariablePlaceInferencePass : public DebugPass {
} else { } else {
x_out->AsArg().type = type; x_out->AsArg().type = type;
} }
} else if (x_out->AsArg().type->target() == TARGET(kUnk) &&
x_out->AsArg().type->precision() != PRECISION(kUnk) &&
x_out->AsArg().type->layout() == DATALAYOUT(kUnk)) {
// If is quantization, infer the Int8 type.
if (type->precision() == PRECISION(kInt8)) {
x_out->AsArg().type = type;
} else {
PrecisionType tmp_ptype = x_out->AsArg().type->precision();
x_out->AsArg().type = LiteType::GetTensorTy(
type->target(), tmp_ptype, type->layout());
}
} }
} }
} }
......
...@@ -168,6 +168,27 @@ void Program::PrepareWorkspace(const cpp::ProgramDesc& prog) { ...@@ -168,6 +168,27 @@ void Program::PrepareWorkspace(const cpp::ProgramDesc& prog) {
tmp_vars_.push_back("feed"); tmp_vars_.push_back("feed");
tmp_vars_.push_back("fetch"); tmp_vars_.push_back("fetch");
auto VarPrecision2KernlPrecision =
[](const lite::VarDescAPI::Type& type) -> PrecisionType {
switch (type) {
case lite::VarDescAPI::Type::FP32:
return PRECISION(kFloat);
case lite::VarDescAPI::Type::FP16:
return PRECISION(kFP16);
case lite::VarDescAPI::Type::INT8:
return PRECISION(kInt8);
case lite::VarDescAPI::Type::INT16:
return PRECISION(kInt16);
case lite::VarDescAPI::Type::INT32:
return PRECISION(kInt32);
case lite::VarDescAPI::Type::INT64:
return PRECISION(kInt64);
default:
// LOG(FATAL) << "not supported type: " << static_cast<int>(type);
return PRECISION(kUnk);
}
};
auto program = prog; auto program = prog;
CHECK(program.BlocksSize()); CHECK(program.BlocksSize());
for (size_t b = 0; b < program.BlocksSize(); ++b) { for (size_t b = 0; b < program.BlocksSize(); ++b) {
...@@ -175,7 +196,16 @@ void Program::PrepareWorkspace(const cpp::ProgramDesc& prog) { ...@@ -175,7 +196,16 @@ void Program::PrepareWorkspace(const cpp::ProgramDesc& prog) {
for (size_t i = 0; i < main_block.VarsSize(); ++i) { for (size_t i = 0; i < main_block.VarsSize(); ++i) {
auto& var_desc = *main_block.GetVar<cpp::VarDesc>(i); auto& var_desc = *main_block.GetVar<cpp::VarDesc>(i);
if (!var_desc.Persistable()) { if (!var_desc.Persistable()) {
if (var_desc.GetType() == lite::VarDescAPI::Type::LOD_TENSOR &&
VarPrecision2KernlPrecision(var_desc.GetDataType()) !=
PRECISION(kUnk)) {
var_data_type_[var_desc.Name()] =
VarPrecision2KernlPrecision(var_desc.GetDataType());
}
tmp_vars_.push_back(var_desc.Name()); tmp_vars_.push_back(var_desc.Name());
VLOG(4) << "var name: " << var_desc.Name() << " type is "
<< static_cast<int>(var_desc.GetType()) << " data type is "
<< static_cast<int>(var_desc.GetDataType());
exec_scope_->Var(var_desc.Name()); exec_scope_->Var(var_desc.Name());
if (b > 0) { if (b > 0) {
VLOG(4) << "var: " << var_desc.Name(); VLOG(4) << "var: " << var_desc.Name();
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <list> #include <list>
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
...@@ -63,6 +64,10 @@ struct Program { ...@@ -63,6 +64,10 @@ struct Program {
lite::Scope* exec_scope() { return exec_scope_; } lite::Scope* exec_scope() { return exec_scope_; }
lite::Scope* scope() { return scope_.get(); } lite::Scope* scope() { return scope_.get(); }
const std::unordered_map<std::string, PrecisionType>& var_data_type() const {
return var_data_type_;
}
private: private:
// Build from a program and scope. // Build from a program and scope.
void Build(const cpp::ProgramDesc& program); void Build(const cpp::ProgramDesc& program);
...@@ -70,6 +75,7 @@ struct Program { ...@@ -70,6 +75,7 @@ struct Program {
void PrepareWorkspace(const cpp::ProgramDesc& program); void PrepareWorkspace(const cpp::ProgramDesc& program);
private: private:
std::unordered_map<std::string, PrecisionType> var_data_type_;
std::list<std::string> tmp_vars_; std::list<std::string> tmp_vars_;
std::list<std::string> weights_; std::list<std::string> weights_;
std::list<std::shared_ptr<OpLite>> ops_; std::list<std::shared_ptr<OpLite>> ops_;
......
...@@ -32,14 +32,6 @@ namespace lite { ...@@ -32,14 +32,6 @@ namespace lite {
/// For VarDesc transfrom /// For VarDesc transfrom
#define TRANS_VAR_ANY_WITH_CPP_IMPL(T) \ #define TRANS_VAR_ANY_WITH_CPP_IMPL(T) \
template <> \ template <> \
void TransformVarDescAnyToCpp<T>(const T &any_desc, \
cpp::VarDesc *cpp_desc) { \
cpp_desc->SetName(any_desc.Name()); \
cpp_desc->SetType(any_desc.GetType()); \
cpp_desc->SetPersistable(any_desc.Persistable()); \
} \
\
template <> \
void TransformVarDescCppToAny<T>(const cpp::VarDesc &cpp_desc, \ void TransformVarDescCppToAny<T>(const cpp::VarDesc &cpp_desc, \
T *any_desc) { \ T *any_desc) { \
any_desc->SetName(cpp_desc.Name()); \ any_desc->SetName(cpp_desc.Name()); \
...@@ -47,6 +39,25 @@ namespace lite { ...@@ -47,6 +39,25 @@ namespace lite {
any_desc->SetPersistable(cpp_desc.Persistable()); \ any_desc->SetPersistable(cpp_desc.Persistable()); \
} }
#ifndef LITE_ON_TINY_PUBLISH
template <>
void TransformVarDescAnyToCpp<pb::VarDesc>(const pb::VarDesc &any_desc,
cpp::VarDesc *cpp_desc) {
cpp_desc->SetName(any_desc.Name());
cpp_desc->SetType(any_desc.GetType());
cpp_desc->SetPersistable(any_desc.Persistable());
cpp_desc->SetDataType(any_desc.GetDataType());
}
#endif
template <>
void TransformVarDescAnyToCpp<naive_buffer::VarDesc>(
const naive_buffer::VarDesc &any_desc, cpp::VarDesc *cpp_desc) {
cpp_desc->SetName(any_desc.Name());
cpp_desc->SetType(any_desc.GetType());
cpp_desc->SetPersistable(any_desc.Persistable());
}
/// For OpDesc transform /// For OpDesc transform
template <typename OpDescType> template <typename OpDescType>
void OpInputsAnyToCpp(const OpDescType &any_desc, cpp::OpDesc *cpp_desc) { void OpInputsAnyToCpp(const OpDescType &any_desc, cpp::OpDesc *cpp_desc) {
......
...@@ -42,9 +42,14 @@ class VarDesc : public VarDescAPI { ...@@ -42,9 +42,14 @@ class VarDesc : public VarDescAPI {
void SetPersistable(bool persistable) override { persistable_ = persistable; } void SetPersistable(bool persistable) override { persistable_ = persistable; }
Type GetDataType() const { return data_type_; }
void SetDataType(Type data_type) { data_type_ = data_type; }
private: private:
std::string name_; std::string name_;
Type type_; Type type_;
Type data_type_;
bool persistable_; bool persistable_;
}; };
......
...@@ -99,6 +99,32 @@ const proto::VarType& VarDesc::GetVarType() const { ...@@ -99,6 +99,32 @@ const proto::VarType& VarDesc::GetVarType() const {
return desc_->GetField<proto::VarType>("type"); return desc_->GetField<proto::VarType>("type");
} }
VarDescAPI::VarDataType VarDesc::GetDataType() const {
using data_type_builder_t = EnumBuilder<proto::VarDataType>;
auto data_type = desc_->GetField<proto::TensorDesc>("tensor_desc")
.GetField<data_type_builder_t>("data_type")
.data();
#define GET_DATA_TYPE_CASE_ITEM(type__) \
case proto::VarDataType::type__: \
return VarDescAPI::VarDataType::type__
switch (data_type) {
// Only support primary data type now.
GET_DATA_TYPE_CASE_ITEM(UINT8);
GET_DATA_TYPE_CASE_ITEM(INT8);
GET_DATA_TYPE_CASE_ITEM(INT16);
GET_DATA_TYPE_CASE_ITEM(INT32);
GET_DATA_TYPE_CASE_ITEM(INT64);
GET_DATA_TYPE_CASE_ITEM(FP32);
GET_DATA_TYPE_CASE_ITEM(FP64);
default:
LOG(FATAL) << "Unknown var data type";
}
return VarDescAPI::VarDataType();
#undef GET_DATA_TYPE_CASE_ITEM
}
proto::VarType* VarDesc::GetMutableVarType() { proto::VarType* VarDesc::GetMutableVarType() {
auto* builder = desc_->GetMutableField<proto::VarType>("type"); auto* builder = desc_->GetMutableField<proto::VarType>("type");
CHECK(builder); CHECK(builder);
......
...@@ -51,6 +51,8 @@ class VarDesc : public VarDescAPI { ...@@ -51,6 +51,8 @@ class VarDesc : public VarDescAPI {
void SetPersistable(bool persistable) override; void SetPersistable(bool persistable) override;
VarDescAPI::VarDataType GetDataType() const;
private: private:
const proto::VarType &GetVarType() const; const proto::VarType &GetVarType() const;
proto::VarType *GetMutableVarType(); proto::VarType *GetMutableVarType();
......
...@@ -151,8 +151,36 @@ void VarDesc::SetDataTypes( ...@@ -151,8 +151,36 @@ void VarDesc::SetDataTypes(
} }
} }
proto::VarType::Type VarDesc::GetDataType() const { // proto::VarType::Type VarDesc::GetDataType() const {
return tensor_desc().data_type(); // return tensor_desc().data_type();
// }
VarDescAPI::VarDataType VarDesc::GetDataType() const {
CHECK(desc_->has_type()) << "The var's type hasn't been set.";
CHECK(desc_->type().has_type()) << "The var type hasn't been set.";
if (desc_->type().type() != proto::VarType::LOD_TENSOR) {
return VarDescAPI::Type();
}
auto type = tensor_desc().data_type();
#define GET_DATA_TYPE_CASE_ITEM(type__) \
case proto::VarType::Type::VarType_Type_##type__: \
return VarDescAPI::Type::type__
switch (type) {
GET_DATA_TYPE_CASE_ITEM(BOOL);
GET_DATA_TYPE_CASE_ITEM(SIZE_T);
GET_DATA_TYPE_CASE_ITEM(UINT8);
GET_DATA_TYPE_CASE_ITEM(INT8);
GET_DATA_TYPE_CASE_ITEM(INT16);
GET_DATA_TYPE_CASE_ITEM(INT32);
GET_DATA_TYPE_CASE_ITEM(INT64);
GET_DATA_TYPE_CASE_ITEM(FP16);
GET_DATA_TYPE_CASE_ITEM(FP32);
GET_DATA_TYPE_CASE_ITEM(FP64);
default:
LOG(FATAL) << "Unknown var type: " << static_cast<int>(type);
return VarDescAPI::Type();
}
#undef GET_DATA_TYPE_CASE_ITEM
} }
std::vector<proto::VarType::Type> VarDesc::GetDataTypes() const { std::vector<proto::VarType::Type> VarDesc::GetDataTypes() const {
......
...@@ -89,7 +89,8 @@ class VarDesc : public VarDescAPI { ...@@ -89,7 +89,8 @@ class VarDesc : public VarDescAPI {
void SetDataTypes( void SetDataTypes(
const std::vector<framework::proto::VarType::Type> &multiple_data_type); const std::vector<framework::proto::VarType::Type> &multiple_data_type);
framework::proto::VarType::Type GetDataType() const; // framework::proto::VarType::Type GetDataType() const;
VarDescAPI::VarDataType GetDataType() const;
std::vector<framework::proto::VarType::Type> GetDataTypes() const; std::vector<framework::proto::VarType::Type> GetDataTypes() const;
......
...@@ -148,8 +148,8 @@ void PrepareModelInputTensor(const DebugConfig& conf, ...@@ -148,8 +148,8 @@ void PrepareModelInputTensor(const DebugConfig& conf,
auto* input_tensor = &feed_var->at(item.first); auto* input_tensor = &feed_var->at(item.first);
input_tensor->Resize(DDim(dim)); input_tensor->Resize(DDim(dim));
switch (val_type) { switch (val_type) {
#define FILL_TENSOR_BY_TYPE_ONCE(pb_type__, type__) \ #define FILL_TENSOR_BY_TYPE_ONCE(var_type__, type__) \
case framework::proto::VarType::pb_type__: \ case VarDescAPI::Type::var_type__: \
FillTensorData<type__>(input_tensor, conf, item.first); \ FillTensorData<type__>(input_tensor, conf, item.first); \
break break
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册