cpu_quantize_pass.cc 40.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#include "paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h"

17
#include <sstream>
18 19
#include <utility>
#include <vector>
W
wanghuancoder 已提交
20

B
baoachun 已提交
21
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h"
22
#include "paddle/fluid/platform/mkldnn_helper.h"
23 24 25 26 27 28
#include "paddle/fluid/string/pretty_log.h"

namespace paddle {
namespace framework {
namespace ir {

29
using EigenVectorArrayMap = Eigen::Map<Eigen::Array<double, Eigen::Dynamic, 1>>;
30 31
using EigenVectorArrayMapFloat =
    Eigen::Map<Eigen::Array<float, Eigen::Dynamic, 1>>;
32 33
using string::PrettyLogDetail;

34 35 36 37 38 39 40 41 42
namespace {

void UnlinkNodes(ir::Node* a, ir::Node* b) {
  a->outputs.erase(std::remove(a->outputs.begin(), a->outputs.end(), b),
                   a->outputs.end());
  b->inputs.erase(std::remove(b->inputs.begin(), b->inputs.end(), a),
                  b->inputs.end());
}

43
void MarkAndLogCannotQuantizeOp(Node* op, const char* details = nullptr) {
44 45 46
  std::stringstream msg_ss;
  msg_ss << "Cannot quantize operator " << op->Name()
         << " (type: " << op->Op()->Type() << ", id: " << op->id() << ").";
47
  if (details) msg_ss << " " << details;
48 49
  VLOG(2) << msg_ss.str().c_str();
  op->Op()->SetAttr("mkldnn_data_type", std::string("float32"));
50 51
}

52 53 54 55 56 57
void LogScaleIsMissingForVarName(const std::string& name) {
  VLOG(4) << "Quantization scale for the variable " << name << " is missing.";
}

void LogScaleIsMissingForVarNode(Node* node) {
  LogScaleIsMissingForVarName(node->Name());
58 59
}

60
void LogQuantizationDisabled(Node* op) {
61
  VLOG(2) << "Quantization skipped for operator " << op->Name()
62
          << " (type: " << op->Op()->Type() << ", id: " << op->id()
63
          << "). Attribute mkldnn_data_type != \"int8\".";
64 65
}

66 67
void LogQuantizedOpsCounter(const std::string& type,
                            const int counter,
68 69 70 71 72 73 74
                            const char* details = nullptr) {
  std::stringstream msg_ss;
  msg_ss << "---    quantized " << counter << " " << type << " ops";
  if (details) msg_ss << " " << details;
  PrettyLogDetail(msg_ss.str().c_str());
}

75 76 77 78
}  // namespace

enum { U8_MAX = 255, S8_MAX = 127 };

79 80 81 82 83
void CPUQuantizePass::QuantizeInput(Graph* g,
                                    Node* op,
                                    Node* input,
                                    std::string input_name,
                                    double scale_to_one,
84
                                    bool is_input_unsigned,
85 86
                                    std::string scale_attr_name,
                                    float shift,
87
                                    std::string shift_attr_name) const {
M
Michał Gallus 已提交
88 89 90
  auto inputs = op->Op()->InputNames();
  bool name_found =
      std::find(inputs.begin(), inputs.end(), input_name) != inputs.end();
91 92
  PADDLE_ENFORCE_EQ(name_found,
                    true,
93 94
                    platform::errors::InvalidArgument(
                        "Var(%s) isn't the input of the %s operator.",
95 96
                        input_name,
                        op->Op()->Type()));
97
  unsigned max = is_input_unsigned ? U8_MAX : S8_MAX;
98 99 100 101 102 103 104 105 106 107 108 109 110
  float scale = scale_to_one * max;

  // Create quantize output variable
  VarDesc quantize_out_desc(patterns::PDNodeName("quantize", "out"));
  auto* quantize_out_node = g->CreateVarNode(&quantize_out_desc);

  // create a quantize op node
  OpDesc q_desc;
  q_desc.SetType("quantize");
  q_desc.SetInput("Input", std::vector<std::string>({input->Name()}));
  q_desc.SetOutput("Output",
                   std::vector<std::string>({quantize_out_node->Name()}));
  q_desc.SetAttr("Scale", scale);
111 112
  q_desc.SetAttr("Shift", shift);
  q_desc.SetAttr("is_negative_input", !is_input_unsigned);
113

Z
Zuza 已提交
114 115 116
  // fix to fc format error
  if (op->Op()->Type() == "fc" &&
      op->Op()->GetAttrIfExists<int>("in_num_col_dims") == 2) {
117 118 119
    q_desc.SetAttr(
        "output_format",
        Has("data_layout") ? Get<std::string>("data_layout") : "NCHW");
Z
Zuza 已提交
120
  } else {
121 122 123
    q_desc.SetAttr(
        "output_format",
        Has("data_layout") ? Get<std::string>("data_layout") : "NHWC");
Z
Zuza 已提交
124
  }
125 126 127 128 129 130 131 132 133 134 135 136 137
  auto quantize_op = g->CreateOpNode(&q_desc);  // OpDesc will be copied.

  // update op's input
  op->Op()->SetInput(input_name,
                     std::vector<std::string>({quantize_out_node->Name()}));

  // link quantize op
  UnlinkNodes(input, op);
  IR_NODE_LINK_TO(input, quantize_op);
  IR_NODE_LINK_TO(quantize_op, quantize_out_node);
  IR_NODE_LINK_TO(quantize_out_node, op);

  if (!scale_attr_name.empty()) op->Op()->SetAttr(scale_attr_name, scale);
138
  if (!shift_attr_name.empty()) op->Op()->SetAttr(shift_attr_name, shift);
139 140
}

141 142 143
void CPUQuantizePass::QuantizeInputs(Graph* g,
                                     Node* op,
                                     std::string input_name,
144
                                     bool are_inputs_unsigned,
145 146
                                     std::string scale_attr_name,
                                     float shift,
147
                                     std::string shift_attr_name) const {
148
  auto inputs = op->inputs;
149
  auto output = op->outputs[0];
150 151
  PADDLE_ENFORCE_GE(inputs.size(),
                    1,
152 153
                    platform::errors::InvalidArgument(
                        "OP(%s)'s inputs(%d) must be equal or greater than 1.",
154 155 156 157
                        op->Name(),
                        inputs.size()));
  PADDLE_ENFORCE_EQ(op->outputs.size(),
                    1,
158
                    platform::errors::InvalidArgument(
159 160
                        "OP(%s)'s outputs(%d) must be equal to 1.",
                        op->Name(),
161
                        op->outputs.size()));
162 163 164 165 166 167 168 169

  // create a quantize op desc prototype
  OpDesc q_desc;
  q_desc.SetType("quantize");

  std::vector<Node*> quantize_out_nodes(inputs.size());
  std::vector<std::string> quantize_out_node_names(inputs.size());

170
  double scale_out = GetScaleValueForNode(output);
171
  unsigned max = are_inputs_unsigned ? U8_MAX : S8_MAX;
172
  float scale = scale_out * max;
173 174 175 176 177 178 179 180

  for (size_t i = 0; i < inputs.size(); i++) {
    // Create quantize output variable
    VarDesc quantize_out_desc(patterns::PDNodeName("quantize", "out"));
    quantize_out_nodes[i] = g->CreateVarNode(&quantize_out_desc);
    quantize_out_node_names[i] = quantize_out_nodes[i]->Name();

    q_desc.SetAttr("Scale", scale);
181
    q_desc.SetAttr("Shift", shift);
182 183 184
    q_desc.SetInput("Input", std::vector<std::string>({inputs[i]->Name()}));
    q_desc.SetOutput("Output",
                     std::vector<std::string>({quantize_out_node_names[i]}));
185
    q_desc.SetAttr("is_negative_input", !are_inputs_unsigned);
186 187 188 189 190 191 192 193 194 195 196 197 198
    auto quantize_op = g->CreateOpNode(&q_desc);  // OpDesc will be copied.

    // link quantize op
    UnlinkNodes(inputs[i], op);
    IR_NODE_LINK_TO(inputs[i], quantize_op);
    IR_NODE_LINK_TO(quantize_op, quantize_out_nodes[i]);
    IR_NODE_LINK_TO(quantize_out_nodes[i], op);
  }

  // update op's input
  op->Op()->SetInput(input_name, quantize_out_node_names);

  if (!scale_attr_name.empty()) op->Op()->SetAttr(scale_attr_name, scale);
199
  if (!shift_attr_name.empty()) op->Op()->SetAttr(shift_attr_name, shift);
200 201
}

202 203 204
void CPUQuantizePass::DequantizeOutput(Graph* g,
                                       Node* op,
                                       Node* output,
205
                                       std::string output_name,
206 207
                                       double scale_to_one,
                                       bool is_unsigned,
208
                                       std::string scale_attr_name) const {
M
Michał Gallus 已提交
209 210 211
  auto outputs = op->Op()->OutputNames();
  bool name_found =
      std::find(outputs.begin(), outputs.end(), output_name) != outputs.end();
212 213
  PADDLE_ENFORCE_EQ(name_found,
                    true,
M
Michał Gallus 已提交
214
                    platform::errors::InvalidArgument(
215
                        "Var(%s) isn't the output of the %s operator.",
216 217
                        output_name,
                        op->Op()->Type()));
218 219 220 221 222 223 224 225 226 227 228 229 230 231
  unsigned max = is_unsigned ? U8_MAX : S8_MAX;
  float scale = scale_to_one * max;

  // Create dequantize input variable
  VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in"));
  auto* dequantize_in_node = g->CreateVarNode(&dequantize_in_desc);

  // create a dequantize op node for output.
  OpDesc deq_desc;
  deq_desc.SetType("dequantize");
  deq_desc.SetInput("Input",
                    std::vector<std::string>({dequantize_in_node->Name()}));
  deq_desc.SetOutput("Output", std::vector<std::string>({output->Name()}));
  deq_desc.SetAttr("Scale", scale);
232
  deq_desc.SetAttr("is_negative_input", !is_unsigned);
233 234 235 236 237 238 239 240 241 242 243 244 245 246 247
  auto dequantize_op = g->CreateOpNode(&deq_desc);  // OpDesc will be copied.

  // update op's output
  op->Op()->SetOutput(output_name,
                      std::vector<std::string>({dequantize_in_node->Name()}));

  // link dequantize op
  UnlinkNodes(op, output);
  IR_NODE_LINK_TO(op, dequantize_in_node);
  IR_NODE_LINK_TO(dequantize_in_node, dequantize_op);
  IR_NODE_LINK_TO(dequantize_op, output);

  if (!scale_attr_name.empty()) op->Op()->SetAttr(scale_attr_name, scale);
}

248 249 250
bool CPUQuantizePass::AreScalesPresentForVarNames(
    std::vector<std::string> names) const {
  bool present = true;
B
baoachun 已提交
251 252 253 254 255 256 257 258 259 260 261 262 263 264
  if (var_quant_scales_->empty()) {
    auto& scales = Get<VarQuantScale>("quant_var_scales");
    for (auto name : names) {
      if (scales.find(name) == scales.end()) {
        present = false;
        LogScaleIsMissingForVarName(name);
      }
    }
  } else {
    for (auto name : names) {
      if (var_quant_scales_->find(name) == var_quant_scales_->end()) {
        present = false;
        LogScaleIsMissingForVarName(name);
      }
265 266 267 268 269
    }
  }
  return present;
}

270
bool CPUQuantizePass::AreScalesPresentForNodes(
271
    std::initializer_list<Node*> nodes) const {
272
  bool present = true;
B
baoachun 已提交
273 274 275 276 277 278 279 280 281 282 283 284 285 286
  if (var_quant_scales_->empty()) {
    auto& scales = Get<VarQuantScale>("quant_var_scales");
    for (auto node : nodes) {
      if (scales.count(node->Name()) == 0) {
        present = false;
        LogScaleIsMissingForVarNode(node);
      }
    }
  } else {
    for (auto node : nodes) {
      if (var_quant_scales_->count(node->Name()) == 0) {
        present = false;
        LogScaleIsMissingForVarNode(node);
      }
287 288 289 290 291
    }
  }
  return present;
}

292 293
std::pair<bool, LoDTensor> CPUQuantizePass::GetScaleDataByName(
    const std::string& name) const {
B
baoachun 已提交
294 295 296 297 298
  if (var_quant_scales_->empty()) {
    auto& scales = Get<VarQuantScale>("quant_var_scales");
    return scales.at(name);
  }
  return var_quant_scales_->at(name);
299 300
}

301 302
std::pair<bool, LoDTensor> CPUQuantizePass::GetScaleDataForNode(
    const Node* node) const {
303 304 305 306 307
  return GetScaleDataByName(node->Name());
}

LoDTensor CPUQuantizePass::GetScaleTensorByName(const std::string& name) const {
  return GetScaleDataByName(name).second;
308 309 310 311 312 313 314 315 316 317 318 319 320
}

LoDTensor CPUQuantizePass::GetScaleTensorForNode(const Node* node) const {
  return GetScaleDataForNode(node).second;
}

double CPUQuantizePass::GetScaleValueForNode(const Node* node,
                                             bool* is_unsigned) const {
  auto scale_data = GetScaleDataForNode(node);
  if (is_unsigned != nullptr) *is_unsigned = scale_data.first;
  return scale_data.second.data<double>()[0];
}

321 322
bool CPUQuantizePass::IsOpDequantized(const Node* node) const {
  return node->Op()->Type() == "dequantize" ||
323
         platform::HasOpINT8DataType(node->Op());
324 325 326
}

bool CPUQuantizePass::IsOpQuantized(const Node* node) const {
327 328 329 330 331 332
  // return true only if all of outputs are ops and their are either quantize or
  // have int8 data type
  return all_of(node->outputs.begin(), node->outputs.end(), [](Node* output) {
    return (output->IsOp() && (output->Op()->Type() == "quantize" ||
                               platform::HasOpINT8DataType(output->Op())));
  });
333 334
}

B
baoachun 已提交
335
void CPUQuantizePass::GetQuantInfo(Graph* graph) const {
336 337
  GetInfoFromTheFirstOp(
      graph, "has_quant_info", "var_quant_scales", var_quant_scales_);
B
baoachun 已提交
338 339
}

340 341 342 343 344 345 346 347 348 349 350 351 352 353
void CPUQuantizePass::QuantizeConv(Graph* graph,
                                   bool with_residual_data) const {
  GraphPatternDetector gpd;
  auto pattern = gpd.mutable_pattern();
  patterns::ConvResidual conv_pattern{pattern, name_scope_};
  conv_pattern(with_residual_data);

  int quantize_conv_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    VLOG(4) << "Quantize conv2d op";
    GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);

