cpu_quantize_pass.cc 41.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#include "paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h"

17
#include <sstream>
18 19
#include <utility>
#include <vector>
W
wanghuancoder 已提交
20

B
baoachun 已提交
21
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h"
22
#include "paddle/fluid/platform/mkldnn_helper.h"
23 24 25 26 27 28
#include "paddle/fluid/string/pretty_log.h"

namespace paddle {
namespace framework {
namespace ir {

29
using EigenVectorArrayMap = Eigen::Map<Eigen::Array<double, Eigen::Dynamic, 1>>;
30 31
using EigenVectorArrayMapFloat =
    Eigen::Map<Eigen::Array<float, Eigen::Dynamic, 1>>;
32 33
using string::PrettyLogDetail;

34 35 36 37 38 39 40 41 42
namespace {

void UnlinkNodes(ir::Node* a, ir::Node* b) {
  a->outputs.erase(std::remove(a->outputs.begin(), a->outputs.end(), b),
                   a->outputs.end());
  b->inputs.erase(std::remove(b->inputs.begin(), b->inputs.end(), a),
                  b->inputs.end());
}

43
void MarkAndLogCannotQuantizeOp(Node* op, const char* details = nullptr) {
44 45 46
  std::stringstream msg_ss;
  msg_ss << "Cannot quantize operator " << op->Name()
         << " (type: " << op->Op()->Type() << ", id: " << op->id() << ").";
47
  if (details) msg_ss << " " << details;
48 49
  VLOG(2) << msg_ss.str().c_str();
  op->Op()->SetAttr("mkldnn_data_type", std::string("float32"));
50 51
}

52 53 54 55 56 57
void LogScaleIsMissingForVarName(const std::string& name) {
  VLOG(4) << "Quantization scale for the variable " << name << " is missing.";
}

void LogScaleIsMissingForVarNode(Node* node) {
  LogScaleIsMissingForVarName(node->Name());
58 59
}

60
void LogQuantizationDisabled(Node* op) {
61
  VLOG(2) << "Quantization skipped for operator " << op->Name()
62
          << " (type: " << op->Op()->Type() << ", id: " << op->id()
63
          << "). Attribute mkldnn_data_type != \"int8\".";
64 65
}

66 67
void LogQuantizedOpsCounter(const std::string& type,
                            const int counter,
68 69 70 71 72 73 74
                            const char* details = nullptr) {
  std::stringstream msg_ss;
  msg_ss << "---    quantized " << counter << " " << type << " ops";
  if (details) msg_ss << " " << details;
  PrettyLogDetail(msg_ss.str().c_str());
}

75 76 77 78
}  // namespace

enum { U8_MAX = 255, S8_MAX = 127 };

79 80 81 82 83
void CPUQuantizePass::QuantizeInput(Graph* g,
                                    Node* op,
                                    Node* input,
                                    std::string input_name,
                                    double scale_to_one,
84
                                    bool is_input_unsigned,
85 86
                                    std::string scale_attr_name,
                                    float shift,
87
                                    std::string shift_attr_name) const {
M
Michał Gallus 已提交
88 89 90
  auto inputs = op->Op()->InputNames();
  bool name_found =
      std::find(inputs.begin(), inputs.end(), input_name) != inputs.end();
91 92
  PADDLE_ENFORCE_EQ(name_found,
                    true,
93 94
                    platform::errors::InvalidArgument(
                        "Var(%s) isn't the input of the %s operator.",
95 96
                        input_name,
                        op->Op()->Type()));
97
  unsigned max = is_input_unsigned ? U8_MAX : S8_MAX;
98 99 100 101 102 103 104 105 106 107 108 109 110
  float scale = scale_to_one * max;

  // Create quantize output variable
  VarDesc quantize_out_desc(patterns::PDNodeName("quantize", "out"));
  auto* quantize_out_node = g->CreateVarNode(&quantize_out_desc);

  // create a quantize op node
  OpDesc q_desc;
  q_desc.SetType("quantize");
  q_desc.SetInput("Input", std::vector<std::string>({input->Name()}));
  q_desc.SetOutput("Output",
                   std::vector<std::string>({quantize_out_node->Name()}));
  q_desc.SetAttr("Scale", scale);
111 112
  q_desc.SetAttr("Shift", shift);
  q_desc.SetAttr("is_negative_input", !is_input_unsigned);
113

Z
Zuza 已提交
114 115 116
  // fix to fc format error
  if (op->Op()->Type() == "fc" &&
      op->Op()->GetAttrIfExists<int>("in_num_col_dims") == 2) {
117 118 119
    q_desc.SetAttr(
        "output_format",
        Has("data_layout") ? Get<std::string>("data_layout") : "NCHW");
Z
Zuza 已提交
120
  } else {
121 122 123
    q_desc.SetAttr(
        "output_format",
        Has("data_layout") ? Get<std::string>("data_layout") : "NHWC");
Z
Zuza 已提交
124
  }
125 126 127 128 129 130 131 132 133 134 135 136 137
  auto quantize_op = g->CreateOpNode(&q_desc);  // OpDesc will be copied.

  // update op's input
  op->Op()->SetInput(input_name,
                     std::vector<std::string>({quantize_out_node->Name()}));

  // link quantize op
  UnlinkNodes(input, op);
  IR_NODE_LINK_TO(input, quantize_op);
  IR_NODE_LINK_TO(quantize_op, quantize_out_node);
  IR_NODE_LINK_TO(quantize_out_node, op);

  if (!scale_attr_name.empty()) op->Op()->SetAttr(scale_attr_name, scale);
138
  if (!shift_attr_name.empty()) op->Op()->SetAttr(shift_attr_name, shift);
139 140
}

141 142 143
void CPUQuantizePass::QuantizeInputs(Graph* g,
                                     Node* op,
                                     std::string input_name,
144
                                     bool are_inputs_unsigned,
145 146
                                     std::string scale_attr_name,
                                     float shift,
147
                                     std::string shift_attr_name) const {
148
  auto inputs = op->inputs;
149
  auto output = op->outputs[0];
150 151
  PADDLE_ENFORCE_GE(inputs.size(),
                    1,
152 153
                    platform::errors::InvalidArgument(
                        "OP(%s)'s inputs(%d) must be equal or greater than 1.",
154 155 156 157
                        op->Name(),
                        inputs.size()));
  PADDLE_ENFORCE_EQ(op->outputs.size(),
                    1,
158
                    platform::errors::InvalidArgument(
159 160
                        "OP(%s)'s outputs(%d) must be equal to 1.",
                        op->Name(),
161
                        op->outputs.size()));
162 163 164 165 166 167 168 169

  // create a quantize op desc prototype
  OpDesc q_desc;
  q_desc.SetType("quantize");

  std::vector<Node*> quantize_out_nodes(inputs.size());
  std::vector<std::string> quantize_out_node_names(inputs.size());

170
  double scale_out = GetScaleValueForNode(output);
171
  unsigned max = are_inputs_unsigned ? U8_MAX : S8_MAX;
172
  float scale = scale_out * max;
173 174 175 176 177 178 179 180

  for (size_t i = 0; i < inputs.size(); i++) {
    // Create quantize output variable
    VarDesc quantize_out_desc(patterns::PDNodeName("quantize", "out"));
    quantize_out_nodes[i] = g->CreateVarNode(&quantize_out_desc);
    quantize_out_node_names[i] = quantize_out_nodes[i]->Name();

    q_desc.SetAttr("Scale", scale);
181
    q_desc.SetAttr("Shift", shift);
182 183 184
    q_desc.SetInput("Input", std::vector<std::string>({inputs[i]->Name()}));
    q_desc.SetOutput("Output",
                     std::vector<std::string>({quantize_out_node_names[i]}));
185
    q_desc.SetAttr("is_negative_input", !are_inputs_unsigned);
186 187 188 189 190 191 192 193 194 195 196 197 198
    auto quantize_op = g->CreateOpNode(&q_desc);  // OpDesc will be copied.

    // link quantize op
    UnlinkNodes(inputs[i], op);
    IR_NODE_LINK_TO(inputs[i], quantize_op);
    IR_NODE_LINK_TO(quantize_op, quantize_out_nodes[i]);
    IR_NODE_LINK_TO(quantize_out_nodes[i], op);
  }

  // update op's input
  op->Op()->SetInput(input_name, quantize_out_node_names);

  if (!scale_attr_name.empty()) op->Op()->SetAttr(scale_attr_name, scale);
199
  if (!shift_attr_name.empty()) op->Op()->SetAttr(shift_attr_name, shift);
200 201
}

202 203 204
void CPUQuantizePass::DequantizeOutput(Graph* g,
                                       Node* op,
                                       Node* output,
205
                                       std::string output_name,
206 207
                                       double scale_to_one,
                                       bool is_unsigned,
208
                                       std::string scale_attr_name) const {
M
Michał Gallus 已提交
209 210 211
  auto outputs = op->Op()->OutputNames();
  bool name_found =
      std::find(outputs.begin(), outputs.end(), output_name) != outputs.end();
212 213
  PADDLE_ENFORCE_EQ(name_found,
                    true,
M
Michał Gallus 已提交
214
                    platform::errors::InvalidArgument(
215
                        "Var(%s) isn't the output of the %s operator.",
216 217
                        output_name,
                        op->Op()->Type()));
218 219 220 221 222 223 224 225 226 227 228 229 230 231
  unsigned max = is_unsigned ? U8_MAX : S8_MAX;
  float scale = scale_to_one * max;

  // Create dequantize input variable
  VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in"));
  auto* dequantize_in_node = g->CreateVarNode(&dequantize_in_desc);

  // create a dequantize op node for output.
  OpDesc deq_desc;
  deq_desc.SetType("dequantize");
  deq_desc.SetInput("Input",
                    std::vector<std::string>({dequantize_in_node->Name()}));
  deq_desc.SetOutput("Output", std::vector<std::string>({output->Name()}));
  deq_desc.SetAttr("Scale", scale);
232
  deq_desc.SetAttr("is_negative_input", !is_unsigned);
233 234 235 236 237 238 239 240 241 242 243 244 245 246 247
  auto dequantize_op = g->CreateOpNode(&deq_desc);  // OpDesc will be copied.

  // update op's output
  op->Op()->SetOutput(output_name,
                      std::vector<std::string>({dequantize_in_node->Name()}));

  // link dequantize op
  UnlinkNodes(op, output);
  IR_NODE_LINK_TO(op, dequantize_in_node);
  IR_NODE_LINK_TO(dequantize_in_node, dequantize_op);
  IR_NODE_LINK_TO(dequantize_op, output);

  if (!scale_attr_name.empty()) op->Op()->SetAttr(scale_attr_name, scale);
}

248 249 250
bool CPUQuantizePass::AreScalesPresentForVarNames(
    std::vector<std::string> names) const {
  bool present = true;
B
baoachun 已提交
251 252 253 254 255 256 257 258 259 260 261 262 263 264
  if (var_quant_scales_->empty()) {
    auto& scales = Get<VarQuantScale>("quant_var_scales");
    for (auto name : names) {
      if (scales.find(name) == scales.end()) {
        present = false;
        LogScaleIsMissingForVarName(name);
      }
    }
  } else {
    for (auto name : names) {
      if (var_quant_scales_->find(name) == var_quant_scales_->end()) {
        present = false;
        LogScaleIsMissingForVarName(name);
      }
265 266 267 268 269
    }
  }
  return present;
}

270
bool CPUQuantizePass::AreScalesPresentForNodes(
271
    std::initializer_list<Node*> nodes) const {
272
  bool present = true;
B
baoachun 已提交
273 274 275 276 277 278 279 280 281 282 283 284 285 286
  if (var_quant_scales_->empty()) {
    auto& scales = Get<VarQuantScale>("quant_var_scales");
    for (auto node : nodes) {
      if (scales.count(node->Name()) == 0) {
        present = false;
        LogScaleIsMissingForVarNode(node);
      }
    }
  } else {
    for (auto node : nodes) {
      if (var_quant_scales_->count(node->Name()) == 0) {
        present = false;
        LogScaleIsMissingForVarNode(node);
      }
287 288 289 290 291
    }
  }
  return present;
}

292 293
std::pair<bool, LoDTensor> CPUQuantizePass::GetScaleDataByName(
    const std::string& name) const {
B
baoachun 已提交
294 295 296 297 298
  if (var_quant_scales_->empty()) {
    auto& scales = Get<VarQuantScale>("quant_var_scales");
    return scales.at(name);
  }
  return var_quant_scales_->at(name);
299 300
}

301 302
std::pair<bool, LoDTensor> CPUQuantizePass::GetScaleDataForNode(
    const Node* node) const {
303 304 305 306 307
  return GetScaleDataByName(node->Name());
}

LoDTensor CPUQuantizePass::GetScaleTensorByName(const std::string& name) const {
  return GetScaleDataByName(name).second;
308 309 310 311 312 313 314 315 316 317 318 319 320
}

LoDTensor CPUQuantizePass::GetScaleTensorForNode(const Node* node) const {
  return GetScaleDataForNode(node).second;
}

double CPUQuantizePass::GetScaleValueForNode(const Node* node,
                                             bool* is_unsigned) const {
  auto scale_data = GetScaleDataForNode(node);
  if (is_unsigned != nullptr) *is_unsigned = scale_data.first;
  return scale_data.second.data<double>()[0];
}

321 322
bool CPUQuantizePass::IsOpDequantized(const Node* node) const {
  return node->Op()->Type() == "dequantize" ||
323
         platform::HasOpINT8DataType(node->Op());
324 325 326
}

bool CPUQuantizePass::IsOpQuantized(const Node* node) const {
327 328 329 330 331 332
  // return true only if all of outputs are ops and their are either quantize or
  // have int8 data type
  return all_of(node->outputs.begin(), node->outputs.end(), [](Node* output) {
    return (output->IsOp() && (output->Op()->Type() == "quantize" ||
                               platform::HasOpINT8DataType(output->Op())));
  });
333 334
}

B
baoachun 已提交
335
void CPUQuantizePass::GetQuantInfo(Graph* graph) const {
336 337
  GetInfoFromTheFirstOp(
      graph, "has_quant_info", "var_quant_scales", var_quant_scales_);
B
baoachun 已提交
338 339
}

340 341 342 343 344 345 346 347 348 349 350 351 352 353
void CPUQuantizePass::QuantizeConv(Graph* graph,
                                   bool with_residual_data) const {
  GraphPatternDetector gpd;
  auto pattern = gpd.mutable_pattern();
  patterns::ConvResidual conv_pattern{pattern, name_scope_};
  conv_pattern(with_residual_data);

  int quantize_conv_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    VLOG(4) << "Quantize conv2d op";
    GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);

