From 08a3ed12af356b075ff9f8e499610bb6270184e4 Mon Sep 17 00:00:00 2001 From: hong19860320 <9973393+hong19860320@users.noreply.github.com> Date: Tue, 10 Mar 2020 22:15:55 +0800 Subject: [PATCH] [CORE] Support the fully quantized model for MTK and RK NPU (#3096) --- lite/api/paddle_use_passes.h | 3 +- lite/core/mir/CMakeLists.txt | 1 + lite/core/mir/graph_visualize_pass.cc | 117 ++++++++++----- .../quantized_op_attributes_inference_pass.cc | 75 ++++++++++ .../quantized_op_attributes_inference_pass.h | 36 +++++ lite/core/mir/ssa_graph.cc | 15 +- lite/core/mir/subgraph/subgraph_detector.cc | 31 ++++ lite/core/mir/type_layout_cast_pass.cc | 7 +- lite/core/mir/type_layout_cast_pass.h | 12 -- lite/core/mir/type_precision_cast_pass.cc | 124 ++++++++++++++- lite/core/mir/type_precision_cast_pass.h | 12 +- lite/core/mir/type_target_cast_pass.cc | 6 +- lite/core/mir/type_target_cast_pass.h | 12 -- lite/core/op_registry.cc | 5 + lite/core/optimizer.h | 6 + lite/core/tensor.cc | 3 +- lite/kernels/arm/calib_compute.cc | 141 +++++++++++++----- lite/kernels/arm/calib_compute.h | 6 +- lite/kernels/arm/layout_compute.cc | 74 +++++---- lite/kernels/host/feed_compute.cc | 9 +- lite/kernels/host/fetch_compute.cc | 13 +- 21 files changed, 537 insertions(+), 171 deletions(-) create mode 100644 lite/core/mir/quantized_op_attributes_inference_pass.cc create mode 100644 lite/core/mir/quantized_op_attributes_inference_pass.h diff --git a/lite/api/paddle_use_passes.h b/lite/api/paddle_use_passes.h index 0cd3e55821..41eca021a9 100644 --- a/lite/api/paddle_use_passes.h +++ b/lite/api/paddle_use_passes.h @@ -24,7 +24,7 @@ USE_MIR_PASS(generate_program_pass); USE_MIR_PASS(io_copy_kernel_pick_pass); USE_MIR_PASS(argument_type_display_pass); USE_MIR_PASS(runtime_context_assign_pass); -USE_MIR_PASS(graph_visualze); +USE_MIR_PASS(graph_visualize_pass); USE_MIR_PASS(lite_conv_bn_fuse_pass); USE_MIR_PASS(lite_fc_fuse_pass); @@ -46,3 +46,4 @@ USE_MIR_PASS(elementwise_mul_constant_eliminate_pass) USE_MIR_PASS(npu_subgraph_pass); USE_MIR_PASS(xpu_subgraph_pass); USE_MIR_PASS(weight_quantization_preprocess_pass); +USE_MIR_PASS(quantized_op_attributes_inference_pass); diff --git a/lite/core/mir/CMakeLists.txt b/lite/core/mir/CMakeLists.txt index 379ef67f29..82b19b030c 100644 --- a/lite/core/mir/CMakeLists.txt +++ b/lite/core/mir/CMakeLists.txt @@ -36,6 +36,7 @@ lite_cc_library(mir_passes runtime_context_assign_pass.cc memory_optimize_pass.cc weight_quantization_preprocess_pass.cc + quantized_op_attributes_inference_pass.cc DEPS mir_pass types context ${mir_fusers} ${mir_subgraphs}) # lite_cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS diff --git a/lite/core/mir/graph_visualize_pass.cc b/lite/core/mir/graph_visualize_pass.cc index 3a27360f94..d3e7a625a7 100644 --- a/lite/core/mir/graph_visualize_pass.cc +++ b/lite/core/mir/graph_visualize_pass.cc @@ -18,6 +18,7 @@ #include #include #include +#include #include "lite/core/mir/pass_registry.h" #include "lite/utils/string.h" @@ -28,56 +29,98 @@ namespace mir { using inference::analysis::Dot; void GraphVisualizePass::Apply(const std::unique_ptr& graph) { - Visualize(graph.get()); + VLOG(5) << "\n" << Visualize(graph.get()); } std::string Visualize(mir::SSAGraph* graph) { + std::ostringstream os; inference::analysis::Dot dot; - - int id = 0; - std::set exists_args; - for (auto& node : graph->mutable_nodes()) { - std::string key; - if (node.IsArg()) { - key = node.AsArg().name; - } else { - key = string_format("%s%d", node.AsStmt().op_type().c_str(), id++); + auto string_trunc = [](const std::string& str) -> std::string { + const int max_disp_size = 100; + if (str.length() > max_disp_size) + return str.substr(0, max_disp_size) + "..."; + return str; + }; + auto attr_repr = [&](const OpInfo* op_info, + const std::string& attr_name) -> std::string { + std::ostringstream os; + using AttrType = cpp::OpDesc::AttrType; + auto attr_type = op_info->GetAttrType(attr_name); + switch (attr_type) { + case AttrType::INT: + os << ":int:" << std::to_string(op_info->GetAttr(attr_name)); + break; + case AttrType::FLOAT: + os << ":float:" << std::to_string(op_info->GetAttr(attr_name)); + break; + case AttrType::BOOLEAN: + os << ":int:" << std::to_string(op_info->GetAttr(attr_name)); + break; + case AttrType::STRING: + os << ":string: \"" + << string_trunc(op_info->GetAttr(attr_name)) << "\""; + break; + case AttrType::FLOATS: { + auto vals = op_info->GetAttr>(attr_name); + os << ":floats: {" + Join(vals, ",") << "}"; + } break; + case AttrType::INTS: { + auto vals = op_info->GetAttr>(attr_name); + os << ":ints: {" + Join(vals, ",") + "}"; + } break; + case AttrType::STRINGS: { + auto vals = op_info->GetAttr>(attr_name); + os << ":strings: {" + string_trunc(Join(vals, ",")) << "}"; + } break; + default: + os << ":Unknow type(" << static_cast(attr_type) << ")"; + break; } - if (node.IsStmt()) { - dot.AddNode(key, - {Dot::Attr("shape", "box"), - Dot::Attr("style", "filled"), - Dot::Attr("color", "black"), - Dot::Attr("fillcolor", "yellow")}); - for (auto& x : node.inlinks) { - auto name = x->AsArg().name; - if (!exists_args.count(name)) { - dot.AddNode(name, {}); - } - dot.AddEdge(name, key, {}); - exists_args.insert(name); + return os.str(); + }; + int op_idx = 0; + std::set exists_var_names; + for (auto& node : graph->StmtTopologicalOrder()) { + if (!node->IsStmt()) continue; + auto op_info = node->AsStmt().op_info(); + auto op_type = op_info->Type(); + std::string op_name = string_format("%s%d", op_type.c_str(), op_idx++); + // Add its input&output variables as the Dot nodes + dot.AddNode(op_name, + {Dot::Attr("shape", "box"), + Dot::Attr("style", "filled"), + Dot::Attr("color", "black"), + Dot::Attr("fillcolor", "yellow")}); + for (auto& x : node->inlinks) { + auto var_name = x->AsArg().name; + if (!exists_var_names.count(var_name)) { + dot.AddNode(var_name, {}); + exists_var_names.insert(var_name); } - for (auto& x : node.outlinks) { - auto name = x->AsArg().name; - if (!exists_args.count(name)) { - dot.AddNode(name, {}); - } - dot.AddEdge(key, name, {}); - exists_args.insert(name); + dot.AddEdge(var_name, op_name, {}); + } + for (auto& x : node->outlinks) { + auto var_name = x->AsArg().name; + if (!exists_var_names.count(var_name)) { + dot.AddNode(var_name, {}); + exists_var_names.insert(var_name); } + dot.AddEdge(op_name, var_name, {}); + } + // Output its all of attributes(name and values) + os << "* " << op_name << "\n"; + const auto& attr_names = op_info->AttrNames(); + for (auto& attr_name : attr_names) { + os << " - " << attr_name << attr_repr(op_info, attr_name) << "\n"; } } - - auto res = dot.Build(); - // If we use VLOG here, we can not type all graph out. - // So we change VLOG to std::cout. - std::cout << "dot:\n" << res << std::endl; - return res; + os << dot.Build(); + return os.str(); } } // namespace mir } // namespace lite } // namespace paddle -REGISTER_MIR_PASS(graph_visualze, paddle::lite::mir::GraphVisualizePass) +REGISTER_MIR_PASS(graph_visualize_pass, paddle::lite::mir::GraphVisualizePass) .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/quantized_op_attributes_inference_pass.cc b/lite/core/mir/quantized_op_attributes_inference_pass.cc new file mode 100644 index 0000000000..54a4e779c6 --- /dev/null +++ b/lite/core/mir/quantized_op_attributes_inference_pass.cc @@ -0,0 +1,75 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/quantized_op_attributes_inference_pass.h" +#include +#include +#include +#include +#include +#include +#include +#include "lite/core/mir/graph_visualize_pass.h" +#include "lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +void QuantizedOpAttributesInferencePass::Apply( + const std::unique_ptr& graph) { + // Only for fully quantized model which is only supported by MTK and RK NPU. + // Replace the output_scale with the input_scale of the adjacent quantized + // ops, and fix the missing of the attribute 'enable_int8'. + for (auto& op_node : graph->StmtTopologicalOrder()) { + if (!op_node->IsStmt()) continue; + auto& inst = op_node->AsStmt(); + auto op_info = inst.op_info(); + auto op_type = op_info->Type(); + if (!op_info->HasAttr("input_scale")) continue; + bool found = false; + float output_scale; + for (auto out_var_node : op_node->outlinks) { + CHECK(out_var_node->IsArg()); + for (auto out_op_node : out_var_node->outlinks) { + CHECK(out_op_node->IsStmt()); + auto& out_inst = out_op_node->AsStmt(); + auto out_op_info = out_inst.op_info(); + if (!out_op_info->HasAttr("input_scale")) continue; + auto input_scale = out_op_info->GetAttr("input_scale"); + if (!found) { + found = true; + output_scale = input_scale; + } else { + CHECK_EQ(output_scale, input_scale); + } + } + } + if (found) { + inst.mutable_op_info()->SetAttr("output_scale", output_scale); + } + if (op_info->HasAttr("output_scale")) { + inst.mutable_op_info()->SetAttr("enable_int8", true); + } + } + VLOG(5) << "\n" << Visualize(graph.get()); +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(quantized_op_attributes_inference_pass, + paddle::lite::mir::QuantizedOpAttributesInferencePass) + .BindTargets({TARGET(kNPU)}); diff --git a/lite/core/mir/quantized_op_attributes_inference_pass.h b/lite/core/mir/quantized_op_attributes_inference_pass.h new file mode 100644 index 0000000000..2b475e0b3d --- /dev/null +++ b/lite/core/mir/quantized_op_attributes_inference_pass.h @@ -0,0 +1,36 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include "lite/core/mir/pass.h" +#include "lite/core/types.h" + +namespace paddle { +namespace lite { +namespace mir { + +class QuantizedOpAttributesInferencePass : public mir::StmtPass { + public: + void Apply(const std::unique_ptr& graph) override; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/ssa_graph.cc b/lite/core/mir/ssa_graph.cc index 2b5b65ce59..6c45ce8282 100644 --- a/lite/core/mir/ssa_graph.cc +++ b/lite/core/mir/ssa_graph.cc @@ -140,9 +140,18 @@ void SSAGraph::Build(const Program &program, arg_node->AsArg(name, node_storage_.size() - 1); arg_update_node_map_[name] = arg_node; } - if (var_types.count(name) && !arg_node->arg()->type) { - arg_node->arg()->type = LiteType::GetTensorTy( - TARGET(kUnk), var_types[name], DATALAYOUT(kUnk)); + if (var_types.count(name)) { + if (!arg_node->arg()->type) { + arg_node->arg()->type = LiteType::GetTensorTy( + TARGET(kUnk), var_types[name], DATALAYOUT(kUnk)); + } + // Store the original data type of the output tensors for + // type_precision_cast_pass, to keep the consistency between the + // output types of original graph and optimized graph's + if (op->op_info()->Type() == "fetch") { + op->mutable_op_info()->SetAttr( + "data_type", static_cast(var_types[name])); + } } if (is_weights(name)) arg_node->AsArg().is_weight = true; CHECK(arg_node->IsRoleSet()); diff --git a/lite/core/mir/subgraph/subgraph_detector.cc b/lite/core/mir/subgraph/subgraph_detector.cc index c46e12c1cd..65fb11ff2c 100644 --- a/lite/core/mir/subgraph/subgraph_detector.cc +++ b/lite/core/mir/subgraph/subgraph_detector.cc @@ -372,6 +372,37 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph, subgraph_op_desc.SetAttr>("output_data_names", output_var_names); + // Set input/output scale values of input/output var nodes for + // type_precision_cast_pass. + std::vector input_data_scales; + std::vector output_data_scales; + for (auto &var_node : input_var_nodes) { + auto any_op_node = var_node->outlinks.front(); + CHECK(any_op_node->IsStmt()); + auto &any_inst = any_op_node->AsStmt(); + if (any_inst.op_info()->HasAttr("input_scale")) { + input_data_scales.push_back( + any_inst.op_info()->GetAttr("input_scale")); + } + } + for (auto &var_node : output_var_nodes) { + auto any_op_node = var_node->inlinks.front(); + CHECK(any_op_node->IsStmt()); + auto &any_inst = any_op_node->AsStmt(); + if (any_inst.op_info()->HasAttr("output_scale")) { + output_data_scales.push_back( + any_inst.op_info()->GetAttr("output_scale")); + } + } + if (input_data_scales.size() > 0) { + subgraph_op_desc.SetAttr>("input_data_scales", + input_data_scales); + } + if (output_data_scales.size() > 0) { + subgraph_op_desc.SetAttr>("output_data_scales", + output_data_scales); + } + // Set all of the inputs and outputs to the target subgraph op // To prevent vars are removed in RuntimeProgram::UpdateVarsOfProgram() for (auto &var_node : weight_var_nodes) { diff --git a/lite/core/mir/type_layout_cast_pass.cc b/lite/core/mir/type_layout_cast_pass.cc index 6cf03ee3b5..f517a04120 100644 --- a/lite/core/mir/type_layout_cast_pass.cc +++ b/lite/core/mir/type_layout_cast_pass.cc @@ -20,6 +20,8 @@ #include #include "lite/core/mir/graph_visualize_pass.h" #include "lite/core/mir/pass_registry.h" +#include "lite/core/mir/type_precision_cast_pass.h" +#include "lite/operators/subgraph_op.h" #include "lite/utils/string.h" namespace paddle { @@ -170,9 +172,8 @@ void TypeLayoutTransformPass::AddLayoutInst( DirectedLink(layout_output_arg, inst_node); // reset opdesc and update kernel information - UpdateInputTo(inst_node->AsStmt().op()->mutable_op_info(), - in->AsArg().name, - layout_output_name); + UpdateInputs( + inst_node->AsStmt().op().get(), in->AsArg().name, layout_output_name); auto original_selected_kernel = std::move(inst_node->AsStmt().kernels().front()); auto update_op_info = *inst_node->AsStmt().op_info(); diff --git a/lite/core/mir/type_layout_cast_pass.h b/lite/core/mir/type_layout_cast_pass.h index b6af6b285f..4a3e4c02d1 100644 --- a/lite/core/mir/type_layout_cast_pass.h +++ b/lite/core/mir/type_layout_cast_pass.h @@ -24,18 +24,6 @@ namespace paddle { namespace lite { namespace mir { -static void UpdateInputTo(cpp::OpDesc* desc, - const std::string& from, - const std::string& to) { - for (auto& item : *desc->mutable_inputs()) { - for (auto& input : item.second) { - if (input == from) { - input = to; - } - } - } -} - class TypeLayoutTransformPass : public ProgramPass { public: void Apply(const std::unique_ptr& graph) override; diff --git a/lite/core/mir/type_precision_cast_pass.cc b/lite/core/mir/type_precision_cast_pass.cc index 655fe0d203..86eb43be59 100644 --- a/lite/core/mir/type_precision_cast_pass.cc +++ b/lite/core/mir/type_precision_cast_pass.cc @@ -20,11 +20,115 @@ #include #include "lite/core/mir/graph_visualize_pass.h" #include "lite/core/mir/pass_registry.h" +#include "lite/operators/subgraph_op.h" namespace paddle { namespace lite { namespace mir { +// For the subgraph op, we also need to update the attr 'input_data_names' and +// the input variables names of the Ops in the subblock. +void UpdateInputsForSubgraph(OpLite* op, + const std::string& from, + const std::string& to) { + auto* op_desc = op->mutable_op_info(); + auto input_data_names = + op_desc->GetAttr>("input_data_names"); + std::replace(input_data_names.begin(), input_data_names.end(), from, to); + op_desc->SetAttr("input_data_names", input_data_names); + auto* subblock_desc = static_cast(op)->GetSubBlock(); + CHECK(subblock_desc); + for (size_t i = 0; i < subblock_desc->OpsSize(); i++) { + auto* subblock_op_desc = subblock_desc->GetOp(i); + for (auto& subblock_op_input : *subblock_op_desc->mutable_inputs()) { + for (auto& subblock_var_name : subblock_op_input.second) { + if (subblock_var_name == from) { + subblock_var_name = to; + } + } + } + } +} + +// Update the input variable names from 'from' to 'to' for the target Op +void UpdateInputs(OpLite* op, const std::string& from, const std::string& to) { + auto* op_desc = op->mutable_op_info(); + auto op_type = op_desc->Type(); + for (auto& op_input : *op_desc->mutable_inputs()) { + for (auto& var_name : op_input.second) { + if (var_name == from) { + var_name = to; + } + } + } + if (op_type == "subgraph") { + UpdateInputsForSubgraph(op, from, to); + } +} + +// Infer the scale value for the new calib op from the subgraph op +static bool InferScaleFromSubgraph(std::string var_name, + const OpInfo* op_info, + float* scale, + bool reverse = false) { + bool found = false; + auto input_or_output_names = op_info->GetAttr>( + reverse ? "output_data_names" : "input_data_names"); + auto input_or_output_scales = op_info->GetAttr>( + reverse ? "output_data_scales" : "input_data_scales"); + auto size = input_or_output_names.size(); + CHECK(size == input_or_output_scales.size()); + for (int i = 0; i < size; i++) { + if (input_or_output_names[i] == var_name) { + *scale = input_or_output_scales[i]; + found = true; + break; + } + } + return found; +} + +// Infer the scale value for the new calib op from the input_scale of the +// current op and output_scale of the previous op. +// case 1: prev_op->var_node->op_node(int8->any op, with input_scale). +// case 2: prev_op->var_node->op_node(subgraph op, int8->any, with +// input_data_scales). +// case 3: prev_op(any->int8, with output_scale)->var_node->op_node(fp32->any, +// without input_scale). +// case 4: prev_op(any->int8, subgraph_op, with +// output_data_scales)->var_node->op_node(fp32->any, without input_scale). +static bool InferScale(Node* var_node, Node* op_node, float* scale) { + bool found = false; + auto& inst = op_node->AsStmt(); + auto op_info = inst.op_info(); + auto op_type = op_info->Type(); + auto var_name = var_node->AsArg().name; + if (op_type == "subgraph") { + found = InferScaleFromSubgraph(var_name, op_info, scale, false); + } else { + if (op_info->HasAttr("input_scale")) { + *scale = op_info->GetAttr("input_scale"); + found = true; + } else { + // Obtain the output_scale from one of its previous Ops + auto prev_op_node = var_node->inlinks.front(); + CHECK(prev_op_node->IsStmt()); + auto& prev_inst = prev_op_node->AsStmt(); + auto prev_op_info = prev_inst.op_info(); + auto prev_op_type = prev_op_info->Type(); + if (prev_op_type == "subgraph") { + found = InferScaleFromSubgraph(var_name, prev_op_info, scale, true); + } else { + if (prev_op_info->HasAttr("output_scale")) { + *scale = prev_op_info->GetAttr("output_scale"); + found = true; + } + } + } + } + return found; +} + void PrecisionCastPass::Apply(const std::unique_ptr& graph) { // Start from inputs of the graph, those should have place set. std::list nodes; @@ -59,6 +163,14 @@ void PrecisionCastPass::ComplementInputs(SSAGraph* graph, auto decl_arg_type = inst.picked_kernel().GetInputDeclType(tmp); CHECK(in->AsArg().type); VLOG(4) << inst.picked_kernel().name(); + if (inst.op_info()->Type() == "fetch") { + if (inst.op_info()->HasAttr("data_type")) { + auto data_type = + static_cast(inst.op_info()->GetAttr("data_type")); + decl_arg_type = LiteType::GetTensorTy( + decl_arg_type->target(), data_type, decl_arg_type->layout()); + } + } // if (!in->AsArg().is_weight && !PrecisionCompatibleTo(*in->AsArg().type, // *decl_arg_type)) { if (!PrecisionCompatibleTo(*in->AsArg().type, *decl_arg_type)) { @@ -109,10 +221,11 @@ void PrecisionCastPass::AddCastInst(const Type& from, op_desc.SetType(cast_type); op_desc.SetInput("Input", {in->AsArg().name}); op_desc.SetOutput("Out", {cast_op_output_name}); - if (inst_node->AsStmt().op_info()->HasAttr("input_scale")) { - op_desc.SetAttr( - "scale", inst_node->AsStmt().op_info()->GetAttr("input_scale")); + float scale; + if (InferScale(in, inst_node, &scale)) { + op_desc.SetAttr("scale", scale); } + cast_op->Attach(op_desc, inst_node->AsStmt().op()->scope()); auto kernels = cast_op->CreateKernels(valid_places); std::vector> selected_kernels; @@ -146,9 +259,8 @@ void PrecisionCastPass::AddCastInst(const Type& from, DirectedLink(cast_op_output_arg, inst_node); // reset opdesc and update kernel information - UpdateInputTo(inst_node->AsStmt().op()->mutable_op_info(), - in->AsArg().name, - cast_op_output_name); + UpdateInputs( + inst_node->AsStmt().op().get(), in->AsArg().name, cast_op_output_name); // recreate the op auto original_selected_kernel = diff --git a/lite/core/mir/type_precision_cast_pass.h b/lite/core/mir/type_precision_cast_pass.h index 3f55e52ef9..b5f7c5d902 100644 --- a/lite/core/mir/type_precision_cast_pass.h +++ b/lite/core/mir/type_precision_cast_pass.h @@ -24,17 +24,7 @@ namespace paddle { namespace lite { namespace mir { -static void UpdateInputTo(cpp::OpDesc* desc, - const std::string& from, - const std::string& to) { - for (auto& item : *desc->mutable_inputs()) { - for (auto& input : item.second) { - if (input == from) { - input = to; - } - } - } -} +void UpdateInputs(OpLite* op, const std::string& from, const std::string& to); /* * The pass complement the necessary instruction to make data diff --git a/lite/core/mir/type_target_cast_pass.cc b/lite/core/mir/type_target_cast_pass.cc index ae74bd8d4d..75d8022d5f 100644 --- a/lite/core/mir/type_target_cast_pass.cc +++ b/lite/core/mir/type_target_cast_pass.cc @@ -21,6 +21,7 @@ #include #include "lite/core/mir/graph_visualize_pass.h" #include "lite/core/mir/pass_registry.h" +#include "lite/core/mir/type_precision_cast_pass.h" #include "lite/utils/string.h" namespace paddle { @@ -240,9 +241,8 @@ void TypeTargetTransformPass::UpdateInstNode(Node* in, Node* inst_node, std::string io_copy_output_name) { // reset opdesc and update kernel information - UpdateInputTo(inst_node->AsStmt().op()->mutable_op_info(), - in->AsArg().name, - io_copy_output_name); + UpdateInputs( + inst_node->AsStmt().op().get(), in->AsArg().name, io_copy_output_name); auto original_selected_kernel = std::move(inst_node->AsStmt().kernels().front()); auto update_op_info = *inst_node->AsStmt().op_info(); diff --git a/lite/core/mir/type_target_cast_pass.h b/lite/core/mir/type_target_cast_pass.h index e9a275882f..3561a0a7dd 100644 --- a/lite/core/mir/type_target_cast_pass.h +++ b/lite/core/mir/type_target_cast_pass.h @@ -25,18 +25,6 @@ namespace paddle { namespace lite { namespace mir { -static void UpdateInputTo(cpp::OpDesc* desc, - const std::string& from, - const std::string& to) { - for (auto& item : *desc->mutable_inputs()) { - for (auto& input : item.second) { - if (input == from) { - input = to; - } - } - } -} - /* * IoComplementPass complement the necessary instruction to make data * transferring or transformation between different places. diff --git a/lite/core/op_registry.cc b/lite/core/op_registry.cc index 38625a7d7f..4b6d3282ed 100644 --- a/lite/core/op_registry.cc +++ b/lite/core/op_registry.cc @@ -154,7 +154,9 @@ KernelRegistry::KernelRegistry() INIT_FOR(kX86, kInt64, kNCHW); INIT_FOR(kARM, kFloat, kNCHW); + INIT_FOR(kARM, kFloat, kNHWC); INIT_FOR(kARM, kInt8, kNCHW); + INIT_FOR(kARM, kInt8, kNHWC); INIT_FOR(kARM, kAny, kNCHW); INIT_FOR(kARM, kAny, kAny); INIT_FOR(kARM, kInt32, kNCHW); @@ -180,8 +182,11 @@ KernelRegistry::KernelRegistry() INIT_FOR(kOpenCL, kAny, kImageNW); INIT_FOR(kNPU, kFloat, kNCHW); + INIT_FOR(kNPU, kFloat, kNHWC); INIT_FOR(kNPU, kInt8, kNCHW); + INIT_FOR(kNPU, kInt8, kNHWC); INIT_FOR(kNPU, kAny, kNCHW); + INIT_FOR(kNPU, kAny, kNHWC); INIT_FOR(kNPU, kAny, kAny); INIT_FOR(kXPU, kFloat, kNCHW); diff --git a/lite/core/optimizer.h b/lite/core/optimizer.h index 8646db3c5b..ca22c86907 100644 --- a/lite/core/optimizer.h +++ b/lite/core/optimizer.h @@ -75,6 +75,12 @@ class Optimizer { (defined LITE_WITH_ARM) "lite_elementwise_add_activation_fuse_pass", // #endif + "quantized_op_attributes_inference_pass", // Only for fully + // quantized model, infer + // the output scale and + // fix the attribute + // 'enable_int8' for all + // of the quantized ops. "npu_subgraph_pass", "xpu_subgraph_pass", "bm_subgraph_pass", diff --git a/lite/core/tensor.cc b/lite/core/tensor.cc index 604c3f5328..7664633077 100644 --- a/lite/core/tensor.cc +++ b/lite/core/tensor.cc @@ -75,6 +75,7 @@ void TensorLite::ShareDataWith(const TensorLite &other) { target_ = other.target_; lod_ = other.lod_; memory_size_ = other.memory_size_; + precision_ = other.precision_; } void TensorLite::CopyDataFrom(const TensorLite &other) { @@ -82,7 +83,7 @@ void TensorLite::CopyDataFrom(const TensorLite &other) { target_ = other.target_; lod_ = other.lod_; memory_size_ = other.memory_size_; - precision_ = other.precision(); + precision_ = other.precision_; buffer_->CopyDataFrom(*other.buffer_, memory_size_); } diff --git a/lite/kernels/arm/calib_compute.cc b/lite/kernels/arm/calib_compute.cc index 525e5aefd6..6dac97dcbc 100644 --- a/lite/kernels/arm/calib_compute.cc +++ b/lite/kernels/arm/calib_compute.cc @@ -23,24 +23,24 @@ namespace lite { namespace kernels { namespace arm { -void CalibComputeFp32ToInt8::Run() { - auto& param = this->Param(); +template +void CalibComputeFp32ToInt8::Run() { + auto& param = this->template Param(); std::vector scale = {param.scale}; - const auto* din = param.input->data(); - auto* dout = param.output->mutable_data(); + const auto* din = param.input->template data(); + auto* dout = param.output->template mutable_data(); lite::arm::math::fp32_to_int8( din, dout, scale.data(), 1, 1, param.input->numel()); - return; } -void CalibComputeInt8ToFp32::Run() { - auto& param = this->Param(); - const auto* din = param.input->data(); +template +void CalibComputeInt8ToFp32::Run() { + auto& param = this->template Param(); + const auto* din = param.input->template data(); std::vector scale = {param.scale}; - auto* dout = param.output->mutable_data(); + auto* dout = param.output->template mutable_data(); lite::arm::math::int8_to_fp32( din, dout, scale.data(), 1, 1, param.input->numel()); - return; } } // namespace arm @@ -48,43 +48,116 @@ void CalibComputeInt8ToFp32::Run() { } // namespace lite } // namespace paddle -REGISTER_LITE_KERNEL(calib, - kARM, - kInt8, - kNCHW, - paddle::lite::kernels::arm::CalibComputeFp32ToInt8, - fp32_to_int8) +REGISTER_LITE_KERNEL( + calib, + kARM, + kInt8, + kNCHW, + paddle::lite::kernels::arm::CalibComputeFp32ToInt8, + fp32_to_int8) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) .Finalize(); -REGISTER_LITE_KERNEL(calib, - kARM, - kInt8, - kNCHW, - paddle::lite::kernels::arm::CalibComputeInt8ToFp32, - int8_to_fp32) +REGISTER_LITE_KERNEL( + calib, + kARM, + kInt8, + kNCHW, + paddle::lite::kernels::arm::CalibComputeInt8ToFp32, + int8_to_fp32) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))}) .Finalize(); -REGISTER_LITE_KERNEL(calib_once, - kARM, - kInt8, - kNCHW, - paddle::lite::kernels::arm::CalibComputeFp32ToInt8, - fp32_to_int8) + +REGISTER_LITE_KERNEL( + calib, + kARM, + kInt8, + kNHWC, + paddle::lite::kernels::arm::CalibComputeFp32ToInt8, + fp32_to_int8) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kARM), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kARM), + PRECISION(kInt8), + DATALAYOUT(kNHWC))}) + .Finalize(); + +REGISTER_LITE_KERNEL( + calib, + kARM, + kInt8, + kNHWC, + paddle::lite::kernels::arm::CalibComputeInt8ToFp32, + int8_to_fp32) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kARM), + PRECISION(kInt8), + DATALAYOUT(kNHWC))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kARM), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .Finalize(); + +REGISTER_LITE_KERNEL( + calib_once, + kARM, + kInt8, + kNCHW, + paddle::lite::kernels::arm::CalibComputeFp32ToInt8, + fp32_to_int8) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) .Finalize(); -REGISTER_LITE_KERNEL(calib_once, - kARM, - kInt8, - kNCHW, - paddle::lite::kernels::arm::CalibComputeInt8ToFp32, - int8_to_fp32) +REGISTER_LITE_KERNEL( + calib_once, + kARM, + kInt8, + kNCHW, + paddle::lite::kernels::arm::CalibComputeInt8ToFp32, + int8_to_fp32) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))}) .Finalize(); + +REGISTER_LITE_KERNEL( + calib_once, + kARM, + kInt8, + kNHWC, + paddle::lite::kernels::arm::CalibComputeFp32ToInt8, + fp32_to_int8) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kARM), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kARM), + PRECISION(kInt8), + DATALAYOUT(kNHWC))}) + .Finalize(); + +REGISTER_LITE_KERNEL( + calib_once, + kARM, + kInt8, + kNHWC, + paddle::lite::kernels::arm::CalibComputeInt8ToFp32, + int8_to_fp32) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kARM), + PRECISION(kInt8), + DATALAYOUT(kNHWC))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kARM), + PRECISION(kFloat), + DATALAYOUT(kNHWC))}) + .Finalize(); diff --git a/lite/kernels/arm/calib_compute.h b/lite/kernels/arm/calib_compute.h index 8d9a32bc24..a4c8b4c123 100644 --- a/lite/kernels/arm/calib_compute.h +++ b/lite/kernels/arm/calib_compute.h @@ -21,8 +21,9 @@ namespace lite { namespace kernels { namespace arm { +template class CalibComputeFp32ToInt8 - : public KernelLite { + : public KernelLite { public: using param_t = operators::CalibParam; @@ -33,8 +34,9 @@ class CalibComputeFp32ToInt8 private: }; +template class CalibComputeInt8ToFp32 - : public KernelLite { + : public KernelLite { public: using param_t = operators::CalibParam; diff --git a/lite/kernels/arm/layout_compute.cc b/lite/kernels/arm/layout_compute.cc index bc52c5ea3e..d25fdc082f 100644 --- a/lite/kernels/arm/layout_compute.cc +++ b/lite/kernels/arm/layout_compute.cc @@ -20,40 +20,50 @@ namespace lite { namespace kernels { namespace arm { -#define NCHWTONHWC(type) \ - auto& param = this->template Param(); \ - auto input = param.x->template data(); \ - auto input_dim = param.x->dims(); \ - CHECK(input_dim.size() == 4) \ - << "NCHW to NHWC should guarantee that the input dims should be 4"; \ - int n = input_dim[0]; \ - int c = input_dim[1]; \ - int h = input_dim[2]; \ - int w = input_dim[3]; \ - param.y->Resize({n, h, w, c}); \ - auto output = param.y->template mutable_data(TARGET(kARM)); \ - if (c == 1) { \ - memcpy(output, input, sizeof(type) * n * h * w); \ - return; \ - } \ +#define NCHWTONHWC(type) \ + auto& param = this->template Param(); \ + auto input = param.x->template data(); \ + auto input_dim = param.x->dims(); \ + if (input_dim.size() != 4) { \ + LOG(WARNING) << "NCHW to NHWC should guarantee that the input dims " \ + "should be 4, but received " \ + << input_dim.size(); \ + param.y->ShareDataWith(*param.x); \ + return; \ + } \ + int n = input_dim[0]; \ + int c = input_dim[1]; \ + int h = input_dim[2]; \ + int w = input_dim[3]; \ + param.y->Resize({n, h, w, c}); \ + auto output = param.y->template mutable_data(TARGET(kARM)); \ + if (c == 1) { \ + memcpy(output, input, sizeof(type) * n * h * w); \ + return; \ + } \ lite::arm::math::NCHW2NHWC(n, c, h * w, input, output); -#define NHWCTONCHW(type) \ - auto& param = this->template Param(); \ - auto input = param.x->template data(); \ - auto input_dim = param.x->dims(); \ - CHECK(input_dim.size() == 4) \ - << "NHWC to NCHW should guarantee that the input dims should be 4"; \ - int n = input_dim[0]; \ - int h = input_dim[1]; \ - int w = input_dim[2]; \ - int c = input_dim[3]; \ - param.y->Resize({n, c, h, w}); \ - auto output = param.y->template mutable_data(TARGET(kARM)); \ - if (c == 1) { \ - memcpy(output, input, sizeof(type) * n * h * w); \ - return; \ - } \ +#define NHWCTONCHW(type) \ + auto& param = this->template Param(); \ + auto input = param.x->template data(); \ + auto input_dim = param.x->dims(); \ + if (input_dim.size() != 4) { \ + LOG(WARNING) << "NHWC to NCHW should guarantee that the input dims " \ + "should be 4, but received " \ + << input_dim.size(); \ + param.y->ShareDataWith(*param.x); \ + return; \ + } \ + int n = input_dim[0]; \ + int h = input_dim[1]; \ + int w = input_dim[2]; \ + int c = input_dim[3]; \ + param.y->Resize({n, c, h, w}); \ + auto output = param.y->template mutable_data(TARGET(kARM)); \ + if (c == 1) { \ + memcpy(output, input, sizeof(type) * n * h * w); \ + return; \ + } \ lite::arm::math::NHWC2NCHW(n, c, h * w, input, output); template <> diff --git a/lite/kernels/host/feed_compute.cc b/lite/kernels/host/feed_compute.cc index b16be42031..dd71b3efc3 100644 --- a/lite/kernels/host/feed_compute.cc +++ b/lite/kernels/host/feed_compute.cc @@ -20,8 +20,7 @@ namespace lite { namespace kernels { namespace host { -class FeedCompute - : public KernelLite { +class FeedCompute : public KernelLite { public: using param_t = operators::FeedParam; @@ -40,7 +39,7 @@ class FeedCompute } // namespace paddle REGISTER_LITE_KERNEL( - feed, kHost, kAny, kAny, paddle::lite::kernels::host::FeedCompute, def) - .BindInput("X", {LiteType::GetTensorTy(TARGET(kHost))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))}) + feed, kHost, kAny, kNCHW, paddle::lite::kernels::host::FeedCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny))}) .Finalize(); diff --git a/lite/kernels/host/fetch_compute.cc b/lite/kernels/host/fetch_compute.cc index c53b987b8f..43db74bef9 100644 --- a/lite/kernels/host/fetch_compute.cc +++ b/lite/kernels/host/fetch_compute.cc @@ -20,8 +20,7 @@ namespace lite { namespace kernels { namespace host { -class FetchCompute - : public KernelLite { +class FetchCompute : public KernelLite { public: using param_t = operators::FeedParam; @@ -43,11 +42,7 @@ class FetchCompute } // namespace paddle REGISTER_LITE_KERNEL( - fetch, kHost, kAny, kAny, paddle::lite::kernels::host::FetchCompute, def) - .BindInput("X", - {LiteType::GetTensorTy( - TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) - .BindOutput("Out", - {LiteType::GetTensorTy( - TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)}) + fetch, kHost, kAny, kNCHW, paddle::lite::kernels::host::FetchCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny))}) .Finalize(); -- GitLab