cpu_quantize_pass.cc 41.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#include "paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h"

17
#include <sstream>
18 19
#include <utility>
#include <vector>
W
wanghuancoder 已提交
20

B
baoachun 已提交
21
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h"
22
#include "paddle/fluid/platform/mkldnn_helper.h"
23 24 25 26 27 28
#include "paddle/fluid/string/pretty_log.h"

namespace paddle {
namespace framework {
namespace ir {

29
using EigenVectorArrayMap = Eigen::Map<Eigen::Array<double, Eigen::Dynamic, 1>>;
30 31
using EigenVectorArrayMapFloat =
    Eigen::Map<Eigen::Array<float, Eigen::Dynamic, 1>>;
32 33
using string::PrettyLogDetail;

34 35 36 37 38 39 40 41 42
namespace {

void UnlinkNodes(ir::Node* a, ir::Node* b) {
  a->outputs.erase(std::remove(a->outputs.begin(), a->outputs.end(), b),
                   a->outputs.end());
  b->inputs.erase(std::remove(b->inputs.begin(), b->inputs.end(), a),
                  b->inputs.end());
}

43
void MarkAndLogCannotQuantizeOp(Node* op, const char* details = nullptr) {
44 45 46
  std::stringstream msg_ss;
  msg_ss << "Cannot quantize operator " << op->Name()
         << " (type: " << op->Op()->Type() << ", id: " << op->id() << ").";
47
  if (details) msg_ss << " " << details;
48 49
  VLOG(2) << msg_ss.str().c_str();
  op->Op()->SetAttr("mkldnn_data_type", std::string("float32"));
50 51
}

52 53 54 55 56 57
void LogScaleIsMissingForVarName(const std::string& name) {
  VLOG(4) << "Quantization scale for the variable " << name << " is missing.";
}

void LogScaleIsMissingForVarNode(Node* node) {
  LogScaleIsMissingForVarName(node->Name());
58 59
}

60
void LogQuantizationDisabled(Node* op) {
61
  VLOG(2) << "Quantization skipped for operator " << op->Name()
62
          << " (type: " << op->Op()->Type() << ", id: " << op->id()
63
          << "). Attribute mkldnn_data_type != \"int8\".";
64 65
}

66 67
void LogQuantizedOpsCounter(const std::string& type,
                            const int counter,
68 69 70 71 72 73 74
                            const char* details = nullptr) {
  std::stringstream msg_ss;
  msg_ss << "---    quantized " << counter << " " << type << " ops";
  if (details) msg_ss << " " << details;
  PrettyLogDetail(msg_ss.str().c_str());
}

75 76 77 78
}  // namespace

enum { U8_MAX = 255, S8_MAX = 127 };

79 80 81 82 83
void CPUQuantizePass::QuantizeInput(Graph* g,
                                    Node* op,
                                    Node* input,
                                    std::string input_name,
                                    double scale_to_one,
84
                                    bool is_input_unsigned,
85 86
                                    std::string scale_attr_name,
                                    float shift,
87
                                    std::string shift_attr_name) const {
M
Michał Gallus 已提交
88 89 90
  auto inputs = op->Op()->InputNames();
  bool name_found =
      std::find(inputs.begin(), inputs.end(), input_name) != inputs.end();
91 92
  PADDLE_ENFORCE_EQ(name_found,
                    true,
93 94
                    platform::errors::InvalidArgument(
                        "Var(%s) isn't the input of the %s operator.",
95 96
                        input_name,
                        op->Op()->Type()));
97
  unsigned max = is_input_unsigned ? U8_MAX : S8_MAX;
98 99 100 101 102 103 104 105 106 107 108 109 110
  float scale = scale_to_one * max;

  // Create quantize output variable
  VarDesc quantize_out_desc(patterns::PDNodeName("quantize", "out"));
  auto* quantize_out_node = g->CreateVarNode(&quantize_out_desc);

  // create a quantize op node
  OpDesc q_desc;
  q_desc.SetType("quantize");
  q_desc.SetInput("Input", std::vector<std::string>({input->Name()}));
  q_desc.SetOutput("Output",
                   std::vector<std::string>({quantize_out_node->Name()}));
  q_desc.SetAttr("Scale", scale);
111 112
  q_desc.SetAttr("Shift", shift);
  q_desc.SetAttr("is_negative_input", !is_input_unsigned);
113

Z
Zuza 已提交
114 115 116
  // fix to fc format error
  if (op->Op()->Type() == "fc" &&
      op->Op()->GetAttrIfExists<int>("in_num_col_dims") == 2) {
117 118 119
    q_desc.SetAttr(
        "output_format",
        Has("data_layout") ? Get<std::string>("data_layout") : "NCHW");
Z
Zuza 已提交
120
  } else {
121 122 123
    q_desc.SetAttr(
        "output_format",
        Has("data_layout") ? Get<std::string>("data_layout") : "NHWC");
Z
Zuza 已提交
124
  }
125 126 127 128 129 130 131 132 133 134 135 136 137
  auto quantize_op = g->CreateOpNode(&q_desc);  // OpDesc will be copied.

  // update op's input
  op->Op()->SetInput(input_name,
                     std::vector<std::string>({quantize_out_node->Name()}));

  // link quantize op
  UnlinkNodes(input, op);
  IR_NODE_LINK_TO(input, quantize_op);
  IR_NODE_LINK_TO(quantize_op, quantize_out_node);
  IR_NODE_LINK_TO(quantize_out_node, op);

  if (!scale_attr_name.empty()) op->Op()->SetAttr(scale_attr_name, scale);
138
  if (!shift_attr_name.empty()) op->Op()->SetAttr(shift_attr_name, shift);
139 140
}

141 142 143
void CPUQuantizePass::QuantizeInputs(Graph* g,
                                     Node* op,
                                     std::string input_name,
144
                                     bool are_inputs_unsigned,
145 146
                                     std::string scale_attr_name,
                                     float shift,
147
                                     std::string shift_attr_name) const {
148
  auto inputs = op->inputs;
149
  auto output = op->outputs[0];
150 151
  PADDLE_ENFORCE_GE(inputs.size(),
                    1,
152 153
                    platform::errors::InvalidArgument(
                        "OP(%s)'s inputs(%d) must be equal or greater than 1.",
154 155 156 157
                        op->Name(),
                        inputs.size()));
  PADDLE_ENFORCE_EQ(op->outputs.size(),
                    1,
158
                    platform::errors::InvalidArgument(
159 160
                        "OP(%s)'s outputs(%d) must be equal to 1.",
                        op->Name(),
161
                        op->outputs.size()));
162 163 164 165 166 167 168 169

  // create a quantize op desc prototype
  OpDesc q_desc;
  q_desc.SetType("quantize");

  std::vector<Node*> quantize_out_nodes(inputs.size());
  std::vector<std::string> quantize_out_node_names(inputs.size());

170
  double scale_out = GetScaleValueForNode(output);
171
  unsigned max = are_inputs_unsigned ? U8_MAX : S8_MAX;
172
  float scale = scale_out * max;
173 174 175 176 177 178 179 180

  for (size_t i = 0; i < inputs.size(); i++) {
    // Create quantize output variable
    VarDesc quantize_out_desc(patterns::PDNodeName("quantize", "out"));
    quantize_out_nodes[i] = g->CreateVarNode(&quantize_out_desc);
    quantize_out_node_names[i] = quantize_out_nodes[i]->Name();

    q_desc.SetAttr("Scale", scale);
181
    q_desc.SetAttr("Shift", shift);
182 183 184
    q_desc.SetInput("Input", std::vector<std::string>({inputs[i]->Name()}));
    q_desc.SetOutput("Output",
                     std::vector<std::string>({quantize_out_node_names[i]}));
185
    q_desc.SetAttr("is_negative_input", !are_inputs_unsigned);
186 187 188 189 190 191 192 193 194 195 196 197 198
    auto quantize_op = g->CreateOpNode(&q_desc);  // OpDesc will be copied.

    // link quantize op
    UnlinkNodes(inputs[i], op);
    IR_NODE_LINK_TO(inputs[i], quantize_op);
    IR_NODE_LINK_TO(quantize_op, quantize_out_nodes[i]);
    IR_NODE_LINK_TO(quantize_out_nodes[i], op);
  }

  // update op's input
  op->Op()->SetInput(input_name, quantize_out_node_names);

  if (!scale_attr_name.empty()) op->Op()->SetAttr(scale_attr_name, scale);
199
  if (!shift_attr_name.empty()) op->Op()->SetAttr(shift_attr_name, shift);
200 201
}

202 203 204
void CPUQuantizePass::DequantizeOutput(Graph* g,
                                       Node* op,
                                       Node* output,
205
                                       std::string output_name,
206 207
                                       double scale_to_one,
                                       bool is_unsigned,
208
                                       std::string scale_attr_name) const {
M
Michał Gallus 已提交
209 210 211
  auto outputs = op->Op()->OutputNames();
  bool name_found =
      std::find(outputs.begin(), outputs.end(), output_name) != outputs.end();
212 213
  PADDLE_ENFORCE_EQ(name_found,
                    true,
M
Michał Gallus 已提交
214
                    platform::errors::InvalidArgument(
215
                        "Var(%s) isn't the output of the %s operator.",
216 217
                        output_name,
                        op->Op()->Type()));
218 219 220 221 222 223 224 225 226 227 228 229 230 231
  unsigned max = is_unsigned ? U8_MAX : S8_MAX;
  float scale = scale_to_one * max;

  // Create dequantize input variable
  VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in"));
  auto* dequantize_in_node = g->CreateVarNode(&dequantize_in_desc);

  // create a dequantize op node for output.
  OpDesc deq_desc;
  deq_desc.SetType("dequantize");
  deq_desc.SetInput("Input",
                    std::vector<std::string>({dequantize_in_node->Name()}));
  deq_desc.SetOutput("Output", std::vector<std::string>({output->Name()}));
  deq_desc.SetAttr("Scale", scale);
232
  deq_desc.SetAttr("is_negative_input", !is_unsigned);
233 234 235 236 237 238 239 240 241 242 243 244 245 246 247
  auto dequantize_op = g->CreateOpNode(&deq_desc);  // OpDesc will be copied.

  // update op's output
  op->Op()->SetOutput(output_name,
                      std::vector<std::string>({dequantize_in_node->Name()}));

  // link dequantize op
  UnlinkNodes(op, output);
  IR_NODE_LINK_TO(op, dequantize_in_node);
  IR_NODE_LINK_TO(dequantize_in_node, dequantize_op);
  IR_NODE_LINK_TO(dequantize_op, output);

  if (!scale_attr_name.empty()) op->Op()->SetAttr(scale_attr_name, scale);
}

248 249 250
bool CPUQuantizePass::AreScalesPresentForVarNames(
    std::vector<std::string> names) const {
  bool present = true;
B
baoachun 已提交
251 252 253 254 255 256 257 258 259 260 261 262 263 264
  if (var_quant_scales_->empty()) {
    auto& scales = Get<VarQuantScale>("quant_var_scales");
    for (auto name : names) {
      if (scales.find(name) == scales.end()) {
        present = false;
        LogScaleIsMissingForVarName(name);
      }
    }
  } else {
    for (auto name : names) {
      if (var_quant_scales_->find(name) == var_quant_scales_->end()) {
        present = false;
        LogScaleIsMissingForVarName(name);
      }
265 266 267 268 269
    }
  }
  return present;
}

270
bool CPUQuantizePass::AreScalesPresentForNodes(
271
    std::initializer_list<Node*> nodes) const {
272
  bool present = true;
B
baoachun 已提交
273 274 275 276 277 278 279 280 281 282 283 284 285 286
  if (var_quant_scales_->empty()) {
    auto& scales = Get<VarQuantScale>("quant_var_scales");
    for (auto node : nodes) {
      if (scales.count(node->Name()) == 0) {
        present = false;
        LogScaleIsMissingForVarNode(node);
      }
    }
  } else {
    for (auto node : nodes) {
      if (var_quant_scales_->count(node->Name()) == 0) {
        present = false;
        LogScaleIsMissingForVarNode(node);
      }
287 288 289 290 291
    }
  }
  return present;
}

292
std::pair<bool, phi::DenseTensor> CPUQuantizePass::GetScaleDataByName(
293
    const std::string& name) const {
B
baoachun 已提交
294 295 296 297 298
  if (var_quant_scales_->empty()) {
    auto& scales = Get<VarQuantScale>("quant_var_scales");
    return scales.at(name);
  }
  return var_quant_scales_->at(name);
299 300
}

301
std::pair<bool, phi::DenseTensor> CPUQuantizePass::GetScaleDataForNode(
302
    const Node* node) const {
303 304 305
  return GetScaleDataByName(node->Name());
}

306 307
phi::DenseTensor CPUQuantizePass::GetScaleTensorByName(
    const std::string& name) const {
308
  return GetScaleDataByName(name).second;
309 310
}

311 312
phi::DenseTensor CPUQuantizePass::GetScaleTensorForNode(
    const Node* node) const {
313 314 315 316 317 318 319 320 321 322
  return GetScaleDataForNode(node).second;
}

double CPUQuantizePass::GetScaleValueForNode(const Node* node,
                                             bool* is_unsigned) const {
  auto scale_data = GetScaleDataForNode(node);
  if (is_unsigned != nullptr) *is_unsigned = scale_data.first;
  return scale_data.second.data<double>()[0];
}

323 324
bool CPUQuantizePass::IsOpDequantized(const Node* node) const {
  return node->Op()->Type() == "dequantize" ||
325
         platform::HasOpINT8DataType(node->Op());
326 327 328
}

bool CPUQuantizePass::IsOpQuantized(const Node* node) const {
329 330 331 332 333 334
  // return true only if all of outputs are ops and their are either quantize or
  // have int8 data type
  return all_of(node->outputs.begin(), node->outputs.end(), [](Node* output) {
    return (output->IsOp() && (output->Op()->Type() == "quantize" ||
                               platform::HasOpINT8DataType(output->Op())));
  });
335 336
}

B
baoachun 已提交
337
void CPUQuantizePass::GetQuantInfo(Graph* graph) const {
338 339
  GetInfoFromTheFirstOp(
      graph, "has_quant_info", "var_quant_scales", var_quant_scales_);
B
baoachun 已提交
340 341
}

342 343 344 345 346 347 348 349 350 351 352 353 354 355
void CPUQuantizePass::QuantizeConv(Graph* graph,
                                   bool with_residual_data) const {
  GraphPatternDetector gpd;
  auto pattern = gpd.mutable_pattern();
  patterns::ConvResidual conv_pattern{pattern, name_scope_};
  conv_pattern(with_residual_data);

  int quantize_conv_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    VLOG(4) << "Quantize conv2d op";
    GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);