    // skip if should not be quantized
354
    if (!platform::HasOpINT8DataType(conv_op->Op())) {
355 356 357
      LogQuantizationDisabled(conv_op);
      return;
    }
358 359 360 361 362

    GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);

363
    auto has_output_scale = AreScalesPresentForNodes({conv_output});
W
Wojciech Uss 已提交
364
    if (with_residual_data && !has_output_scale) {
365 366 367 368
      MarkAndLogCannotQuantizeOp(
          conv_op,
          "Conv op with ResidualData input cannot be quantized "
          "without output scale.");
W
Wojciech Uss 已提交
369 370 371
      return;
    }

372
    if (with_residual_data) {
373 374
      GET_IR_NODE_FROM_SUBGRAPH(
          conv_residual_data, conv_residual_data, conv_pattern);
375
      if (!AreScalesPresentForNodes(
376
              {conv_input, conv_filter, conv_residual_data})) {
377 378
        MarkAndLogCannotQuantizeOp(conv_op,
                                   "No scale available for the operator");
379
        return;
380
      }
381 382 383 384 385

      bool is_residual_unsigned{false};
      auto residual_scale =
          GetScaleValueForNode(conv_residual_data, &is_residual_unsigned);

386 387 388 389 390 391 392
      QuantizeInput(g,
                    conv_op,
                    conv_residual_data,
                    "ResidualData",
                    residual_scale,
                    is_residual_unsigned,
                    "Scale_in_eltwise");
393
    } else {
394
      if (!AreScalesPresentForNodes({conv_input, conv_filter})) {
395 396
        MarkAndLogCannotQuantizeOp(conv_op,
                                   "No scale available for the operator");
397
        return;
398
      }
399 400
    }

401 402
    bool is_input_unsigned{false};
    auto input_scale = GetScaleValueForNode(conv_input, &is_input_unsigned);
403 404 405 406 407 408 409
    QuantizeInput(g,
                  conv_op,
                  conv_input,
                  "Input",
                  input_scale,
                  is_input_unsigned,
                  "Scale_in");
410

411
    auto filter_scale_tensor = GetScaleTensorForNode(conv_filter);
412
    EigenVectorArrayMap eigen_tensor{filter_scale_tensor.data<double>(),
413
                                     filter_scale_tensor.numel()};
414 415 416 417 418 419 420
    eigen_tensor *= static_cast<double>(S8_MAX);
    std::vector<float> filter_scale{
        filter_scale_tensor.data<double>(),
        filter_scale_tensor.data<double>() + filter_scale_tensor.numel()};

