提交 0260d322 编写于 作者: J juncaipeng 提交者: GitHub

Optimize quant_dequant_fuse_pass (#2169)

* optimize quant_dequant_fuse_pass, test=develop
上级 9aa795ca
......@@ -74,6 +74,21 @@ void TestModel(const std::vector<Place>& valid_places,
1e-6);
}
}
auto* out_data = out->data<float>();
LOG(INFO) << "output data:";
for (int i = 0; i < out->numel(); i += step) {
LOG(INFO) << out_data[i];
}
float max_val = out_data[0];
int max_val_arg = 0;
for (int i = 1; i < out->numel(); i++) {
if (max_val < out_data[i]) {
max_val = out_data[i];
max_val_arg = i;
}
}
LOG(INFO) << "max val:" << max_val << ", max_val_arg:" << max_val_arg;
}
TEST(MobileNetV1, test_arm) {
......
......@@ -13,7 +13,9 @@
// limitations under the License.
#include "lite/core/mir/fusion/quant_dequant_fuse_pass.h"
#include <list>
#include <memory>
#include <unordered_set>
#include <vector>
#include "lite/api/paddle_place.h"
#include "lite/core/mir/fusion/quant_dequant_op_fuser.h"
......@@ -24,18 +26,60 @@ namespace lite {
namespace mir {
void QuantDequantFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// obtain useful values and save to quantized_node, remove quant_nodes and
// releated nodes
std::unordered_set<std::string> quant_types = {
"fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"};
std::unordered_set<std::string> quantized_op_types = {
"conv2d", "mul", "depthwise_conv2d"};
for (auto& quant_type : quant_types) {
for (auto& op_type : quantized_op_types) {
for (int i = 6; i >= 1; i--) {
fusion::QuantDequantOpFuser fuser(op_type, quant_type, i);
fuser(graph.get());
for (auto& cur_node : graph->mutable_nodes()) {
if (cur_node.IsStmt() && quant_types.count(cur_node.stmt()->op_type())) {
// find input nodes and output nodes
std::list<Node*> input_nodes = cur_node.inlinks;
std::list<Node*> output_nodes = cur_node.outlinks;
CHECK_EQ(input_nodes.size(), 2);
CHECK_EQ(output_nodes.size(), 2);
bool front_is_scale = input_nodes.front()->arg()->is_weight;
Node* input_scale_node =
front_is_scale ? input_nodes.front() : input_nodes.back();
Node* input_act_node =
front_is_scale ? input_nodes.back() : input_nodes.front();
front_is_scale = output_nodes.front()->arg()->is_weight;
Node* output_scale_node =
front_is_scale ? output_nodes.front() : output_nodes.back();
Node* output_act_node =
front_is_scale ? output_nodes.back() : output_nodes.front();
// relink nodes and save value to quantized_node
int bit_length = cur_node.stmt()->op_info()->GetAttr<int>("bit_length");
int range = ((1 << (bit_length - 1)) - 1);
auto* scope = cur_node.stmt()->op()->scope();
auto scale_tensor = scope->FindVar(output_scale_node->arg()->name)
->GetMutable<lite::Tensor>();
float scale_value = scale_tensor->data<float>()[0] / range;
for (auto* quantized_node_ptr : output_act_node->outlinks) {
quantized_node_ptr->stmt()->mutable_op_info()->SetAttr<int>(
"bit_length", bit_length);
quantized_node_ptr->stmt()->mutable_op_info()->SetAttr<float>(
"input_scale", scale_value);
IR_NODE_LINK_TO(input_act_node, quantized_node_ptr)
RemoveDirectedLink(output_act_node, quantized_node_ptr);
}
// delete nodes and edges
std::unordered_set<const Node*> nodes2rm = {
input_scale_node, &cur_node, output_scale_node, output_act_node};
GraphSafeRemoveNodes(graph.get(), nodes2rm);
}
}
// fuse quantized node and dequant node
std::unordered_set<std::string> quantized_op_types = {
"conv2d", "mul", "depthwise_conv2d"};
for (auto& op_type : quantized_op_types) {
fusion::QuantDequantOpFuser fuser(op_type);
fuser(graph.get());
}
}
} // namespace mir
......
......@@ -23,170 +23,108 @@ namespace mir {
namespace fusion {
void QuantDequantOpFuser::BuildPattern() {
const int kNumFields = 5;
const int kQuantizedWeightOffset = 0;
const int kQuantizedOpOffset = 1;
const int kQuantizedOpOutOffset = 2;
const int kDequantOpOffset = 3;
const int kDequantOpOutOffset = 4;
std::string weight_name = "";
if (op_type_ == "conv2d" || op_type_ == "depthwise_conv2d") {
weight_name = "Filter";
} else {
weight_name = "Y";
}
auto* quant_op_input = VarNode("quant_op_input")
->assert_is_op_input(quant_type_, "X")
->AsInput();
auto* quant_op_in_scale = VarNode("quant_op_in_scale")
->assert_is_op_input(quant_type_, "InScale")
->AsIntermediate();
auto* quant_op = OpNode("quant_op", quant_type_)
->assert_is_op(quant_type_)
->AsIntermediate();
auto* quant_op_out_scale =
VarNode("quant_op_out_scale")
->assert_is_op_output(quant_type_, "OutScale")
->assert_is_op_input("fake_dequantize_max_abs", "Scale")
->AsIntermediate();
auto* quant_op_out = VarNode("quant_op_out")
->assert_is_op_output(quant_type_, "Out")
->assert_is_op_input(op_type_)
auto* quantized_op_input =
VarNode("quantized_op_input")->assert_is_op_input(op_type_)->AsInput();
auto* quantized_op_weight = VarNode("quantized_op_weight")
->assert_is_op_input(op_type_, weight_name)
->AsInput();
auto* quantized_op = OpNode("quantized_op", op_type_)
->assert_is_op(op_type_)
->AsIntermediate();
std::vector<PMNode*> nodes;
for (int i = 0; i < times_; i++) {
nodes.push_back(VarNode(string_format("quantized_op_weight%d", i))
->assert_is_op_input(op_type_, weight_name)
->AsInput());
nodes.push_back(OpNode(string_format("quantized_op%d", i), op_type_)
->assert_is_op(op_type_)
->AsIntermediate());
nodes.push_back(VarNode(string_format("quantized_op_out%d", i))
->assert_is_op_output(op_type_)
->assert_is_op_input("fake_dequantize_max_abs", "X")
->AsIntermediate());
nodes.push_back(
OpNode(string_format("dequant_op%d", i), "fake_dequantize_max_abs")
->assert_is_op("fake_dequantize_max_abs")
->AsIntermediate());
nodes.push_back(VarNode(string_format("dequant_op_out%d", i))
->assert_is_op_output("fake_dequantize_max_abs", "Out")
->AsOutput());
}
quant_op->LinksFrom({quant_op_input, quant_op_in_scale});
quant_op_out->LinksFrom({quant_op});
quant_op_out_scale->LinksFrom({quant_op});
for (int i = 0; i < times_; i++) {
nodes[i * kNumFields + kQuantizedOpOffset]->LinksFrom(
{quant_op_out, nodes[i * kNumFields + kQuantizedWeightOffset]});
nodes[i * kNumFields + kQuantizedOpOutOffset]->LinksFrom(
{nodes[i * kNumFields + kQuantizedOpOffset]});
nodes[i * kNumFields + kDequantOpOffset]->LinksFrom(
{nodes[i * kNumFields + kQuantizedOpOutOffset], quant_op_out_scale});
nodes[i * kNumFields + kDequantOpOutOffset]->LinksFrom(
{nodes[i * kNumFields + kDequantOpOffset]});
}
auto* quantized_op_out =
VarNode("quantized_op_out")
->assert_is_op_output(op_type_)
->assert_is_op_input("fake_dequantize_max_abs", "X")
->AsIntermediate();
auto* dequant_op = OpNode("dequant_op", "fake_dequantize_max_abs")
->assert_is_op("fake_dequantize_max_abs")
->AsIntermediate();
auto* dequant_op_out =
VarNode("dequant_op_out")
->assert_is_op_output("fake_dequantize_max_abs", "Out")
->AsOutput();
quantized_op->LinksFrom({quantized_op_input, quantized_op_weight});
quantized_op_out->LinksFrom({quantized_op});
dequant_op->LinksFrom({quantized_op_out});
dequant_op_out->LinksFrom({dequant_op});
}
void QuantDequantOpFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
const int kNumFields = 5;
const int kQuantizedWeightOffset = 0;
const int kQuantizedOpOffset = 1;
const int kDequantOpOffset = 3;
const int kDequantOpOutOffset = 4;
auto* quant_op_input = matched.at("quant_op_input");
auto* quant_op_in_scale = matched.at("quant_op_in_scale");
auto* quant_op = matched.at("quant_op");
std::vector<Node*> nodes;
for (int i = 0; i < times_; i++) {
nodes.push_back(matched.at(string_format("quantized_op_weight%d", i)));
nodes.push_back(matched.at(string_format("quantized_op%d", i)));
nodes.push_back(matched.at(string_format("quantized_op_out%d", i)));
nodes.push_back(matched.at(string_format("dequant_op%d", i)));
nodes.push_back(matched.at(string_format("dequant_op_out%d", i)));
}
int bit_length = quant_op->stmt()->op_info()->GetAttr<int>("bit_length");
auto* scope = quant_op->stmt()->op()->scope();
auto& valid_places = quant_op->stmt()->op()->valid_places();
auto* quant_op_input = matched.at("quantized_op_input");
auto* quantized_op_weight = matched.at("quantized_op_weight");
auto* quantized_op = matched.at("quantized_op");
auto* dequant_op = matched.at("dequant_op");
auto* dequant_op_out = matched.at("dequant_op_out");
// obtain input_scale and weight_scale
auto* scope = quantized_op->stmt()->op()->scope();
auto& valid_places = quantized_op->stmt()->op()->valid_places();
int bit_length = quantized_op->stmt()->op_info()->GetAttr<int>("bit_length");
int range = ((1 << (bit_length - 1)) - 1);
auto input_scale_t = scope->FindVar(quant_op_in_scale->arg()->name)
->GetMutable<lite::Tensor>();
float input_scale = input_scale_t->data<float>()[0] / range;
VLOG(4) << "range: " << range << " input_scale: " << input_scale;
for (int i = 0; i < times_; i++) {
float max_range = nodes[i * kNumFields + kDequantOpOffset]
->stmt()
->op_info()
->GetAttr<float>("max_range");
// weight_scale = max(abs(weight))
float whole_weight_scale =
static_cast<float>(range * range) / max_range / range;
cpp::OpDesc op_desc =
*nodes[i * kNumFields + kQuantizedOpOffset]->stmt()->op_info();
auto quantized_weight_var_name =
nodes[i * kNumFields + kQuantizedWeightOffset]->arg()->name;
auto quantized_weight_t =
scope->FindVar(quantized_weight_var_name)->GetMutable<lite::Tensor>();
std::vector<float> weight_scale;
int weight_scale_size;
if (op_type_ == "conv2d" || op_type_ == "depthwise_conv2d") {
op_desc.SetInput("Input", {matched.at("quant_op_input")->arg()->name});
op_desc.SetOutput(
"Output", {nodes[i * kNumFields + kDequantOpOutOffset]->arg()->name});
// Conv weight shape: Cout * Cin * kh * hw, the weight_scale_size should
// be Cout.
weight_scale_size = quantized_weight_t->dims()[0];
} else if (op_type_ == "mul") {
op_desc.SetInput("X", {matched.at("quant_op_input")->arg()->name});
op_desc.SetOutput(
"Out", {nodes[i * kNumFields + kDequantOpOutOffset]->arg()->name});
// Fc weight: Cin * Cout, the weight_scale_size should be Cout.
weight_scale_size = quantized_weight_t->dims()[1];
}
for (int i = 0; i < weight_scale_size; i++) {
weight_scale.push_back(whole_weight_scale);
}
op_desc.SetAttr("enable_int8", true);
op_desc.SetAttr("input_scale", input_scale);
op_desc.SetAttr("weight_scale", weight_scale);
Tensor temp_tensor;
temp_tensor.CopyDataFrom(*quantized_weight_t);
float* temp_data = temp_tensor.mutable_data<float>();
size_t weight_num = quantized_weight_t->data_size();
int8_t* quantized_weight_data = quantized_weight_t->mutable_data<int8_t>();
// change the weight from the float type to int8 type.
for (size_t i = 0; i < weight_num; i++) {
quantized_weight_data[i] = static_cast<int8_t>(temp_data[i]);
}
quantized_weight_t->set_persistable(true);
quantized_weight_t->set_precision(PRECISION(kInt8));
auto quantized_op = LiteOpRegistry::Global().Create(op_type_);
quantized_op->Attach(op_desc, scope);
auto* new_op_node =
graph->GraphCreateInstructNode(quantized_op, valid_places);
IR_NODE_LINK_TO(quant_op_input, new_op_node);
IR_NODE_LINK_TO(nodes[i * kNumFields + kQuantizedWeightOffset],
new_op_node);
IR_NODE_LINK_TO(new_op_node, nodes[i * kNumFields + kDequantOpOutOffset]);
float input_scale =
quantized_op->stmt()->op_info()->GetAttr<float>("input_scale");
float max_range = dequant_op->stmt()->op_info()->GetAttr<float>("max_range");
float whole_weight_scale =
static_cast<float>(range * range) / max_range / range;
// max_range = range * range / max(abs(weight))
// weight_scale = range * range / (range * range / max(abs(weight))) / range
// = max(abs(weight)) / range
// set op desc
cpp::OpDesc op_desc = *quantized_op->stmt()->op_info();
auto quantized_weight_var_name = quantized_op_weight->arg()->name;
auto quantized_weight_t =
scope->FindVar(quantized_weight_var_name)->GetMutable<lite::Tensor>();
std::vector<float> weight_scale;
int weight_scale_size;
if (op_type_ == "conv2d" || op_type_ == "depthwise_conv2d") {
op_desc.SetInput("Input", {quant_op_input->arg()->name});
op_desc.SetOutput("Output", {dequant_op_out->arg()->name});
// Conv weight shape: Cout * Cin * kh * hw, the weight_scale_size should
// be Cout.
weight_scale_size = quantized_weight_t->dims()[0];
} else if (op_type_ == "mul") {
op_desc.SetInput("X", {quant_op_input->arg()->name});
op_desc.SetOutput("Out", {dequant_op_out->arg()->name});
// Fc weight: Cin * Cout, the weight_scale_size should be Cout.
weight_scale_size = quantized_weight_t->dims()[1];
}
for (int i = 0; i < weight_scale_size; i++) {
weight_scale.push_back(whole_weight_scale);
}
op_desc.SetAttr("enable_int8", true);
op_desc.SetAttr("input_scale", input_scale);
op_desc.SetAttr("weight_scale", weight_scale);
// change the weight from the float type to int8 type.
Tensor temp_tensor;
temp_tensor.CopyDataFrom(*quantized_weight_t);
float* temp_data = temp_tensor.mutable_data<float>();
size_t weight_num = quantized_weight_t->data_size();
int8_t* quantized_weight_data = quantized_weight_t->mutable_data<int8_t>();
for (size_t i = 0; i < weight_num; i++) {
quantized_weight_data[i] = static_cast<int8_t>(temp_data[i]);
}
quantized_weight_t->set_persistable(true);
quantized_weight_t->set_precision(PRECISION(kInt8));
// new op and relink nodes
auto new_quantized_op = LiteOpRegistry::Global().Create(op_type_);
new_quantized_op->Attach(op_desc, scope);
auto* new_quantized_op_node =
graph->GraphCreateInstructNode(new_quantized_op, valid_places);
IR_NODE_LINK_TO(quant_op_input, new_quantized_op_node);
IR_NODE_LINK_TO(quantized_op_weight, new_quantized_op_node);
IR_NODE_LINK_TO(new_quantized_op_node, dequant_op_out);
}
cpp::OpDesc QuantDequantOpFuser::GenOpDesc(const key2nodes_t& matched) {
......
......@@ -37,10 +37,8 @@ namespace fusion {
*/
class QuantDequantOpFuser : public FuseBase {
public:
explicit QuantDequantOpFuser(const std::string& op_type,
const std::string& quant_type,
int times)
: op_type_(op_type), quant_type_(quant_type), times_(times) {}
explicit QuantDequantOpFuser(const std::string& op_type)
: op_type_(op_type) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
......@@ -48,9 +46,7 @@ class QuantDequantOpFuser : public FuseBase {
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
private:
std::string op_type_{"conv2d"};
std::string quant_type_;
int times_;
std::string op_type_{};
};
} // namespace fusion
......
......@@ -35,7 +35,8 @@ using param_t = Any;
bool enable_int8{false}; \
float input_scale{1.0}; \
std::vector<float> weight_scale{}; \
float output_scale{1.0};
float output_scale{1.0}; \
int bit_length{8};
/// ----------------------- Functional operators ------------------------------
struct FeedParam {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册