    // skip if should not be quantized
356
    if (!platform::HasOpINT8DataType(conv_op->Op())) {
357 358 359
      LogQuantizationDisabled(conv_op);
      return;
    }
360 361 362 363 364

    GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);

365
    auto has_output_scale = AreScalesPresentForNodes({conv_output});
W
Wojciech Uss 已提交
366
    if (with_residual_data && !has_output_scale) {
367 368 369 370
      MarkAndLogCannotQuantizeOp(
          conv_op,
          "Conv op with ResidualData input cannot be quantized "
          "without output scale.");
W
Wojciech Uss 已提交
371 372 373
      return;
    }

374
    if (with_residual_data) {
375 376
      GET_IR_NODE_FROM_SUBGRAPH(
          conv_residual_data, conv_residual_data, conv_pattern);
377
      if (!AreScalesPresentForNodes(
378
              {conv_input, conv_filter, conv_residual_data})) {
379 380
        MarkAndLogCannotQuantizeOp(conv_op,
                                   "No scale available for the operator");
381
        return;
382
      }
383 384 385 386 387

      bool is_residual_unsigned{false};
      auto residual_scale =
          GetScaleValueForNode(conv_residual_data, &is_residual_unsigned);

388 389 390 391 392 393 394
      QuantizeInput(g,
                    conv_op,
                    conv_residual_data,
                    "ResidualData",
                    residual_scale,
                    is_residual_unsigned,
                    "Scale_in_eltwise");
395
    } else {
396
      if (!AreScalesPresentForNodes({conv_input, conv_filter})) {
397 398
        MarkAndLogCannotQuantizeOp(conv_op,
                                   "No scale available for the operator");
399
        return;
400
      }
401 402
    }

403 404
    bool is_input_unsigned{false};
    auto input_scale = GetScaleValueForNode(conv_input, &is_input_unsigned);
405 406 407 408 409 410 411
    QuantizeInput(g,
                  conv_op,
                  conv_input,
                  "Input",
                  input_scale,
                  is_input_unsigned,
                  "Scale_in");
412

413
    auto filter_scale_tensor = GetScaleTensorForNode(conv_filter);
414
    EigenVectorArrayMap eigen_tensor{filter_scale_tensor.data<double>(),
415
                                     filter_scale_tensor.numel()};
416 417 418 419 420 421 422 423 424 425