    // skip if should not be quantized
354
    if (!platform::HasOpINT8DataType(conv_op->Op())) {
355 356 357
      LogQuantizationDisabled(conv_op);
      return;
    }
358 359 360 361 362

    GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);

363
    auto has_output_scale = AreScalesPresentForNodes({conv_output});
W
Wojciech Uss 已提交
364
    if (with_residual_data && !has_output_scale) {
365 366 367 368
      MarkAndLogCannotQuantizeOp(
          conv_op,
          "Conv op with ResidualData input cannot be quantized "
          "without output scale.");
W
Wojciech Uss 已提交
369 370 371
      return;
    }

372
    if (with_residual_data) {
373 374
      GET_IR_NODE_FROM_SUBGRAPH(
          conv_residual_data, conv_residual_data, conv_pattern);
375
      if (!AreScalesPresentForNodes(
376
              {conv_input, conv_filter, conv_residual_data})) {
377 378
        MarkAndLogCannotQuantizeOp(conv_op,
                                   "No scale available for the operator");
379
        return;
380
      }
381 382 383 384 385

      bool is_residual_unsigned{false};
      auto residual_scale =
          GetScaleValueForNode(conv_residual_data, &is_residual_unsigned);

386 387 388 389 390 391 392
      QuantizeInput(g,
                    conv_op,
                    conv_residual_data,
                    "ResidualData",
                    residual_scale,
                    is_residual_unsigned,
                    "Scale_in_eltwise");
393
    } else {
394
      if (!AreScalesPresentForNodes({conv_input, conv_filter})) {
395 396
        MarkAndLogCannotQuantizeOp(conv_op,
                                   "No scale available for the operator");
397
        return;
398
      }
399 400
    }

401 402
    bool is_input_unsigned{false};
    auto input_scale = GetScaleValueForNode(conv_input, &is_input_unsigned);
403 404 405 406 407 408 409
    QuantizeInput(g,
                  conv_op,
                  conv_input,
                  "Input",
                  input_scale,
                  is_input_unsigned,
                  "Scale_in");
410

411
    auto filter_scale_tensor = GetScaleTensorForNode(conv_filter);
412
    EigenVectorArrayMap eigen_tensor{filter_scale_tensor.data<double>(),
413
                                     filter_scale_tensor.numel()};
414 415 416 417 418 419 420 421 422 423