    conv_op->Op()->SetAttr("Scale_weights", filter_scale);

421
    // if quantization scale is missing for output tensor, return fp32 data
W
Wojciech Uss 已提交
422
    if (has_output_scale) {
423 424 425
      bool is_output_unsigned{false};
      auto output_scale =
          GetScaleValueForNode(conv_output, &is_output_unsigned);
426 427 428 429 430 431 432
      DequantizeOutput(g,
                       conv_op,
                       conv_output,
                       "Output",
                       output_scale,
                       is_output_unsigned,
                       "Scale_out");
433 434 435
    } else {
      conv_op->Op()->SetAttr("force_fp32_output", true);
    }
436

437
    // change threshold in bounded ReLu
438 439
    if (conv_op->Op()->GetAttrIfExists<std::string>("fuse_activation") ==
        "relu6") {
440
      float scale_out =
R
Ruibiao Chen 已提交
441
          PADDLE_GET_CONST(float, conv_op->Op()->GetAttr("Scale_out"));
442
      float threshold =
R
Ruibiao Chen 已提交
443
          PADDLE_GET_CONST(float, conv_op->Op()->GetAttr("fuse_alpha"));
444
      conv_op->Op()->SetAttr("fuse_alpha", scale_out * threshold);
445 446
    }

447 448 449 450 451 452
    ++quantize_conv_count;
  };