    // If the scale value of a weight is already multiplied by S8_MAX, it does
    // not need to be multiplied again
    if (std::find(change_weight_->begin(),
                  change_weight_->end(),
                  conv_filter->Name()) == change_weight_->end()) {
      eigen_tensor *= static_cast<double>(S8_MAX);
      change_weight_->push_back(conv_filter->Name());
    }

426 427 428 429 430 431
    std::vector<float> filter_scale{
        filter_scale_tensor.data<double>(),
        filter_scale_tensor.data<double>() + filter_scale_tensor.numel()};

    conv_op->Op()->SetAttr("Scale_weights", filter_scale);

432
    // if quantization scale is missing for output tensor, return fp32 data
W
Wojciech Uss 已提交
433
    if (has_output_scale) {
434 435 436
      bool is_output_unsigned{false};
      auto output_scale =
          GetScaleValueForNode(conv_output, &is_output_unsigned);
437 438 439 440 441 442 443
      DequantizeOutput(g,
                       conv_op,
                       conv_output,
                       "Output",
                       output_scale,
                       is_output_unsigned,
                       "Scale_out");
444 445 446
    } else {
      conv_op->Op()->SetAttr("force_fp32_output", true);
    }
447

448
    // change threshold in bounded ReLu
449 450
    if (conv_op->Op()->GetAttrIfExists<std::string>("fuse_activation") ==
        "relu6") {
451
      float scale_out =
R
Ruibiao Chen 已提交
452
          PADDLE_GET_CONST(float, conv_op->Op()->GetAttr("Scale_out"));
453
      float threshold =
R
Ruibiao Chen 已提交
454
          PADDLE_GET_CONST(float, conv_op->Op()->GetAttr("fuse_alpha"));
455
      conv_op->Op()->SetAttr("fuse_alpha", scale_out * threshold);
456 457
    }

458 459 460 461 462 463
    ++quantize_conv_count;
  };

