quant_conv2d_dequant_fuse_pass.cc 8.0 KB
Newer Older
N
nhzlx 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <memory>
#include <string>
#include <unordered_set>
#include <vector>

#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.h"

namespace paddle {
namespace framework {
namespace ir {

void RunQuantDequant(ir::Graph* graph, Scope* scope, int times,
28 29
                     const std::string& op_type, const std::string& quant_type,
                     const std::string& dequant_type) {
N
nhzlx 已提交
30
  const std::string pattern_name = "quant_dequant_fuse";
31
  int kNumFields = 5;
N
nhzlx 已提交
32 33 34 35 36
  const int kQuantizedWeightOffset = 0;
  const int kQuantizedOpOffset = 1;
  const int kQuantizedOpOutOffset = 2;
  const int kDequantOpOffset = 3;
  const int kDequantOpOutOffset = 4;
37 38 39 40 41
  const int kDequantOpWeightScaleOffset = 5;

  if (dequant_type == "fake_channel_wise_dequantize_max_abs") {
    kNumFields += 1;
  }
N
nhzlx 已提交
42 43 44 45

  GraphPatternDetector gpd;
  auto* x = gpd.mutable_pattern()
                ->NewNode("x")
46
                ->assert_is_op_input(quant_type, "X")
N
nhzlx 已提交
47 48
                ->AsInput();

49
  std::string quantized_op_type = op_type;
N
nhzlx 已提交
50
  std::string weight_name = "";
51 52
  if (op_type == "conv2d" || op_type == "depthwise_conv2d" ||
      op_type == "conv2d_fusion") {
N
nhzlx 已提交
53 54 55 56 57 58 59 60 61 62 63 64
    weight_name = "Filter";
  } else if (op_type == "mul") {
    weight_name = "Y";
  } else if (op_type == "fc") {
    weight_name = "W";
  } else {
    PADDLE_ENFORCE(
        "QuantDequantFuse: We only support conv2d, conv2d_fusion, fc, mul for "
        "now.");
  }

  patterns::QuantDequantOpFuse pattern(gpd.mutable_pattern(), pattern_name);
65
  pattern(x, quantized_op_type, weight_name, times, quant_type, dequant_type);
N
nhzlx 已提交
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89

  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    PADDLE_ENFORCE(subgraph.count(x));
    auto* input_node = subgraph.at(x);
    Node* quant_op_in_scale =
        subgraph.at(pattern.GetPDNode("quant_op_in_scale"));
    Node* quant_op = subgraph.at(pattern.GetPDNode("quant_op"));
    Node* quant_op_out_scale =
        subgraph.at(pattern.GetPDNode("quant_op_out_scale"));
    Node* quant_op_out = subgraph.at(pattern.GetPDNode("quant_op_out"));

    std::vector<Node*> nodes;
    for (int i = 0; i < times; i++) {
      nodes.push_back(subgraph.at(
          pattern.GetPDNode("quantized_op_weight" + std::to_string(i))));
      nodes.push_back(
          subgraph.at(pattern.GetPDNode("quantized_op" + std::to_string(i))));
      nodes.push_back(subgraph.at(
          pattern.GetPDNode("quantized_op_out" + std::to_string(i))));
      nodes.push_back(
          subgraph.at(pattern.GetPDNode("dequant_op" + std::to_string(i))));
      nodes.push_back(
          subgraph.at(pattern.GetPDNode("dequant_op_out" + std::to_string(i))));
90 91 92 93
      if (dequant_type == "fake_channel_wise_dequantize_max_abs") {
        nodes.push_back(subgraph.at(
            pattern.GetPDNode("dequant_channel_scale" + std::to_string(i))));
      }
N
nhzlx 已提交
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
    }

    int bit_length = boost::get<int>(quant_op->Op()->GetAttr("bit_length"));
    int range = ((1 << (bit_length - 1)) - 1);
    // Prepare input scale
    std::string input_scale_var_name = quant_op->Op()->Input("InScale").front();
    PADDLE_ENFORCE(scope);
    const LoDTensor& input_scale_tensor =
        scope->FindVar(input_scale_var_name)->Get<LoDTensor>();

    PADDLE_ENFORCE(paddle::platform::is_cpu_place(input_scale_tensor.place()));
    const float* input_scale_data = input_scale_tensor.data<float>();
    float input_scale = input_scale_data[0];
    std::unordered_set<const Node*> delete_nodes;

    for (int i = 0; i < times; i++) {
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
      std::vector<float> weight_scale;

      // Get weight scale from dequant op.
      if (dequant_type == "fake_channel_wise_dequantize_max_abs") {
        auto scales_name =
            nodes[i * kNumFields + kDequantOpOffset]->Op()->Input("Scales");
        PADDLE_ENFORCE(scales_name.size() == 2);
        const LoDTensor& channel_scale_tensor =
            scope->FindVar(scales_name[0])->Get<LoDTensor>();
        PADDLE_ENFORCE(
            paddle::platform::is_cpu_place(channel_scale_tensor.place()));
        const float* channel_scale_data = channel_scale_tensor.data<float>();
        for (int i = 0; i < channel_scale_tensor.numel(); i++) {
          weight_scale.push_back(channel_scale_data[i]);
        }
        delete_nodes.insert(
            nodes[i * kNumFields + kDequantOpWeightScaleOffset]);
      } else {
        float max_range = boost::get<float>(
            nodes[i * kNumFields + kDequantOpOffset]->Op()->GetAttr(
                "max_range"));
        weight_scale.push_back((range * range) / max_range);
      }
N
nhzlx 已提交
133

134
      // create new op_desc
N
nhzlx 已提交
135 136 137 138 139 140 141 142 143 144
      auto base_op_desc =
          *nodes[i * kNumFields + kQuantizedOpOffset]->Op()->Proto();
      std::string new_input = input_node->Name();
      std::string new_output =
          nodes[i * kNumFields + kDequantOpOutOffset]->Name();

      framework::OpDesc new_op_desc(base_op_desc, nullptr);
      new_op_desc.SetType(quantized_op_type);

      if (quantized_op_type == "conv2d" ||
145 146
          quantized_op_type == "conv2d_fusion" ||
          quantized_op_type == "depthwise_conv2d") {
N
nhzlx 已提交
147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
        new_op_desc.SetInput("Input", {new_input});
        new_op_desc.SetOutput("Output", {new_output});
      } else if (quantized_op_type == "fc") {
        new_op_desc.SetInput("Input", {new_input});
        new_op_desc.SetOutput("Out", {new_output});
      } else if (quantized_op_type == "mul") {
        new_op_desc.SetInput("X", {new_input});
        new_op_desc.SetOutput("Out", {new_output});
      }

      new_op_desc.SetAttr("enable_int8", true);
      new_op_desc.SetAttr("input_scale", input_scale);
      new_op_desc.SetAttr("weight_scale", weight_scale);
      new_op_desc.Flush();
      auto* new_op = graph->CreateOpNode(&new_op_desc);
      IR_NODE_LINK_TO(input_node, new_op);
      IR_NODE_LINK_TO(nodes[i * kNumFields + kQuantizedWeightOffset], new_op);
      IR_NODE_LINK_TO(new_op, nodes[i * kNumFields + kDequantOpOutOffset]);
165

N
nhzlx 已提交
166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
      delete_nodes.insert(nodes[i * kNumFields + kQuantizedOpOffset]);
      delete_nodes.insert(nodes[i * kNumFields + kQuantizedOpOutOffset]);
      delete_nodes.insert(nodes[i * kNumFields + kDequantOpOffset]);
    }

    delete_nodes.insert(quant_op_in_scale);
    delete_nodes.insert(quant_op);
    delete_nodes.insert(quant_op_out);
    delete_nodes.insert(quant_op_out_scale);
    // Delete the unneeded nodes.
    GraphSafeRemoveNodes(graph, delete_nodes);
  };
  gpd(graph, handler);
}

void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const {
  const std::string pattern_name = "quant_dequant_fuse";
  FusePassBase::Init(pattern_name, graph);

185 186
  std::unordered_set<std::string> dequant_types = {
      "fake_dequantize_max_abs", "fake_channel_wise_dequantize_max_abs"};
187 188 189 190
  std::unordered_set<std::string> quant_types = {
      "fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"};
  std::unordered_set<std::string> quantized_op_types = {"conv2d", "mul",
                                                        "depthwise_conv2d"};
N
nhzlx 已提交
191
  auto* scope = param_scope();
192 193 194 195 196 197
  for (auto& dequant_type : dequant_types) {
    for (auto& quant_type : quant_types) {
      for (auto& op_type : quantized_op_types) {
        for (int i = 6; i >= 1; i--) {
          RunQuantDequant(graph, scope, i, op_type, quant_type, dequant_type);
        }
198
      }
N
nhzlx 已提交
199 200 201 202 203 204 205 206 207 208
    }
  }
}

}  // namespace ir
}  // namespace framework
}  // namespace paddle

REGISTER_PASS(quant_conv2d_dequant_fuse_pass,
              paddle::framework::ir::QuantDequantFusePass);