  gpd(graph, handler);
  AddStatis(quantize_conv_count);

453
  LogQuantizedOpsCounter(
454 455
      "conv2d",
      quantize_conv_count,
456
      ((with_residual_data) ? "with residual connection" : ""));
457 458
}

M
Michał Gallus 已提交
459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475
void CPUQuantizePass::QuantizeFc(Graph* graph) const {
  GraphPatternDetector gpd;
  auto pattern = gpd.mutable_pattern();
  patterns::FCMKLDNN fc_pattern{pattern, name_scope_};
  auto* fc_input = gpd.mutable_pattern()
                       ->NewNode("fc_quantizer/input")
                       ->AsInput()
                       ->assert_is_op_input("fc", "Input");
  fc_pattern(fc_input, false);

  int quantize_fc_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    VLOG(4) << "Quantize fc op";
    GET_IR_NODE_FROM_SUBGRAPH(fc, fc, fc_pattern);

    // skip if should not be quantized
476
    if (!platform::HasOpINT8DataType(fc->Op())) {
477 478 479
      LogQuantizationDisabled(fc);
      return;
    }
480
    if (!fc->Op()->GetAttrIfExists<bool>("use_mkldnn")) {
481
      MarkAndLogCannotQuantizeOp(fc, "use_mkldnn attribute set to false");
M
Michał Gallus 已提交
482
      return;
483
    }
M
Michał Gallus 已提交
484 485 486 487 488

    GET_IR_NODE_FROM_SUBGRAPH(weights, weights, fc_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(input, input, fc_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(output, output, fc_pattern);

489
    if (!AreScalesPresentForNodes({input, weights})) {
490
      MarkAndLogCannotQuantizeOp(fc, "No scale available for the operator");
491 492
      return;
    }
493

494 495
    bool is_input_unsigned{false};
    auto input_scale = GetScaleValueForNode(input, &is_input_unsigned);
496 497
    QuantizeInput(
        g, fc, input, "Input", input_scale, is_input_unsigned, "Scale_in");
M
Michał Gallus 已提交
498

499
    auto weight_scale_tensor = GetScaleTensorForNode(weights);
M
Michał Gallus 已提交
500
    EigenVectorArrayMap eigen_tensor{weight_scale_tensor.data<double>(),
501
                                     weight_scale_tensor.numel()};
M
Michał Gallus 已提交
502 503 504 505 506 507 508
    eigen_tensor *= static_cast<double>(S8_MAX);
    std::vector<float> filter_scale{
        weight_scale_tensor.data<double>(),
        weight_scale_tensor.data<double>() + weight_scale_tensor.numel()};

    fc->Op()->SetAttr("Scale_weights", filter_scale);

509
    // if quantization scale is missing for output tensor, return fp32 data
510
    if (AreScalesPresentForNodes({output})) {
511 512
      bool is_output_unsigned{false};
      auto output_scale = GetScaleValueForNode(output, &is_output_unsigned);
513 514
      DequantizeOutput(
          g, fc, output, "Out", output_scale, is_output_unsigned, "Scale_out");
515 516 517
    } else {
      fc->Op()->SetAttr("force_fp32_output", true);
    }
M
Michał Gallus 已提交
518 519 520 521 522 523

    ++quantize_fc_count;
  };

  gpd(graph, handler);
  AddStatis(quantize_fc_count);
524
  LogQuantizedOpsCounter("fc", quantize_fc_count);
M
Michał Gallus 已提交
525 526
}

527 528 529 530 531 532 533 534 535 536 537 538 539
void CPUQuantizePass::QuantizePool(Graph* graph) const {
  GraphPatternDetector gpd;
  auto pattern = gpd.mutable_pattern();
  patterns::Pool pool_pattern{pattern, name_scope_};
  pool_pattern();

  int quantize_pool_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    VLOG(4) << "Quantize pool2d op";
    GET_IR_NODE_FROM_SUBGRAPH(pool_op, pool_op, pool_pattern);

    // skip if should not be quantized
540
    if (!platform::HasOpINT8DataType(pool_op->Op())) {
541 542 543
      LogQuantizationDisabled(pool_op);
      return;
    }
544 545 546 547

    GET_IR_NODE_FROM_SUBGRAPH(pool_input, pool_input, pool_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(pool_output, pool_output, pool_pattern);

548
    if (!AreScalesPresentForNodes({pool_input, pool_output})) {
549 550
      MarkAndLogCannotQuantizeOp(pool_op,
                                 "No scale available for the operator");
551 552
      return;
    }
553

554 555
    bool is_input_unsigned{false};
    auto input_scale = GetScaleValueForNode(pool_input, &is_input_unsigned);
556 557
    QuantizeInput(g, pool_op, pool_input, "X", input_scale, is_input_unsigned);

558 559
    bool is_output_unsigned{false};
    auto output_scale = GetScaleValueForNode(pool_output, &is_output_unsigned);
560 561
    DequantizeOutput(
        g, pool_op, pool_output, "Out", output_scale, is_output_unsigned);
562 563 564 565 566 567

    ++quantize_pool_count;
  };

  gpd(graph, handler);
  AddStatis(quantize_pool_count);
568
  LogQuantizedOpsCounter("pool2d", quantize_pool_count);
569 570
}

571 572 573 574 575 576 577 578 579 580 581 582 583
void CPUQuantizePass::QuantizeConcat(Graph* graph) const {
  GraphPatternDetector gpd;
  auto pattern = gpd.mutable_pattern();
  patterns::Concat concat_pattern{pattern, name_scope_};
  concat_pattern();

  int quantize_concat_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    VLOG(4) << "Quantize concat op";
    GET_IR_NODE_FROM_SUBGRAPH(concat_op, concat_op, concat_pattern);

    // skip if should not be quantized
584
    if (!platform::HasOpINT8DataType(concat_op->Op())) {
585 586 587
      LogQuantizationDisabled(concat_op);
      return;
    }
588

589 590 591 592 593 594 595 596 597 598 599 600 601 602
    bool are_all_inputs_unsigned{true};
    // if all inputs were unsigned, then the output was set to unsigned
    // during the scale calculation step
    auto inputs = concat_op->inputs;
    for (size_t i = 0; i < inputs.size(); i++) {
      if (AreScalesPresentForVarNames({inputs[i]->Name()})) {
        auto scale_data = GetScaleDataByName(inputs[i]->Name());
        if (scale_data.first == false) {
          are_all_inputs_unsigned = false;
          break;
        }
      }
    }

603 604
    GET_IR_NODE_FROM_SUBGRAPH(concat_out, concat_out, concat_pattern);

605
    if (!AreScalesPresentForNodes({concat_out})) {
606 607
      MarkAndLogCannotQuantizeOp(concat_op,
                                 "No scale available for the operator");
608 609
      return;
    }
610

611
    auto output_scale = GetScaleValueForNode(concat_out);
612

613
    QuantizeInputs(g, concat_op, "X", are_all_inputs_unsigned);
614

615 616
    DequantizeOutput(
        g, concat_op, concat_out, "Out", output_scale, are_all_inputs_unsigned);
617 618 619 620 621
    ++quantize_concat_count;
  };

  gpd(graph, handler);
  AddStatis(quantize_concat_count);
622
  LogQuantizedOpsCounter("concat", quantize_concat_count);
623 624
}