  gpd(graph, handler);
  AddStatis(quantize_conv_count);

464
  LogQuantizedOpsCounter(
465 466
      "conv2d",
      quantize_conv_count,
467
      ((with_residual_data) ? "with residual connection" : ""));
468 469
}

M
Michał Gallus 已提交
470 471 472 473
void CPUQuantizePass::QuantizeFc(Graph* graph) const {
  GraphPatternDetector gpd;
  auto pattern = gpd.mutable_pattern();
  patterns::FCMKLDNN fc_pattern{pattern, name_scope_};
474
  fc_pattern(false /* with_residual */);
M
Michał Gallus 已提交
475 476 477 478 479 480 481 482

  int quantize_fc_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    VLOG(4) << "Quantize fc op";
    GET_IR_NODE_FROM_SUBGRAPH(fc, fc, fc_pattern);

    // skip if should not be quantized
483
    if (!platform::HasOpINT8DataType(fc->Op())) {
484 485 486
      LogQuantizationDisabled(fc);
      return;
    }
487
    if (!fc->Op()->GetAttrIfExists<bool>("use_mkldnn")) {
488
      MarkAndLogCannotQuantizeOp(fc, "use_mkldnn attribute set to false");
M
Michał Gallus 已提交
489
      return;
490
    }
M
Michał Gallus 已提交
491 492 493 494 495

    GET_IR_NODE_FROM_SUBGRAPH(weights, weights, fc_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(input, input, fc_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(output, output, fc_pattern);

496
    if (!AreScalesPresentForNodes({input, weights})) {
497
      MarkAndLogCannotQuantizeOp(fc, "No scale available for the operator");
498 499
      return;
    }
500

501 502
    bool is_input_unsigned{false};
    auto input_scale = GetScaleValueForNode(input, &is_input_unsigned);
503 504
    QuantizeInput(
        g, fc, input, "Input", input_scale, is_input_unsigned, "Scale_in");
M
Michał Gallus 已提交
505

506
    auto weight_scale_tensor = GetScaleTensorForNode(weights);
M
Michał Gallus 已提交
507
    EigenVectorArrayMap eigen_tensor{weight_scale_tensor.data<double>(),
508
                                     weight_scale_tensor.numel()};
M
Michał Gallus 已提交
509 510 511 512 513 514 515
    eigen_tensor *= static_cast<double>(S8_MAX);
    std::vector<float> filter_scale{
        weight_scale_tensor.data<double>(),
        weight_scale_tensor.data<double>() + weight_scale_tensor.numel()};

    fc->Op()->SetAttr("Scale_weights", filter_scale);

516
    // if quantization scale is missing for output tensor, return fp32 data
517
    if (AreScalesPresentForNodes({output})) {
518 519
      bool is_output_unsigned{false};
      auto output_scale = GetScaleValueForNode(output, &is_output_unsigned);
520 521
      DequantizeOutput(
          g, fc, output, "Out", output_scale, is_output_unsigned, "Scale_out");
522 523 524
    } else {
      fc->Op()->SetAttr("force_fp32_output", true);
    }
M
Michał Gallus 已提交
525 526 527 528 529 530

    ++quantize_fc_count;
  };

  gpd(graph, handler);
  AddStatis(quantize_fc_count);
531
  LogQuantizedOpsCounter("fc", quantize_fc_count);
M
Michał Gallus 已提交
532 533
}

534 535 536 537 538 539 540 541 542 543 544 545 546
void CPUQuantizePass::QuantizePool(Graph* graph) const {
  GraphPatternDetector gpd;
  auto pattern = gpd.mutable_pattern();
  patterns::Pool pool_pattern{pattern, name_scope_};
  pool_pattern();

  int quantize_pool_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    VLOG(4) << "Quantize pool2d op";
    GET_IR_NODE_FROM_SUBGRAPH(pool_op, pool_op, pool_pattern);

    // skip if should not be quantized
547
    if (!platform::HasOpINT8DataType(pool_op->Op())) {
548 549 550
      LogQuantizationDisabled(pool_op);
      return;
    }
551 552 553 554

    GET_IR_NODE_FROM_SUBGRAPH(pool_input, pool_input, pool_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(pool_output, pool_output, pool_pattern);

555
    if (!AreScalesPresentForNodes({pool_input, pool_output})) {
556 557
      MarkAndLogCannotQuantizeOp(pool_op,
                                 "No scale available for the operator");
558 559
      return;
    }
560

561 562
    bool is_input_unsigned{false};
    auto input_scale = GetScaleValueForNode(pool_input, &is_input_unsigned);
563 564
    QuantizeInput(g, pool_op, pool_input, "X", input_scale, is_input_unsigned);

565 566
    bool is_output_unsigned{false};
    auto output_scale = GetScaleValueForNode(pool_output, &is_output_unsigned);
567 568
    DequantizeOutput(
        g, pool_op, pool_output, "Out", output_scale, is_output_unsigned);
569 570 571 572 573 574

    ++quantize_pool_count;
  };

  gpd(graph, handler);
  AddStatis(quantize_pool_count);
575
  LogQuantizedOpsCounter("pool2d", quantize_pool_count);
576 577
}

578 579 580 581 582 583 584 585 586 587 588 589 590
void CPUQuantizePass::QuantizeConcat(Graph* graph) const {
  GraphPatternDetector gpd;
  auto pattern = gpd.mutable_pattern();
  patterns::Concat concat_pattern{pattern, name_scope_};
  concat_pattern();

  int quantize_concat_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    VLOG(4) << "Quantize concat op";
    GET_IR_NODE_FROM_SUBGRAPH(concat_op, concat_op, concat_pattern);

    // skip if should not be quantized
591
    if (!platform::HasOpINT8DataType(concat_op->Op())) {
592 593 594
      LogQuantizationDisabled(concat_op);
      return;
    }
595

596 597 598 599 600 601 602 603 604 605 606 607 608 609
    bool are_all_inputs_unsigned{true};
    // if all inputs were unsigned, then the output was set to unsigned
    // during the scale calculation step
    auto inputs = concat_op->inputs;
    for (size_t i = 0; i < inputs.size(); i++) {
      if (AreScalesPresentForVarNames({inputs[i]->Name()})) {
        auto scale_data = GetScaleDataByName(inputs[i]->Name());
        if (scale_data.first == false) {
          are_all_inputs_unsigned = false;
          break;
        }
      }
    }

610 611
    GET_IR_NODE_FROM_SUBGRAPH(concat_out, concat_out, concat_pattern);

612
    if (!AreScalesPresentForNodes({concat_out})) {
613 614
      MarkAndLogCannotQuantizeOp(concat_op,
                                 "No scale available for the operator");
615 616
      return;
    }
617

618
    auto output_scale = GetScaleValueForNode(concat_out);
619

620
    QuantizeInputs(g, concat_op, "X", are_all_inputs_unsigned);
621

622 623
    DequantizeOutput(
        g, concat_op, concat_out, "Out", output_scale, are_all_inputs_unsigned);
624 625 626 627 628
    ++quantize_concat_count;
  };

  gpd(graph, handler);
  AddStatis(quantize_concat_count);
629
  LogQuantizedOpsCounter("concat", quantize_concat_count);
630 631
}