    // If the scale value of a weight is already multiplied by S8_MAX, it does
    // not need to be multiplied again
    if (std::find(change_weight_->begin(),
                  change_weight_->end(),
                  conv_filter->Name()) == change_weight_->end()) {
      eigen_tensor *= static_cast<double>(S8_MAX);
      change_weight_->push_back(conv_filter->Name());
    }

424 425 426 427 428 429
    std::vector<float> filter_scale{
        filter_scale_tensor.data<double>(),
        filter_scale_tensor.data<double>() + filter_scale_tensor.numel()};

    conv_op->Op()->SetAttr("Scale_weights", filter_scale);

430
    // if quantization scale is missing for output tensor, return fp32 data
W
Wojciech Uss 已提交
431
    if (has_output_scale) {
432 433 434
      bool is_output_unsigned{false};
      auto output_scale =
          GetScaleValueForNode(conv_output, &is_output_unsigned);
435 436 437 438 439 440 441
      DequantizeOutput(g,
                       conv_op,
                       conv_output,
                       "Output",
                       output_scale,
                       is_output_unsigned,
                       "Scale_out");
442 443 444
    } else {
      conv_op->Op()->SetAttr("force_fp32_output", true);
    }
445

446
    // change threshold in bounded ReLu
447 448
    if (conv_op->Op()->GetAttrIfExists<std::string>("fuse_activation") ==
        "relu6") {
449
      float scale_out =
R
Ruibiao Chen 已提交
450
          PADDLE_GET_CONST(float, conv_op->Op()->GetAttr("Scale_out"));
451
      float threshold =
R
Ruibiao Chen 已提交
452
          PADDLE_GET_CONST(float, conv_op->Op()->GetAttr("fuse_alpha"));
453
      conv_op->Op()->SetAttr("fuse_alpha", scale_out * threshold);
454 455
    }

456 457 458 459 460 461
    ++quantize_conv_count;
  };