625 626 627 628 629 630 631 632 633 634 635 636 637
void CPUQuantizePass::QuantizePriorBox(Graph* graph) const {
  GraphPatternDetector gpd;
  auto pattern = gpd.mutable_pattern();
  patterns::PriorBox prior_box_pattern{pattern, name_scope_};
  prior_box_pattern();

  int quantize_prior_box_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    VLOG(4) << "Quantize prior_box op";
    GET_IR_NODE_FROM_SUBGRAPH(prior_box_op, prior_box_op, prior_box_pattern);

    // skip if should not be quantized
638
    if (!platform::HasOpINT8DataType(prior_box_op->Op())) {
639 640 641
      LogQuantizationDisabled(prior_box_op);
      return;
    }
642

643 644
    GET_IR_NODE_FROM_SUBGRAPH(
        prior_box_input, prior_box_input, prior_box_pattern);
645

646
    if (!AreScalesPresentForNodes({prior_box_input})) {
647 648
      MarkAndLogCannotQuantizeOp(prior_box_op,
                                 "No scale available for the operator");
649 650
      return;
    }
651

652 653 654
    bool is_input_unsigned{false};
    auto input_scale =
        GetScaleValueForNode(prior_box_input, &is_input_unsigned);
655 656 657 658 659
    QuantizeInput(g,
                  prior_box_op,
                  prior_box_input,
                  "Input",
                  input_scale,
660 661 662 663 664 665 666
                  is_input_unsigned);

    ++quantize_prior_box_count;
  };

  gpd(graph, handler);
  AddStatis(quantize_prior_box_count);
667
  LogQuantizedOpsCounter("prior_box", quantize_prior_box_count);
668 669
}

