未验证 提交 253acb80 编写于 作者: J juncaipeng 提交者: GitHub

Optimize quant_dequant_fuse_pass (#2169)

* optimize quant_dequant_fuse_pass, test=develop
上级 508ca98b
...@@ -74,6 +74,21 @@ void TestModel(const std::vector<Place>& valid_places, ...@@ -74,6 +74,21 @@ void TestModel(const std::vector<Place>& valid_places,
1e-6); 1e-6);
} }
} }
auto* out_data = out->data<float>();
LOG(INFO) << "output data:";
for (int i = 0; i < out->numel(); i += step) {
LOG(INFO) << out_data[i];
}
float max_val = out_data[0];
int max_val_arg = 0;
for (int i = 1; i < out->numel(); i++) {
if (max_val < out_data[i]) {
max_val = out_data[i];
max_val_arg = i;
}
}
LOG(INFO) << "max val:" << max_val << ", max_val_arg:" << max_val_arg;
} }
TEST(MobileNetV1, test_arm) { TEST(MobileNetV1, test_arm) {
......
...@@ -13,7 +13,9 @@ ...@@ -13,7 +13,9 @@
// limitations under the License. // limitations under the License.
#include "lite/core/mir/fusion/quant_dequant_fuse_pass.h" #include "lite/core/mir/fusion/quant_dequant_fuse_pass.h"
#include <list>
#include <memory> #include <memory>
#include <unordered_set>
#include <vector> #include <vector>
#include "lite/api/paddle_place.h" #include "lite/api/paddle_place.h"
#include "lite/core/mir/fusion/quant_dequant_op_fuser.h" #include "lite/core/mir/fusion/quant_dequant_op_fuser.h"
...@@ -24,18 +26,60 @@ namespace lite { ...@@ -24,18 +26,60 @@ namespace lite {
namespace mir { namespace mir {
void QuantDequantFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { void QuantDequantFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// obtain useful values and save to quantized_node, remove quant_nodes and
// releated nodes
std::unordered_set<std::string> quant_types = { std::unordered_set<std::string> quant_types = {
"fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"}; "fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"};
std::unordered_set<std::string> quantized_op_types = { for (auto& cur_node : graph->mutable_nodes()) {
"conv2d", "mul", "depthwise_conv2d"}; if (cur_node.IsStmt() && quant_types.count(cur_node.stmt()->op_type())) {
for (auto& quant_type : quant_types) { // find input nodes and output nodes
for (auto& op_type : quantized_op_types) { std::list<Node*> input_nodes = cur_node.inlinks;
for (int i = 6; i >= 1; i--) { std::list<Node*> output_nodes = cur_node.outlinks;
fusion::QuantDequantOpFuser fuser(op_type, quant_type, i); CHECK_EQ(input_nodes.size(), 2);
fuser(graph.get()); CHECK_EQ(output_nodes.size(), 2);
bool front_is_scale = input_nodes.front()->arg()->is_weight;
Node* input_scale_node =
front_is_scale ? input_nodes.front() : input_nodes.back();
Node* input_act_node =
front_is_scale ? input_nodes.back() : input_nodes.front();
front_is_scale = output_nodes.front()->arg()->is_weight;
Node* output_scale_node =
front_is_scale ? output_nodes.front() : output_nodes.back();
Node* output_act_node =
front_is_scale ? output_nodes.back() : output_nodes.front();
// relink nodes and save value to quantized_node
int bit_length = cur_node.stmt()->op_info()->GetAttr<int>("bit_length");
int range = ((1 << (bit_length - 1)) - 1);
auto* scope = cur_node.stmt()->op()->scope();
auto scale_tensor = scope->FindVar(output_scale_node->arg()->name)
->GetMutable<lite::Tensor>();
float scale_value = scale_tensor->data<float>()[0] / range;
for (auto* quantized_node_ptr : output_act_node->outlinks) {
quantized_node_ptr->stmt()->mutable_op_info()->SetAttr<int>(
"bit_length", bit_length);
quantized_node_ptr->stmt()->mutable_op_info()->SetAttr<float>(
"input_scale", scale_value);
IR_NODE_LINK_TO(input_act_node, quantized_node_ptr)
RemoveDirectedLink(output_act_node, quantized_node_ptr);
} }
// delete nodes and edges
std::unordered_set<const Node*> nodes2rm = {
input_scale_node, &cur_node, output_scale_node, output_act_node};
GraphSafeRemoveNodes(graph.get(), nodes2rm);
} }
} }
// fuse quantized node and dequant node
std::unordered_set<std::string> quantized_op_types = {
"conv2d", "mul", "depthwise_conv2d"};
for (auto& op_type : quantized_op_types) {
fusion::QuantDequantOpFuser fuser(op_type);
fuser(graph.get());
}
} }
} // namespace mir } // namespace mir
......
...@@ -23,170 +23,108 @@ namespace mir { ...@@ -23,170 +23,108 @@ namespace mir {
namespace fusion { namespace fusion {
void QuantDequantOpFuser::BuildPattern() { void QuantDequantOpFuser::BuildPattern() {
const int kNumFields = 5;
const int kQuantizedWeightOffset = 0;
const int kQuantizedOpOffset = 1;
const int kQuantizedOpOutOffset = 2;
const int kDequantOpOffset = 3;
const int kDequantOpOutOffset = 4;
std::string weight_name = ""; std::string weight_name = "";
if (op_type_ == "conv2d" || op_type_ == "depthwise_conv2d") { if (op_type_ == "conv2d" || op_type_ == "depthwise_conv2d") {
weight_name = "Filter"; weight_name = "Filter";
} else { } else {
weight_name = "Y"; weight_name = "Y";
} }
auto* quant_op_input = VarNode("quant_op_input")
->assert_is_op_input(quant_type_, "X")
->AsInput();
auto* quant_op_in_scale = VarNode("quant_op_in_scale")
->assert_is_op_input(quant_type_, "InScale")
->AsIntermediate();
auto* quant_op = OpNode("quant_op", quant_type_)
->assert_is_op(quant_type_)
->AsIntermediate();
auto* quant_op_out_scale =
VarNode("quant_op_out_scale")
->assert_is_op_output(quant_type_, "OutScale")
->assert_is_op_input("fake_dequantize_max_abs", "Scale")
->AsIntermediate();
auto* quant_op_out = VarNode("quant_op_out") auto* quantized_op_input =
->assert_is_op_output(quant_type_, "Out") VarNode("quantized_op_input")->assert_is_op_input(op_type_)->AsInput();
->assert_is_op_input(op_type_) auto* quantized_op_weight = VarNode("quantized_op_weight")
->assert_is_op_input(op_type_, weight_name)
->AsInput();
auto* quantized_op = OpNode("quantized_op", op_type_)
->assert_is_op(op_type_)
->AsIntermediate(); ->AsIntermediate();
std::vector<PMNode*> nodes; auto* quantized_op_out =
for (int i = 0; i < times_; i++) { VarNode("quantized_op_out")
nodes.push_back(VarNode(string_format("quantized_op_weight%d", i)) ->assert_is_op_output(op_type_)
->assert_is_op_input(op_type_, weight_name) ->assert_is_op_input("fake_dequantize_max_abs", "X")
->AsInput()); ->AsIntermediate();
auto* dequant_op = OpNode("dequant_op", "fake_dequantize_max_abs")
nodes.push_back(OpNode(string_format("quantized_op%d", i), op_type_) ->assert_is_op("fake_dequantize_max_abs")
->assert_is_op(op_type_) ->AsIntermediate();
->AsIntermediate()); auto* dequant_op_out =
VarNode("dequant_op_out")
nodes.push_back(VarNode(string_format("quantized_op_out%d", i)) ->assert_is_op_output("fake_dequantize_max_abs", "Out")
->assert_is_op_output(op_type_) ->AsOutput();
->assert_is_op_input("fake_dequantize_max_abs", "X")
->AsIntermediate()); quantized_op->LinksFrom({quantized_op_input, quantized_op_weight});
quantized_op_out->LinksFrom({quantized_op});
nodes.push_back( dequant_op->LinksFrom({quantized_op_out});
OpNode(string_format("dequant_op%d", i), "fake_dequantize_max_abs") dequant_op_out->LinksFrom({dequant_op});
->assert_is_op("fake_dequantize_max_abs")
->AsIntermediate());
nodes.push_back(VarNode(string_format("dequant_op_out%d", i))
->assert_is_op_output("fake_dequantize_max_abs", "Out")
->AsOutput());
}
quant_op->LinksFrom({quant_op_input, quant_op_in_scale});
quant_op_out->LinksFrom({quant_op});
quant_op_out_scale->LinksFrom({quant_op});
for (int i = 0; i < times_; i++) {
nodes[i * kNumFields + kQuantizedOpOffset]->LinksFrom(
{quant_op_out, nodes[i * kNumFields + kQuantizedWeightOffset]});
nodes[i * kNumFields + kQuantizedOpOutOffset]->LinksFrom(
{nodes[i * kNumFields + kQuantizedOpOffset]});
nodes[i * kNumFields + kDequantOpOffset]->LinksFrom(
{nodes[i * kNumFields + kQuantizedOpOutOffset], quant_op_out_scale});
nodes[i * kNumFields + kDequantOpOutOffset]->LinksFrom(
{nodes[i * kNumFields + kDequantOpOffset]});
}
} }
void QuantDequantOpFuser::InsertNewNode(SSAGraph* graph, void QuantDequantOpFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) { const key2nodes_t& matched) {
const int kNumFields = 5; auto* quant_op_input = matched.at("quantized_op_input");
const int kQuantizedWeightOffset = 0; auto* quantized_op_weight = matched.at("quantized_op_weight");
const int kQuantizedOpOffset = 1; auto* quantized_op = matched.at("quantized_op");
const int kDequantOpOffset = 3; auto* dequant_op = matched.at("dequant_op");
const int kDequantOpOutOffset = 4; auto* dequant_op_out = matched.at("dequant_op_out");
auto* quant_op_input = matched.at("quant_op_input"); // obtain input_scale and weight_scale
auto* quant_op_in_scale = matched.at("quant_op_in_scale"); auto* scope = quantized_op->stmt()->op()->scope();
auto* quant_op = matched.at("quant_op"); auto& valid_places = quantized_op->stmt()->op()->valid_places();
int bit_length = quantized_op->stmt()->op_info()->GetAttr<int>("bit_length");
std::vector<Node*> nodes;
for (int i = 0; i < times_; i++) {
nodes.push_back(matched.at(string_format("quantized_op_weight%d", i)));
nodes.push_back(matched.at(string_format("quantized_op%d", i)));
nodes.push_back(matched.at(string_format("quantized_op_out%d", i)));
nodes.push_back(matched.at(string_format("dequant_op%d", i)));
nodes.push_back(matched.at(string_format("dequant_op_out%d", i)));
}
int bit_length = quant_op->stmt()->op_info()->GetAttr<int>("bit_length");
auto* scope = quant_op->stmt()->op()->scope();
auto& valid_places = quant_op->stmt()->op()->valid_places();
int range = ((1 << (bit_length - 1)) - 1); int range = ((1 << (bit_length - 1)) - 1);
auto input_scale_t = scope->FindVar(quant_op_in_scale->arg()->name) float input_scale =
->GetMutable<lite::Tensor>(); quantized_op->stmt()->op_info()->GetAttr<float>("input_scale");
float input_scale = input_scale_t->data<float>()[0] / range; float max_range = dequant_op->stmt()->op_info()->GetAttr<float>("max_range");
float whole_weight_scale =
VLOG(4) << "range: " << range << " input_scale: " << input_scale; static_cast<float>(range * range) / max_range / range;
for (int i = 0; i < times_; i++) { // max_range = range * range / max(abs(weight))
float max_range = nodes[i * kNumFields + kDequantOpOffset] // weight_scale = range * range / (range * range / max(abs(weight))) / range
->stmt() // = max(abs(weight)) / range
->op_info()
->GetAttr<float>("max_range"); // set op desc
// weight_scale = max(abs(weight)) cpp::OpDesc op_desc = *quantized_op->stmt()->op_info();
float whole_weight_scale = auto quantized_weight_var_name = quantized_op_weight->arg()->name;
static_cast<float>(range * range) / max_range / range; auto quantized_weight_t =
scope->FindVar(quantized_weight_var_name)->GetMutable<lite::Tensor>();
cpp::OpDesc op_desc = std::vector<float> weight_scale;
*nodes[i * kNumFields + kQuantizedOpOffset]->stmt()->op_info(); int weight_scale_size;
if (op_type_ == "conv2d" || op_type_ == "depthwise_conv2d") {
auto quantized_weight_var_name = op_desc.SetInput("Input", {quant_op_input->arg()->name});
nodes[i * kNumFields + kQuantizedWeightOffset]->arg()->name; op_desc.SetOutput("Output", {dequant_op_out->arg()->name});
auto quantized_weight_t = // Conv weight shape: Cout * Cin * kh * hw, the weight_scale_size should
scope->FindVar(quantized_weight_var_name)->GetMutable<lite::Tensor>(); // be Cout.
std::vector<float> weight_scale; weight_scale_size = quantized_weight_t->dims()[0];
int weight_scale_size; } else if (op_type_ == "mul") {
op_desc.SetInput("X", {quant_op_input->arg()->name});
if (op_type_ == "conv2d" || op_type_ == "depthwise_conv2d") { op_desc.SetOutput("Out", {dequant_op_out->arg()->name});
op_desc.SetInput("Input", {matched.at("quant_op_input")->arg()->name}); // Fc weight: Cin * Cout, the weight_scale_size should be Cout.
op_desc.SetOutput( weight_scale_size = quantized_weight_t->dims()[1];
"Output", {nodes[i * kNumFields + kDequantOpOutOffset]->arg()->name}); }
// Conv weight shape: Cout * Cin * kh * hw, the weight_scale_size should for (int i = 0; i < weight_scale_size; i++) {
// be Cout. weight_scale.push_back(whole_weight_scale);
weight_scale_size = quantized_weight_t->dims()[0]; }
} else if (op_type_ == "mul") { op_desc.SetAttr("enable_int8", true);
op_desc.SetInput("X", {matched.at("quant_op_input")->arg()->name}); op_desc.SetAttr("input_scale", input_scale);
op_desc.SetOutput( op_desc.SetAttr("weight_scale", weight_scale);
"Out", {nodes[i * kNumFields + kDequantOpOutOffset]->arg()->name});
// Fc weight: Cin * Cout, the weight_scale_size should be Cout. // change the weight from the float type to int8 type.
weight_scale_size = quantized_weight_t->dims()[1]; Tensor temp_tensor;
} temp_tensor.CopyDataFrom(*quantized_weight_t);
for (int i = 0; i < weight_scale_size; i++) { float* temp_data = temp_tensor.mutable_data<float>();
weight_scale.push_back(whole_weight_scale); size_t weight_num = quantized_weight_t->data_size();
} int8_t* quantized_weight_data = quantized_weight_t->mutable_data<int8_t>();
op_desc.SetAttr("enable_int8", true); for (size_t i = 0; i < weight_num; i++) {
op_desc.SetAttr("input_scale", input_scale); quantized_weight_data[i] = static_cast<int8_t>(temp_data[i]);
op_desc.SetAttr("weight_scale", weight_scale);
Tensor temp_tensor;
temp_tensor.CopyDataFrom(*quantized_weight_t);
float* temp_data = temp_tensor.mutable_data<float>();
size_t weight_num = quantized_weight_t->data_size();
int8_t* quantized_weight_data = quantized_weight_t->mutable_data<int8_t>();
// change the weight from the float type to int8 type.
for (size_t i = 0; i < weight_num; i++) {
quantized_weight_data[i] = static_cast<int8_t>(temp_data[i]);
}
quantized_weight_t->set_persistable(true);
quantized_weight_t->set_precision(PRECISION(kInt8));
auto quantized_op = LiteOpRegistry::Global().Create(op_type_);
quantized_op->Attach(op_desc, scope);
auto* new_op_node =
graph->GraphCreateInstructNode(quantized_op, valid_places);
IR_NODE_LINK_TO(quant_op_input, new_op_node);
IR_NODE_LINK_TO(nodes[i * kNumFields + kQuantizedWeightOffset],
new_op_node);
IR_NODE_LINK_TO(new_op_node, nodes[i * kNumFields + kDequantOpOutOffset]);
} }
quantized_weight_t->set_persistable(true);
quantized_weight_t->set_precision(PRECISION(kInt8));
// new op and relink nodes
auto new_quantized_op = LiteOpRegistry::Global().Create(op_type_);
new_quantized_op->Attach(op_desc, scope);
auto* new_quantized_op_node =
graph->GraphCreateInstructNode(new_quantized_op, valid_places);
IR_NODE_LINK_TO(quant_op_input, new_quantized_op_node);
IR_NODE_LINK_TO(quantized_op_weight, new_quantized_op_node);
IR_NODE_LINK_TO(new_quantized_op_node, dequant_op_out);
} }
cpp::OpDesc QuantDequantOpFuser::GenOpDesc(const key2nodes_t& matched) { cpp::OpDesc QuantDequantOpFuser::GenOpDesc(const key2nodes_t& matched) {
......
...@@ -37,10 +37,8 @@ namespace fusion { ...@@ -37,10 +37,8 @@ namespace fusion {
*/ */
class QuantDequantOpFuser : public FuseBase { class QuantDequantOpFuser : public FuseBase {
public: public:
explicit QuantDequantOpFuser(const std::string& op_type, explicit QuantDequantOpFuser(const std::string& op_type)
const std::string& quant_type, : op_type_(op_type) {}
int times)
: op_type_(op_type), quant_type_(quant_type), times_(times) {}
void BuildPattern() override; void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
...@@ -48,9 +46,7 @@ class QuantDequantOpFuser : public FuseBase { ...@@ -48,9 +46,7 @@ class QuantDequantOpFuser : public FuseBase {
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
private: private:
std::string op_type_{"conv2d"}; std::string op_type_{};
std::string quant_type_;
int times_;
}; };
} // namespace fusion } // namespace fusion
......
...@@ -35,7 +35,8 @@ using param_t = Any; ...@@ -35,7 +35,8 @@ using param_t = Any;
bool enable_int8{false}; \ bool enable_int8{false}; \
float input_scale{1.0}; \ float input_scale{1.0}; \
std::vector<float> weight_scale{}; \ std::vector<float> weight_scale{}; \
float output_scale{1.0}; float output_scale{1.0}; \
int bit_length{8};
/// ----------------------- Functional operators ------------------------------ /// ----------------------- Functional operators ------------------------------
struct FeedParam { struct FeedParam {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册