  gpd(graph, handler);
  AddStatis(quantize_conv_count);

462
  LogQuantizedOpsCounter(
463 464
      "conv2d",
      quantize_conv_count,
465
      ((with_residual_data) ? "with residual connection" : ""));
466 467
}

M
Michał Gallus 已提交
468 469 470 471
void CPUQuantizePass::QuantizeFc(Graph* graph) const {
  GraphPatternDetector gpd;
  auto pattern = gpd.mutable_pattern();
  patterns::FCMKLDNN fc_pattern{pattern, name_scope_};
472
  fc_pattern(false /* with_residual */);
M
Michał Gallus 已提交
473 474 475 476 477 478 479 480

  int quantize_fc_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    VLOG(4) << "Quantize fc op";
    GET_IR_NODE_FROM_SUBGRAPH(fc, fc, fc_pattern);

    // skip if should not be quantized
481
    if (!platform::HasOpINT8DataType(fc->Op())) {
482 483 484
      LogQuantizationDisabled(fc);
      return;
    }
485
    if (!fc->Op()->GetAttrIfExists<bool>("use_mkldnn")) {
486
      MarkAndLogCannotQuantizeOp(fc, "use_mkldnn attribute set to false");
M
Michał Gallus 已提交
487
      return;
488
    }
M
Michał Gallus 已提交
489 490 491 492 493

    GET_IR_NODE_FROM_SUBGRAPH(weights, weights, fc_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(input, input, fc_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(output, output, fc_pattern);

494
    if (!AreScalesPresentForNodes({input, weights})) {
495
      MarkAndLogCannotQuantizeOp(fc, "No scale available for the operator");
496 497
      return;
    }
498

499 500
    bool is_input_unsigned{false};
    auto input_scale = GetScaleValueForNode(input, &is_input_unsigned);
501 502
    QuantizeInput(
        g, fc, input, "Input", input_scale, is_input_unsigned, "Scale_in");
M
Michał Gallus 已提交
503

504
    auto weight_scale_tensor = GetScaleTensorForNode(weights);
M
Michał Gallus 已提交
505
    EigenVectorArrayMap eigen_tensor{weight_scale_tensor.data<double>(),
506
                                     weight_scale_tensor.numel()};
M
Michał Gallus 已提交
507 508 509 510 511 512 513
    eigen_tensor *= static_cast<double>(S8_MAX);
    std::vector<float> filter_scale{
        weight_scale_tensor.data<double>(),
        weight_scale_tensor.data<double>() + weight_scale_tensor.numel()};

    fc->Op()->SetAttr("Scale_weights", filter_scale);

514
    // if quantization scale is missing for output tensor, return fp32 data
515
    if (AreScalesPresentForNodes({output})) {
516 517
      bool is_output_unsigned{false};
      auto output_scale = GetScaleValueForNode(output, &is_output_unsigned);
518 519
      DequantizeOutput(
          g, fc, output, "Out", output_scale, is_output_unsigned, "Scale_out");
520 521 522
    } else {
      fc->Op()->SetAttr("force_fp32_output", true);
    }
M
Michał Gallus 已提交
523 524 525 526 527 528

    ++quantize_fc_count;
  };

  gpd(graph, handler);
  AddStatis(quantize_fc_count);
529
  LogQuantizedOpsCounter("fc", quantize_fc_count);
M
Michał Gallus 已提交
530 531
}

532 533 534 535 536 537 538 539 540 541 542 543 544
void CPUQuantizePass::QuantizePool(Graph* graph) const {
  GraphPatternDetector gpd;
  auto pattern = gpd.mutable_pattern();
  patterns::Pool pool_pattern{pattern, name_scope_};
  pool_pattern();

  int quantize_pool_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    VLOG(4) << "Quantize pool2d op";
    GET_IR_NODE_FROM_SUBGRAPH(pool_op, pool_op, pool_pattern);

    // skip if should not be quantized
545
    if (!platform::HasOpINT8DataType(pool_op->Op())) {
546 547 548
      LogQuantizationDisabled(pool_op);
      return;
    }
549 550 551 552

    GET_IR_NODE_FROM_SUBGRAPH(pool_input, pool_input, pool_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(pool_output, pool_output, pool_pattern);

553
    if (!AreScalesPresentForNodes({pool_input, pool_output})) {
554 555
      MarkAndLogCannotQuantizeOp(pool_op,
                                 "No scale available for the operator");
556 557
      return;
    }
558

559 560
    bool is_input_unsigned{false};
    auto input_scale = GetScaleValueForNode(pool_input, &is_input_unsigned);
561 562
    QuantizeInput(g, pool_op, pool_input, "X", input_scale, is_input_unsigned);

563 564
    bool is_output_unsigned{false};
    auto output_scale = GetScaleValueForNode(pool_output, &is_output_unsigned);
565 566
    DequantizeOutput(
        g, pool_op, pool_output, "Out", output_scale, is_output_unsigned);
567 568 569 570 571 572

    ++quantize_pool_count;
  };

  gpd(graph, handler);
  AddStatis(quantize_pool_count);
573
  LogQuantizedOpsCounter("pool2d", quantize_pool_count);
574 575
}

576 577 578 579 580 581 582 583 584 585 586 587 588
void CPUQuantizePass::QuantizeConcat(Graph* graph) const {
  GraphPatternDetector gpd;
  auto pattern = gpd.mutable_pattern();
  patterns::Concat concat_pattern{pattern, name_scope_};
  concat_pattern();

  int quantize_concat_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    VLOG(4) << "Quantize concat op";
    GET_IR_NODE_FROM_SUBGRAPH(concat_op, concat_op, concat_pattern);

    // skip if should not be quantized
589
    if (!platform::HasOpINT8DataType(concat_op->Op())) {
590 591 592
      LogQuantizationDisabled(concat_op);
      return;
    }
593

594 595 596 597 598 599 600 601 602 603 604 605 606 607
    bool are_all_inputs_unsigned{true};
    // if all inputs were unsigned, then the output was set to unsigned
    // during the scale calculation step
    auto inputs = concat_op->inputs;
    for (size_t i = 0; i < inputs.size(); i++) {
      if (AreScalesPresentForVarNames({inputs[i]->Name()})) {
        auto scale_data = GetScaleDataByName(inputs[i]->Name());
        if (scale_data.first == false) {
          are_all_inputs_unsigned = false;
          break;
        }
      }
    }

608 609
    GET_IR_NODE_FROM_SUBGRAPH(concat_out, concat_out, concat_pattern);

610
    if (!AreScalesPresentForNodes({concat_out})) {
611 612
      MarkAndLogCannotQuantizeOp(concat_op,
                                 "No scale available for the operator");
613 614
      return;
    }
615

616
    auto output_scale = GetScaleValueForNode(concat_out);
617

618
    QuantizeInputs(g, concat_op, "X", are_all_inputs_unsigned);
619

620 621
    DequantizeOutput(
        g, concat_op, concat_out, "Out", output_scale, are_all_inputs_unsigned);
622 623 624 625 626
    ++quantize_concat_count;
  };

  gpd(graph, handler);
  AddStatis(quantize_concat_count);
627
  LogQuantizedOpsCounter("concat", quantize_concat_count);
628 629
}