632 633 634 635 636 637 638 639 640 641 642 643 644
void CPUQuantizePass::QuantizePriorBox(Graph* graph) const {
  GraphPatternDetector gpd;
  auto pattern = gpd.mutable_pattern();
  patterns::PriorBox prior_box_pattern{pattern, name_scope_};
  prior_box_pattern();

  int quantize_prior_box_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    VLOG(4) << "Quantize prior_box op";
    GET_IR_NODE_FROM_SUBGRAPH(prior_box_op, prior_box_op, prior_box_pattern);

    // skip if should not be quantized
645
    if (!platform::HasOpINT8DataType(prior_box_op->Op())) {
646 647 648
      LogQuantizationDisabled(prior_box_op);
      return;
    }
649

650 651
    GET_IR_NODE_FROM_SUBGRAPH(
        prior_box_input, prior_box_input, prior_box_pattern);
652

653
    if (!AreScalesPresentForNodes({prior_box_input})) {
654 655
      MarkAndLogCannotQuantizeOp(prior_box_op,
                                 "No scale available for the operator");
656 657
      return;
    }
658

659 660 661
    bool is_input_unsigned{false};
    auto input_scale =
        GetScaleValueForNode(prior_box_input, &is_input_unsigned);
662 663 664 665 666
    QuantizeInput(g,
                  prior_box_op,
                  prior_box_input,
                  "Input",
                  input_scale,
667 668 669 670 671 672 673
                  is_input_unsigned);

    ++quantize_prior_box_count;
  };

  gpd(graph, handler);
  AddStatis(quantize_prior_box_count);
674
  LogQuantizedOpsCounter("prior_box", quantize_prior_box_count);
675 676
}

