提交 aefb4ea3 编写于 作者: J juncaipeng 提交者: GitHub

Optimize quant_dequant (#2215)

* Add DeleteQuantOpFuser
* Add fake_quantize_dequantize_moving_avg_abs_max_op
* Add DeleteQuantDequantOpFuser
上级 52a093ee
......@@ -14,6 +14,7 @@
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <fstream>
#include <vector>
#include "lite/api/cxx_api.h"
#include "lite/api/paddle_use_kernels.h"
......@@ -22,6 +23,10 @@
#include "lite/api/test_helper.h"
#include "lite/core/op_registry.h"
DEFINE_string(input_img_txt_path,
"",
"if set input_img_txt_path, read the img filename as input.");
namespace paddle {
namespace lite {
......@@ -36,8 +41,18 @@ void TestModel(const std::vector<Place>& valid_places) {
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224})));
auto* data = input_tensor->mutable_data<float>();
auto item_size = input_tensor->dims().production();
for (int i = 0; i < item_size; i++) {
data[i] = 1;
if (FLAGS_input_img_txt_path.empty()) {
for (int i = 0; i < item_size; i++) {
data[i] = 1;
}
} else {
std::fstream fs(FLAGS_input_img_txt_path, std::ios::in);
if (!fs.is_open()) {
LOG(FATAL) << "open input_img_txt error.";
}
for (int i = 0; i < item_size; i++) {
fs >> data[i];
}
}
for (int i = 0; i < FLAGS_warmup; ++i) {
......
......@@ -15,7 +15,6 @@
#include "lite/core/mir/fusion/quant_dequant_fuse_pass.h"
#include <list>
#include <memory>
#include <unordered_set>
#include <vector>
#include "lite/api/paddle_place.h"
#include "lite/core/mir/fusion/quant_dequant_op_fuser.h"
......@@ -26,63 +25,25 @@ namespace lite {
namespace mir {
void QuantDequantFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// obtain useful values and save to quantized_node, remove quant_nodes and
// releated nodes
std::unordered_set<std::string> quant_types = {
// delete quant node
std::vector<std::string> quant_op_types = {
"fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"};
std::vector<Node*> quant_nodes;
for (auto& cur_node : graph->mutable_nodes()) {
if (cur_node.IsStmt() && quant_types.count(cur_node.stmt()->op_type())) {
quant_nodes.push_back(&cur_node);
}
}
for (auto quant_node : quant_nodes) {
// find input nodes and output nodes
std::list<Node*> input_nodes = quant_node->inlinks;
std::list<Node*> output_nodes = quant_node->outlinks;
CHECK_EQ(input_nodes.size(), 2);
CHECK_EQ(output_nodes.size(), 2);
bool front_is_scale = input_nodes.front()->arg()->is_weight;
Node* input_scale_node =
front_is_scale ? input_nodes.front() : input_nodes.back();
Node* input_act_node =
front_is_scale ? input_nodes.back() : input_nodes.front();
front_is_scale = output_nodes.front()->arg()->is_weight;
Node* output_scale_node =
front_is_scale ? output_nodes.front() : output_nodes.back();
Node* output_act_node =
front_is_scale ? output_nodes.back() : output_nodes.front();
// relink nodes and save value to quantized_node
int bit_length = quant_node->stmt()->op_info()->GetAttr<int>("bit_length");
int range = ((1 << (bit_length - 1)) - 1);
auto* scope = quant_node->stmt()->op()->scope();
auto scale_tensor = scope->FindVar(output_scale_node->arg()->name)
->GetMutable<lite::Tensor>();
float scale_value = scale_tensor->data<float>()[0] / range;
auto outlinks = output_act_node->outlinks;
for (auto* quantized_node_ptr : outlinks) {
quantized_node_ptr->stmt()->mutable_op_info()->SetAttr<int>("bit_length",
bit_length);
quantized_node_ptr->stmt()->mutable_op_info()->SetAttr<float>(
"input_scale", scale_value);
IR_NODE_LINK_TO(input_act_node, quantized_node_ptr)
RemoveDirectedLink(output_act_node, quantized_node_ptr);
}
// delete nodes and edges
std::unordered_set<const Node*> nodes2rm = {
input_scale_node, quant_node, output_scale_node, output_act_node};
GraphSafeRemoveNodes(graph.get(), nodes2rm);
for (auto& op_type : quant_op_types) {
fusion::DeleteQuantOpFuser fuser(op_type);
fuser(graph.get());
}
// fuse quantized node and dequant node
std::unordered_set<std::string> quantized_op_types = {
std::vector<std::string> quantized_op_types = {
"conv2d", "mul", "depthwise_conv2d"};
for (auto& op_type : quantized_op_types) {
fusion::QuantDequantOpFuser fuser(op_type);
fusion::DequantOpFuser fuser(op_type);
fuser(graph.get());
}
// delete quant_dequant_node
for (auto op_type : {"pool2d", "elementwise_add"}) {
fusion::DeleteQuantDequantOpFuser fuser(op_type);
fuser(graph.get());
}
}
......
......@@ -14,6 +14,7 @@
#include "lite/core/mir/fusion/quant_dequant_op_fuser.h"
#include <memory>
#include <unordered_set>
#include <vector>
#include "lite/utils/string.h"
......@@ -22,7 +23,61 @@ namespace lite {
namespace mir {
namespace fusion {
void QuantDequantOpFuser::BuildPattern() {
void DeleteQuantOpFuser::BuildPattern() {
auto* input_scale_node = VarNode("input_scale_node")
->assert_is_op_input(quant_op_type_, "InScale");
auto* input_act_node =
VarNode("input_act_node")->assert_is_op_input(quant_op_type_, "X");
auto* quant_node =
OpNode("quant_node", quant_op_type_)->assert_is_op(quant_op_type_);
auto* output_scale_node =
VarNode("output_scale_node")
->assert_is_op_output(quant_op_type_, "OutScale");
auto* output_act_node =
VarNode("output_act_node")->assert_is_op_output(quant_op_type_, "Out");
quant_node->LinksFrom({input_scale_node, input_act_node});
output_scale_node->LinksFrom({quant_node});
output_act_node->LinksFrom({quant_node});
VLOG(4) << "DeleteQuantOpFuser BuildPattern quant_op_type:" << quant_op_type_;
}
void DeleteQuantOpFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
auto* input_scale_node = matched.at("input_scale_node");
auto* input_act_node = matched.at("input_act_node");
auto* quant_node = matched.at("quant_node");
auto* output_scale_node = matched.at("output_scale_node");
auto* output_act_node = matched.at("output_act_node");
// obtain values, save values and relink node
int bit_length = quant_node->stmt()->op_info()->GetAttr<int>("bit_length");
int range = ((1 << (bit_length - 1)) - 1);
auto* scope = quant_node->stmt()->op()->scope();
auto* scale_tensor = scope->FindVar(output_scale_node->arg()->name)
->GetMutable<lite::Tensor>();
float scale_value = scale_tensor->data<float>()[0] / range;
auto outlinks = output_act_node->outlinks;
for (auto* quantized_node : outlinks) {
auto* op_desc = quantized_node->stmt()->mutable_op_info();
op_desc->SetAttr<int>("bit_length", bit_length);
op_desc->SetAttr<float>("input_scale", scale_value);
IR_NODE_LINK_TO(input_act_node, quantized_node)
}
// delete nodes and edges
std::unordered_set<const Node*> nodes2rm = {
input_scale_node, quant_node, output_scale_node, output_act_node};
GraphSafeRemoveNodes(graph, nodes2rm);
}
cpp::OpDesc DeleteQuantOpFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc;
return op_desc;
}
void DequantOpFuser::BuildPattern() {
std::string weight_name = "";
if (op_type_ == "conv2d" || op_type_ == "depthwise_conv2d") {
weight_name = "Filter";
......@@ -55,10 +110,11 @@ void QuantDequantOpFuser::BuildPattern() {
quantized_op_out->LinksFrom({quantized_op});
dequant_op->LinksFrom({quantized_op_out});
dequant_op_out->LinksFrom({dequant_op});
VLOG(4) << "DeQuantOpFuser BuildPattern op_type:" << op_type_;
}
void QuantDequantOpFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
void DequantOpFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
auto* quant_op_input = matched.at("quantized_op_input");
auto* quantized_op_weight = matched.at("quantized_op_weight");
auto* quantized_op = matched.at("quantized_op");
......@@ -127,7 +183,174 @@ void QuantDequantOpFuser::InsertNewNode(SSAGraph* graph,
IR_NODE_LINK_TO(new_quantized_op_node, dequant_op_out);
}
cpp::OpDesc QuantDequantOpFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc DequantOpFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc;
return op_desc;
}
void DeleteQuantDequantOpFuser::BuildPattern() {
std::string quant_dequant_op_type =
"fake_quantize_dequantize_moving_average_abs_max";
if (quantized_op_type_ == "pool2d") {
auto* input_scale_node =
VarNode("input_scale_node")
->assert_is_op_input(quant_dequant_op_type, "InScale");
auto* input_act_node = VarNode("input_act_node")
->assert_is_op_input(quant_dequant_op_type, "X");
auto* quant_dequant_node =
OpNode("quant_dequant_node", quant_dequant_op_type)
->assert_is_op(quant_dequant_op_type);
auto* output_scale_node =
VarNode("output_scale_node")
->assert_is_op_output(quant_dequant_op_type, "OutScale");
auto* output_act_node =
VarNode("output_act_node")
->assert_is_op_output(quant_dequant_op_type, "Out");
auto* quantized_node = OpNode("quantized_node", quantized_op_type_)
->assert_is_op(quantized_op_type_);
quant_dequant_node->LinksFrom({input_scale_node, input_act_node});
output_scale_node->LinksFrom({quant_dequant_node});
output_act_node->LinksFrom({quant_dequant_node});
quantized_node->LinksFrom({output_act_node});
} else if (quantized_op_type_ == "elementwise_add") {
auto* input_scale_left_node =
VarNode("input_scale_left_node")
->assert_is_op_input(quant_dequant_op_type, "InScale");
auto* input_act_left_node =
VarNode("input_act_left_node")
->assert_is_op_input(quant_dequant_op_type, "X");
auto* quant_dequant_left_node =
OpNode("quant_dequant_left_node", quant_dequant_op_type)
->assert_is_op(quant_dequant_op_type);
auto* output_scale_left_node =
VarNode("output_scale_left_node")
->assert_is_op_output(quant_dequant_op_type, "OutScale");
auto* output_act_left_node =
VarNode("output_act_left_node")
->assert_is_op_output(quant_dequant_op_type, "Out")
->assert_is_op_input(quantized_op_type_, "X");
quant_dequant_left_node->LinksFrom(
{input_scale_left_node, input_act_left_node});
output_scale_left_node->LinksFrom({quant_dequant_left_node});
output_act_left_node->LinksFrom({quant_dequant_left_node});
auto* input_scale_right_node =
VarNode("input_scale_right_node")
->assert_is_op_input(quant_dequant_op_type, "InScale");
auto* input_act_right_node =
VarNode("input_act_right_node")
->assert_is_op_input(quant_dequant_op_type, "X");
auto* quant_dequant_right_node =
OpNode("quant_dequant_right_node", quant_dequant_op_type)
->assert_is_op(quant_dequant_op_type);
auto* output_scale_right_node =
VarNode("output_scale_right_node")
->assert_is_op_output(quant_dequant_op_type, "OutScale");
auto* output_act_right_node =
VarNode("output_act_right_node")
->assert_is_op_output(quant_dequant_op_type, "Out")
->assert_is_op_input(quantized_op_type_, "Y");
quant_dequant_right_node->LinksFrom(
{input_scale_right_node, input_act_right_node});
output_scale_right_node->LinksFrom({quant_dequant_right_node});
output_act_right_node->LinksFrom({quant_dequant_right_node});
auto* quantized_node = OpNode("quantized_node", quantized_op_type_)
->assert_is_op(quantized_op_type_);
quantized_node->LinksFrom({output_act_left_node, output_act_right_node});
} else {
LOG(FATAL) << "No support quantized_op_type:" << quantized_op_type_;
}
VLOG(4) << "DeleteQuantDequantOpFuser BuildPattern op_type:"
<< quantized_op_type_;
}
void DeleteQuantDequantOpFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
if (quantized_op_type_ == "pool2d") {
auto* input_scale_node = matched.at("input_scale_node");
auto* input_act_node = matched.at("input_act_node");
auto* quant_dequant_node = matched.at("quant_dequant_node");
auto* output_scale_node = matched.at("output_scale_node");
auto* output_act_node = matched.at("output_act_node");
auto* quantized_node = matched.at("quantized_node");
// obtain values, save values and relink node
int bit_length =
quant_dequant_node->stmt()->op_info()->GetAttr<int>("bit_length");
int range = ((1 << (bit_length - 1)) - 1);
auto* scope = quant_dequant_node->stmt()->op()->scope();
auto* scale_tensor = scope->FindVar(output_scale_node->arg()->name)
->GetMutable<lite::Tensor>();
float scale_value = scale_tensor->data<float>()[0] / range;
auto* op_desc = quantized_node->stmt()->mutable_op_info();
op_desc->SetAttr<int>("bit_length", bit_length);
op_desc->SetAttr<float>("input_scale", scale_value);
op_desc->SetInput("X", {input_act_node->arg()->name});
IR_NODE_LINK_TO(input_act_node, quantized_node)
// delete nodes and edges
std::unordered_set<const Node*> nodes2rm = {input_scale_node,
quant_dequant_node,
output_scale_node,
output_act_node};
GraphSafeRemoveNodes(graph, nodes2rm);
} else if (quantized_op_type_ == "elementwise_add") {
auto* input_scale_left_node = matched.at("input_scale_left_node");
auto* input_act_left_node = matched.at("input_act_left_node");
auto* quant_dequant_left_node = matched.at("quant_dequant_left_node");
auto* output_scale_left_node = matched.at("output_scale_left_node");
auto* output_act_left_node = matched.at("output_act_left_node");
auto* input_scale_right_node = matched.at("input_scale_right_node");
auto* input_act_right_node = matched.at("input_act_right_node");
auto* quant_dequant_right_node = matched.at("quant_dequant_right_node");
auto* output_scale_right_node = matched.at("output_scale_right_node");
auto* output_act_right_node = matched.at("output_act_right_node");
auto* quantized_node = matched.at("quantized_node");
// obtain values, save values and relink node
int bit_length =
quant_dequant_left_node->stmt()->op_info()->GetAttr<int>("bit_length");
int range = ((1 << (bit_length - 1)) - 1);
auto* scope = quant_dequant_left_node->stmt()->op()->scope();
auto* left_scale_tensor =
scope->FindVar(output_scale_left_node->arg()->name)
->GetMutable<lite::Tensor>();
float left_scale_value = left_scale_tensor->data<float>()[0] / range;
auto* right_scale_tensor =
scope->FindVar(output_scale_right_node->arg()->name)
->GetMutable<lite::Tensor>();
float right_scale_value = right_scale_tensor->data<float>()[0] / range;
auto* op_desc = quantized_node->stmt()->mutable_op_info();
op_desc->SetAttr<int>("bit_length", bit_length);
op_desc->SetAttr<float>("x_input_scale", left_scale_value);
op_desc->SetAttr<float>("y_input_scale", right_scale_value);
op_desc->SetInput("X", {input_act_left_node->arg()->name});
op_desc->SetInput("Y", {input_act_right_node->arg()->name});
IR_NODE_LINK_TO(input_act_left_node, quantized_node)
IR_NODE_LINK_TO(input_act_right_node, quantized_node)
// delete nodes and edges
std::unordered_set<const Node*> nodes2rm = {input_scale_left_node,
quant_dequant_left_node,
output_scale_left_node,
output_act_left_node,
input_scale_right_node,
quant_dequant_right_node,
output_scale_right_node,
output_act_right_node};
GraphSafeRemoveNodes(graph, nodes2rm);
} else {
LOG(FATAL) << "No support quantized_op_type:" << quantized_op_type_;
}
}
cpp::OpDesc DeleteQuantDequantOpFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc;
return op_desc;
}
......
......@@ -34,11 +34,25 @@ namespace fusion {
* the quantized_op.
* In addition, the fuser delete fake_quant and fake_dequant op in the graph at
* the last.
*/
class QuantDequantOpFuser : public FuseBase {
*/
class DeleteQuantOpFuser : public FuseBase {
public:
explicit DeleteQuantOpFuser(const std::string& quant_op_type)
: quant_op_type_(quant_op_type) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
private:
std::string quant_op_type_{};
};
class DequantOpFuser : public FuseBase {
public:
explicit QuantDequantOpFuser(const std::string& op_type)
: op_type_(op_type) {}
explicit DequantOpFuser(const std::string& op_type) : op_type_(op_type) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
......@@ -49,6 +63,27 @@ class QuantDequantOpFuser : public FuseBase {
std::string op_type_{};
};
/* The pattern like "fake_quantize_dequantize_moving_average_abs_max +
* pooled/elementwise_add" can be deteted by this fuser. The fuser
* extract the input_scale form fake_quant_dequant_op and save into
* the quantized_op. Besides, the fuser delete fake_quant_dequant_op in
* the graph.
*/
class DeleteQuantDequantOpFuser : public FuseBase {
public:
explicit DeleteQuantDequantOpFuser(const std::string& quantized_op_type)
: quantized_op_type_(quantized_op_type) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
private:
std::string quantized_op_type_{};
};
} // namespace fusion
} // namespace mir
} // namespace lite
......
......@@ -75,6 +75,7 @@ add_operator(fake_quantize_range_abs_max_op basic SRCS fake_quantize_range_abs_m
add_operator(sequence_expand_as_op_lite basic SRCS sequence_expand_as_op.cc DEPS ${op_DEPS})
add_operator(range_op basic SRCS range_op.cc DEPS ${op_DEPS})
add_operator(assign_value_op basic SRCS assign_value_op.cc DEPS ${op_DEPS})
add_operator(fake_quantize_dequantize_moving_avg_abs_max_op basic SRCS fake_quantize_dequantize_moving_avg_max_abs.cc DEPS ${op_DEPS})
# for OCR specific
add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/fake_quantize_dequantize_moving_avg_max_abs.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(
fake_quantize_dequantize_moving_average_abs_max,
paddle::lite::operators::FakeQuantizeDequantizeMovingAvgMaxAbsOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/kernel.h"
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/core/tensor.h"
#include "lite/operators/op_params.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class FakeQuantizeDequantizeMovingAvgMaxAbsOpLite : public OpLite {
public:
FakeQuantizeDequantizeMovingAvgMaxAbsOpLite() {}
explicit FakeQuantizeDequantizeMovingAvgMaxAbsOpLite(const std::string &type)
: OpLite(type) {}
bool CheckShape() const override { return true; }
bool InferShape() const override { return true; }
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
auto x = op_desc.Input("X").front();
auto in_scale = op_desc.Input("InScale").front();
auto out = op_desc.Output("Out").front();
auto out_scale = op_desc.Output("OutScale").front();
param_.x = scope->FindVar(x)->GetMutable<lite::Tensor>();
param_.in_scale = scope->FindVar(in_scale)->GetMutable<lite::Tensor>();
param_.out = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.out_scale = scope->FindVar(out_scale)->GetMutable<lite::Tensor>();
param_.bit_length = op_desc.GetAttr<int>("bit_length");
return true;
}
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override {
return "fake_quantize_dequantize_moving_avg_max_abs";
}
private:
mutable FakeQuantizeMovingAvgMaxAbsParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -294,6 +294,8 @@ struct PoolParam {
bool ceil_mode{false};
bool use_quantizer{false};
std::string data_format{"AnyLayout"};
// for int8
WITH_INT8_CONFIG
};
// For Dropout op
......@@ -332,7 +334,10 @@ struct ElementwiseParam {
const lite::Tensor* Y{};
lite::Tensor* Out{};
int axis{-1}; // for broadcasting.
// for int8
WITH_INT8_CONFIG
float x_input_scale{1.0};
float y_input_scale{1.0};
};
struct ElementwiseGradParam {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册