630 631 632 633 634 635 636 637 638 639 640 641 642
void CPUQuantizePass::QuantizePriorBox(Graph* graph) const {
  GraphPatternDetector gpd;
  auto pattern = gpd.mutable_pattern();
  patterns::PriorBox prior_box_pattern{pattern, name_scope_};
  prior_box_pattern();

  int quantize_prior_box_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    VLOG(4) << "Quantize prior_box op";
    GET_IR_NODE_FROM_SUBGRAPH(prior_box_op, prior_box_op, prior_box_pattern);

    // skip if should not be quantized
643
    if (!platform::HasOpINT8DataType(prior_box_op->Op())) {
644 645 646
      LogQuantizationDisabled(prior_box_op);
      return;
    }
647

648 649
    GET_IR_NODE_FROM_SUBGRAPH(
        prior_box_input, prior_box_input, prior_box_pattern);
650

651
    if (!AreScalesPresentForNodes({prior_box_input})) {
652 653
      MarkAndLogCannotQuantizeOp(prior_box_op,
                                 "No scale available for the operator");
654 655
      return;
    }
656

657 658 659
    bool is_input_unsigned{false};
    auto input_scale =
        GetScaleValueForNode(prior_box_input, &is_input_unsigned);
660 661 662 663 664
    QuantizeInput(g,
                  prior_box_op,
                  prior_box_input,
                  "Input",
                  input_scale,
665 666 667 668 669 670 671
                  is_input_unsigned);

    ++quantize_prior_box_count;
  };

  gpd(graph, handler);
  AddStatis(quantize_prior_box_count);
672
  LogQuantizedOpsCounter("prior_box", quantize_prior_box_count);
673 674
}