677 678 679
void CPUQuantizePass::QuantizeImmutable(Graph* graph,
                                        const std::string& immutable_type,
                                        const std::string& input_name) const {
680 681
  GraphPatternDetector gpd;
  auto pattern = gpd.mutable_pattern();
682 683
  patterns::Immutable immutable_pattern{pattern, name_scope_};
  immutable_pattern(immutable_type, input_name);
684

685
  int quantize_immutable_count = 0;
686 687
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
688 689
    VLOG(4) << "Quantize " + immutable_type + " op";
    GET_IR_NODE_FROM_SUBGRAPH(immutable_op, immutable_op, immutable_pattern);
690 691

    // skip if should not be quantized
692 693
    if (!platform::HasOpINT8DataType(immutable_op->Op())) {
      LogQuantizationDisabled(immutable_op);
694 695
      return;
    }
696 697 698
    GET_IR_NODE_FROM_SUBGRAPH(prev_op, prev_op, immutable_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(immutable_in, immutable_in, immutable_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(immutable_out, immutable_out, immutable_pattern);
699

700
    // skip if prev op and next op is not quantized
701 702
    if (!IsOpDequantized(prev_op) && !IsOpQuantized(immutable_out)) {
      MarkAndLogCannotQuantizeOp(immutable_op,
703
                                 "No other quantizable operators nearby");
704 705 706
      return;
    }

707 708 709 710 711 712 713
    // skip if the dtype of immutable_in is not float32
    auto dtype = immutable_in->Var()->GetDataType();
    if (dtype != proto::VarType::FP32) {
      MarkAndLogCannotQuantizeOp(immutable_op, "The input dtype is not float.");
      return;
    }

714 715
    if (!AreScalesPresentForNodes({immutable_out})) {
      MarkAndLogCannotQuantizeOp(immutable_op,
716
                                 "No scale available for the operator");
717
      return;
718
    }
719

720
    bool is_input_unsigned{false};
721 722 723 724 725 726 727 728
    auto input_scale = GetScaleValueForNode(immutable_out, &is_input_unsigned);

    QuantizeInput(g,
                  immutable_op,
                  immutable_in,
                  input_name,
                  input_scale,
                  is_input_unsigned);
729

730 731
    bool is_output_unsigned{false};
    auto output_scale =
732
        GetScaleValueForNode(immutable_out, &is_output_unsigned);
733
    DequantizeOutput(g,
734 735
                     immutable_op,
                     immutable_out,
736 737
                     "Out",
                     output_scale,
738 739
                     is_output_unsigned);

740
    ++quantize_immutable_count;
741 742 743
  };

  gpd(graph, handler);
744 745
  AddStatis(quantize_immutable_count);
  LogQuantizedOpsCounter(immutable_type, quantize_immutable_count);
Z
Zuza 已提交
746 747
}

748
void CPUQuantizePass::QuantizeMatmul(Graph* graph, bool with_residual) const {
749 750
  GraphPatternDetector gpd;
  auto pattern = gpd.mutable_pattern();
751
  patterns::MatmulWithInputOps matmul_pattern{pattern, name_scope_};
752
  matmul_pattern(with_residual);
753 754 755 756 757 758 759 760

  int quantize_matmul_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    VLOG(4) << "Quantize matmul op";
    GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, matmul_pattern);

    // skip if should not be quantized
761
    if (!platform::HasOpINT8DataType(matmul_op->Op())) {
762
      LogQuantizationDisabled(matmul_op);
763 764 765 766 767 768
      return;
    }
    GET_IR_NODE_FROM_SUBGRAPH(prev_op_x, prev_op_x, matmul_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(prev_op_y, prev_op_y, matmul_pattern);

    // skip if prev ops are not quantized
769
    if (!IsOpDequantized(prev_op_x) && !IsOpDequantized(prev_op_y)) {
770 771
      MarkAndLogCannotQuantizeOp(matmul_op,
                                 "No other quantizable operators nearby");
772 773 774 775 776 777
      return;
    }
    GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, matmul_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, matmul_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, matmul_pattern);

778 779 780 781 782 783 784 785 786
    auto has_output_scale = AreScalesPresentForNodes({matmul_out});
    if (with_residual && !has_output_scale) {
      MarkAndLogCannotQuantizeOp(
          matmul_op,
          "Matmul op with ResidualData input cannot be quantized "
          "without output scale.");
      return;
    }

787
    if (!AreScalesPresentForNodes({matmul_in_x, matmul_in_y})) {
788 789
      MarkAndLogCannotQuantizeOp(matmul_op,
                                 "No scale available for the operator");
790
      return;
791
    }
792

793 794 795
    bool is_x_unsigned{false}, is_y_unsigned{false};
    auto input_x_scale = GetScaleValueForNode(matmul_in_x, &is_x_unsigned);
    auto input_y_scale = GetScaleValueForNode(matmul_in_y, &is_y_unsigned);
796 797
    PADDLE_ENFORCE_EQ(is_x_unsigned,
                      is_y_unsigned,
798 799 800 801
                      platform::errors::InvalidArgument(
                          "Matmul inputs should have the same "
                          "attribute of signed/unsigned, but they "
                          "are different: x(%d), y(%d).",
802 803
                          is_x_unsigned,
                          is_y_unsigned));
804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825

    if (with_residual) {
      GET_IR_NODE_FROM_SUBGRAPH(
          matmul_residual_data, matmul_residual_data, matmul_pattern);
      if (!AreScalesPresentForNodes({matmul_residual_data})) {
        MarkAndLogCannotQuantizeOp(matmul_op,
                                   "No scale available for the operator");
        return;
      }
      bool is_residual_unsigned{false};
      auto residual_scale =
          GetScaleValueForNode(matmul_residual_data, &is_residual_unsigned);

      QuantizeInput(g,
                    matmul_op,
                    matmul_residual_data,
                    "ResidualData",
                    residual_scale,
                    is_residual_unsigned,
                    "Scale_in_eltwise");
    }

826 827 828 829 830 831
    QuantizeInput(g,
                  matmul_op,
                  matmul_in_x,
                  "X",
                  input_x_scale,
                  is_x_unsigned,
832
                  "Scale_x");
833 834 835 836 837 838
    QuantizeInput(g,
                  matmul_op,
                  matmul_in_y,
                  "Y",
                  input_y_scale,
                  is_y_unsigned,
839 840
                  "Scale_y");

841
    // if quantization scale is missing for output tensor, return fp32 data
842
    if (AreScalesPresentForNodes({matmul_out})) {
843 844
      bool is_output_unsigned{false};
      auto output_scale = GetScaleValueForNode(matmul_out, &is_output_unsigned);
845 846 847 848 849 850 851
      DequantizeOutput(g,
                       matmul_op,
                       matmul_out,
                       "Out",
                       output_scale,
                       is_output_unsigned,
                       "Scale_out");
852 853 854
    } else {
      matmul_op->Op()->SetAttr("force_fp32_output", true);
    }
855 856 857 858 859

    ++quantize_matmul_count;
  };
  gpd(graph, handler);
  AddStatis(quantize_matmul_count);
860 861 862
  LogQuantizedOpsCounter("matmul",
                         quantize_matmul_count,
                         (with_residual ? "with residual connection" : ""));
863 864
}