670 671 672
void CPUQuantizePass::QuantizeImmutable(Graph* graph,
                                        const std::string& immutable_type,
                                        const std::string& input_name) const {
673 674
  GraphPatternDetector gpd;
  auto pattern = gpd.mutable_pattern();
675 676
  patterns::Immutable immutable_pattern{pattern, name_scope_};
  immutable_pattern(immutable_type, input_name);
677

678
  int quantize_immutable_count = 0;
679 680
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
681 682
    VLOG(4) << "Quantize " + immutable_type + " op";
    GET_IR_NODE_FROM_SUBGRAPH(immutable_op, immutable_op, immutable_pattern);
683 684

    // skip if should not be quantized
685 686
    if (!platform::HasOpINT8DataType(immutable_op->Op())) {
      LogQuantizationDisabled(immutable_op);
687 688
      return;
    }
689 690 691
    GET_IR_NODE_FROM_SUBGRAPH(prev_op, prev_op, immutable_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(immutable_in, immutable_in, immutable_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(immutable_out, immutable_out, immutable_pattern);
692

693
    // skip if prev op and next op is not quantized
694 695
    if (!IsOpDequantized(prev_op) && !IsOpQuantized(immutable_out)) {
      MarkAndLogCannotQuantizeOp(immutable_op,
696
                                 "No other quantizable operators nearby");
697 698 699
      return;
    }

700 701
    if (!AreScalesPresentForNodes({immutable_out})) {
      MarkAndLogCannotQuantizeOp(immutable_op,
702
                                 "No scale available for the operator");
703
      return;
704
    }
705

706
    bool is_input_unsigned{false};
707 708 709 710 711 712 713 714
    auto input_scale = GetScaleValueForNode(immutable_out, &is_input_unsigned);

    QuantizeInput(g,
                  immutable_op,
                  immutable_in,
                  input_name,
                  input_scale,
                  is_input_unsigned);
715

716 717
    bool is_output_unsigned{false};
    auto output_scale =
718
        GetScaleValueForNode(immutable_out, &is_output_unsigned);
719
    DequantizeOutput(g,
720 721
                     immutable_op,
                     immutable_out,
722 723
                     "Out",
                     output_scale,
724 725
                     is_output_unsigned);

726
    ++quantize_immutable_count;
727 728 729
  };

  gpd(graph, handler);
730 731
  AddStatis(quantize_immutable_count);
  LogQuantizedOpsCounter(immutable_type, quantize_immutable_count);
Z
Zuza 已提交
732 733
}

734
void CPUQuantizePass::QuantizeMatmul(Graph* graph, bool with_residual) const {
735 736
  GraphPatternDetector gpd;
  auto pattern = gpd.mutable_pattern();
737
  patterns::MatmulWithInputOps matmul_pattern{pattern, name_scope_};
738
  matmul_pattern(with_residual);
739 740 741 742 743 744 745 746

  int quantize_matmul_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    VLOG(4) << "Quantize matmul op";
    GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, matmul_pattern);

    // skip if should not be quantized
747
    if (!platform::HasOpINT8DataType(matmul_op->Op())) {
748
      LogQuantizationDisabled(matmul_op);
749 750 751 752 753 754
      return;
    }
    GET_IR_NODE_FROM_SUBGRAPH(prev_op_x, prev_op_x, matmul_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(prev_op_y, prev_op_y, matmul_pattern);

    // skip if prev ops are not quantized
755
    if (!IsOpDequantized(prev_op_x) && !IsOpDequantized(prev_op_y)) {
756 757
      MarkAndLogCannotQuantizeOp(matmul_op,
                                 "No other quantizable operators nearby");
758 759 760 761 762 763
      return;
    }
    GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, matmul_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, matmul_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, matmul_pattern);

764 765 766 767 768 769 770 771 772
    auto has_output_scale = AreScalesPresentForNodes({matmul_out});
    if (with_residual && !has_output_scale) {
      MarkAndLogCannotQuantizeOp(
          matmul_op,
          "Matmul op with ResidualData input cannot be quantized "
          "without output scale.");
      return;
    }

773
    if (!AreScalesPresentForNodes({matmul_in_x, matmul_in_y})) {
774 775
      MarkAndLogCannotQuantizeOp(matmul_op,
                                 "No scale available for the operator");
776
      return;
777
    }
778

779 780 781
    bool is_x_unsigned{false}, is_y_unsigned{false};
    auto input_x_scale = GetScaleValueForNode(matmul_in_x, &is_x_unsigned);
    auto input_y_scale = GetScaleValueForNode(matmul_in_y, &is_y_unsigned);
782 783
    PADDLE_ENFORCE_EQ(is_x_unsigned,
                      is_y_unsigned,
784 785 786 787
                      platform::errors::InvalidArgument(
                          "Matmul inputs should have the same "
                          "attribute of signed/unsigned, but they "
                          "are different: x(%d), y(%d).",
788 789
                          is_x_unsigned,
                          is_y_unsigned));
790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811

    if (with_residual) {
      GET_IR_NODE_FROM_SUBGRAPH(
          matmul_residual_data, matmul_residual_data, matmul_pattern);
      if (!AreScalesPresentForNodes({matmul_residual_data})) {
        MarkAndLogCannotQuantizeOp(matmul_op,
                                   "No scale available for the operator");
        return;
      }
      bool is_residual_unsigned{false};
      auto residual_scale =
          GetScaleValueForNode(matmul_residual_data, &is_residual_unsigned);

      QuantizeInput(g,
                    matmul_op,
                    matmul_residual_data,
                    "ResidualData",
                    residual_scale,
                    is_residual_unsigned,
                    "Scale_in_eltwise");
    }

812 813 814 815 816 817
    QuantizeInput(g,
                  matmul_op,
                  matmul_in_x,
                  "X",
                  input_x_scale,
                  is_x_unsigned,
818
                  "Scale_x");
819 820 821 822 823 824
    QuantizeInput(g,
                  matmul_op,
                  matmul_in_y,
                  "Y",
                  input_y_scale,
                  is_y_unsigned,
825 826
                  "Scale_y");

827
    // if quantization scale is missing for output tensor, return fp32 data
828
    if (AreScalesPresentForNodes({matmul_out})) {
829 830
      bool is_output_unsigned{false};
      auto output_scale = GetScaleValueForNode(matmul_out, &is_output_unsigned);
831 832 833 834 835 836 837
      DequantizeOutput(g,
                       matmul_op,
                       matmul_out,
                       "Out",
                       output_scale,
                       is_output_unsigned,
                       "Scale_out");
838 839 840
    } else {
      matmul_op->Op()->SetAttr("force_fp32_output", true);
    }
841 842 843 844 845

    ++quantize_matmul_count;
  };
  gpd(graph, handler);
  AddStatis(quantize_matmul_count);
846 847 848
  LogQuantizedOpsCounter("matmul",
                         quantize_matmul_count,
                         (with_residual ? "with residual connection" : ""));
849 850
}