675 676 677
void CPUQuantizePass::QuantizeImmutable(Graph* graph,
                                        const std::string& immutable_type,
                                        const std::string& input_name) const {
678 679
  GraphPatternDetector gpd;
  auto pattern = gpd.mutable_pattern();
680 681
  patterns::Immutable immutable_pattern{pattern, name_scope_};
  immutable_pattern(immutable_type, input_name);
682

683
  int quantize_immutable_count = 0;
684 685
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
686 687
    VLOG(4) << "Quantize " + immutable_type + " op";
    GET_IR_NODE_FROM_SUBGRAPH(immutable_op, immutable_op, immutable_pattern);
688 689

    // skip if should not be quantized
690 691
    if (!platform::HasOpINT8DataType(immutable_op->Op())) {
      LogQuantizationDisabled(immutable_op);
692 693
      return;
    }
694 695 696
    GET_IR_NODE_FROM_SUBGRAPH(prev_op, prev_op, immutable_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(immutable_in, immutable_in, immutable_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(immutable_out, immutable_out, immutable_pattern);
697

698
    // skip if prev op and next op is not quantized
699 700
    if (!IsOpDequantized(prev_op) && !IsOpQuantized(immutable_out)) {
      MarkAndLogCannotQuantizeOp(immutable_op,
701
                                 "No other quantizable operators nearby");
702 703 704
      return;
    }

705 706 707 708 709 710 711
    // skip if the dtype of immutable_in is not float32
    auto dtype = immutable_in->Var()->GetDataType();
    if (dtype != proto::VarType::FP32) {
      MarkAndLogCannotQuantizeOp(immutable_op, "The input dtype is not float.");
      return;
    }

712 713
    if (!AreScalesPresentForNodes({immutable_out})) {
      MarkAndLogCannotQuantizeOp(immutable_op,
714
                                 "No scale available for the operator");
715
      return;
716
    }
717

718
    bool is_input_unsigned{false};
719 720 721 722 723 724 725 726
    auto input_scale = GetScaleValueForNode(immutable_out, &is_input_unsigned);

    QuantizeInput(g,
                  immutable_op,
                  immutable_in,
                  input_name,
                  input_scale,
                  is_input_unsigned);
727

728 729
    bool is_output_unsigned{false};
    auto output_scale =
730
        GetScaleValueForNode(immutable_out, &is_output_unsigned);
731
    DequantizeOutput(g,
732 733
                     immutable_op,
                     immutable_out,
734 735
                     "Out",
                     output_scale,
736 737
                     is_output_unsigned);

738
    ++quantize_immutable_count;
739 740 741
  };

  gpd(graph, handler);
742 743
  AddStatis(quantize_immutable_count);
  LogQuantizedOpsCounter(immutable_type, quantize_immutable_count);
Z
Zuza 已提交
744 745
}

746
void CPUQuantizePass::QuantizeMatmul(Graph* graph, bool with_residual) const {
747 748
  GraphPatternDetector gpd;
  auto pattern = gpd.mutable_pattern();
749
  patterns::MatmulWithInputOps matmul_pattern{pattern, name_scope_};
750
  matmul_pattern(with_residual);
751 752 753 754 755 756 757 758

  int quantize_matmul_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    VLOG(4) << "Quantize matmul op";
    GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, matmul_pattern);

    // skip if should not be quantized
759
    if (!platform::HasOpINT8DataType(matmul_op->Op())) {
760
      LogQuantizationDisabled(matmul_op);
761 762 763 764 765 766
      return;
    }
    GET_IR_NODE_FROM_SUBGRAPH(prev_op_x, prev_op_x, matmul_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(prev_op_y, prev_op_y, matmul_pattern);

    // skip if prev ops are not quantized
767
    if (!IsOpDequantized(prev_op_x) && !IsOpDequantized(prev_op_y)) {
768 769
      MarkAndLogCannotQuantizeOp(matmul_op,
                                 "No other quantizable operators nearby");
770 771 772 773 774 775
      return;
    }
    GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, matmul_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, matmul_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, matmul_pattern);

776 777 778 779 780 781 782 783 784
    auto has_output_scale = AreScalesPresentForNodes({matmul_out});
    if (with_residual && !has_output_scale) {
      MarkAndLogCannotQuantizeOp(
          matmul_op,
          "Matmul op with ResidualData input cannot be quantized "
          "without output scale.");
      return;
    }

785
    if (!AreScalesPresentForNodes({matmul_in_x, matmul_in_y})) {
786 787
      MarkAndLogCannotQuantizeOp(matmul_op,
                                 "No scale available for the operator");
788
      return;
789
    }
790

791 792 793
    bool is_x_unsigned{false}, is_y_unsigned{false};
    auto input_x_scale = GetScaleValueForNode(matmul_in_x, &is_x_unsigned);
    auto input_y_scale = GetScaleValueForNode(matmul_in_y, &is_y_unsigned);
794 795
    PADDLE_ENFORCE_EQ(is_x_unsigned,
                      is_y_unsigned,
796 797 798 799
                      platform::errors::InvalidArgument(
                          "Matmul inputs should have the same "
                          "attribute of signed/unsigned, but they "
                          "are different: x(%d), y(%d).",
800 801
                          is_x_unsigned,
                          is_y_unsigned));
802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823

    if (with_residual) {
      GET_IR_NODE_FROM_SUBGRAPH(
          matmul_residual_data, matmul_residual_data, matmul_pattern);
      if (!AreScalesPresentForNodes({matmul_residual_data})) {
        MarkAndLogCannotQuantizeOp(matmul_op,
                                   "No scale available for the operator");
        return;
      }
      bool is_residual_unsigned{false};
      auto residual_scale =
          GetScaleValueForNode(matmul_residual_data, &is_residual_unsigned);

      QuantizeInput(g,
                    matmul_op,
                    matmul_residual_data,
                    "ResidualData",
                    residual_scale,
                    is_residual_unsigned,
                    "Scale_in_eltwise");
    }

824 825 826 827 828 829
    QuantizeInput(g,
                  matmul_op,
                  matmul_in_x,
                  "X",
                  input_x_scale,
                  is_x_unsigned,
830
                  "Scale_x");
831 832 833 834 835 836
    QuantizeInput(g,
                  matmul_op,
                  matmul_in_y,
                  "Y",
                  input_y_scale,
                  is_y_unsigned,
837 838
                  "Scale_y");

839
    // if quantization scale is missing for output tensor, return fp32 data
840
    if (AreScalesPresentForNodes({matmul_out})) {
841 842
      bool is_output_unsigned{false};
      auto output_scale = GetScaleValueForNode(matmul_out, &is_output_unsigned);
843 844 845 846 847 848 849
      DequantizeOutput(g,
                       matmul_op,
                       matmul_out,
                       "Out",
                       output_scale,
                       is_output_unsigned,
                       "Scale_out");
850 851 852
    } else {
      matmul_op->Op()->SetAttr("force_fp32_output", true);
    }
853 854 855 856 857

    ++quantize_matmul_count;
  };
  gpd(graph, handler);
  AddStatis(quantize_matmul_count);
858 859 860
  LogQuantizedOpsCounter("matmul",
                         quantize_matmul_count,
                         (with_residual ? "with residual connection" : ""));
861 862
}