Z
Zuza 已提交
865
void CPUQuantizePass::QuantizeElementwise(
866
    Graph* graph, const std::string& elementwise_type) const {
867 868
  GraphPatternDetector gpd;
  auto pattern = gpd.mutable_pattern();
869
  patterns::ElementwiseOp elementwise_pattern{pattern, name_scope_};
870

871
  elementwise_pattern(elementwise_type);
872

Z
Zuza 已提交
873
  int quantize_elementwise_count = 0;
874 875
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
Z
Zuza 已提交
876
    VLOG(4) << "Quantize " + elementwise_type + " op";
877 878
    GET_IR_NODE_FROM_SUBGRAPH(
        elementwise_op, elementwise_op, elementwise_pattern);
879 880

    // skip if should not be quantized
Z
Zuza 已提交
881 882
    if (!platform::HasOpINT8DataType(elementwise_op->Op())) {
      LogQuantizationDisabled(elementwise_op);
883 884 885
      return;
    }

886 887
    auto x_name = elementwise_op->Op()->Input("X");
    auto y_name = elementwise_op->Op()->Input("Y");
888
    Node *elementwise_x{nullptr}, *elementwise_y{nullptr};
889 890 891 892 893 894 895 896 897

    for (auto& input : elementwise_op->inputs) {
      if (input->Name() == x_name[0]) elementwise_x = input;
      if (input->Name() == y_name[0]) elementwise_y = input;
    }
    if (!elementwise_x || !elementwise_y) {
      return;
    }

898 899
    GET_IR_NODE_FROM_SUBGRAPH(
        elementwise_out, elementwise_out, elementwise_pattern);
900

901
    if (!AreScalesPresentForNodes(
Z
Zuza 已提交
902
            {elementwise_x, elementwise_y, elementwise_out})) {
903 904
      MarkAndLogCannotQuantizeOp(elementwise_op,
                                 "No scale available for the operator");
905 906 907 908
      return;
    }

    bool is_x_unsigned{false}, is_y_unsigned{false};
Z
Zuza 已提交
909 910
    auto input_x_scale = GetScaleValueForNode(elementwise_x, &is_x_unsigned);
    auto input_y_scale = GetScaleValueForNode(elementwise_y, &is_y_unsigned);
911 912 913

    // TODO(sfraczek): add support for different signness
    if (is_x_unsigned != is_y_unsigned) {
914 915
      MarkAndLogCannotQuantizeOp(
          elementwise_op, "Elementwise inputs must be of the same type.");
916 917 918
      return;
    }

919 920 921 922 923 924 925 926 927 928 929 930 931 932
    QuantizeInput(g,
                  elementwise_op,
                  elementwise_x,
                  "X",
                  input_x_scale,
                  is_x_unsigned,
                  "Scale_x");
    QuantizeInput(g,
                  elementwise_op,
                  elementwise_y,
                  "Y",
                  input_y_scale,
                  is_y_unsigned,
                  "Scale_y");
933

934 935
    bool is_output_unsigned{false};
    auto output_scale =
Z
Zuza 已提交
936
        GetScaleValueForNode(elementwise_out, &is_output_unsigned);
937

938 939 940 941 942 943 944
    DequantizeOutput(g,
                     elementwise_op,
                     elementwise_out,
                     "Out",
                     output_scale,
                     is_output_unsigned,
                     "Scale_out");
945

Z
Zuza 已提交
946
    ++quantize_elementwise_count;
947 948
  };
  gpd(graph, handler);
Z
Zuza 已提交
949
  AddStatis(quantize_elementwise_count);
950
  LogQuantizedOpsCounter(elementwise_type, quantize_elementwise_count);
951 952
}

