未验证 提交 08a3ed12 编写于 作者: H hong19860320 提交者: GitHub

[CORE] Support the fully quantized model for MTK and RK NPU (#3096)

上级 89da9953
......@@ -24,7 +24,7 @@ USE_MIR_PASS(generate_program_pass);
USE_MIR_PASS(io_copy_kernel_pick_pass);
USE_MIR_PASS(argument_type_display_pass);
USE_MIR_PASS(runtime_context_assign_pass);
USE_MIR_PASS(graph_visualze);
USE_MIR_PASS(graph_visualize_pass);
USE_MIR_PASS(lite_conv_bn_fuse_pass);
USE_MIR_PASS(lite_fc_fuse_pass);
......@@ -46,3 +46,4 @@ USE_MIR_PASS(elementwise_mul_constant_eliminate_pass)
USE_MIR_PASS(npu_subgraph_pass);
USE_MIR_PASS(xpu_subgraph_pass);
USE_MIR_PASS(weight_quantization_preprocess_pass);
USE_MIR_PASS(quantized_op_attributes_inference_pass);
......@@ -36,6 +36,7 @@ lite_cc_library(mir_passes
runtime_context_assign_pass.cc
memory_optimize_pass.cc
weight_quantization_preprocess_pass.cc
quantized_op_attributes_inference_pass.cc
DEPS mir_pass types context ${mir_fusers} ${mir_subgraphs})
# lite_cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS
......
......@@ -18,6 +18,7 @@
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "lite/core/mir/pass_registry.h"
#include "lite/utils/string.h"
......@@ -28,56 +29,98 @@ namespace mir {
using inference::analysis::Dot;
void GraphVisualizePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
Visualize(graph.get());
VLOG(5) << "\n" << Visualize(graph.get());
}
std::string Visualize(mir::SSAGraph* graph) {
std::ostringstream os;
inference::analysis::Dot dot;
int id = 0;
std::set<std::string> exists_args;
for (auto& node : graph->mutable_nodes()) {
std::string key;
if (node.IsArg()) {
key = node.AsArg().name;
} else {
key = string_format("%s%d", node.AsStmt().op_type().c_str(), id++);
auto string_trunc = [](const std::string& str) -> std::string {
const int max_disp_size = 100;
if (str.length() > max_disp_size)
return str.substr(0, max_disp_size) + "...";
return str;
};
auto attr_repr = [&](const OpInfo* op_info,
const std::string& attr_name) -> std::string {
std::ostringstream os;
using AttrType = cpp::OpDesc::AttrType;
auto attr_type = op_info->GetAttrType(attr_name);
switch (attr_type) {
case AttrType::INT:
os << ":int:" << std::to_string(op_info->GetAttr<int>(attr_name));
break;
case AttrType::FLOAT:
os << ":float:" << std::to_string(op_info->GetAttr<float>(attr_name));
break;
case AttrType::BOOLEAN:
os << ":int:" << std::to_string(op_info->GetAttr<bool>(attr_name));
break;
case AttrType::STRING:
os << ":string: \""
<< string_trunc(op_info->GetAttr<std::string>(attr_name)) << "\"";
break;
case AttrType::FLOATS: {
auto vals = op_info->GetAttr<std::vector<float>>(attr_name);
os << ":floats: {" + Join(vals, ",") << "}";
} break;
case AttrType::INTS: {
auto vals = op_info->GetAttr<std::vector<int>>(attr_name);
os << ":ints: {" + Join(vals, ",") + "}";
} break;
case AttrType::STRINGS: {
auto vals = op_info->GetAttr<std::vector<std::string>>(attr_name);
os << ":strings: {" + string_trunc(Join(vals, ",")) << "}";
} break;
default:
os << ":Unknow type(" << static_cast<int>(attr_type) << ")";
break;
}
if (node.IsStmt()) {
dot.AddNode(key,
{Dot::Attr("shape", "box"),
Dot::Attr("style", "filled"),
Dot::Attr("color", "black"),
Dot::Attr("fillcolor", "yellow")});
for (auto& x : node.inlinks) {
auto name = x->AsArg().name;
if (!exists_args.count(name)) {
dot.AddNode(name, {});
}
dot.AddEdge(name, key, {});
exists_args.insert(name);
return os.str();
};
int op_idx = 0;
std::set<std::string> exists_var_names;
for (auto& node : graph->StmtTopologicalOrder()) {
if (!node->IsStmt()) continue;
auto op_info = node->AsStmt().op_info();
auto op_type = op_info->Type();
std::string op_name = string_format("%s%d", op_type.c_str(), op_idx++);
// Add its input&output variables as the Dot nodes
dot.AddNode(op_name,
{Dot::Attr("shape", "box"),
Dot::Attr("style", "filled"),
Dot::Attr("color", "black"),
Dot::Attr("fillcolor", "yellow")});
for (auto& x : node->inlinks) {
auto var_name = x->AsArg().name;
if (!exists_var_names.count(var_name)) {
dot.AddNode(var_name, {});
exists_var_names.insert(var_name);
}
for (auto& x : node.outlinks) {
auto name = x->AsArg().name;
if (!exists_args.count(name)) {
dot.AddNode(name, {});
}
dot.AddEdge(key, name, {});
exists_args.insert(name);
dot.AddEdge(var_name, op_name, {});
}
for (auto& x : node->outlinks) {
auto var_name = x->AsArg().name;
if (!exists_var_names.count(var_name)) {
dot.AddNode(var_name, {});
exists_var_names.insert(var_name);
}
dot.AddEdge(op_name, var_name, {});
}
// Output its all of attributes(name and values)
os << "* " << op_name << "\n";
const auto& attr_names = op_info->AttrNames();
for (auto& attr_name : attr_names) {
os << " - " << attr_name << attr_repr(op_info, attr_name) << "\n";
}
}
auto res = dot.Build();
// If we use VLOG here, we can not type all graph out.
// So we change VLOG to std::cout.
std::cout << "dot:\n" << res << std::endl;
return res;
os << dot.Build();
return os.str();
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(graph_visualze, paddle::lite::mir::GraphVisualizePass)
REGISTER_MIR_PASS(graph_visualize_pass, paddle::lite::mir::GraphVisualizePass)
.BindTargets({TARGET(kAny)});
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/quantized_op_attributes_inference_pass.h"
#include <algorithm>
#include <list>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "lite/core/mir/graph_visualize_pass.h"
#include "lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void QuantizedOpAttributesInferencePass::Apply(
const std::unique_ptr<SSAGraph>& graph) {
// Only for fully quantized model which is only supported by MTK and RK NPU.
// Replace the output_scale with the input_scale of the adjacent quantized
// ops, and fix the missing of the attribute 'enable_int8'.
for (auto& op_node : graph->StmtTopologicalOrder()) {
if (!op_node->IsStmt()) continue;
auto& inst = op_node->AsStmt();
auto op_info = inst.op_info();
auto op_type = op_info->Type();
if (!op_info->HasAttr("input_scale")) continue;
bool found = false;
float output_scale;
for (auto out_var_node : op_node->outlinks) {
CHECK(out_var_node->IsArg());
for (auto out_op_node : out_var_node->outlinks) {
CHECK(out_op_node->IsStmt());
auto& out_inst = out_op_node->AsStmt();
auto out_op_info = out_inst.op_info();
if (!out_op_info->HasAttr("input_scale")) continue;
auto input_scale = out_op_info->GetAttr<float>("input_scale");
if (!found) {
found = true;
output_scale = input_scale;
} else {
CHECK_EQ(output_scale, input_scale);
}
}
}
if (found) {
inst.mutable_op_info()->SetAttr("output_scale", output_scale);
}
if (op_info->HasAttr("output_scale")) {
inst.mutable_op_info()->SetAttr("enable_int8", true);
}
}
VLOG(5) << "\n" << Visualize(graph.get());
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(quantized_op_attributes_inference_pass,
paddle::lite::mir::QuantizedOpAttributesInferencePass)
.BindTargets({TARGET(kNPU)});
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <limits>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "lite/core/mir/pass.h"
#include "lite/core/types.h"
namespace paddle {
namespace lite {
namespace mir {
class QuantizedOpAttributesInferencePass : public mir::StmtPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};
} // namespace mir
} // namespace lite
} // namespace paddle
......@@ -140,9 +140,18 @@ void SSAGraph::Build(const Program &program,
arg_node->AsArg(name, node_storage_.size() - 1);
arg_update_node_map_[name] = arg_node;
}
if (var_types.count(name) && !arg_node->arg()->type) {
arg_node->arg()->type = LiteType::GetTensorTy(
TARGET(kUnk), var_types[name], DATALAYOUT(kUnk));
if (var_types.count(name)) {
if (!arg_node->arg()->type) {
arg_node->arg()->type = LiteType::GetTensorTy(
TARGET(kUnk), var_types[name], DATALAYOUT(kUnk));
}
// Store the original data type of the output tensors for
// type_precision_cast_pass, to keep the consistency between the
// output types of original graph and optimized graph's
if (op->op_info()->Type() == "fetch") {
op->mutable_op_info()->SetAttr<int>(
"data_type", static_cast<int>(var_types[name]));
}
}
if (is_weights(name)) arg_node->AsArg().is_weight = true;
CHECK(arg_node->IsRoleSet());
......
......@@ -372,6 +372,37 @@ void SubgraphFuser::InsertNewNode(SSAGraph *graph,
subgraph_op_desc.SetAttr<std::vector<std::string>>("output_data_names",
output_var_names);
// Set input/output scale values of input/output var nodes for
// type_precision_cast_pass.
std::vector<float> input_data_scales;
std::vector<float> output_data_scales;
for (auto &var_node : input_var_nodes) {
auto any_op_node = var_node->outlinks.front();
CHECK(any_op_node->IsStmt());
auto &any_inst = any_op_node->AsStmt();
if (any_inst.op_info()->HasAttr("input_scale")) {
input_data_scales.push_back(
any_inst.op_info()->GetAttr<float>("input_scale"));
}
}
for (auto &var_node : output_var_nodes) {
auto any_op_node = var_node->inlinks.front();
CHECK(any_op_node->IsStmt());
auto &any_inst = any_op_node->AsStmt();
if (any_inst.op_info()->HasAttr("output_scale")) {
output_data_scales.push_back(
any_inst.op_info()->GetAttr<float>("output_scale"));
}
}
if (input_data_scales.size() > 0) {
subgraph_op_desc.SetAttr<std::vector<float>>("input_data_scales",
input_data_scales);
}
if (output_data_scales.size() > 0) {
subgraph_op_desc.SetAttr<std::vector<float>>("output_data_scales",
output_data_scales);
}
// Set all of the inputs and outputs to the target subgraph op
// To prevent vars are removed in RuntimeProgram::UpdateVarsOfProgram()
for (auto &var_node : weight_var_nodes) {
......
......@@ -20,6 +20,8 @@
#include <vector>
#include "lite/core/mir/graph_visualize_pass.h"
#include "lite/core/mir/pass_registry.h"
#include "lite/core/mir/type_precision_cast_pass.h"
#include "lite/operators/subgraph_op.h"
#include "lite/utils/string.h"
namespace paddle {
......@@ -170,9 +172,8 @@ void TypeLayoutTransformPass::AddLayoutInst(
DirectedLink(layout_output_arg, inst_node);
// reset opdesc and update kernel information
UpdateInputTo(inst_node->AsStmt().op()->mutable_op_info(),
in->AsArg().name,
layout_output_name);
UpdateInputs(
inst_node->AsStmt().op().get(), in->AsArg().name, layout_output_name);
auto original_selected_kernel =
std::move(inst_node->AsStmt().kernels().front());
auto update_op_info = *inst_node->AsStmt().op_info();
......
......@@ -24,18 +24,6 @@ namespace paddle {
namespace lite {
namespace mir {
static void UpdateInputTo(cpp::OpDesc* desc,
const std::string& from,
const std::string& to) {
for (auto& item : *desc->mutable_inputs()) {
for (auto& input : item.second) {
if (input == from) {
input = to;
}
}
}
}
class TypeLayoutTransformPass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
......
......@@ -20,11 +20,115 @@
#include <vector>
#include "lite/core/mir/graph_visualize_pass.h"
#include "lite/core/mir/pass_registry.h"
#include "lite/operators/subgraph_op.h"
namespace paddle {
namespace lite {
namespace mir {
// For the subgraph op, we also need to update the attr 'input_data_names' and
// the input variables names of the Ops in the subblock.
void UpdateInputsForSubgraph(OpLite* op,
const std::string& from,
const std::string& to) {
auto* op_desc = op->mutable_op_info();
auto input_data_names =
op_desc->GetAttr<std::vector<std::string>>("input_data_names");
std::replace(input_data_names.begin(), input_data_names.end(), from, to);
op_desc->SetAttr("input_data_names", input_data_names);
auto* subblock_desc = static_cast<operators::SubgraphOp*>(op)->GetSubBlock();
CHECK(subblock_desc);
for (size_t i = 0; i < subblock_desc->OpsSize(); i++) {
auto* subblock_op_desc = subblock_desc->GetOp<cpp::OpDesc>(i);
for (auto& subblock_op_input : *subblock_op_desc->mutable_inputs()) {
for (auto& subblock_var_name : subblock_op_input.second) {
if (subblock_var_name == from) {
subblock_var_name = to;
}
}
}
}
}
// Update the input variable names from 'from' to 'to' for the target Op
void UpdateInputs(OpLite* op, const std::string& from, const std::string& to) {
auto* op_desc = op->mutable_op_info();
auto op_type = op_desc->Type();
for (auto& op_input : *op_desc->mutable_inputs()) {
for (auto& var_name : op_input.second) {
if (var_name == from) {
var_name = to;
}
}
}
if (op_type == "subgraph") {
UpdateInputsForSubgraph(op, from, to);
}
}
// Infer the scale value for the new calib op from the subgraph op
static bool InferScaleFromSubgraph(std::string var_name,
const OpInfo* op_info,
float* scale,
bool reverse = false) {
bool found = false;
auto input_or_output_names = op_info->GetAttr<std::vector<std::string>>(
reverse ? "output_data_names" : "input_data_names");
auto input_or_output_scales = op_info->GetAttr<std::vector<float>>(
reverse ? "output_data_scales" : "input_data_scales");
auto size = input_or_output_names.size();
CHECK(size == input_or_output_scales.size());
for (int i = 0; i < size; i++) {
if (input_or_output_names[i] == var_name) {
*scale = input_or_output_scales[i];
found = true;
break;
}
}
return found;
}
// Infer the scale value for the new calib op from the input_scale of the
// current op and output_scale of the previous op.
// case 1: prev_op->var_node->op_node(int8->any op, with input_scale).
// case 2: prev_op->var_node->op_node(subgraph op, int8->any, with
// input_data_scales).
// case 3: prev_op(any->int8, with output_scale)->var_node->op_node(fp32->any,
// without input_scale).
// case 4: prev_op(any->int8, subgraph_op, with
// output_data_scales)->var_node->op_node(fp32->any, without input_scale).
static bool InferScale(Node* var_node, Node* op_node, float* scale) {
bool found = false;
auto& inst = op_node->AsStmt();
auto op_info = inst.op_info();
auto op_type = op_info->Type();
auto var_name = var_node->AsArg().name;
if (op_type == "subgraph") {
found = InferScaleFromSubgraph(var_name, op_info, scale, false);
} else {
if (op_info->HasAttr("input_scale")) {
*scale = op_info->GetAttr<float>("input_scale");
found = true;
} else {
// Obtain the output_scale from one of its previous Ops
auto prev_op_node = var_node->inlinks.front();
CHECK(prev_op_node->IsStmt());
auto& prev_inst = prev_op_node->AsStmt();
auto prev_op_info = prev_inst.op_info();
auto prev_op_type = prev_op_info->Type();
if (prev_op_type == "subgraph") {
found = InferScaleFromSubgraph(var_name, prev_op_info, scale, true);
} else {
if (prev_op_info->HasAttr("output_scale")) {
*scale = prev_op_info->GetAttr<float>("output_scale");
found = true;
}
}
}
}
return found;
}
void PrecisionCastPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// Start from inputs of the graph, those should have place set.
std::list<Node*> nodes;
......@@ -59,6 +163,14 @@ void PrecisionCastPass::ComplementInputs(SSAGraph* graph,
auto decl_arg_type = inst.picked_kernel().GetInputDeclType(tmp);
CHECK(in->AsArg().type);
VLOG(4) << inst.picked_kernel().name();
if (inst.op_info()->Type() == "fetch") {
if (inst.op_info()->HasAttr("data_type")) {
auto data_type =
static_cast<PrecisionType>(inst.op_info()->GetAttr<int>("data_type"));
decl_arg_type = LiteType::GetTensorTy(
decl_arg_type->target(), data_type, decl_arg_type->layout());
}
}
// if (!in->AsArg().is_weight && !PrecisionCompatibleTo(*in->AsArg().type,
// *decl_arg_type)) {
if (!PrecisionCompatibleTo(*in->AsArg().type, *decl_arg_type)) {
......@@ -109,10 +221,11 @@ void PrecisionCastPass::AddCastInst(const Type& from,
op_desc.SetType(cast_type);
op_desc.SetInput("Input", {in->AsArg().name});
op_desc.SetOutput("Out", {cast_op_output_name});
if (inst_node->AsStmt().op_info()->HasAttr("input_scale")) {
op_desc.SetAttr(
"scale", inst_node->AsStmt().op_info()->GetAttr<float>("input_scale"));
float scale;
if (InferScale(in, inst_node, &scale)) {
op_desc.SetAttr("scale", scale);
}
cast_op->Attach(op_desc, inst_node->AsStmt().op()->scope());
auto kernels = cast_op->CreateKernels(valid_places);
std::vector<std::unique_ptr<KernelBase>> selected_kernels;
......@@ -146,9 +259,8 @@ void PrecisionCastPass::AddCastInst(const Type& from,
DirectedLink(cast_op_output_arg, inst_node);
// reset opdesc and update kernel information
UpdateInputTo(inst_node->AsStmt().op()->mutable_op_info(),
in->AsArg().name,
cast_op_output_name);
UpdateInputs(
inst_node->AsStmt().op().get(), in->AsArg().name, cast_op_output_name);
// recreate the op
auto original_selected_kernel =
......
......@@ -24,17 +24,7 @@ namespace paddle {
namespace lite {
namespace mir {
static void UpdateInputTo(cpp::OpDesc* desc,
const std::string& from,
const std::string& to) {
for (auto& item : *desc->mutable_inputs()) {
for (auto& input : item.second) {
if (input == from) {
input = to;
}
}
}
}
void UpdateInputs(OpLite* op, const std::string& from, const std::string& to);
/*
* The pass complement the necessary instruction to make data
......
......@@ -21,6 +21,7 @@
#include <vector>
#include "lite/core/mir/graph_visualize_pass.h"
#include "lite/core/mir/pass_registry.h"
#include "lite/core/mir/type_precision_cast_pass.h"
#include "lite/utils/string.h"
namespace paddle {
......@@ -240,9 +241,8 @@ void TypeTargetTransformPass::UpdateInstNode(Node* in,
Node* inst_node,
std::string io_copy_output_name) {
// reset opdesc and update kernel information
UpdateInputTo(inst_node->AsStmt().op()->mutable_op_info(),
in->AsArg().name,
io_copy_output_name);
UpdateInputs(
inst_node->AsStmt().op().get(), in->AsArg().name, io_copy_output_name);
auto original_selected_kernel =
std::move(inst_node->AsStmt().kernels().front());
auto update_op_info = *inst_node->AsStmt().op_info();
......
......@@ -25,18 +25,6 @@ namespace paddle {
namespace lite {
namespace mir {
static void UpdateInputTo(cpp::OpDesc* desc,
const std::string& from,
const std::string& to) {
for (auto& item : *desc->mutable_inputs()) {
for (auto& input : item.second) {
if (input == from) {
input = to;
}
}
}
}
/*
* IoComplementPass complement the necessary instruction to make data
* transferring or transformation between different places.
......
......@@ -154,7 +154,9 @@ KernelRegistry::KernelRegistry()
INIT_FOR(kX86, kInt64, kNCHW);
INIT_FOR(kARM, kFloat, kNCHW);
INIT_FOR(kARM, kFloat, kNHWC);
INIT_FOR(kARM, kInt8, kNCHW);
INIT_FOR(kARM, kInt8, kNHWC);
INIT_FOR(kARM, kAny, kNCHW);
INIT_FOR(kARM, kAny, kAny);
INIT_FOR(kARM, kInt32, kNCHW);
......@@ -180,8 +182,11 @@ KernelRegistry::KernelRegistry()
INIT_FOR(kOpenCL, kAny, kImageNW);
INIT_FOR(kNPU, kFloat, kNCHW);
INIT_FOR(kNPU, kFloat, kNHWC);
INIT_FOR(kNPU, kInt8, kNCHW);
INIT_FOR(kNPU, kInt8, kNHWC);
INIT_FOR(kNPU, kAny, kNCHW);
INIT_FOR(kNPU, kAny, kNHWC);
INIT_FOR(kNPU, kAny, kAny);
INIT_FOR(kXPU, kFloat, kNCHW);
......
......@@ -75,6 +75,12 @@ class Optimizer {
(defined LITE_WITH_ARM)
"lite_elementwise_add_activation_fuse_pass", //
#endif
"quantized_op_attributes_inference_pass", // Only for fully
// quantized model, infer
// the output scale and
// fix the attribute
// 'enable_int8' for all
// of the quantized ops.
"npu_subgraph_pass",
"xpu_subgraph_pass",
"bm_subgraph_pass",
......
......@@ -75,6 +75,7 @@ void TensorLite::ShareDataWith(const TensorLite &other) {
target_ = other.target_;
lod_ = other.lod_;
memory_size_ = other.memory_size_;
precision_ = other.precision_;
}
void TensorLite::CopyDataFrom(const TensorLite &other) {
......@@ -82,7 +83,7 @@ void TensorLite::CopyDataFrom(const TensorLite &other) {
target_ = other.target_;
lod_ = other.lod_;
memory_size_ = other.memory_size_;
precision_ = other.precision();
precision_ = other.precision_;
buffer_->CopyDataFrom(*other.buffer_, memory_size_);
}
......
......@@ -23,24 +23,24 @@ namespace lite {
namespace kernels {
namespace arm {
void CalibComputeFp32ToInt8::Run() {
auto& param = this->Param<operators::CalibParam>();
template <DataLayoutType DLType>
void CalibComputeFp32ToInt8<DLType>::Run() {
auto& param = this->template Param<operators::CalibParam>();
std::vector<float> scale = {param.scale};
const auto* din = param.input->data<float>();
auto* dout = param.output->mutable_data<signed char>();
const auto* din = param.input->template data<float>();
auto* dout = param.output->template mutable_data<signed char>();
lite::arm::math::fp32_to_int8(
din, dout, scale.data(), 1, 1, param.input->numel());
return;
}
void CalibComputeInt8ToFp32::Run() {
auto& param = this->Param<operators::CalibParam>();
const auto* din = param.input->data<signed char>();
template <DataLayoutType DLType>
void CalibComputeInt8ToFp32<DLType>::Run() {
auto& param = this->template Param<operators::CalibParam>();
const auto* din = param.input->template data<signed char>();
std::vector<float> scale = {param.scale};
auto* dout = param.output->mutable_data<float>();
auto* dout = param.output->template mutable_data<float>();
lite::arm::math::int8_to_fp32(
din, dout, scale.data(), 1, 1, param.input->numel());
return;
}
} // namespace arm
......@@ -48,43 +48,116 @@ void CalibComputeInt8ToFp32::Run() {
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(calib,
kARM,
kInt8,
kNCHW,
paddle::lite::kernels::arm::CalibComputeFp32ToInt8,
fp32_to_int8)
REGISTER_LITE_KERNEL(
calib,
kARM,
kInt8,
kNCHW,
paddle::lite::kernels::arm::CalibComputeFp32ToInt8<DATALAYOUT(kNCHW)>,
fp32_to_int8)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.Finalize();
REGISTER_LITE_KERNEL(calib,
kARM,
kInt8,
kNCHW,
paddle::lite::kernels::arm::CalibComputeInt8ToFp32,
int8_to_fp32)
REGISTER_LITE_KERNEL(
calib,
kARM,
kInt8,
kNCHW,
paddle::lite::kernels::arm::CalibComputeInt8ToFp32<DATALAYOUT(kNCHW)>,
int8_to_fp32)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.Finalize();
REGISTER_LITE_KERNEL(calib_once,
kARM,
kInt8,
kNCHW,
paddle::lite::kernels::arm::CalibComputeFp32ToInt8,
fp32_to_int8)
REGISTER_LITE_KERNEL(
calib,
kARM,
kInt8,
kNHWC,
paddle::lite::kernels::arm::CalibComputeFp32ToInt8<DATALAYOUT(kNHWC)>,
fp32_to_int8)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kARM),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kARM),
PRECISION(kInt8),
DATALAYOUT(kNHWC))})
.Finalize();
REGISTER_LITE_KERNEL(
calib,
kARM,
kInt8,
kNHWC,
paddle::lite::kernels::arm::CalibComputeInt8ToFp32<DATALAYOUT(kNHWC)>,
int8_to_fp32)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kARM),
PRECISION(kInt8),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kARM),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.Finalize();
REGISTER_LITE_KERNEL(
calib_once,
kARM,
kInt8,
kNCHW,
paddle::lite::kernels::arm::CalibComputeFp32ToInt8<DATALAYOUT(kNCHW)>,
fp32_to_int8)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.Finalize();
REGISTER_LITE_KERNEL(calib_once,
kARM,
kInt8,
kNCHW,
paddle::lite::kernels::arm::CalibComputeInt8ToFp32,
int8_to_fp32)
REGISTER_LITE_KERNEL(
calib_once,
kARM,
kInt8,
kNCHW,
paddle::lite::kernels::arm::CalibComputeInt8ToFp32<DATALAYOUT(kNCHW)>,
int8_to_fp32)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt8))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.Finalize();
REGISTER_LITE_KERNEL(
calib_once,
kARM,
kInt8,
kNHWC,
paddle::lite::kernels::arm::CalibComputeFp32ToInt8<DATALAYOUT(kNHWC)>,
fp32_to_int8)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kARM),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kARM),
PRECISION(kInt8),
DATALAYOUT(kNHWC))})
.Finalize();
REGISTER_LITE_KERNEL(
calib_once,
kARM,
kInt8,
kNHWC,
paddle::lite::kernels::arm::CalibComputeInt8ToFp32<DATALAYOUT(kNHWC)>,
int8_to_fp32)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kARM),
PRECISION(kInt8),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kARM),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.Finalize();
......@@ -21,8 +21,9 @@ namespace lite {
namespace kernels {
namespace arm {
template <DataLayoutType DLType>
class CalibComputeFp32ToInt8
: public KernelLite<TARGET(kARM), PRECISION(kInt8)> {
: public KernelLite<TARGET(kARM), PRECISION(kInt8), DLType> {
public:
using param_t = operators::CalibParam;
......@@ -33,8 +34,9 @@ class CalibComputeFp32ToInt8
private:
};
template <DataLayoutType DLType>
class CalibComputeInt8ToFp32
: public KernelLite<TARGET(kARM), PRECISION(kInt8)> {
: public KernelLite<TARGET(kARM), PRECISION(kInt8), DLType> {
public:
using param_t = operators::CalibParam;
......
......@@ -20,40 +20,50 @@ namespace lite {
namespace kernels {
namespace arm {
#define NCHWTONHWC(type) \
auto& param = this->template Param<param_t>(); \
auto input = param.x->template data<type>(); \
auto input_dim = param.x->dims(); \
CHECK(input_dim.size() == 4) \
<< "NCHW to NHWC should guarantee that the input dims should be 4"; \
int n = input_dim[0]; \
int c = input_dim[1]; \
int h = input_dim[2]; \
int w = input_dim[3]; \
param.y->Resize({n, h, w, c}); \
auto output = param.y->template mutable_data<type>(TARGET(kARM)); \
if (c == 1) { \
memcpy(output, input, sizeof(type) * n * h * w); \
return; \
} \
#define NCHWTONHWC(type) \
auto& param = this->template Param<param_t>(); \
auto input = param.x->template data<type>(); \
auto input_dim = param.x->dims(); \
if (input_dim.size() != 4) { \
LOG(WARNING) << "NCHW to NHWC should guarantee that the input dims " \
"should be 4, but received " \
<< input_dim.size(); \
param.y->ShareDataWith(*param.x); \
return; \
} \
int n = input_dim[0]; \
int c = input_dim[1]; \
int h = input_dim[2]; \
int w = input_dim[3]; \
param.y->Resize({n, h, w, c}); \
auto output = param.y->template mutable_data<type>(TARGET(kARM)); \
if (c == 1) { \
memcpy(output, input, sizeof(type) * n * h * w); \
return; \
} \
lite::arm::math::NCHW2NHWC<type>(n, c, h * w, input, output);
#define NHWCTONCHW(type) \
auto& param = this->template Param<param_t>(); \
auto input = param.x->template data<type>(); \
auto input_dim = param.x->dims(); \
CHECK(input_dim.size() == 4) \
<< "NHWC to NCHW should guarantee that the input dims should be 4"; \
int n = input_dim[0]; \
int h = input_dim[1]; \
int w = input_dim[2]; \
int c = input_dim[3]; \
param.y->Resize({n, c, h, w}); \
auto output = param.y->template mutable_data<type>(TARGET(kARM)); \
if (c == 1) { \
memcpy(output, input, sizeof(type) * n * h * w); \
return; \
} \
#define NHWCTONCHW(type) \
auto& param = this->template Param<param_t>(); \
auto input = param.x->template data<type>(); \
auto input_dim = param.x->dims(); \
if (input_dim.size() != 4) { \
LOG(WARNING) << "NHWC to NCHW should guarantee that the input dims " \
"should be 4, but received " \
<< input_dim.size(); \
param.y->ShareDataWith(*param.x); \
return; \
} \
int n = input_dim[0]; \
int h = input_dim[1]; \
int w = input_dim[2]; \
int c = input_dim[3]; \
param.y->Resize({n, c, h, w}); \
auto output = param.y->template mutable_data<type>(TARGET(kARM)); \
if (c == 1) { \
memcpy(output, input, sizeof(type) * n * h * w); \
return; \
} \
lite::arm::math::NHWC2NCHW<type>(n, c, h * w, input, output);
template <>
......
......@@ -20,8 +20,7 @@ namespace lite {
namespace kernels {
namespace host {
class FeedCompute
: public KernelLite<TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)> {
class FeedCompute : public KernelLite<TARGET(kHost), PRECISION(kAny)> {
public:
using param_t = operators::FeedParam;
......@@ -40,7 +39,7 @@ class FeedCompute
} // namespace paddle
REGISTER_LITE_KERNEL(
feed, kHost, kAny, kAny, paddle::lite::kernels::host::FeedCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kHost))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))})
feed, kHost, kAny, kNCHW, paddle::lite::kernels::host::FeedCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny))})
.Finalize();
......@@ -20,8 +20,7 @@ namespace lite {
namespace kernels {
namespace host {
class FetchCompute
: public KernelLite<TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)> {
class FetchCompute : public KernelLite<TARGET(kHost), PRECISION(kAny)> {
public:
using param_t = operators::FeedParam;
......@@ -43,11 +42,7 @@ class FetchCompute
} // namespace paddle
REGISTER_LITE_KERNEL(
fetch, kHost, kAny, kAny, paddle::lite::kernels::host::FetchCompute, def)
.BindInput("X",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
.BindOutput("Out",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
fetch, kHost, kAny, kNCHW, paddle::lite::kernels::host::FetchCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny))})
.Finalize();
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册