Z
Zuza 已提交
863
void CPUQuantizePass::QuantizeElementwise(
864
    Graph* graph, const std::string& elementwise_type) const {
865 866
  GraphPatternDetector gpd;
  auto pattern = gpd.mutable_pattern();
867
  patterns::ElementwiseOp elementwise_pattern{pattern, name_scope_};
868

869
  elementwise_pattern(elementwise_type);
870

Z
Zuza 已提交
871
  int quantize_elementwise_count = 0;
872 873
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
Z
Zuza 已提交
874
    VLOG(4) << "Quantize " + elementwise_type + " op";
875 876
    GET_IR_NODE_FROM_SUBGRAPH(
        elementwise_op, elementwise_op, elementwise_pattern);
877 878

    // skip if should not be quantized
Z
Zuza 已提交
879 880
    if (!platform::HasOpINT8DataType(elementwise_op->Op())) {
      LogQuantizationDisabled(elementwise_op);
881 882 883
      return;
    }

884 885
    auto x_name = elementwise_op->Op()->Input("X");
    auto y_name = elementwise_op->Op()->Input("Y");
886
    Node *elementwise_x{nullptr}, *elementwise_y{nullptr};
887 888 889 890 891 892 893 894 895

    for (auto& input : elementwise_op->inputs) {
      if (input->Name() == x_name[0]) elementwise_x = input;
      if (input->Name() == y_name[0]) elementwise_y = input;
    }
    if (!elementwise_x || !elementwise_y) {
      return;
    }

896 897
    GET_IR_NODE_FROM_SUBGRAPH(
        elementwise_out, elementwise_out, elementwise_pattern);
898

899
    if (!AreScalesPresentForNodes(
Z
Zuza 已提交
900
            {elementwise_x, elementwise_y, elementwise_out})) {
901 902
      MarkAndLogCannotQuantizeOp(elementwise_op,
                                 "No scale available for the operator");
903 904 905 906
      return;
    }

    bool is_x_unsigned{false}, is_y_unsigned{false};
Z
Zuza 已提交
907 908
    auto input_x_scale = GetScaleValueForNode(elementwise_x, &is_x_unsigned);
    auto input_y_scale = GetScaleValueForNode(elementwise_y, &is_y_unsigned);
909 910 911

    // TODO(sfraczek): add support for different signness
    if (is_x_unsigned != is_y_unsigned) {
912 913
      MarkAndLogCannotQuantizeOp(
          elementwise_op, "Elementwise inputs must be of the same type.");
914 915 916
      return;
    }

917 918 919 920 921 922 923 924 925 926 927 928 929 930
    QuantizeInput(g,
                  elementwise_op,
                  elementwise_x,
                  "X",
                  input_x_scale,
                  is_x_unsigned,
                  "Scale_x");
    QuantizeInput(g,
                  elementwise_op,
                  elementwise_y,
                  "Y",
                  input_y_scale,
                  is_y_unsigned,
                  "Scale_y");
931

932 933
    bool is_output_unsigned{false};
    auto output_scale =
Z
Zuza 已提交
934
        GetScaleValueForNode(elementwise_out, &is_output_unsigned);
935

936 937 938 939 940 941 942
    DequantizeOutput(g,
                     elementwise_op,
                     elementwise_out,
                     "Out",
                     output_scale,
                     is_output_unsigned,
                     "Scale_out");
943

Z
Zuza 已提交
944
    ++quantize_elementwise_count;
945 946
  };
  gpd(graph, handler);
Z
Zuza 已提交
947
  AddStatis(quantize_elementwise_count);
948
  LogQuantizedOpsCounter(elementwise_type, quantize_elementwise_count);
949 950
}