953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974
void CPUQuantizePass::QuantizeFusionGru(Graph* graph) const {
  GraphPatternDetector gpd;
  patterns::FusionGru pattern{gpd.mutable_pattern(), name_scope_};
  pattern();

  int quantize_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    VLOG(4) << "Quantize fusion_gru op";
    GET_IR_NODE_FROM_SUBGRAPH(op, op, pattern);

    // skip if should not be quantized
    if (!platform::HasOpINT8DataType(op->Op())) {
      LogQuantizationDisabled(op);
      return;
    }

    GET_IR_NODE_FROM_SUBGRAPH(x, x, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(weight_h, weight_h, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(weight_x, weight_x, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(out, out, pattern);

975
    if (!AreScalesPresentForNodes({x, weight_x})) {
976
      MarkAndLogCannotQuantizeOp(op, "No scale available for the operator");
977 978 979 980 981 982 983 984 985
      return;
    }

    bool is_x_unsigned{false};
    auto input_x_scale = GetScaleValueForNode(x, &is_x_unsigned);

    double input_x_shift{128.};
    if (is_x_unsigned) input_x_shift = 0.;

986 987 988 989 990 991 992 993 994
    QuantizeInput(g,
                  op,
                  x,
                  "X",
                  input_x_scale,
                  is_x_unsigned,
                  "Scale_data",
                  input_x_shift,
                  "Shift_data");
995 996 997

    auto weight_scale_tensor = GetScaleTensorForNode(weight_x);
    EigenVectorArrayMap eigen_tensor{weight_scale_tensor.data<double>(),
998
                                     weight_scale_tensor.numel()};
999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011
    eigen_tensor *= static_cast<double>(S8_MAX);
    std::vector<float> scale_weights{
        weight_scale_tensor.data<double>(),
        weight_scale_tensor.data<double>() + weight_scale_tensor.numel()};

    op->Op()->SetAttr("Scale_weights", scale_weights);
    // return fp32 data
    op->Op()->SetAttr("force_fp32_output", true);

    ++quantize_count;
  };
  gpd(graph, handler);
  AddStatis(quantize_count);
1012
  LogQuantizedOpsCounter("fusion_gru", quantize_count);
1013 1014
}

1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038
void CPUQuantizePass::QuantizeMultiGru(Graph* graph) const {
  GraphPatternDetector gpd;
  patterns::MultiGru pattern{gpd.mutable_pattern(), name_scope_};
  pattern();

  int quantize_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    VLOG(4) << "Quantize multi_gru op";
    GET_IR_NODE_FROM_SUBGRAPH(gru, gru, pattern);

    // skip if should not be quantized
    if (!platform::HasOpINT8DataType(gru->Op())) {
      LogQuantizationDisabled(gru);
      return;
    }

    GET_IR_NODE_FROM_SUBGRAPH(x, x, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(wx, wx, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(h, h, pattern);

    auto wx_names = gru->Op()->Input("WeightX");
    if (!AreScalesPresentForNodes({x}) ||
        !AreScalesPresentForVarNames(wx_names)) {
1039
      MarkAndLogCannotQuantizeOp(gru, "No scale available for the operator");
1040 1041 1042 1043 1044 1045 1046 1047 1048
      return;
    }

    bool is_x_unsigned{false};
    auto input_x_scale = GetScaleValueForNode(x, &is_x_unsigned);

    double input_x_shift{128.};
    if (is_x_unsigned) input_x_shift = 0.;

1049 1050 1051 1052 1053 1054 1055 1056 1057
    QuantizeInput(g,
                  gru,
                  x,
                  "X",
                  input_x_scale,
                  is_x_unsigned,
                  "Scale_data",
                  input_x_shift,
                  "Shift_data");
1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068

    auto* scope = param_scope();
    int wx_size = wx_names.size();
    std::vector<std::string> w_scale_var_names;
    for (int i = 0; i < wx_size; ++i) {
      auto scale_tensor_src = GetScaleTensorByName(wx_names[i]);
      EigenVectorArrayMap eigen_tensor_src{scale_tensor_src.data<double>(),
                                           scale_tensor_src.numel()};

      VarDesc scale_var_desc(patterns::PDNodeName("multi_gru", "w_scale"));

1069
      scale_var_desc.SetShape(phi::vectorize(scale_tensor_src.dims()));
1070 1071 1072 1073 1074 1075
      scale_var_desc.SetDataType(proto::VarType::FP32);
      scale_var_desc.SetLoDLevel(scale_tensor_src.lod().size());
      scale_var_desc.SetPersistable(true);
      auto* w_scale_node = g->CreateVarNode(&scale_var_desc);

      auto* w_scale_tensor_dst =
1076
          scope->Var(w_scale_node->Name())->GetMutable<phi::DenseTensor>();
1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095
      w_scale_tensor_dst->Resize(scale_tensor_src.dims());
      auto* dst_data =
          w_scale_tensor_dst->mutable_data<float>(platform::CPUPlace());
      EigenVectorArrayMapFloat eigen_tensor_dst{dst_data,
                                                w_scale_tensor_dst->numel()};
      eigen_tensor_dst =
          eigen_tensor_src.cast<float>() * static_cast<float>(S8_MAX);
      w_scale_var_names.push_back(w_scale_node->Name());
      IR_NODE_LINK_TO(w_scale_node, gru);
    }

    gru->Op()->SetInput("Scale_weights", w_scale_var_names);
    // return fp32 data
    gru->Op()->SetAttr("force_fp32_output", true);

    ++quantize_count;
  };
  gpd(graph, handler);
  AddStatis(quantize_count);
1096
  LogQuantizedOpsCounter("multi_gru", quantize_count);
1097 1098
}

1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123
void CPUQuantizePass::QuantizeFusionLSTM(Graph* graph) const {
  GraphPatternDetector gpd;
  patterns::FusionLSTM pattern{gpd.mutable_pattern(), name_scope_};
  pattern();

  int quantize_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    VLOG(4) << "Quantize fusion_lstm op";
    GET_IR_NODE_FROM_SUBGRAPH(op, op, pattern);

    // skip if should not be quantized
    if (!platform::HasOpINT8DataType(op->Op())) {
      LogQuantizationDisabled(op);
      return;
    }

    GET_IR_NODE_FROM_SUBGRAPH(x, x, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(weight_h, weight_h, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(weight_x, weight_x, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(hidden, hidden, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(cell, cell, pattern);

    // Starting from here there maybe issues
    if (!AreScalesPresentForNodes({x, weight_x})) {
1124
      MarkAndLogCannotQuantizeOp(op, "No scale available for the operator");
1125 1126 1127 1128 1129 1130 1131 1132 1133
      return;
    }

    bool is_x_unsigned{false};
    auto input_x_scale = GetScaleValueForNode(x, &is_x_unsigned);

    double input_x_shift{128.};
    if (is_x_unsigned) input_x_shift = 0.;

1134 1135 1136 1137 1138 1139 1140 1141 1142
    QuantizeInput(g,
                  op,
                  x,
                  "X",
                  input_x_scale,
                  is_x_unsigned,
                  "Scale_data",
                  input_x_shift,
                  "Shift_data");
1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159

    auto weight_scale_tensor = GetScaleTensorForNode(weight_x);
    EigenVectorArrayMap eigen_tensor{weight_scale_tensor.data<double>(),
                                     weight_scale_tensor.numel()};
    eigen_tensor *= static_cast<double>(S8_MAX);
    std::vector<float> scale_weights{
        weight_scale_tensor.data<double>(),
        weight_scale_tensor.data<double>() + weight_scale_tensor.numel()};

    op->Op()->SetAttr("Scale_weights", scale_weights);
    // return fp32 data
    op->Op()->SetAttr("force_fp32_output", true);

    ++quantize_count;
  };
  gpd(graph, handler);
  AddStatis(quantize_count);
1160
  LogQuantizedOpsCounter("fusion_lstm", quantize_count);
1161 1162
}

1163
void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
1164
  VLOG(3) << "Quantizing the graph.";
1165 1166
  PADDLE_ENFORCE_NOT_NULL(
      graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
1167
  FusePassBase::Init(name_scope_, graph);
1168

1169 1170 1171
  PADDLE_ENFORCE_NOT_NULL(
      param_scope(),
      platform::errors::InvalidArgument("Scope cannot be nullptr."));
1172

B
baoachun 已提交
1173
  GetQuantInfo(graph);
1174 1175 1176
  QuantizeConv(graph, false /* with_residual_data */);
  QuantizeConv(graph, true /* with_residual_data */);
  QuantizePool(graph);
1177
  QuantizeConcat(graph);
1178
  QuantizePriorBox(graph);
M
Michał Gallus 已提交
1179
  QuantizeFc(graph);
1180 1181
  QuantizeMatmul(graph, false /* with_residual_data */);
  QuantizeMatmul(graph, true /* with_residual_data */);
1182 1183 1184 1185 1186
  QuantizeImmutable(graph, "reshape2", "X");
  QuantizeImmutable(graph, "transpose2", "X");
  QuantizeImmutable(graph, "slice", "Input");
  QuantizeImmutable(graph, "nearest_interp", "X");
  QuantizeImmutable(graph, "nearest_interp_v2", "X");
Z
Zuza 已提交
1187 1188
  QuantizeElementwise(graph, "elementwise_add");
  QuantizeElementwise(graph, "elementwise_mul");
1189
  QuantizeElementwise(graph, "elementwise_sub");
1190
  QuantizeFusionGru(graph);
1191
  QuantizeMultiGru(graph);
1192
  QuantizeFusionLSTM(graph);
1193 1194 1195 1196 1197 1198 1199 1200
}

}  // namespace ir
}  // namespace framework
}  // namespace paddle

REGISTER_PASS(cpu_quantize_pass, paddle::framework::ir::CPUQuantizePass)
    .RequirePassAttr("quant_var_scales");