diff --git a/lite/core/mir/ssa_graph.cc b/lite/core/mir/ssa_graph.cc index 8f22022789046900c3c09cfb122c914968d8d87f..2b5b65ce5903ede41137311c585c0e87eaaa0e9d 100644 --- a/lite/core/mir/ssa_graph.cc +++ b/lite/core/mir/ssa_graph.cc @@ -123,6 +123,9 @@ void SSAGraph::Build(const Program &program, return true; }; + std::unordered_map var_types = + program.var_data_type(); + std::unordered_map arg_update_node_map_; for (auto &op : program.ops()) { VLOG(3) << op->op_info()->Type(); @@ -137,6 +140,10 @@ void SSAGraph::Build(const Program &program, arg_node->AsArg(name, node_storage_.size() - 1); arg_update_node_map_[name] = arg_node; } + if (var_types.count(name) && !arg_node->arg()->type) { + arg_node->arg()->type = LiteType::GetTensorTy( + TARGET(kUnk), var_types[name], DATALAYOUT(kUnk)); + } if (is_weights(name)) arg_node->AsArg().is_weight = true; CHECK(arg_node->IsRoleSet()); DirectedLink(arg_node, op_node); @@ -146,6 +153,10 @@ void SSAGraph::Build(const Program &program, auto *arg_node = &node_storage_.back(); arg_node->AsArg(name, node_storage_.size() - 1); arg_update_node_map_[name] = arg_node; + if (var_types.count(name) && !arg_node->arg()->type) { + arg_node->arg()->type = LiteType::GetTensorTy( + TARGET(kUnk), var_types[name], DATALAYOUT(kUnk)); + } if (is_weights(name)) arg_node->AsArg().is_weight = true; CHECK(arg_node->IsRoleSet()); diff --git a/lite/core/mir/static_kernel_pick_pass.cc b/lite/core/mir/static_kernel_pick_pass.cc index c49e4497099c5f04a39bf91e70ca8f48900e7ba7..1cc8942d611db389a44cbf6a244775a5b666b587 100644 --- a/lite/core/mir/static_kernel_pick_pass.cc +++ b/lite/core/mir/static_kernel_pick_pass.cc @@ -14,7 +14,10 @@ #include "lite/core/mir/static_kernel_pick_pass.h" #include +#include #include +#include +#include #include #include #include "lite/core/mir/graph_visualize_pass.h" @@ -43,13 +46,33 @@ void StaticKernelPickPass::Apply(const std::unique_ptr& graph) { if (!node.IsStmt()) continue; auto& instruct = node.AsStmt(); + std::unordered_map in_types; + std::unordered_map out_types; + for (std::list::iterator i = node.inlinks.begin(); + i != node.inlinks.end(); + ++i) { + if ((*i)->arg()->type) + in_types[(*i)->arg()->name] = (*i)->arg()->type->precision(); + } + for (std::list::iterator i = node.outlinks.begin(); + i != node.outlinks.end(); + ++i) { + if ((*i)->arg()->type) + out_types[(*i)->arg()->name] = (*i)->arg()->type->precision(); + } // Get candidate kernels std::vector>> scored; CHECK(!instruct.kernels().empty()) << "No kernels found for " << instruct.op_type(); VLOG(4) << "instruct.kernels().size():" << instruct.kernels().size(); for (auto&& kernel : instruct.kernels()) { - float score = KernelGrade(instruct, *kernel, graph->valid_places()); + float score = KernelGrade(instruct, + *kernel, + graph->valid_places(), + in_types, + out_types, + instruct.op_info()->input_names(), + instruct.op_info()->output_names()); VLOG(4) << "kernel->summary():" << kernel->summary() << " score:" << score; scored.emplace_back(score, std::move(kernel)); @@ -99,7 +122,13 @@ void StaticKernelPickPass::Apply(const std::unique_ptr& graph) { instruct.ResetOp(update_desc, graph->valid_places()); scored.clear(); for (auto&& kernel : instruct.kernels()) { - float score = KernelGrade(instruct, *kernel, graph->valid_places()); + float score = KernelGrade(instruct, + *kernel, + graph->valid_places(), + in_types, + out_types, + instruct.op_info()->input_names(), + instruct.op_info()->output_names()); scored.emplace_back(score, std::move(kernel)); } std::sort(scored.begin(), scored.end(), KernelScoreCmp); diff --git a/lite/core/mir/static_kernel_pick_pass.h b/lite/core/mir/static_kernel_pick_pass.h index cd54e2654c22b98cbacc9a73bef7770a029c0b30..f655b298bf2d800f4adf142ad14b8ac05ca00482 100644 --- a/lite/core/mir/static_kernel_pick_pass.h +++ b/lite/core/mir/static_kernel_pick_pass.h @@ -16,6 +16,8 @@ #include #include +#include +#include #include #include "lite/core/mir/pass.h" #include "lite/core/types.h" @@ -48,9 +50,14 @@ class StaticKernelPickPass : public mir::StmtPass { private: // Score the kernel. - size_t KernelGrade(const lite::mir::Node::Stmt& instruct, - const lite::KernelBase& kernel, - const std::vector& places) { + size_t KernelGrade( + const lite::mir::Node::Stmt& instruct, + const lite::KernelBase& kernel, + const std::vector& places, + const std::unordered_map& in_types, + const std::unordered_map& out_types, + const std::vector& in_names, + const std::vector& out_names) { CHECK_GT(places.size(), 0) << "valid_places is empty."; float final_score{-1.}; Place winner_place{places[0]}; @@ -100,6 +107,37 @@ class StaticKernelPickPass : public mir::StmtPass { core::KernelPickFactor::Factor::DataLayoutFirst); } VLOG(4) << "[score s3]:" << score; + + // add new rules for precision: When the input types are consistent with + // kernel's input types and the output types are consistent with kernel's + // output types. Select the kernel of the precision. Note that this + // strategy is not compatible with quantization, so skip quantization op. + if (!instruct.op_info()->HasAttr("enable_int8")) { + bool type_match = true; + for (size_t i = 0; i < in_names.size(); ++i) { + std::string tmp; + CHECK(instruct.op_info()->GetInputArgname(in_names[i], &tmp)); + if (in_types.count(in_names[i]) && + in_types.at(in_names[i]) != + kernel.GetInputDeclType(tmp)->precision()) { + type_match = false; + } + } + for (size_t i = 0; i < out_names.size(); ++i) { + std::string tmp; + CHECK(instruct.op_info()->GetOutputArgname(out_names[i], &tmp)); + if (out_types.count(out_names[i]) && + out_types.at(out_names[i]) != + kernel.GetOutputDeclType(tmp)->precision()) { + type_match = false; + } + } + if (type_match) { + score *= 2; + } + VLOG(4) << "[score s4]:" << score; + } + if (weight * score > final_score) { final_score = weight * score; winner_place = place; diff --git a/lite/core/mir/subgraph/subgraph_program_pass_test.cc b/lite/core/mir/subgraph/subgraph_program_pass_test.cc index 22e20b81d831ff25df090a7565e671b9139122f7..8949eeeca20aa1446520348e0203b29115e6e051 100644 --- a/lite/core/mir/subgraph/subgraph_program_pass_test.cc +++ b/lite/core/mir/subgraph/subgraph_program_pass_test.cc @@ -84,23 +84,22 @@ std::vector AddFCDesc( static int id = 0; std::string prefix = "fc_" + std::to_string(id); auto* op_desc = block_desc->AddOp(); - auto* wgt = block_desc->AddVar(); - auto* bias = block_desc->AddVar(); - auto* out = block_desc->AddVar(); + auto* wgt = block_desc->AddVar(); wgt->SetName(prefix + "_W"); - bias->SetName(prefix + "_Bias"); - out->SetName(prefix + "_Out"); - std::vector out_var_names{prefix + "_Out"}; - auto* wtensor = scope->Var(prefix + "_W")->GetMutable(); wtensor->Resize(wshape); wtensor->mutable_data(); + auto* bias = block_desc->AddVar(); + bias->SetName(prefix + "_Bias"); auto* btensor = scope->Var(prefix + "_Bias")->GetMutable(); btensor->Resize({wshape[1]}); btensor->mutable_data(); + auto* out = block_desc->AddVar(); + out->SetName(prefix + "_Out"); + std::vector out_var_names{prefix + "_Out"}; scope->Var(prefix + "_Out")->GetMutable(); op_desc->SetType("fc"); @@ -192,7 +191,6 @@ std::unique_ptr BuildSimpleNet( auto* block_desc = program_desc->AddBlock(); block_desc->ClearOps(); block_desc->ClearVars(); - auto* var_desc = block_desc->AddVar(); var_desc->SetName("feed_var"); auto* feed_var = scope->Var("feed_var")->GetMutable(); diff --git a/lite/core/mir/variable_place_inference_pass.h b/lite/core/mir/variable_place_inference_pass.h index 3f5d161a56aafa7fd9d058fd404e65cb04572116..b3609230b993b73edd386cc3aea55081d25538a2 100644 --- a/lite/core/mir/variable_place_inference_pass.h +++ b/lite/core/mir/variable_place_inference_pass.h @@ -129,6 +129,17 @@ class VariablePlaceInferencePass : public DebugPass { } else { x_in->AsArg().type = type; } + } else if (x_in->AsArg().type->target() == TARGET(kUnk) && + x_in->AsArg().type->precision() != PRECISION(kUnk) && + x_in->AsArg().type->layout() == DATALAYOUT(kUnk)) { + // If is quantization, infer the Int8 type. + if (type->precision() == PRECISION(kInt8)) { + x_in->AsArg().type = type; + } else { + PrecisionType tmp_ptype = x_in->AsArg().type->precision(); + x_in->AsArg().type = LiteType::GetTensorTy( + type->target(), tmp_ptype, type->layout()); + } } } @@ -149,6 +160,17 @@ class VariablePlaceInferencePass : public DebugPass { } else { x_out->AsArg().type = type; } + } else if (x_out->AsArg().type->target() == TARGET(kUnk) && + x_out->AsArg().type->precision() != PRECISION(kUnk) && + x_out->AsArg().type->layout() == DATALAYOUT(kUnk)) { + // If is quantization, infer the Int8 type. + if (type->precision() == PRECISION(kInt8)) { + x_out->AsArg().type = type; + } else { + PrecisionType tmp_ptype = x_out->AsArg().type->precision(); + x_out->AsArg().type = LiteType::GetTensorTy( + type->target(), tmp_ptype, type->layout()); + } } } } diff --git a/lite/core/program.cc b/lite/core/program.cc index 109a96b187595e973fe37227e31a54f0bce86334..c891fbfe92b6b50ba33affa8b7e506aaa3e35018 100644 --- a/lite/core/program.cc +++ b/lite/core/program.cc @@ -168,6 +168,27 @@ void Program::PrepareWorkspace(const cpp::ProgramDesc& prog) { tmp_vars_.push_back("feed"); tmp_vars_.push_back("fetch"); + auto VarPrecision2KernlPrecision = + [](const lite::VarDescAPI::Type& type) -> PrecisionType { + switch (type) { + case lite::VarDescAPI::Type::FP32: + return PRECISION(kFloat); + case lite::VarDescAPI::Type::FP16: + return PRECISION(kFP16); + case lite::VarDescAPI::Type::INT8: + return PRECISION(kInt8); + case lite::VarDescAPI::Type::INT16: + return PRECISION(kInt16); + case lite::VarDescAPI::Type::INT32: + return PRECISION(kInt32); + case lite::VarDescAPI::Type::INT64: + return PRECISION(kInt64); + default: + // LOG(FATAL) << "not supported type: " << static_cast(type); + return PRECISION(kUnk); + } + }; + auto program = prog; CHECK(program.BlocksSize()); for (size_t b = 0; b < program.BlocksSize(); ++b) { @@ -175,7 +196,16 @@ void Program::PrepareWorkspace(const cpp::ProgramDesc& prog) { for (size_t i = 0; i < main_block.VarsSize(); ++i) { auto& var_desc = *main_block.GetVar(i); if (!var_desc.Persistable()) { + if (var_desc.GetType() == lite::VarDescAPI::Type::LOD_TENSOR && + VarPrecision2KernlPrecision(var_desc.GetDataType()) != + PRECISION(kUnk)) { + var_data_type_[var_desc.Name()] = + VarPrecision2KernlPrecision(var_desc.GetDataType()); + } tmp_vars_.push_back(var_desc.Name()); + VLOG(4) << "var name: " << var_desc.Name() << " type is " + << static_cast(var_desc.GetType()) << " data type is " + << static_cast(var_desc.GetDataType()); exec_scope_->Var(var_desc.Name()); if (b > 0) { VLOG(4) << "var: " << var_desc.Name(); diff --git a/lite/core/program.h b/lite/core/program.h index 1c1e4975c3a13bcfa9a22999a705f3a78b0fc68e..9d3f8a97bc40961bf4cbef30cffdfcd429f34397 100644 --- a/lite/core/program.h +++ b/lite/core/program.h @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include "lite/core/kernel.h" @@ -63,6 +64,10 @@ struct Program { lite::Scope* exec_scope() { return exec_scope_; } lite::Scope* scope() { return scope_.get(); } + const std::unordered_map& var_data_type() const { + return var_data_type_; + } + private: // Build from a program and scope. void Build(const cpp::ProgramDesc& program); @@ -70,6 +75,7 @@ struct Program { void PrepareWorkspace(const cpp::ProgramDesc& program); private: + std::unordered_map var_data_type_; std::list tmp_vars_; std::list weights_; std::list> ops_; diff --git a/lite/model_parser/compatible_pb.cc b/lite/model_parser/compatible_pb.cc index 2df4a92270466b1f3b56dec8deecf8e9a8e62390..d1131539bf30abba22feeba8abf009f95ab70a00 100644 --- a/lite/model_parser/compatible_pb.cc +++ b/lite/model_parser/compatible_pb.cc @@ -32,14 +32,6 @@ namespace lite { /// For VarDesc transfrom #define TRANS_VAR_ANY_WITH_CPP_IMPL(T) \ template <> \ - void TransformVarDescAnyToCpp(const T &any_desc, \ - cpp::VarDesc *cpp_desc) { \ - cpp_desc->SetName(any_desc.Name()); \ - cpp_desc->SetType(any_desc.GetType()); \ - cpp_desc->SetPersistable(any_desc.Persistable()); \ - } \ - \ - template <> \ void TransformVarDescCppToAny(const cpp::VarDesc &cpp_desc, \ T *any_desc) { \ any_desc->SetName(cpp_desc.Name()); \ @@ -47,6 +39,25 @@ namespace lite { any_desc->SetPersistable(cpp_desc.Persistable()); \ } +#ifndef LITE_ON_TINY_PUBLISH +template <> +void TransformVarDescAnyToCpp(const pb::VarDesc &any_desc, + cpp::VarDesc *cpp_desc) { + cpp_desc->SetName(any_desc.Name()); + cpp_desc->SetType(any_desc.GetType()); + cpp_desc->SetPersistable(any_desc.Persistable()); + cpp_desc->SetDataType(any_desc.GetDataType()); +} +#endif + +template <> +void TransformVarDescAnyToCpp( + const naive_buffer::VarDesc &any_desc, cpp::VarDesc *cpp_desc) { + cpp_desc->SetName(any_desc.Name()); + cpp_desc->SetType(any_desc.GetType()); + cpp_desc->SetPersistable(any_desc.Persistable()); +} + /// For OpDesc transform template void OpInputsAnyToCpp(const OpDescType &any_desc, cpp::OpDesc *cpp_desc) { diff --git a/lite/model_parser/cpp/var_desc.h b/lite/model_parser/cpp/var_desc.h index c346934dfd721bcd6424fcf2b9d22a0ded9dab14..9232bba3e8620b2e5e769c9f7a0f50969abe8421 100644 --- a/lite/model_parser/cpp/var_desc.h +++ b/lite/model_parser/cpp/var_desc.h @@ -42,9 +42,14 @@ class VarDesc : public VarDescAPI { void SetPersistable(bool persistable) override { persistable_ = persistable; } + Type GetDataType() const { return data_type_; } + + void SetDataType(Type data_type) { data_type_ = data_type; } + private: std::string name_; Type type_; + Type data_type_; bool persistable_; }; diff --git a/lite/model_parser/naive_buffer/var_desc.cc b/lite/model_parser/naive_buffer/var_desc.cc index cccf7582912d1edff2c91fbfa5ed602f028be648..86b6dd72844c694dee1781d322491bf922f32d09 100644 --- a/lite/model_parser/naive_buffer/var_desc.cc +++ b/lite/model_parser/naive_buffer/var_desc.cc @@ -99,6 +99,32 @@ const proto::VarType& VarDesc::GetVarType() const { return desc_->GetField("type"); } +VarDescAPI::VarDataType VarDesc::GetDataType() const { + using data_type_builder_t = EnumBuilder; + + auto data_type = desc_->GetField("tensor_desc") + .GetField("data_type") + .data(); +#define GET_DATA_TYPE_CASE_ITEM(type__) \ + case proto::VarDataType::type__: \ + return VarDescAPI::VarDataType::type__ + + switch (data_type) { + // Only support primary data type now. + GET_DATA_TYPE_CASE_ITEM(UINT8); + GET_DATA_TYPE_CASE_ITEM(INT8); + GET_DATA_TYPE_CASE_ITEM(INT16); + GET_DATA_TYPE_CASE_ITEM(INT32); + GET_DATA_TYPE_CASE_ITEM(INT64); + GET_DATA_TYPE_CASE_ITEM(FP32); + GET_DATA_TYPE_CASE_ITEM(FP64); + default: + LOG(FATAL) << "Unknown var data type"; + } + return VarDescAPI::VarDataType(); +#undef GET_DATA_TYPE_CASE_ITEM +} + proto::VarType* VarDesc::GetMutableVarType() { auto* builder = desc_->GetMutableField("type"); CHECK(builder); diff --git a/lite/model_parser/naive_buffer/var_desc.h b/lite/model_parser/naive_buffer/var_desc.h index 92a0cfe3cdc5bf8a397bb1b8140dba0312791730..b638afd79d085e64ef7f1174f0d27975b827e76a 100644 --- a/lite/model_parser/naive_buffer/var_desc.h +++ b/lite/model_parser/naive_buffer/var_desc.h @@ -51,6 +51,8 @@ class VarDesc : public VarDescAPI { void SetPersistable(bool persistable) override; + VarDescAPI::VarDataType GetDataType() const; + private: const proto::VarType &GetVarType() const; proto::VarType *GetMutableVarType(); diff --git a/lite/model_parser/pb/var_desc.cc b/lite/model_parser/pb/var_desc.cc index 517f4cc6dcefbb5e517b6f84ac1b695dbbbc5925..a3f28d00b94054addd728775e9373d73f9b7b729 100644 --- a/lite/model_parser/pb/var_desc.cc +++ b/lite/model_parser/pb/var_desc.cc @@ -151,8 +151,36 @@ void VarDesc::SetDataTypes( } } -proto::VarType::Type VarDesc::GetDataType() const { - return tensor_desc().data_type(); +// proto::VarType::Type VarDesc::GetDataType() const { +// return tensor_desc().data_type(); +// } +VarDescAPI::VarDataType VarDesc::GetDataType() const { + CHECK(desc_->has_type()) << "The var's type hasn't been set."; + CHECK(desc_->type().has_type()) << "The var type hasn't been set."; + if (desc_->type().type() != proto::VarType::LOD_TENSOR) { + return VarDescAPI::Type(); + } + auto type = tensor_desc().data_type(); +#define GET_DATA_TYPE_CASE_ITEM(type__) \ + case proto::VarType::Type::VarType_Type_##type__: \ + return VarDescAPI::Type::type__ + + switch (type) { + GET_DATA_TYPE_CASE_ITEM(BOOL); + GET_DATA_TYPE_CASE_ITEM(SIZE_T); + GET_DATA_TYPE_CASE_ITEM(UINT8); + GET_DATA_TYPE_CASE_ITEM(INT8); + GET_DATA_TYPE_CASE_ITEM(INT16); + GET_DATA_TYPE_CASE_ITEM(INT32); + GET_DATA_TYPE_CASE_ITEM(INT64); + GET_DATA_TYPE_CASE_ITEM(FP16); + GET_DATA_TYPE_CASE_ITEM(FP32); + GET_DATA_TYPE_CASE_ITEM(FP64); + default: + LOG(FATAL) << "Unknown var type: " << static_cast(type); + return VarDescAPI::Type(); + } +#undef GET_DATA_TYPE_CASE_ITEM } std::vector VarDesc::GetDataTypes() const { diff --git a/lite/model_parser/pb/var_desc.h b/lite/model_parser/pb/var_desc.h index c0ac6316016df3cdb06ddede9c78d58540a40864..bbf78b75d3f1b1a4a6488e28380f2587ca77bbc4 100644 --- a/lite/model_parser/pb/var_desc.h +++ b/lite/model_parser/pb/var_desc.h @@ -89,7 +89,8 @@ class VarDesc : public VarDescAPI { void SetDataTypes( const std::vector &multiple_data_type); - framework::proto::VarType::Type GetDataType() const; + // framework::proto::VarType::Type GetDataType() const; + VarDescAPI::VarDataType GetDataType() const; std::vector GetDataTypes() const; diff --git a/lite/tools/debug/debug_utils.h b/lite/tools/debug/debug_utils.h index ff08c47e524cacee37e95572a7f7a2fb444d4d16..9cc283ecc2ed75032cc8102be28e8a1fcd882395 100644 --- a/lite/tools/debug/debug_utils.h +++ b/lite/tools/debug/debug_utils.h @@ -148,8 +148,8 @@ void PrepareModelInputTensor(const DebugConfig& conf, auto* input_tensor = &feed_var->at(item.first); input_tensor->Resize(DDim(dim)); switch (val_type) { -#define FILL_TENSOR_BY_TYPE_ONCE(pb_type__, type__) \ - case framework::proto::VarType::pb_type__: \ +#define FILL_TENSOR_BY_TYPE_ONCE(var_type__, type__) \ + case VarDescAPI::Type::var_type__: \ FillTensorData(input_tensor, conf, item.first); \ break