951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972
void CPUQuantizePass::QuantizeFusionGru(Graph* graph) const {
  GraphPatternDetector gpd;
  patterns::FusionGru pattern{gpd.mutable_pattern(), name_scope_};
  pattern();

  int quantize_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    VLOG(4) << "Quantize fusion_gru op";
    GET_IR_NODE_FROM_SUBGRAPH(op, op, pattern);

    // skip if should not be quantized
    if (!platform::HasOpINT8DataType(op->Op())) {
      LogQuantizationDisabled(op);
      return;
    }

    GET_IR_NODE_FROM_SUBGRAPH(x, x, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(weight_h, weight_h, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(weight_x, weight_x, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(out, out, pattern);

973
    if (!AreScalesPresentForNodes({x, weight_x})) {
974
      MarkAndLogCannotQuantizeOp(op, "No scale available for the operator");
975 976 977 978 979 980 981 982 983
      return;
    }

    bool is_x_unsigned{false};
    auto input_x_scale = GetScaleValueForNode(x, &is_x_unsigned);

    double input_x_shift{128.};
    if (is_x_unsigned) input_x_shift = 0.;

984 985 986 987 988 989 990 991 992
    QuantizeInput(g,
                  op,
                  x,
                  "X",
                  input_x_scale,
                  is_x_unsigned,
                  "Scale_data",
                  input_x_shift,
                  "Shift_data");
993 994 995

    auto weight_scale_tensor = GetScaleTensorForNode(weight_x);
    EigenVectorArrayMap eigen_tensor{weight_scale_tensor.data<double>(),
996
                                     weight_scale_tensor.numel()};
997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009
    eigen_tensor *= static_cast<double>(S8_MAX);
    std::vector<float> scale_weights{
        weight_scale_tensor.data<double>(),
        weight_scale_tensor.data<double>() + weight_scale_tensor.numel()};

    op->Op()->SetAttr("Scale_weights", scale_weights);
    // return fp32 data
    op->Op()->SetAttr("force_fp32_output", true);

    ++quantize_count;
  };
  gpd(graph, handler);
  AddStatis(quantize_count);
1010
  LogQuantizedOpsCounter("fusion_gru", quantize_count);
1011 1012
}

1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036
void CPUQuantizePass::QuantizeMultiGru(Graph* graph) const {
  GraphPatternDetector gpd;
  patterns::MultiGru pattern{gpd.mutable_pattern(), name_scope_};
  pattern();

  int quantize_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    VLOG(4) << "Quantize multi_gru op";
    GET_IR_NODE_FROM_SUBGRAPH(gru, gru, pattern);

    // skip if should not be quantized
    if (!platform::HasOpINT8DataType(gru->Op())) {
      LogQuantizationDisabled(gru);
      return;
    }

    GET_IR_NODE_FROM_SUBGRAPH(x, x, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(wx, wx, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(h, h, pattern);

    auto wx_names = gru->Op()->Input("WeightX");
    if (!AreScalesPresentForNodes({x}) ||
        !AreScalesPresentForVarNames(wx_names)) {
1037
      MarkAndLogCannotQuantizeOp(gru, "No scale available for the operator");
1038 1039 1040 1041 1042 1043 1044 1045 1046
      return;
    }

    bool is_x_unsigned{false};
    auto input_x_scale = GetScaleValueForNode(x, &is_x_unsigned);

    double input_x_shift{128.};
    if (is_x_unsigned) input_x_shift = 0.;

1047 1048 1049 1050 1051 1052 1053 1054 1055
    QuantizeInput(g,
                  gru,
                  x,
                  "X",
                  input_x_scale,
                  is_x_unsigned,
                  "Scale_data",
                  input_x_shift,
                  "Shift_data");
1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066

    auto* scope = param_scope();
    int wx_size = wx_names.size();
    std::vector<std::string> w_scale_var_names;
    for (int i = 0; i < wx_size; ++i) {
      auto scale_tensor_src = GetScaleTensorByName(wx_names[i]);
      EigenVectorArrayMap eigen_tensor_src{scale_tensor_src.data<double>(),
                                           scale_tensor_src.numel()};

      VarDesc scale_var_desc(patterns::PDNodeName("multi_gru", "w_scale"));

1067
      scale_var_desc.SetShape(phi::vectorize(scale_tensor_src.dims()));
1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093
      scale_var_desc.SetDataType(proto::VarType::FP32);
      scale_var_desc.SetLoDLevel(scale_tensor_src.lod().size());
      scale_var_desc.SetPersistable(true);
      auto* w_scale_node = g->CreateVarNode(&scale_var_desc);

      auto* w_scale_tensor_dst =
          scope->Var(w_scale_node->Name())->GetMutable<LoDTensor>();
      w_scale_tensor_dst->Resize(scale_tensor_src.dims());
      auto* dst_data =
          w_scale_tensor_dst->mutable_data<float>(platform::CPUPlace());
      EigenVectorArrayMapFloat eigen_tensor_dst{dst_data,
                                                w_scale_tensor_dst->numel()};
      eigen_tensor_dst =
          eigen_tensor_src.cast<float>() * static_cast<float>(S8_MAX);
      w_scale_var_names.push_back(w_scale_node->Name());
      IR_NODE_LINK_TO(w_scale_node, gru);
    }

    gru->Op()->SetInput("Scale_weights", w_scale_var_names);
    // return fp32 data
    gru->Op()->SetAttr("force_fp32_output", true);

    ++quantize_count;
  };
  gpd(graph, handler);
  AddStatis(quantize_count);
1094
  LogQuantizedOpsCounter("multi_gru", quantize_count);
1095 1096
}

1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121
void CPUQuantizePass::QuantizeFusionLSTM(Graph* graph) const {
  GraphPatternDetector gpd;
  patterns::FusionLSTM pattern{gpd.mutable_pattern(), name_scope_};
  pattern();

  int quantize_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    VLOG(4) << "Quantize fusion_lstm op";
    GET_IR_NODE_FROM_SUBGRAPH(op, op, pattern);

    // skip if should not be quantized
    if (!platform::HasOpINT8DataType(op->Op())) {
      LogQuantizationDisabled(op);
      return;
    }

    GET_IR_NODE_FROM_SUBGRAPH(x, x, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(weight_h, weight_h, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(weight_x, weight_x, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(hidden, hidden, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(cell, cell, pattern);

    // Starting from here there maybe issues
    if (!AreScalesPresentForNodes({x, weight_x})) {
1122
      MarkAndLogCannotQuantizeOp(op, "No scale available for the operator");
1123 1124 1125 1126 1127 1128 1129 1130 1131
      return;
    }

    bool is_x_unsigned{false};
    auto input_x_scale = GetScaleValueForNode(x, &is_x_unsigned);

    double input_x_shift{128.};
    if (is_x_unsigned) input_x_shift = 0.;

1132 1133 1134 1135 1136 1137 1138 1139 1140
    QuantizeInput(g,
                  op,
                  x,
                  "X",
                  input_x_scale,
                  is_x_unsigned,
                  "Scale_data",
                  input_x_shift,
                  "Shift_data");
1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157

    auto weight_scale_tensor = GetScaleTensorForNode(weight_x);
    EigenVectorArrayMap eigen_tensor{weight_scale_tensor.data<double>(),
                                     weight_scale_tensor.numel()};
    eigen_tensor *= static_cast<double>(S8_MAX);
    std::vector<float> scale_weights{
        weight_scale_tensor.data<double>(),
        weight_scale_tensor.data<double>() + weight_scale_tensor.numel()};

    op->Op()->SetAttr("Scale_weights", scale_weights);
    // return fp32 data
    op->Op()->SetAttr("force_fp32_output", true);

    ++quantize_count;
  };
  gpd(graph, handler);
  AddStatis(quantize_count);
1158
  LogQuantizedOpsCounter("fusion_lstm", quantize_count);
1159 1160
}

1161
void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
1162
  VLOG(3) << "Quantizing the graph.";
1163 1164
  PADDLE_ENFORCE_NOT_NULL(
      graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
1165
  FusePassBase::Init(name_scope_, graph);
1166

1167 1168 1169
  PADDLE_ENFORCE_NOT_NULL(
      param_scope(),
      platform::errors::InvalidArgument("Scope cannot be nullptr."));
1170

B
baoachun 已提交
1171
  GetQuantInfo(graph);
1172 1173 1174
  QuantizeConv(graph, false /* with_residual_data */);
  QuantizeConv(graph, true /* with_residual_data */);
  QuantizePool(graph);
1175
  QuantizeConcat(graph);
1176
  QuantizePriorBox(graph);
M
Michał Gallus 已提交
1177
  QuantizeFc(graph);
1178 1179
  QuantizeMatmul(graph, false /* with_residual_data */);
  QuantizeMatmul(graph, true /* with_residual_data */);
1180 1181 1182 1183 1184
  QuantizeImmutable(graph, "reshape2", "X");
  QuantizeImmutable(graph, "transpose2", "X");
  QuantizeImmutable(graph, "slice", "Input");
  QuantizeImmutable(graph, "nearest_interp", "X");
  QuantizeImmutable(graph, "nearest_interp_v2", "X");
Z
Zuza 已提交
1185 1186
  QuantizeElementwise(graph, "elementwise_add");
  QuantizeElementwise(graph, "elementwise_mul");
1187
  QuantizeElementwise(graph, "elementwise_sub");
1188
  QuantizeFusionGru(graph);
1189
  QuantizeMultiGru(graph);
1190
  QuantizeFusionLSTM(graph);
1191 1192 1193 1194 1195 1196 1197 1198
}

}  // namespace ir
}  // namespace framework
}  // namespace paddle

REGISTER_PASS(cpu_quantize_pass, paddle::framework::ir::CPUQuantizePass)
    .RequirePassAttr("quant_var_scales");