Z
Zuza 已提交
851
void CPUQuantizePass::QuantizeElementwise(
852
    Graph* graph, const std::string& elementwise_type) const {
853 854
  GraphPatternDetector gpd;
  auto pattern = gpd.mutable_pattern();
855
  patterns::ElementwiseOp elementwise_pattern{pattern, name_scope_};
856

857
  elementwise_pattern(elementwise_type);
858

Z
Zuza 已提交
859
  int quantize_elementwise_count = 0;
860 861
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
Z
Zuza 已提交
862
    VLOG(4) << "Quantize " + elementwise_type + " op";
863 864
    GET_IR_NODE_FROM_SUBGRAPH(
        elementwise_op, elementwise_op, elementwise_pattern);
865 866

    // skip if should not be quantized
Z
Zuza 已提交
867 868
    if (!platform::HasOpINT8DataType(elementwise_op->Op())) {
      LogQuantizationDisabled(elementwise_op);
869 870 871
      return;
    }

872 873 874 875 876 877 878 879 880 881 882 883
    auto x_name = elementwise_op->Op()->Input("X");
    auto y_name = elementwise_op->Op()->Input("Y");
    Node *elementwise_x, *elementwise_y;

    for (auto& input : elementwise_op->inputs) {
      if (input->Name() == x_name[0]) elementwise_x = input;
      if (input->Name() == y_name[0]) elementwise_y = input;
    }
    if (!elementwise_x || !elementwise_y) {
      return;
    }

884 885
    GET_IR_NODE_FROM_SUBGRAPH(
        elementwise_out, elementwise_out, elementwise_pattern);
886

887
    if (!AreScalesPresentForNodes(
Z
Zuza 已提交
888
            {elementwise_x, elementwise_y, elementwise_out})) {
889 890
      MarkAndLogCannotQuantizeOp(elementwise_op,
                                 "No scale available for the operator");
891 892 893 894
      return;
    }

    bool is_x_unsigned{false}, is_y_unsigned{false};
Z
Zuza 已提交
895 896
    auto input_x_scale = GetScaleValueForNode(elementwise_x, &is_x_unsigned);
    auto input_y_scale = GetScaleValueForNode(elementwise_y, &is_y_unsigned);
897 898 899

    // TODO(sfraczek): add support for different signness
    if (is_x_unsigned != is_y_unsigned) {
900 901
      MarkAndLogCannotQuantizeOp(
          elementwise_op, "Elementwise inputs must be of the same type.");
902 903 904
      return;
    }

905 906 907 908 909 910 911 912 913 914 915 916 917 918
    QuantizeInput(g,
                  elementwise_op,
                  elementwise_x,
                  "X",
                  input_x_scale,
                  is_x_unsigned,
                  "Scale_x");
    QuantizeInput(g,
                  elementwise_op,
                  elementwise_y,
                  "Y",
                  input_y_scale,
                  is_y_unsigned,
                  "Scale_y");
919

920 921
    bool is_output_unsigned{false};
    auto output_scale =
Z
Zuza 已提交
922
        GetScaleValueForNode(elementwise_out, &is_output_unsigned);
923

924 925 926 927 928 929 930
    DequantizeOutput(g,
                     elementwise_op,
                     elementwise_out,
                     "Out",
                     output_scale,
                     is_output_unsigned,
                     "Scale_out");
931

Z
Zuza 已提交
932
    ++quantize_elementwise_count;
933 934
  };
  gpd(graph, handler);
Z
Zuza 已提交
935
  AddStatis(quantize_elementwise_count);
936
  LogQuantizedOpsCounter(elementwise_type, quantize_elementwise_count);
937 938
}

