未验证 提交 08a3ed12 编写于 作者: H hong19860320 提交者: GitHub

[CORE] Support the fully quantized model for MTK and RK NPU (#3096)

上级 89da9953
...@@ -24,7 +24,7 @@ USE_MIR_PASS(generate_program_pass); ...@@ -24,7 +24,7 @@ USE_MIR_PASS(generate_program_pass);
USE_MIR_PASS(io_copy_kernel_pick_pass); USE_MIR_PASS(io_copy_kernel_pick_pass);
USE_MIR_PASS(argument_type_display_pass); USE_MIR_PASS(argument_type_display_pass);
USE_MIR_PASS(runtime_context_assign_pass); USE_MIR_PASS(runtime_context_assign_pass);
USE_MIR_PASS(graph_visualze); USE_MIR_PASS(graph_visualize_pass);
USE_MIR_PASS(lite_conv_bn_fuse_pass); USE_MIR_PASS(lite_conv_bn_fuse_pass);
USE_MIR_PASS(lite_fc_fuse_pass); USE_MIR_PASS(lite_fc_fuse_pass);
...@@ -46,3 +46,4 @@ USE_MIR_PASS(elementwise_mul_constant_eliminate_pass) ...@@ -46,3 +46,4 @@ USE_MIR_PASS(elementwise_mul_constant_eliminate_pass)
USE_MIR_PASS(npu_subgraph_pass); USE_MIR_PASS(npu_subgraph_pass);
USE_MIR_PASS(xpu_subgraph_pass); USE_MIR_PASS(xpu_subgraph_pass);
USE_MIR_PASS(weight_quantization_preprocess_pass); USE_MIR_PASS(weight_quantization_preprocess_pass);
USE_MIR_PASS(quantized_op_attributes_inference_pass);
...@@ -36,6 +36,7 @@ lite_cc_library(mir_passes ...@@ -36,6 +36,7 @@ lite_cc_library(mir_passes
runtime_context_assign_pass.cc runtime_context_assign_pass.cc
memory_optimize_pass.cc memory_optimize_pass.cc
weight_quantization_preprocess_pass.cc weight_quantization_preprocess_pass.cc
quantized_op_attributes_inference_pass.cc
DEPS mir_pass types context ${mir_fusers} ${mir_subgraphs}) DEPS mir_pass types context ${mir_fusers} ${mir_subgraphs})
# lite_cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS # lite_cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <set> #include <set>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector>
#include "lite/core/mir/pass_registry.h" #include "lite/core/mir/pass_registry.h"
#include "lite/utils/string.h" #include "lite/utils/string.h"
...@@ -28,56 +29,98 @@ namespace mir { ...@@ -28,56 +29,98 @@ namespace mir {
using inference::analysis::Dot; using inference::analysis::Dot;
void GraphVisualizePass::Apply(const std::unique_ptr<SSAGraph>& graph) { void GraphVisualizePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
Visualize(graph.get()); VLOG(5) << "\n" << Visualize(graph.get());
} }
std::string Visualize(mir::SSAGraph* graph) { std::string Visualize(mir::SSAGraph* graph) {
std::ostringstream os;
inference::analysis::Dot dot; inference::analysis::Dot dot;
auto string_trunc = [](const std::string& str) -> std::string {
int id = 0; const int max_disp_size = 100;
std::set<std::string> exists_args; if (str.length() > max_disp_size)
for (auto& node : graph->mutable_nodes()) { return str.substr(0, max_disp_size) + "...";
std::string key; return str;
if (node.IsArg()) { };
key = node.AsArg().name; auto attr_repr = [&](const OpInfo* op_info,
} else { const std::string& attr_name) -> std::string {
key = string_format("%s%d", node.AsStmt().op_type().c_str(), id++); std::ostringstream os;
using AttrType = cpp::OpDesc::AttrType;
auto attr_type = op_info->GetAttrType(attr_name);
switch (attr_type) {
case AttrType::INT:
os << ":int:" << std::to_string(op_info->GetAttr<int>(attr_name));
break;
case AttrType::FLOAT:
os << ":float:" << std::to_string(op_info->GetAttr<float>(attr_name));
break;
case AttrType::BOOLEAN:
os << ":int:" << std::to_string(op_info->GetAttr<bool>(attr_name));
break;
case AttrType::STRING:
os << ":string: \""
<< string_trunc(op_info->GetAttr<std::string>(attr_name)) << "\"";
break;
case AttrType::FLOATS: {
auto vals = op_info->GetAttr<std::vector<float>>(attr_name);
os << ":floats: {" + Join(vals, ",") << "}";
} break;
case AttrType::INTS: {
auto vals = op_info->GetAttr<std::vector<int>>(attr_name);
os << ":ints: {" + Join(vals, ",") + "}";
} break;
case AttrType::STRINGS: {
auto vals = op_info->GetAttr<std::vector<std::string>>(attr_name);
os << ":strings: {" + string_trunc(Join(vals, ",")) << "}";
} break;
default:
os << ":Unknow type(" << static_cast<int>(attr_type) << ")";
break;
} }
if (node.IsStmt()) { return os.str();
dot.AddNode(key, };
{Dot::Attr("shape", "box"), int op_idx = 0;
Dot::Attr("style", "filled"), std::set<std::string> exists_var_names;
Dot::Attr("color", "black"), for (auto& node : graph->StmtTopologicalOrder()) {
Dot::Attr("fillcolor", "yellow")}); if (!node->IsStmt()) continue;
for (auto& x : node.inlinks) { auto op_info = node->AsStmt().op_info();
auto name = x->AsArg().name; auto op_type = op_info->Type();
if (!exists_args.count(name)) { std::string op_name = string_format("%s%d", op_type.c_str(), op_idx++);
dot.AddNode(name, {}); // Add its input&output variables as the Dot nodes
} dot.AddNode(op_name,
dot.AddEdge(name, key, {}); {Dot::Attr("shape", "box"),
exists_args.insert(name); Dot::Attr("style", "filled"),
Dot::Attr("color", "black"),
Dot::Attr("fillcolor", "yellow")});
for (auto& x : node->inlinks) {
auto var_name = x->AsArg().name;
if (!exists_var_names.count(var_name)) {
dot.AddNode(var_name, {});
exists_var_names.insert(var_name);
} }
for (auto& x : node.outlinks) { dot.AddEdge(var_name, op_name, {});
auto name = x->AsArg().name; }
if (!exists_args.count(name)) { for (auto& x : node->outlinks) {
dot.AddNode(name, {}); auto var_name = x->AsArg().name;
} if (!exists_var_names.count(var_name)) {
dot.AddEdge(key, name, {}); dot.AddNode(var_name, {});
exists_args.insert(name); exists_var_names.insert(var_name);
} }
dot.AddEdge(op_name, var_name, {});
}
// Output its all of attributes(name and values)
os << "* " << op_name << "\n";
const auto& attr_names = op_info->AttrNames();
for (auto& attr_name : attr_names) {
os << " - " << attr_name << attr_repr(op_info, attr_name) << "\n";
} }
} }
os << dot.Build();
auto res = dot.Build(); return os.str();
// If we use VLOG here, we can not type all graph out.
// So we change VLOG to std::cout.
std::cout << "dot:\n" << res << std::endl;
return res;
} }
} // namespace mir } // namespace mir
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_MIR_PASS(graph_visualze, paddle::lite::mir::GraphVisualizePass) REGISTER_MIR_PASS(graph_visualize_pass, paddle::lite::mir::GraphVisualizePass)
.BindTargets({TARGET(kAny)}); .BindTargets({TARGET(kAny)});
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/quantized_op_attributes_inference_pass.h"
#include <algorithm>
#include <list>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "lite/core/mir/graph_visualize_pass.h"
#include "lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void QuantizedOpAttributesInferencePass::Apply(
const std::unique_ptr<SSAGraph>& graph) {
// Only for fully quantized model which is only supported by MTK and RK NPU.
// Replace the output_scale with the input_scale of the adjacent quantized
// ops, and fix the missing of the attribute 'enable_int8'.
for (auto& op_node : graph->StmtTopologicalOrder()) {
if (!op_node->IsStmt()) continue;
auto& inst = op_node->AsStmt();
auto op_info = inst.op_info();
auto op_type = op_info->Type();
if (!op_info->HasAttr("input_scale")) continue;
bool found = false;
float output_scale;
for (auto out_var_node : op_node->outlinks) {
CHECK(out_var_node->IsArg());
for (auto out_op_node : out_var_node->outlinks) {
CHECK(out_op_node->IsStmt());
auto& out_inst = out_op_node->AsStmt();
auto out_op_info = out_inst.op_info();
if (!out_op_info->HasAttr("input_scale")) continue;
auto input_scale = out_op_info->GetAttr<float>("input_scale");
if (!found) {
found = true;
output_scale = input_scale;
} else {
CHECK_EQ(output_scale, input_scale);
}
}
}
if (found) {
inst.mutable_op_info()->SetAttr("output_scale", output_scale);
}
if (op_info->HasAttr("output_scale")) {
inst.mutable_op_info()->SetAttr("enable_int8", true);
}
}
VLOG(5) << "\n" << Visualize(graph.get());
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(quantized_op_attributes_inference_pass,
paddle::lite::mir::QuantizedOpAttributesInferencePass)
.BindTargets({TARGET(kNPU)});
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <limits>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "lite/core/mir/pass.h"
#include "lite/core/types.h"
namespace paddle {
namespace lite {
namespace mir {
class QuantizedOpAttributesInferencePass : public mir::StmtPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};
} // namespace mir
} // namespace lite
} // namespace paddle
...@@ -140,9 +140,18 @@ void SSAGraph::Build(const Program &program, ...@@ -140,9 +140,18 @@ void SSAGraph::Build(const Program &program,
arg_node->AsArg(name, node_storage_.size() - 1); arg_node->AsArg(name, node_storage_.size() - 1);
arg_update_node_map_[name] = arg_node; arg_update_node_map_[name] = arg_node;
} }
if (var_types.count(name) && !arg_node->arg()->type) { if (var_types.count(name)) {
arg_node->arg()->type = LiteType::GetTensorTy( if (!arg_node->arg()->type) {
TARGET(kUnk), var_types[name], DATALAYOUT(kUnk)); arg_node->arg()->type = LiteType::GetTensorTy(
TARGET(kUnk), var_types[name], DATALAYOUT(kUnk));
}
// Store the original data type of the output tensors for
// type_precision_cast_pass, to keep the consistency between the
// output types of original graph and optimized graph's
if (op->op_info()->Type() == "fetch") {
op->mutable_op_info()->SetAttr<int>(
"data_type", static_cast<int>(var_types[name]));
}
} }
if (is_weights(name)) arg_node->AsArg().is_weight = true; if (is_weights(name)) arg_node->AsArg().is_weight = true;
CHECK(arg_node->IsRoleSet()); CHECK(arg_node->IsRoleSet());
......
...@@ -372,6 +372,37 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph, ...@@ -372,6 +372,37 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph,
subgraph_op_desc.SetAttr<std::vector<std::string>>("output_data_names", subgraph_op_desc.SetAttr<std::vector<std::string>>("output_data_names",
output_var_names); output_var_names);
// Set input/output scale values of input/output var nodes for
// type_precision_cast_pass.
std::vector<float> input_data_scales;
std::vector<float> output_data_scales;
for (auto &var_node : input_var_nodes) {
auto any_op_node = var_node->outlinks.front();
CHECK(any_op_node->IsStmt());
auto &any_inst = any_op_node->AsStmt();
if (any_inst.op_info()->HasAttr("input_scale")) {
input_data_scales.push_back(
any_inst.op_info()->GetAttr<float>("input_scale"));
}
}
for (auto &var_node : output_var_nodes) {
auto any_op_node = var_node->inlinks.front();
CHECK(any_op_node->IsStmt());
auto &any_inst = any_op_node->AsStmt();
if (any_inst.op_info()->HasAttr("output_scale")) {
output_data_scales.push_back(
any_inst.op_info()->GetAttr<float>("output_scale"));
}
}
if (input_data_scales.size() > 0) {
subgraph_op_desc.SetAttr<std::vector<float>>("input_data_scales",
input_data_scales);
}
if (output_data_scales.size() > 0) {
subgraph_op_desc.SetAttr<std::vector<float>>("output_data_scales",
output_data_scales);
}
// Set all of the inputs and outputs to the target subgraph op // Set all of the inputs and outputs to the target subgraph op
// To prevent vars are removed in RuntimeProgram::UpdateVarsOfProgram() // To prevent vars are removed in RuntimeProgram::UpdateVarsOfProgram()
for (auto &var_node : weight_var_nodes) { for (auto &var_node : weight_var_nodes) {
......
...@@ -20,6 +20,8 @@ ...@@ -20,6 +20,8 @@
#include <vector> #include <vector>
#include "lite/core/mir/graph_visualize_pass.h" #include "lite/core/mir/graph_visualize_pass.h"
#include "lite/core/mir/pass_registry.h" #include "lite/core/mir/pass_registry.h"
#include "lite/core/mir/type_precision_cast_pass.h"
#include "lite/operators/subgraph_op.h"
#include "lite/utils/string.h" #include "lite/utils/string.h"
namespace paddle { namespace paddle {
...@@ -170,9 +172,8 @@ void TypeLayoutTransformPass::AddLayoutInst( ...@@ -170,9 +172,8 @@ void TypeLayoutTransformPass::AddLayoutInst(
DirectedLink(layout_output_arg, inst_node); DirectedLink(layout_output_arg, inst_node);
// reset opdesc and update kernel information // reset opdesc and update kernel information
UpdateInputTo(inst_node->AsStmt().op()->mutable_op_info(), UpdateInputs(
in->AsArg().name, inst_node->AsStmt().op().get(), in->AsArg().name, layout_output_name);
layout_output_name);
auto original_selected_kernel = auto original_selected_kernel =
std::move(inst_node->AsStmt().kernels().front()); std::move(inst_node->AsStmt().kernels().front());
auto update_op_info = *inst_node->AsStmt().op_info(); auto update_op_info = *inst_node->AsStmt().op_info();
......
...@@ -24,18 +24,6 @@ namespace paddle { ...@@ -24,18 +24,6 @@ namespace paddle {
namespace lite { namespace lite {
namespace mir { namespace mir {
static void UpdateInputTo(cpp::OpDesc* desc,
const std::string& from,
const std::string& to) {
for (auto& item : *desc->mutable_inputs()) {
for (auto& input : item.second) {
if (input == from) {
input = to;
}
}
}
}
class TypeLayoutTransformPass : public ProgramPass { class TypeLayoutTransformPass : public ProgramPass {
public: public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override; void Apply(const std::unique_ptr<SSAGraph>& graph) override;
......
...@@ -20,11 +20,115 @@ ...@@ -20,11 +20,115 @@
#include <vector> #include <vector>
#include "lite/core/mir/graph_visualize_pass.h" #include "lite/core/mir/graph_visualize_pass.h"
#include "lite/core/mir/pass_registry.h" #include "lite/core/mir/pass_registry.h"
#include "lite/operators/subgraph_op.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace mir { namespace mir {
// For the subgraph op, we also need to update the attr 'input_data_names' and
// the input variables names of the Ops in the subblock.
void UpdateInputsForSubgraph(OpLite* op,
const std::string& from,
const std::string& to) {
auto* op_desc = op->mutable_op_info();
auto input_data_names =
op_desc->GetAttr<std::vector<std::string>>("input_data_names");
std::replace(input_data_names.begin(), input_data_names.end(), from, to);
op_desc->SetAttr("input_data_names", input_data_names);
auto* subblock_desc = static_cast<operators::SubgraphOp*>(op)->GetSubBlock();
CHECK(subblock_desc);
for (size_t i = 0; i < subblock_desc->OpsSize(); i++) {
auto* subblock_op_desc = subblock_desc->GetOp<cpp::OpDesc>(i);
for (auto& subblock_op_input : *subblock_op_desc->mutable_inputs()) {
for (auto& subblock_var_name : subblock_op_input.second) {
if (subblock_var_name == from) {
subblock_var_name = to;
}
}
}
}
}
// Update the input variable names from 'from' to 'to' for the target Op
void UpdateInputs(OpLite* op, const std::string& from, const std::string& to) {
auto* op_desc = op->mutable_op_info();
auto op_type = op_desc->Type();
for (auto& op_input : *op_desc->mutable_inputs()) {
for (auto& var_name : op_input.second) {
if (var_name == from) {
var_name = to;
}
}
}
if (op_type == "subgraph") {
UpdateInputsForSubgraph(op, from, to);
}
}
// Infer the scale value for the new calib op from the subgraph op
static bool InferScaleFromSubgraph(std::string var_name,
const OpInfo* op_info,
float* scale,
bool reverse = false) {
bool found = false;
auto input_or_output_names = op_info->GetAttr<std::vector<std::string>>(
reverse ? "output_data_names" : "input_data_names");
auto input_or_output_scales = op_info->GetAttr<std::vector<float>>(
reverse ? "output_data_scales" : "input_data_scales");
auto size = input_or_output_names.size();
CHECK(size == input_or_output_scales.size());
for (int i = 0; i < size; i++) {
if (input_or_output_names[i] == var_name) {
*scale = input_or_output_scales[i];
found = true;
break;
}
}
return found;
}
// Infer the scale value for the new calib op from the input_scale of the
// current op and output_scale of the previous op.
// case 1: prev_op->var_node->op_node(int8->any op, with input_scale).
// case 2: prev_op->var_node->op_node(subgraph op, int8->any, with
// input_data_scales).
// case 3: prev_op(any->int8, with output_scale)->var_node->op_node(fp32->any,
// without input_scale).
// case 4: prev_op(any->int8, subgraph_op, with
// output_data_scales)->var_node->op_node(fp32->any, without input_scale).
static bool InferScale(Node* var_node, Node* op_node, float* scale) {
bool found = false;
auto& inst = op_node->AsStmt();
auto op_info = inst.op_info();
auto op_type = op_info->Type();
auto var_name = var_node->AsArg().name;
if (op_type == "subgraph") {
found = InferScaleFromSubgraph(var_name, op_info, scale, false);
} else {
if (op_info->HasAttr("input_scale")) {
*scale = op_info->GetAttr<float>("input_scale");
found = true;
} else {
// Obtain the output_scale from one of its previous Ops
auto prev_op_node = var_node->inlinks.front();
CHECK(prev_op_node->IsStmt());
auto& prev_inst = prev_op_node->AsStmt();
auto prev_op_info = prev_inst.op_info();
auto prev_op_type = prev_op_info->Type();
if (prev_op_type == "subgraph") {
found = InferScaleFromSubgraph(var_name, prev_op_info, scale, true);
} else {
if (prev_op_info->HasAttr("output_scale")) {
*scale = prev_op_info->GetAttr<float>("output_scale");
found = true;
}
}
}
}
return found;
}
void PrecisionCastPass::Apply(const std::unique_ptr<SSAGraph>& graph) { void PrecisionCastPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// Start from inputs of the graph, those should have place set. // Start from inputs of the graph, those should have place set.
std::list<Node*> nodes; std::list<Node*> nodes;
...@@ -59,6 +163,14 @@ void PrecisionCastPass::ComplementInputs(SSAGraph* graph, ...@@ -59,6 +163,14 @@ void PrecisionCastPass::ComplementInputs(SSAGraph* graph,
auto decl_arg_type = inst.picked_kernel().GetInputDeclType(tmp); auto decl_arg_type = inst.picked_kernel().GetInputDeclType(tmp);
CHECK(in->AsArg().type); CHECK(in->AsArg().type);
VLOG(4) << inst.picked_kernel().name(); VLOG(4) << inst.picked_kernel().name();
if (inst.op_info()->Type() == "fetch") {
if (inst.op_info()->HasAttr("data_type")) {
auto data_type =
static_cast<PrecisionType>(inst.op_info()->GetAttr<int>("data_type"));
decl_arg_type = LiteType::GetTensorTy(
decl_arg_type->target(), data_type, decl_arg_type->layout());
}
}
// if (!in->AsArg().is_weight && !PrecisionCompatibleTo(*in->AsArg().type, // if (!in->AsArg().is_weight && !PrecisionCompatibleTo(*in->AsArg().type,
// *decl_arg_type)) { // *decl_arg_type)) {
if (!PrecisionCompatibleTo(*in->AsArg().type, *decl_arg_type)) { if (!PrecisionCompatibleTo(*in->AsArg().type, *decl_arg_type)) {
...@@ -109,10 +221,11 @@ void PrecisionCastPass::AddCastInst(const Type& from, ...@@ -109,10 +221,11 @@ void PrecisionCastPass::AddCastInst(const Type& from,
op_desc.SetType(cast_type); op_desc.SetType(cast_type);
op_desc.SetInput("Input", {in->AsArg().name}); op_desc.SetInput("Input", {in->AsArg().name});
op_desc.SetOutput("Out", {cast_op_output_name}); op_desc.SetOutput("Out", {cast_op_output_name});
if (inst_node->AsStmt().op_info()->HasAttr("input_scale")) { float scale;
op_desc.SetAttr( if (InferScale(in, inst_node, &scale)) {
"scale", inst_node->AsStmt().op_info()->GetAttr<float>("input_scale")); op_desc.SetAttr("scale", scale);
} }
cast_op->Attach(op_desc, inst_node->AsStmt().op()->scope()); cast_op->Attach(op_desc, inst_node->AsStmt().op()->scope());
auto kernels = cast_op->CreateKernels(valid_places); auto kernels = cast_op->CreateKernels(valid_places);
std::vector<std::unique_ptr<KernelBase>> selected_kernels; std::vector<std::unique_ptr<KernelBase>> selected_kernels;
...@@ -146,9 +259,8 @@ void PrecisionCastPass::AddCastInst(const Type& from, ...@@ -146,9 +259,8 @@ void PrecisionCastPass::AddCastInst(const Type& from,
DirectedLink(cast_op_output_arg, inst_node); DirectedLink(cast_op_output_arg, inst_node);
// reset opdesc and update kernel information // reset opdesc and update kernel information
UpdateInputTo(inst_node->AsStmt().op()->mutable_op_info(), UpdateInputs(
in->AsArg().name, inst_node->AsStmt().op().get(), in->AsArg().name, cast_op_output_name);
cast_op_output_name);
// recreate the op // recreate the op
auto original_selected_kernel = auto original_selected_kernel =
......
...@@ -24,17 +24,7 @@ namespace paddle { ...@@ -24,17 +24,7 @@ namespace paddle {
namespace lite { namespace lite {
namespace mir { namespace mir {
static void UpdateInputTo(cpp::OpDesc* desc, void UpdateInputs(OpLite* op, const std::string& from, const std::string& to);
const std::string& from,
const std::string& to) {
for (auto& item : *desc->mutable_inputs()) {
for (auto& input : item.second) {
if (input == from) {
input = to;
}
}
}
}
/* /*
* The pass complement the necessary instruction to make data * The pass complement the necessary instruction to make data
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <vector> #include <vector>
#include "lite/core/mir/graph_visualize_pass.h" #include "lite/core/mir/graph_visualize_pass.h"
#include "lite/core/mir/pass_registry.h" #include "lite/core/mir/pass_registry.h"
#include "lite/core/mir/type_precision_cast_pass.h"
#include "lite/utils/string.h" #include "lite/utils/string.h"
namespace paddle { namespace paddle {
...@@ -240,9 +241,8 @@ void TypeTargetTransformPass::UpdateInstNode(Node* in, ...@@ -240,9 +241,8 @@ void TypeTargetTransformPass::UpdateInstNode(Node* in,
Node* inst_node, Node* inst_node,
std::string io_copy_output_name) { std::string io_copy_output_name) {
// reset opdesc and update kernel information // reset opdesc and update kernel information
UpdateInputTo(inst_node->AsStmt().op()->mutable_op_info(), UpdateInputs(
in->AsArg().name, inst_node->AsStmt().op().get(), in->AsArg().name, io_copy_output_name);
io_copy_output_name);
auto original_selected_kernel = auto original_selected_kernel =
std::move(inst_node->AsStmt().kernels().front()); std::move(inst_node->AsStmt().kernels().front());
auto update_op_info = *inst_node->AsStmt().op_info(); auto update_op_info = *inst_node->AsStmt().op_info();
......
...@@ -25,18 +25,6 @@ namespace paddle { ...@@ -25,18 +25,6 @@ namespace paddle {
namespace lite { namespace lite {
namespace mir { namespace mir {
static void UpdateInputTo(cpp::OpDesc* desc,
const std::string& from,
const std::string& to) {
for (auto& item : *desc->mutable_inputs()) {
for (auto& input : item.second) {
if (input == from) {
input = to;
}
}
}
}
/* /*
* IoComplementPass complement the necessary instruction to make data * IoComplementPass complement the necessary instruction to make data
* transferring or transformation between different places. * transferring or transformation between different places.
......
...@@ -154,7 +154,9 @@ KernelRegistry::KernelRegistry() ...@@ -154,7 +154,9 @@ KernelRegistry::KernelRegistry()
INIT_FOR(kX86, kInt64, kNCHW); INIT_FOR(kX86, kInt64, kNCHW);
INIT_FOR(kARM, kFloat, kNCHW); INIT_FOR(kARM, kFloat, kNCHW);
INIT_FOR(kARM, kFloat, kNHWC);
INIT_FOR(kARM, kInt8, kNCHW); INIT_FOR(kARM, kInt8, kNCHW);
INIT_FOR(kARM, kInt8, kNHWC);
INIT_FOR(kARM, kAny, kNCHW); INIT_FOR(kARM, kAny, kNCHW);
INIT_FOR(kARM, kAny, kAny); INIT_FOR(kARM, kAny, kAny);
INIT_FOR(kARM, kInt32, kNCHW); INIT_FOR(kARM, kInt32, kNCHW);
...@@ -180,8 +182,11 @@ KernelRegistry::KernelRegistry() ...@@ -180,8 +182,11 @@ KernelRegistry::KernelRegistry()
INIT_FOR(kOpenCL, kAny, kImageNW); INIT_FOR(kOpenCL, kAny, kImageNW);
INIT_FOR(kNPU, kFloat, kNCHW); INIT_FOR(kNPU, kFloat, kNCHW);
INIT_FOR(kNPU, kFloat, kNHWC);
INIT_FOR(kNPU, kInt8, kNCHW); INIT_FOR(kNPU, kInt8, kNCHW);
INIT_FOR(kNPU, kInt8, kNHWC);
INIT_FOR(kNPU, kAny, kNCHW); INIT_FOR(kNPU, kAny, kNCHW);
INIT_FOR(kNPU, kAny, kNHWC);
INIT_FOR(kNPU, kAny, kAny); INIT_FOR(kNPU, kAny, kAny);
INIT_FOR(kXPU, kFloat, kNCHW); INIT_FOR(kXPU, kFloat, kNCHW);
......
...@@ -75,6 +75,12 @@ class Optimizer { ...@@ -75,6 +75,12 @@ class Optimizer {
(defined LITE_WITH_ARM) (defined LITE_WITH_ARM)
"lite_elementwise_add_activation_fuse_pass", // "lite_elementwise_add_activation_fuse_pass", //
#endif #endif
"quantized_op_attributes_inference_pass", // Only for fully
// quantized model, infer
// the output scale and
// fix the attribute
// 'enable_int8' for all
// of the quantized ops.
"npu_subgraph_pass", "npu_subgraph_pass",
"xpu_subgraph_pass", "xpu_subgraph_pass",
"bm_subgraph_pass", "bm_subgraph_pass",
......
...@@ -75,6 +75,7 @@ void TensorLite::ShareDataWith(const TensorLite &other) { ...@@ -75,6 +75,7 @@ void TensorLite::ShareDataWith(const TensorLite &other) {
target_ = other.target_; target_ = other.target_;
lod_ = other.lod_; lod_ = other.lod_;
memory_size_ = other.memory_size_; memory_size_ = other.memory_size_;
precision_ = other.precision_;
} }
void TensorLite::CopyDataFrom(const TensorLite &other) { void TensorLite::CopyDataFrom(const TensorLite &other) {
...@@ -82,7 +83,7 @@ void TensorLite::CopyDataFrom(const TensorLite &other) { ...@@ -82,7 +83,7 @@ void TensorLite::CopyDataFrom(const TensorLite &other) {
target_ = other.target_; target_ = other.target_;
lod_ = other.lod_; lod_ = other.lod_;
memory_size_ = other.memory_size_; memory_size_ = other.memory_size_;
precision_ = other.precision(); precision_ = other.precision_;
buffer_->CopyDataFrom(*other.buffer_, memory_size_); buffer_->CopyDataFrom(*other.buffer_, memory_size_);
} }
......
...@@ -23,24 +23,24 @@ namespace lite { ...@@ -23,24 +23,24 @@ namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace arm {
void CalibComputeFp32ToInt8::Run() { template <DataLayoutType DLType>
auto& param = this->Param<operators::CalibParam>(); void CalibComputeFp32ToInt8<DLType>::Run() {
auto& param = this->template Param<operators::CalibParam>();
std::vector<float> scale = {param.scale}; std::vector<float> scale = {param.scale};
const auto* din = param.input->data<float>(); const auto* din = param.input->template data<float>();
auto* dout = param.output->mutable_data<signed char>(); auto* dout = param.output->template mutable_data<signed char>();
lite::arm::math::fp32_to_int8( lite::arm::math::fp32_to_int8(
din, dout, scale.data(), 1, 1, param.input->numel()); din, dout, scale.data(), 1, 1, param.input->numel());
return;
} }
void CalibComputeInt8ToFp32::Run() { template <DataLayoutType DLType>
auto& param = this->Param<operators::CalibParam>(); void CalibComputeInt8ToFp32<DLType>::Run() {
const auto* din = param.input->data<signed char>(); auto& param = this->template Param<operators::CalibParam>();
const auto* din = param.input->template data<signed char>();
std::vector<float> scale = {param.scale}; std::vector<float> scale = {param.scale};
auto* dout = param.output->mutable_data<float>(); auto* dout = param.output->template mutable_data<float>();
lite::arm::math::int8_to_fp32( lite::arm::math::int8_to_fp32(
din, dout, scale.data(), 1, 1, param.input->numel()); din, dout, scale.data(), 1, 1, param.input->numel());
return;
} }
} // namespace arm } // namespace arm
...@@ -48,43 +48,116 @@ void CalibComputeInt8ToFp32::Run() { ...@@ -48,43 +48,116 @@ void CalibComputeInt8ToFp32::Run() {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL(calib, REGISTER_LITE_KERNEL(
kARM, calib,
kInt8, kARM,
kNCHW, kInt8,
paddle::lite::kernels::arm::CalibComputeFp32ToInt8, kNCHW,
fp32_to_int8) paddle::lite::kernels::arm::CalibComputeFp32ToInt8<DATALAYOUT(kNCHW)>,
fp32_to_int8)
.BindInput("Input", .BindInput("Input",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))}) {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(calib, REGISTER_LITE_KERNEL(
kARM, calib,
kInt8, kARM,
kNCHW, kInt8,
paddle::lite::kernels::arm::CalibComputeInt8ToFp32, kNCHW,
int8_to_fp32) paddle::lite::kernels::arm::CalibComputeInt8ToFp32<DATALAYOUT(kNCHW)>,
int8_to_fp32)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(calib_once,
kARM, REGISTER_LITE_KERNEL(
kInt8, calib,
kNCHW, kARM,
paddle::lite::kernels::arm::CalibComputeFp32ToInt8, kInt8,
fp32_to_int8) kNHWC,
paddle::lite::kernels::arm::CalibComputeFp32ToInt8<DATALAYOUT(kNHWC)>,
fp32_to_int8)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kARM),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kARM),
PRECISION(kInt8),
DATALAYOUT(kNHWC))})
.Finalize();
REGISTER_LITE_KERNEL(
calib,
kARM,
kInt8,
kNHWC,
paddle::lite::kernels::arm::CalibComputeInt8ToFp32<DATALAYOUT(kNHWC)>,
int8_to_fp32)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kARM),
PRECISION(kInt8),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kARM),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.Finalize();
REGISTER_LITE_KERNEL(
calib_once,
kARM,
kInt8,
kNCHW,
paddle::lite::kernels::arm::CalibComputeFp32ToInt8<DATALAYOUT(kNCHW)>,
fp32_to_int8)
.BindInput("Input", .BindInput("Input",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))}) {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(calib_once, REGISTER_LITE_KERNEL(
kARM, calib_once,
kInt8, kARM,
kNCHW, kInt8,
paddle::lite::kernels::arm::CalibComputeInt8ToFp32, kNCHW,
int8_to_fp32) paddle::lite::kernels::arm::CalibComputeInt8ToFp32<DATALAYOUT(kNCHW)>,
int8_to_fp32)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))}) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(
calib_once,
kARM,
kInt8,
kNHWC,
paddle::lite::kernels::arm::CalibComputeFp32ToInt8<DATALAYOUT(kNHWC)>,
fp32_to_int8)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kARM),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kARM),
PRECISION(kInt8),
DATALAYOUT(kNHWC))})
.Finalize();
REGISTER_LITE_KERNEL(
calib_once,
kARM,
kInt8,
kNHWC,
paddle::lite::kernels::arm::CalibComputeInt8ToFp32<DATALAYOUT(kNHWC)>,
int8_to_fp32)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kARM),
PRECISION(kInt8),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kARM),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.Finalize();
...@@ -21,8 +21,9 @@ namespace lite { ...@@ -21,8 +21,9 @@ namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace arm {
template <DataLayoutType DLType>
class CalibComputeFp32ToInt8 class CalibComputeFp32ToInt8
: public KernelLite<TARGET(kARM), PRECISION(kInt8)> { : public KernelLite<TARGET(kARM), PRECISION(kInt8), DLType> {
public: public:
using param_t = operators::CalibParam; using param_t = operators::CalibParam;
...@@ -33,8 +34,9 @@ class CalibComputeFp32ToInt8 ...@@ -33,8 +34,9 @@ class CalibComputeFp32ToInt8
private: private:
}; };
template <DataLayoutType DLType>
class CalibComputeInt8ToFp32 class CalibComputeInt8ToFp32
: public KernelLite<TARGET(kARM), PRECISION(kInt8)> { : public KernelLite<TARGET(kARM), PRECISION(kInt8), DLType> {
public: public:
using param_t = operators::CalibParam; using param_t = operators::CalibParam;
......
...@@ -20,40 +20,50 @@ namespace lite { ...@@ -20,40 +20,50 @@ namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace arm {
#define NCHWTONHWC(type) \ #define NCHWTONHWC(type) \
auto& param = this->template Param<param_t>(); \ auto& param = this->template Param<param_t>(); \
auto input = param.x->template data<type>(); \ auto input = param.x->template data<type>(); \
auto input_dim = param.x->dims(); \ auto input_dim = param.x->dims(); \
CHECK(input_dim.size() == 4) \ if (input_dim.size() != 4) { \
<< "NCHW to NHWC should guarantee that the input dims should be 4"; \ LOG(WARNING) << "NCHW to NHWC should guarantee that the input dims " \
int n = input_dim[0]; \ "should be 4, but received " \
int c = input_dim[1]; \ << input_dim.size(); \
int h = input_dim[2]; \ param.y->ShareDataWith(*param.x); \
int w = input_dim[3]; \ return; \
param.y->Resize({n, h, w, c}); \ } \
auto output = param.y->template mutable_data<type>(TARGET(kARM)); \ int n = input_dim[0]; \
if (c == 1) { \ int c = input_dim[1]; \
memcpy(output, input, sizeof(type) * n * h * w); \ int h = input_dim[2]; \
return; \ int w = input_dim[3]; \
} \ param.y->Resize({n, h, w, c}); \
auto output = param.y->template mutable_data<type>(TARGET(kARM)); \
if (c == 1) { \
memcpy(output, input, sizeof(type) * n * h * w); \
return; \
} \
lite::arm::math::NCHW2NHWC<type>(n, c, h * w, input, output); lite::arm::math::NCHW2NHWC<type>(n, c, h * w, input, output);
#define NHWCTONCHW(type) \ #define NHWCTONCHW(type) \
auto& param = this->template Param<param_t>(); \ auto& param = this->template Param<param_t>(); \
auto input = param.x->template data<type>(); \ auto input = param.x->template data<type>(); \
auto input_dim = param.x->dims(); \ auto input_dim = param.x->dims(); \
CHECK(input_dim.size() == 4) \ if (input_dim.size() != 4) { \
<< "NHWC to NCHW should guarantee that the input dims should be 4"; \ LOG(WARNING) << "NHWC to NCHW should guarantee that the input dims " \
int n = input_dim[0]; \ "should be 4, but received " \
int h = input_dim[1]; \ << input_dim.size(); \
int w = input_dim[2]; \ param.y->ShareDataWith(*param.x); \
int c = input_dim[3]; \ return; \
param.y->Resize({n, c, h, w}); \ } \
auto output = param.y->template mutable_data<type>(TARGET(kARM)); \ int n = input_dim[0]; \
if (c == 1) { \ int h = input_dim[1]; \
memcpy(output, input, sizeof(type) * n * h * w); \ int w = input_dim[2]; \
return; \ int c = input_dim[3]; \
} \ param.y->Resize({n, c, h, w}); \
auto output = param.y->template mutable_data<type>(TARGET(kARM)); \
if (c == 1) { \
memcpy(output, input, sizeof(type) * n * h * w); \
return; \
} \
lite::arm::math::NHWC2NCHW<type>(n, c, h * w, input, output); lite::arm::math::NHWC2NCHW<type>(n, c, h * w, input, output);
template <> template <>
......
...@@ -20,8 +20,7 @@ namespace lite { ...@@ -20,8 +20,7 @@ namespace lite {
namespace kernels { namespace kernels {
namespace host { namespace host {
class FeedCompute class FeedCompute : public KernelLite<TARGET(kHost), PRECISION(kAny)> {
: public KernelLite<TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)> {
public: public:
using param_t = operators::FeedParam; using param_t = operators::FeedParam;
...@@ -40,7 +39,7 @@ class FeedCompute ...@@ -40,7 +39,7 @@ class FeedCompute
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(
feed, kHost, kAny, kAny, paddle::lite::kernels::host::FeedCompute, def) feed, kHost, kAny, kNCHW, paddle::lite::kernels::host::FeedCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kHost))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny))})
.Finalize(); .Finalize();
...@@ -20,8 +20,7 @@ namespace lite { ...@@ -20,8 +20,7 @@ namespace lite {
namespace kernels { namespace kernels {
namespace host { namespace host {
class FetchCompute class FetchCompute : public KernelLite<TARGET(kHost), PRECISION(kAny)> {
: public KernelLite<TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)> {
public: public:
using param_t = operators::FeedParam; using param_t = operators::FeedParam;
...@@ -43,11 +42,7 @@ class FetchCompute ...@@ -43,11 +42,7 @@ class FetchCompute
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(
fetch, kHost, kAny, kAny, paddle::lite::kernels::host::FetchCompute, def) fetch, kHost, kAny, kNCHW, paddle::lite::kernels::host::FetchCompute, def)
.BindInput("X", .BindInput("X", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny))})
{LiteType::GetTensorTy( .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny))})
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
.BindOutput("Out",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
.Finalize(); .Finalize();
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册