939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960
void CPUQuantizePass::QuantizeFusionGru(Graph* graph) const {
  GraphPatternDetector gpd;
  patterns::FusionGru pattern{gpd.mutable_pattern(), name_scope_};
  pattern();

  int quantize_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    VLOG(4) << "Quantize fusion_gru op";
    GET_IR_NODE_FROM_SUBGRAPH(op, op, pattern);

    // skip if should not be quantized
    if (!platform::HasOpINT8DataType(op->Op())) {
      LogQuantizationDisabled(op);
      return;
    }

    GET_IR_NODE_FROM_SUBGRAPH(x, x, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(weight_h, weight_h, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(weight_x, weight_x, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(out, out, pattern);

961
    if (!AreScalesPresentForNodes({x, weight_x})) {
962
      MarkAndLogCannotQuantizeOp(op, "No scale available for the operator");
963 964 965 966 967 968 969 970 971
      return;
    }

    bool is_x_unsigned{false};
    auto input_x_scale = GetScaleValueForNode(x, &is_x_unsigned);

    double input_x_shift{128.};
    if (is_x_unsigned) input_x_shift = 0.;

972 973 974 975 976 977 978 979 980
    QuantizeInput(g,
                  op,
                  x,
                  "X",
                  input_x_scale,
                  is_x_unsigned,
                  "Scale_data",
                  input_x_shift,
                  "Shift_data");
981 982 983

    auto weight_scale_tensor = GetScaleTensorForNode(weight_x);
    EigenVectorArrayMap eigen_tensor{weight_scale_tensor.data<double>(),
984
                                     weight_scale_tensor.numel()};
985 986 987 988 989 990 991 992 993 994 995 996 997
    eigen_tensor *= static_cast<double>(S8_MAX);
    std::vector<float> scale_weights{
        weight_scale_tensor.data<double>(),
        weight_scale_tensor.data<double>() + weight_scale_tensor.numel()};

    op->Op()->SetAttr("Scale_weights", scale_weights);
    // return fp32 data
    op->Op()->SetAttr("force_fp32_output", true);

    ++quantize_count;
  };
  gpd(graph, handler);
  AddStatis(quantize_count);
998
  LogQuantizedOpsCounter("fusion_gru", quantize_count);
999 1000
}

1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024
void CPUQuantizePass::QuantizeMultiGru(Graph* graph) const {
  GraphPatternDetector gpd;
  patterns::MultiGru pattern{gpd.mutable_pattern(), name_scope_};
  pattern();

  int quantize_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    VLOG(4) << "Quantize multi_gru op";
    GET_IR_NODE_FROM_SUBGRAPH(gru, gru, pattern);

    // skip if should not be quantized
    if (!platform::HasOpINT8DataType(gru->Op())) {
      LogQuantizationDisabled(gru);
      return;
    }

    GET_IR_NODE_FROM_SUBGRAPH(x, x, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(wx, wx, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(h, h, pattern);

    auto wx_names = gru->Op()->Input("WeightX");
    if (!AreScalesPresentForNodes({x}) ||
        !AreScalesPresentForVarNames(wx_names)) {
1025
      MarkAndLogCannotQuantizeOp(gru, "No scale available for the operator");
1026 1027 1028 1029 1030 1031 1032 1033 1034
      return;
    }

    bool is_x_unsigned{false};
    auto input_x_scale = GetScaleValueForNode(x, &is_x_unsigned);

    double input_x_shift{128.};
    if (is_x_unsigned) input_x_shift = 0.;

1035 1036 1037 1038 1039 1040 1041 1042 1043
    QuantizeInput(g,
                  gru,
                  x,
                  "X",
                  input_x_scale,
                  is_x_unsigned,
                  "Scale_data",
                  input_x_shift,
                  "Shift_data");
1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054

    auto* scope = param_scope();
    int wx_size = wx_names.size();
    std::vector<std::string> w_scale_var_names;
    for (int i = 0; i < wx_size; ++i) {
      auto scale_tensor_src = GetScaleTensorByName(wx_names[i]);
      EigenVectorArrayMap eigen_tensor_src{scale_tensor_src.data<double>(),
                                           scale_tensor_src.numel()};

      VarDesc scale_var_desc(patterns::PDNodeName("multi_gru", "w_scale"));

1055
      scale_var_desc.SetShape(phi::vectorize(scale_tensor_src.dims()));
1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081
      scale_var_desc.SetDataType(proto::VarType::FP32);
      scale_var_desc.SetLoDLevel(scale_tensor_src.lod().size());
      scale_var_desc.SetPersistable(true);
      auto* w_scale_node = g->CreateVarNode(&scale_var_desc);

      auto* w_scale_tensor_dst =
          scope->Var(w_scale_node->Name())->GetMutable<LoDTensor>();
      w_scale_tensor_dst->Resize(scale_tensor_src.dims());
      auto* dst_data =
          w_scale_tensor_dst->mutable_data<float>(platform::CPUPlace());
      EigenVectorArrayMapFloat eigen_tensor_dst{dst_data,
                                                w_scale_tensor_dst->numel()};
      eigen_tensor_dst =
          eigen_tensor_src.cast<float>() * static_cast<float>(S8_MAX);
      w_scale_var_names.push_back(w_scale_node->Name());
      IR_NODE_LINK_TO(w_scale_node, gru);
    }

    gru->Op()->SetInput("Scale_weights", w_scale_var_names);
    // return fp32 data
    gru->Op()->SetAttr("force_fp32_output", true);

    ++quantize_count;
  };
  gpd(graph, handler);
  AddStatis(quantize_count);
1082
  LogQuantizedOpsCounter("multi_gru", quantize_count);
1083 1084
}

1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109
void CPUQuantizePass::QuantizeFusionLSTM(Graph* graph) const {
  GraphPatternDetector gpd;
  patterns::FusionLSTM pattern{gpd.mutable_pattern(), name_scope_};
  pattern();

  int quantize_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    VLOG(4) << "Quantize fusion_lstm op";
    GET_IR_NODE_FROM_SUBGRAPH(op, op, pattern);

    // skip if should not be quantized
    if (!platform::HasOpINT8DataType(op->Op())) {
      LogQuantizationDisabled(op);
      return;
    }

    GET_IR_NODE_FROM_SUBGRAPH(x, x, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(weight_h, weight_h, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(weight_x, weight_x, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(hidden, hidden, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(cell, cell, pattern);

    // Starting from here there maybe issues
    if (!AreScalesPresentForNodes({x, weight_x})) {
1110
      MarkAndLogCannotQuantizeOp(op, "No scale available for the operator");
1111 1112 1113 1114 1115 1116 1117 1118 1119
      return;
    }

    bool is_x_unsigned{false};
    auto input_x_scale = GetScaleValueForNode(x, &is_x_unsigned);

    double input_x_shift{128.};
    if (is_x_unsigned) input_x_shift = 0.;

1120 1121 1122 1123 1124 1125 1126 1127 1128
    QuantizeInput(g,
                  op,
                  x,
                  "X",
                  input_x_scale,
                  is_x_unsigned,
                  "Scale_data",
                  input_x_shift,
                  "Shift_data");
1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145

    auto weight_scale_tensor = GetScaleTensorForNode(weight_x);
    EigenVectorArrayMap eigen_tensor{weight_scale_tensor.data<double>(),
                                     weight_scale_tensor.numel()};
    eigen_tensor *= static_cast<double>(S8_MAX);
    std::vector<float> scale_weights{
        weight_scale_tensor.data<double>(),
        weight_scale_tensor.data<double>() + weight_scale_tensor.numel()};

    op->Op()->SetAttr("Scale_weights", scale_weights);
    // return fp32 data
    op->Op()->SetAttr("force_fp32_output", true);

    ++quantize_count;
  };
  gpd(graph, handler);
  AddStatis(quantize_count);
1146
  LogQuantizedOpsCounter("fusion_lstm", quantize_count);
1147 1148
}

1149
void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
1150
  VLOG(3) << "Quantizing the graph.";
1151 1152
  PADDLE_ENFORCE_NOT_NULL(
      graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
1153
  FusePassBase::Init(name_scope_, graph);
1154

1155 1156 1157
  PADDLE_ENFORCE_NOT_NULL(
      param_scope(),
      platform::errors::InvalidArgument("Scope cannot be nullptr."));
1158

B
baoachun 已提交
1159
  GetQuantInfo(graph);
1160 1161 1162
  QuantizeConv(graph, false /* with_residual_data */);
  QuantizeConv(graph, true /* with_residual_data */);
  QuantizePool(graph);
1163
  QuantizeConcat(graph);
1164
  QuantizePriorBox(graph);
M
Michał Gallus 已提交
1165
  QuantizeFc(graph);
1166 1167
  QuantizeMatmul(graph, false /* with_residual_data */);
  QuantizeMatmul(graph, true /* with_residual_data */);
1168 1169 1170
  QuantizeImmutable(graph, "reshape2", "X");
  QuantizeImmutable(graph, "transpose2", "X");
  QuantizeImmutable(graph, "slice", "Input");
1171
  QuantizeImmutable(graph, "shape", "Input");
1172 1173
  QuantizeImmutable(graph, "nearest_interp", "X");
  QuantizeImmutable(graph, "nearest_interp_v2", "X");
Z
Zuza 已提交
1174 1175
  QuantizeElementwise(graph, "elementwise_add");
  QuantizeElementwise(graph, "elementwise_mul");
1176
  QuantizeElementwise(graph, "elementwise_sub");
1177
  QuantizeFusionGru(graph);
1178
  QuantizeMultiGru(graph);
1179
  QuantizeFusionLSTM(graph);
1180 1181 1182 1183 1184 1185 1186 1187
}

}  // namespace ir
}  // namespace framework
}  // namespace paddle

REGISTER_PASS(cpu_quantize_pass, paddle::framework::ir::CPUQuantizePass)
    .RequirePassAttr("quant_var_scales");