cpu_quantize_pass.cc 46.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#include "paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h"

17
#include <sstream>
18 19
#include <utility>
#include <vector>
W
wanghuancoder 已提交
20

B
baoachun 已提交
21
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h"
22
#include "paddle/fluid/platform/mkldnn_helper.h"
23
#include "paddle/utils/string/pretty_log.h"
24 25 26 27 28

namespace paddle {
namespace framework {
namespace ir {

29
using EigenVectorArrayMap = Eigen::Map<Eigen::Array<double, Eigen::Dynamic, 1>>;
30 31
using EigenVectorArrayMapFloat =
    Eigen::Map<Eigen::Array<float, Eigen::Dynamic, 1>>;
32 33
using string::PrettyLogDetail;

34 35 36 37 38 39 40 41 42
namespace {

void UnlinkNodes(ir::Node* a, ir::Node* b) {
  a->outputs.erase(std::remove(a->outputs.begin(), a->outputs.end(), b),
                   a->outputs.end());
  b->inputs.erase(std::remove(b->inputs.begin(), b->inputs.end(), a),
                  b->inputs.end());
}

43
void MarkAndLogCannotQuantizeOp(Node* op, const char* details = nullptr) {
44 45 46
  std::stringstream msg_ss;
  msg_ss << "Cannot quantize operator " << op->Name()
         << " (type: " << op->Op()->Type() << ", id: " << op->id() << ").";
47
  if (details) msg_ss << " " << details;
48 49
  VLOG(2) << msg_ss.str().c_str();
  op->Op()->SetAttr("mkldnn_data_type", std::string("float32"));
50 51
}

52 53 54 55 56 57
void LogScaleIsMissingForVarName(const std::string& name) {
  VLOG(4) << "Quantization scale for the variable " << name << " is missing.";
}

void LogScaleIsMissingForVarNode(Node* node) {
  LogScaleIsMissingForVarName(node->Name());
58 59
}

60
void LogQuantizationDisabled(Node* op) {
61
  VLOG(2) << "Quantization skipped for operator " << op->Name()
62
          << " (type: " << op->Op()->Type() << ", id: " << op->id()
63
          << "). Attribute mkldnn_data_type != \"int8\".";
64 65
}

66 67
void LogQuantizedOpsCounter(const std::string& type,
                            const int counter,
68 69 70 71 72 73 74
                            const char* details = nullptr) {
  std::stringstream msg_ss;
  msg_ss << "---    quantized " << counter << " " << type << " ops";
  if (details) msg_ss << " " << details;
  PrettyLogDetail(msg_ss.str().c_str());
}

75 76 77 78
}  // namespace

enum { U8_MAX = 255, S8_MAX = 127 };

79 80 81 82 83
void CPUQuantizePass::QuantizeInput(Graph* g,
                                    Node* op,
                                    Node* input,
                                    std::string input_name,
                                    double scale_to_one,
84
                                    bool is_input_unsigned,
85 86
                                    std::string scale_attr_name,
                                    float shift,
87
                                    std::string shift_attr_name) const {
M
Michał Gallus 已提交
88 89 90
  auto inputs = op->Op()->InputNames();
  bool name_found =
      std::find(inputs.begin(), inputs.end(), input_name) != inputs.end();
91 92
  PADDLE_ENFORCE_EQ(name_found,
                    true,
93 94
                    platform::errors::InvalidArgument(
                        "Var(%s) isn't the input of the %s operator.",
95 96
                        input_name,
                        op->Op()->Type()));
97
  unsigned max = is_input_unsigned ? U8_MAX : S8_MAX;
98 99 100 101 102 103 104 105 106 107 108 109 110
  float scale = scale_to_one * max;

  // Create quantize output variable
  VarDesc quantize_out_desc(patterns::PDNodeName("quantize", "out"));
  auto* quantize_out_node = g->CreateVarNode(&quantize_out_desc);

  // create a quantize op node
  OpDesc q_desc;
  q_desc.SetType("quantize");
  q_desc.SetInput("Input", std::vector<std::string>({input->Name()}));
  q_desc.SetOutput("Output",
                   std::vector<std::string>({quantize_out_node->Name()}));
  q_desc.SetAttr("Scale", scale);
111 112
  q_desc.SetAttr("Shift", shift);
  q_desc.SetAttr("is_negative_input", !is_input_unsigned);
113

Z
Zuza 已提交
114 115 116
  // fix to fc format error
  if (op->Op()->Type() == "fc" &&
      op->Op()->GetAttrIfExists<int>("in_num_col_dims") == 2) {
117 118 119
    q_desc.SetAttr(
        "output_format",
        Has("data_layout") ? Get<std::string>("data_layout") : "NCHW");
Z
Zuza 已提交
120
  } else {
121 122 123
    q_desc.SetAttr(
        "output_format",
        Has("data_layout") ? Get<std::string>("data_layout") : "NHWC");
Z
Zuza 已提交
124
  }
125 126 127 128 129 130 131 132 133 134 135 136 137
  auto quantize_op = g->CreateOpNode(&q_desc);  // OpDesc will be copied.

  // update op's input
  op->Op()->SetInput(input_name,
                     std::vector<std::string>({quantize_out_node->Name()}));

  // link quantize op
  UnlinkNodes(input, op);
  IR_NODE_LINK_TO(input, quantize_op);
  IR_NODE_LINK_TO(quantize_op, quantize_out_node);
  IR_NODE_LINK_TO(quantize_out_node, op);

  if (!scale_attr_name.empty()) op->Op()->SetAttr(scale_attr_name, scale);
138
  if (!shift_attr_name.empty()) op->Op()->SetAttr(shift_attr_name, shift);
139 140
}

141 142 143
void CPUQuantizePass::QuantizeInputs(Graph* g,
                                     Node* op,
                                     std::string input_name,
144
                                     bool are_inputs_unsigned,
145 146
                                     std::string scale_attr_name,
                                     float shift,
147
                                     std::string shift_attr_name) const {
148
  auto inputs = op->inputs;
149 150 151 152 153 154 155 156
  auto var_names = op->Op()->Inputs().at(input_name);
  std::vector<std::string> unique_var_names;
  for (unsigned i = 0; i < var_names.size(); i++)
    if (std::find(unique_var_names.begin(),
                  unique_var_names.end(),
                  var_names[i]) == unique_var_names.end())
      unique_var_names.push_back(var_names[i]);

157
  auto output = op->outputs[0];
158 159
  PADDLE_ENFORCE_GE(inputs.size(),
                    1,
160 161
                    platform::errors::InvalidArgument(
                        "OP(%s)'s inputs(%d) must be equal or greater than 1.",
162 163 164 165
                        op->Name(),
                        inputs.size()));
  PADDLE_ENFORCE_EQ(op->outputs.size(),
                    1,
166
                    platform::errors::InvalidArgument(
167 168
                        "OP(%s)'s outputs(%d) must be equal to 1.",
                        op->Name(),
169
                        op->outputs.size()));
170 171 172 173 174 175 176

  // create a quantize op desc prototype
  OpDesc q_desc;
  q_desc.SetType("quantize");
  std::vector<Node*> quantize_out_nodes(inputs.size());
  std::vector<std::string> quantize_out_node_names(inputs.size());

177
  double scale_out = GetScaleValueForNode(output);
178
  unsigned max = are_inputs_unsigned ? U8_MAX : S8_MAX;
179
  float scale = scale_out * max;
180

181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
  for (size_t var_id = 0; var_id < unique_var_names.size(); var_id++) {
    auto index = -1;
    for (size_t it = 0; it < inputs.size(); it++) {
      if (inputs[it]->Name() == unique_var_names[var_id]) index = it;
    }

    if (index == -1) {
      PADDLE_ENFORCE_NE(index,
                        -1,
                        platform::errors::InvalidArgument(
                            "Var(%s) isn't the input of the %s operator.",
                            unique_var_names[var_id],
                            op->Op()->Type()));
    }

    auto* input = inputs.at(index);

198
    VarDesc quantize_out_desc(patterns::PDNodeName("quantize", "out"));
199 200
    quantize_out_nodes[var_id] = g->CreateVarNode(&quantize_out_desc);
    quantize_out_node_names[var_id] = quantize_out_nodes[var_id]->Name();
201 202

    q_desc.SetAttr("Scale", scale);
203
    q_desc.SetAttr("Shift", shift);
204 205 206
    q_desc.SetInput("Input", std::vector<std::string>({input->Name()}));
    q_desc.SetOutput(
        "Output", std::vector<std::string>({quantize_out_node_names[var_id]}));
207
    q_desc.SetAttr("is_negative_input", !are_inputs_unsigned);
208 209 210
    auto quantize_op = g->CreateOpNode(&q_desc);  // OpDesc will be copied.

    // link quantize op
211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226
    UnlinkNodes(input, op);
    IR_NODE_LINK_TO(input, quantize_op);
    IR_NODE_LINK_TO(quantize_op, quantize_out_nodes[var_id]);
    IR_NODE_LINK_TO(quantize_out_nodes[var_id], op);
  }

  // If any inputs were duplicated, now you have to enter them in the correct
  // order.
  for (size_t i = unique_var_names.size(); i < var_names.size(); i++) {
    auto index = std::find(
        unique_var_names.begin(), unique_var_names.end(), var_names[i]);
    if (index != unique_var_names.end()) {
      auto id = std::distance(unique_var_names.begin(), index);
      quantize_out_node_names[i] = quantize_out_nodes[id]->Name();
      IR_NODE_LINK_TO(quantize_out_nodes[id], op);
    }
227 228 229 230 231 232
  }

  // update op's input
  op->Op()->SetInput(input_name, quantize_out_node_names);

  if (!scale_attr_name.empty()) op->Op()->SetAttr(scale_attr_name, scale);
233
  if (!shift_attr_name.empty()) op->Op()->SetAttr(shift_attr_name, shift);
234 235
}

236 237 238
void CPUQuantizePass::DequantizeOutput(Graph* g,
                                       Node* op,
                                       Node* output,
239
                                       std::string output_name,
240 241
                                       double scale_to_one,
                                       bool is_unsigned,
242
                                       std::string scale_attr_name) const {
M
Michał Gallus 已提交
243 244 245
  auto outputs = op->Op()->OutputNames();
  bool name_found =
      std::find(outputs.begin(), outputs.end(), output_name) != outputs.end();
246 247
  PADDLE_ENFORCE_EQ(name_found,
                    true,
M
Michał Gallus 已提交
248
                    platform::errors::InvalidArgument(
249
                        "Var(%s) isn't the output of the %s operator.",
250 251
                        output_name,
                        op->Op()->Type()));
252 253 254 255 256 257 258 259 260 261 262 263 264 265
  unsigned max = is_unsigned ? U8_MAX : S8_MAX;
  float scale = scale_to_one * max;

  // Create dequantize input variable
  VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in"));
  auto* dequantize_in_node = g->CreateVarNode(&dequantize_in_desc);

  // create a dequantize op node for output.
  OpDesc deq_desc;
  deq_desc.SetType("dequantize");
  deq_desc.SetInput("Input",
                    std::vector<std::string>({dequantize_in_node->Name()}));
  deq_desc.SetOutput("Output", std::vector<std::string>({output->Name()}));
  deq_desc.SetAttr("Scale", scale);
266
  deq_desc.SetAttr("is_negative_input", !is_unsigned);
267 268 269 270 271 272 273 274 275 276 277 278 279 280 281
  auto dequantize_op = g->CreateOpNode(&deq_desc);  // OpDesc will be copied.

  // update op's output
  op->Op()->SetOutput(output_name,
                      std::vector<std::string>({dequantize_in_node->Name()}));

  // link dequantize op
  UnlinkNodes(op, output);
  IR_NODE_LINK_TO(op, dequantize_in_node);
  IR_NODE_LINK_TO(dequantize_in_node, dequantize_op);
  IR_NODE_LINK_TO(dequantize_op, output);

  if (!scale_attr_name.empty()) op->Op()->SetAttr(scale_attr_name, scale);
}

P
Paulina Gacek 已提交
282 283 284 285 286 287 288
void CPUQuantizePass::DequantizeOutputs(Graph* g,
                                        Node* op,
                                        std::string output_name,
                                        double scale_to_one,
                                        bool is_unsigned,
                                        std::string scale_attr_name) const {
  auto outputs = op->outputs;
289 290
  auto var_names = op->Op()->Outputs().at(output_name);

P
Paulina Gacek 已提交
291 292 293 294 295 296 297
  PADDLE_ENFORCE_GE(outputs.size(),
                    1,
                    platform::errors::InvalidArgument(
                        "OP(%s)'s outputs(%d) must be equal or greater than 1.",
                        op->Name(),
                        outputs.size()));

298 299
  std::vector<std::string> dequantize_in_node_names(outputs.size());
  std::vector<Node*> dequantize_in_nodes(outputs.size());
P
Paulina Gacek 已提交
300 301 302 303

  unsigned max = is_unsigned ? U8_MAX : S8_MAX;
  float scale = scale_to_one * max;

304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320
  for (size_t var_id = 0; var_id < var_names.size(); var_id++) {
    auto index = -1;
    for (size_t it = 0; it < outputs.size(); it++) {
      if (outputs[it]->Name() == var_names[var_id]) index = it;
    }

    if (index == -1) {
      PADDLE_ENFORCE_NE(index,
                        -1,
                        platform::errors::InvalidArgument(
                            "Var(%s) isn't the input of the %s operator.",
                            var_names[var_id],
                            op->Op()->Type()));
    }

    auto* output = outputs.at(index);

P
Paulina Gacek 已提交
321 322
    // Create dequantize input variable
    VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in"));
323 324
    dequantize_in_nodes[var_id] = g->CreateVarNode(&dequantize_in_desc);
    dequantize_in_node_names[var_id] = dequantize_in_nodes[var_id]->Name();
P
Paulina Gacek 已提交
325 326 327 328

    // create a dequantize op node for output.
    OpDesc deq_desc;
    deq_desc.SetType("dequantize");
329 330 331
    deq_desc.SetInput(
        "Input", std::vector<std::string>({dequantize_in_node_names[var_id]}));
    deq_desc.SetOutput("Output", std::vector<std::string>({output->Name()}));
P
Paulina Gacek 已提交
332 333 334 335 336
    deq_desc.SetAttr("Scale", scale);
    deq_desc.SetAttr("is_negative_input", !is_unsigned);
    auto dequantize_op = g->CreateOpNode(&deq_desc);  // OpDesc will be copied.

    // link dequantize op
337 338 339 340
    UnlinkNodes(op, output);
    IR_NODE_LINK_TO(op, dequantize_in_nodes[var_id]);
    IR_NODE_LINK_TO(dequantize_in_nodes[var_id], dequantize_op);
    IR_NODE_LINK_TO(dequantize_op, output);
P
Paulina Gacek 已提交
341 342 343
  }

  // update op's output
344
  op->Op()->SetOutput(output_name, dequantize_in_node_names);
P
Paulina Gacek 已提交
345 346 347
  if (!scale_attr_name.empty()) op->Op()->SetAttr(scale_attr_name, scale);
}

348 349 350
bool CPUQuantizePass::AreScalesPresentForVarNames(
    std::vector<std::string> names) const {
  bool present = true;
B
baoachun 已提交
351 352 353 354 355 356 357 358 359 360 361 362 363 364
  if (var_quant_scales_->empty()) {
    auto& scales = Get<VarQuantScale>("quant_var_scales");
    for (auto name : names) {
      if (scales.find(name) == scales.end()) {
        present = false;
        LogScaleIsMissingForVarName(name);
      }
    }
  } else {
    for (auto name : names) {
      if (var_quant_scales_->find(name) == var_quant_scales_->end()) {
        present = false;
        LogScaleIsMissingForVarName(name);
      }
365 366 367 368 369
    }
  }
  return present;
}

370
bool CPUQuantizePass::AreScalesPresentForNodes(
371
    std::initializer_list<Node*> nodes) const {
372
  bool present = true;
B
baoachun 已提交
373 374 375 376 377 378 379 380 381 382 383 384 385 386
  if (var_quant_scales_->empty()) {
    auto& scales = Get<VarQuantScale>("quant_var_scales");
    for (auto node : nodes) {
      if (scales.count(node->Name()) == 0) {
        present = false;
        LogScaleIsMissingForVarNode(node);
      }
    }
  } else {
    for (auto node : nodes) {
      if (var_quant_scales_->count(node->Name()) == 0) {
        present = false;
        LogScaleIsMissingForVarNode(node);
      }
387 388 389 390 391
    }
  }
  return present;
}

392
std::pair<bool, phi::DenseTensor> CPUQuantizePass::GetScaleDataByName(
393
    const std::string& name) const {
B
baoachun 已提交
394 395 396 397 398
  if (var_quant_scales_->empty()) {
    auto& scales = Get<VarQuantScale>("quant_var_scales");
    return scales.at(name);
  }
  return var_quant_scales_->at(name);
399 400
}

401
std::pair<bool, phi::DenseTensor> CPUQuantizePass::GetScaleDataForNode(
402
    const Node* node) const {
403 404 405
  return GetScaleDataByName(node->Name());
}

406 407
phi::DenseTensor CPUQuantizePass::GetScaleTensorByName(
    const std::string& name) const {
408
  return GetScaleDataByName(name).second;
409 410
}

411 412
phi::DenseTensor CPUQuantizePass::GetScaleTensorForNode(
    const Node* node) const {
413 414 415 416 417 418 419 420 421 422
  return GetScaleDataForNode(node).second;
}

double CPUQuantizePass::GetScaleValueForNode(const Node* node,
                                             bool* is_unsigned) const {
  auto scale_data = GetScaleDataForNode(node);
  if (is_unsigned != nullptr) *is_unsigned = scale_data.first;
  return scale_data.second.data<double>()[0];
}

423 424
bool CPUQuantizePass::IsOpDequantized(const Node* node) const {
  return node->Op()->Type() == "dequantize" ||
425
         platform::HasOpINT8DataType(node->Op());
426 427 428
}

bool CPUQuantizePass::IsOpQuantized(const Node* node) const {
429 430 431 432 433 434
  // return true only if all of outputs are ops and their are either quantize or
  // have int8 data type
  return all_of(node->outputs.begin(), node->outputs.end(), [](Node* output) {
    return (output->IsOp() && (output->Op()->Type() == "quantize" ||
                               platform::HasOpINT8DataType(output->Op())));
  });
435 436
}

B
baoachun 已提交
437
void CPUQuantizePass::GetQuantInfo(Graph* graph) const {
438
  GetInfoFromTheTmpOp(
439
      graph, "has_quant_info", "var_quant_scales", var_quant_scales_);
B
baoachun 已提交
440 441
}

442
void CPUQuantizePass::QuantizeConv(Graph* graph,
443
                                   const std::string& conv_type,
444 445 446 447
                                   bool with_residual_data) const {
  GraphPatternDetector gpd;
  auto pattern = gpd.mutable_pattern();
  patterns::ConvResidual conv_pattern{pattern, name_scope_};
448
  conv_pattern(conv_type, with_residual_data);
449 450 451 452 453 454

  int quantize_conv_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    VLOG(4) << "Quantize conv2d op";
    GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
Z
zyfncg 已提交
455
    if (conv_op->Op()->Type() == "conv2d") {
S
Sławomir Siwek 已提交
456
      ConvertToFusedOp(conv_op->Op());
Z
zyfncg 已提交
457
    }
458 459

    // skip if should not be quantized
460
    if (!platform::HasOpINT8DataType(conv_op->Op())) {
461 462 463
      LogQuantizationDisabled(conv_op);
      return;
    }
464 465 466 467 468

    GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);

469
    auto has_output_scale = AreScalesPresentForNodes({conv_output});
W
Wojciech Uss 已提交
470
    if (with_residual_data && !has_output_scale) {
471 472 473 474
      MarkAndLogCannotQuantizeOp(
          conv_op,
          "Conv op with ResidualData input cannot be quantized "
          "without output scale.");
W
Wojciech Uss 已提交
475 476 477
      return;
    }

478
    if (with_residual_data) {
479 480
      GET_IR_NODE_FROM_SUBGRAPH(
          conv_residual_data, conv_residual_data, conv_pattern);
481
      if (!AreScalesPresentForNodes(
482
              {conv_input, conv_filter, conv_residual_data})) {
483 484
        MarkAndLogCannotQuantizeOp(conv_op,
                                   "No scale available for the operator");
485
        return;
486
      }
487 488 489 490 491

      bool is_residual_unsigned{false};
      auto residual_scale =
          GetScaleValueForNode(conv_residual_data, &is_residual_unsigned);

492 493 494 495 496 497 498
      QuantizeInput(g,
                    conv_op,
                    conv_residual_data,
                    "ResidualData",
                    residual_scale,
                    is_residual_unsigned,
                    "Scale_in_eltwise");
499
    } else {
500
      if (!AreScalesPresentForNodes({conv_input, conv_filter})) {
501 502
        MarkAndLogCannotQuantizeOp(conv_op,
                                   "No scale available for the operator");
503
        return;
504
      }
505 506
    }

507 508
    bool is_input_unsigned{false};
    auto input_scale = GetScaleValueForNode(conv_input, &is_input_unsigned);
509 510 511 512 513 514 515
    QuantizeInput(g,
                  conv_op,
                  conv_input,
                  "Input",
                  input_scale,
                  is_input_unsigned,
                  "Scale_in");
516

517
    auto filter_scale_tensor = GetScaleTensorForNode(conv_filter);
518
    EigenVectorArrayMap eigen_tensor{filter_scale_tensor.data<double>(),
519
                                     filter_scale_tensor.numel()};
520 521 522 523 524 525 526 527 528 529

    // If the scale value of a weight is already multiplied by S8_MAX, it does
    // not need to be multiplied again
    if (std::find(change_weight_->begin(),
                  change_weight_->end(),
                  conv_filter->Name()) == change_weight_->end()) {
      eigen_tensor *= static_cast<double>(S8_MAX);
      change_weight_->push_back(conv_filter->Name());
    }

530 531 532 533 534 535
    std::vector<float> filter_scale{
        filter_scale_tensor.data<double>(),
        filter_scale_tensor.data<double>() + filter_scale_tensor.numel()};

    conv_op->Op()->SetAttr("Scale_weights", filter_scale);

536
    // if quantization scale is missing for output tensor, return fp32 data
W
Wojciech Uss 已提交
537
    if (has_output_scale) {
538 539 540
      bool is_output_unsigned{false};
      auto output_scale =
          GetScaleValueForNode(conv_output, &is_output_unsigned);
541 542 543 544 545 546 547
      DequantizeOutput(g,
                       conv_op,
                       conv_output,
                       "Output",
                       output_scale,
                       is_output_unsigned,
                       "Scale_out");
548 549 550
    } else {
      conv_op->Op()->SetAttr("force_fp32_output", true);
    }
551

552
    // change threshold in bounded ReLu
553 554
    if (conv_op->Op()->GetAttrIfExists<std::string>("fuse_activation") ==
        "relu6") {
555
      float scale_out =
R
Ruibiao Chen 已提交
556
          PADDLE_GET_CONST(float, conv_op->Op()->GetAttr("Scale_out"));
557
      float threshold =
R
Ruibiao Chen 已提交
558
          PADDLE_GET_CONST(float, conv_op->Op()->GetAttr("fuse_alpha"));
559
      conv_op->Op()->SetAttr("fuse_alpha", scale_out * threshold);
560 561
    }

562 563 564 565 566 567
    ++quantize_conv_count;
  };

  gpd(graph, handler);
  AddStatis(quantize_conv_count);

568
  LogQuantizedOpsCounter(
569
      conv_type,
570
      quantize_conv_count,
571
      ((with_residual_data) ? "with residual connection" : ""));
572 573
}

574
void CPUQuantizePass::QuantizeFc(Graph* graph, bool with_residual_data) const {
M
Michał Gallus 已提交
575 576 577
  GraphPatternDetector gpd;
  auto pattern = gpd.mutable_pattern();
  patterns::FCMKLDNN fc_pattern{pattern, name_scope_};
578
  fc_pattern(with_residual_data);
M
Michał Gallus 已提交
579 580 581 582

  int quantize_fc_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
583 584
    VLOG(4) << "Quantize fc op " << (with_residual_data ? "with" : "without")
            << " residual data";
M
Michał Gallus 已提交
585 586 587
    GET_IR_NODE_FROM_SUBGRAPH(fc, fc, fc_pattern);

    // skip if should not be quantized
588
    if (!platform::HasOpINT8DataType(fc->Op())) {
589 590 591
      LogQuantizationDisabled(fc);
      return;
    }
592

593
    if (!fc->Op()->GetAttrIfExists<bool>("use_mkldnn")) {
594
      MarkAndLogCannotQuantizeOp(fc, "use_mkldnn attribute set to false");
M
Michał Gallus 已提交
595
      return;
596
    }
M
Michał Gallus 已提交
597 598 599 600 601

    GET_IR_NODE_FROM_SUBGRAPH(weights, weights, fc_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(input, input, fc_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(output, output, fc_pattern);

602
    if (!AreScalesPresentForNodes({input, weights})) {
603
      MarkAndLogCannotQuantizeOp(fc, "No scale available for the operator");
604 605
      return;
    }
606

607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626
    if (with_residual_data) {
      GET_IR_NODE_FROM_SUBGRAPH(residual_data, residual_data, fc_pattern);
      if (!AreScalesPresentForNodes({residual_data})) {
        MarkAndLogCannotQuantizeOp(fc, "No scale available for the operator");
        return;
      }

      bool is_residual_unsigned{false};
      auto residual_scale =
          GetScaleValueForNode(residual_data, &is_residual_unsigned);

      QuantizeInput(g,
                    fc,
                    residual_data,
                    "ResidualData",
                    residual_scale,
                    is_residual_unsigned,
                    "Scale_in_eltwise");
    }

627 628
    bool is_input_unsigned{false};
    auto input_scale = GetScaleValueForNode(input, &is_input_unsigned);
629 630
    QuantizeInput(
        g, fc, input, "Input", input_scale, is_input_unsigned, "Scale_in");
M
Michał Gallus 已提交
631

632
    auto weight_scale_tensor = GetScaleTensorForNode(weights);
M
Michał Gallus 已提交
633
    EigenVectorArrayMap eigen_tensor{weight_scale_tensor.data<double>(),
634
                                     weight_scale_tensor.numel()};
M
Michał Gallus 已提交
635 636 637 638 639 640 641
    eigen_tensor *= static_cast<double>(S8_MAX);
    std::vector<float> filter_scale{
        weight_scale_tensor.data<double>(),
        weight_scale_tensor.data<double>() + weight_scale_tensor.numel()};

    fc->Op()->SetAttr("Scale_weights", filter_scale);

642
    // if quantization scale is missing for output tensor, return fp32 data
643
    if (AreScalesPresentForNodes({output})) {
644 645
      bool is_output_unsigned{false};
      auto output_scale = GetScaleValueForNode(output, &is_output_unsigned);
646 647
      DequantizeOutput(
          g, fc, output, "Out", output_scale, is_output_unsigned, "Scale_out");
648 649 650
    } else {
      fc->Op()->SetAttr("force_fp32_output", true);
    }
M
Michał Gallus 已提交
651 652 653 654 655 656

    ++quantize_fc_count;
  };

  gpd(graph, handler);
  AddStatis(quantize_fc_count);
657 658 659
  LogQuantizedOpsCounter("fc",
                         quantize_fc_count,
                         with_residual_data ? "with residual connection" : "");
M
Michał Gallus 已提交
660 661
}

662 663 664 665 666 667 668 669 670 671 672 673 674
void CPUQuantizePass::QuantizePool(Graph* graph) const {
  GraphPatternDetector gpd;
  auto pattern = gpd.mutable_pattern();
  patterns::Pool pool_pattern{pattern, name_scope_};
  pool_pattern();

  int quantize_pool_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    VLOG(4) << "Quantize pool2d op";
    GET_IR_NODE_FROM_SUBGRAPH(pool_op, pool_op, pool_pattern);

    // skip if should not be quantized
675
    if (!platform::HasOpINT8DataType(pool_op->Op())) {
676 677 678
      LogQuantizationDisabled(pool_op);
      return;
    }
679 680 681 682

    GET_IR_NODE_FROM_SUBGRAPH(pool_input, pool_input, pool_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(pool_output, pool_output, pool_pattern);

683
    if (!AreScalesPresentForNodes({pool_input, pool_output})) {
684 685
      MarkAndLogCannotQuantizeOp(pool_op,
                                 "No scale available for the operator");
686 687
      return;
    }
688

689 690
    bool is_input_unsigned{false};
    auto input_scale = GetScaleValueForNode(pool_input, &is_input_unsigned);
691 692
    QuantizeInput(g, pool_op, pool_input, "X", input_scale, is_input_unsigned);

693 694
    bool is_output_unsigned{false};
    auto output_scale = GetScaleValueForNode(pool_output, &is_output_unsigned);
695 696
    DequantizeOutput(
        g, pool_op, pool_output, "Out", output_scale, is_output_unsigned);
697 698 699 700 701 702

    ++quantize_pool_count;
  };

  gpd(graph, handler);
  AddStatis(quantize_pool_count);
703
  LogQuantizedOpsCounter("pool2d", quantize_pool_count);
704 705
}

706 707 708 709 710 711 712 713 714 715 716 717 718
void CPUQuantizePass::QuantizeConcat(Graph* graph) const {
  GraphPatternDetector gpd;
  auto pattern = gpd.mutable_pattern();
  patterns::Concat concat_pattern{pattern, name_scope_};
  concat_pattern();

  int quantize_concat_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    VLOG(4) << "Quantize concat op";
    GET_IR_NODE_FROM_SUBGRAPH(concat_op, concat_op, concat_pattern);

    // skip if should not be quantized
719
    if (!platform::HasOpINT8DataType(concat_op->Op())) {
720 721 722
      LogQuantizationDisabled(concat_op);
      return;
    }
723

724 725 726 727 728 729 730 731 732 733 734 735 736 737
    bool are_all_inputs_unsigned{true};
    // if all inputs were unsigned, then the output was set to unsigned
    // during the scale calculation step
    auto inputs = concat_op->inputs;
    for (size_t i = 0; i < inputs.size(); i++) {
      if (AreScalesPresentForVarNames({inputs[i]->Name()})) {
        auto scale_data = GetScaleDataByName(inputs[i]->Name());
        if (scale_data.first == false) {
          are_all_inputs_unsigned = false;
          break;
        }
      }
    }

738 739
    GET_IR_NODE_FROM_SUBGRAPH(concat_out, concat_out, concat_pattern);

740
    if (!AreScalesPresentForNodes({concat_out})) {
741 742
      MarkAndLogCannotQuantizeOp(concat_op,
                                 "No scale available for the operator");
743 744
      return;
    }
745

746
    auto output_scale = GetScaleValueForNode(concat_out);
747

748
    QuantizeInputs(g, concat_op, "X", are_all_inputs_unsigned);
749

750 751
    DequantizeOutput(
        g, concat_op, concat_out, "Out", output_scale, are_all_inputs_unsigned);
752 753 754 755 756
    ++quantize_concat_count;
  };

  gpd(graph, handler);
  AddStatis(quantize_concat_count);
757
  LogQuantizedOpsCounter("concat", quantize_concat_count);
758 759
}

760 761 762 763 764 765 766 767 768 769 770 771 772
void CPUQuantizePass::QuantizePriorBox(Graph* graph) const {
  GraphPatternDetector gpd;
  auto pattern = gpd.mutable_pattern();
  patterns::PriorBox prior_box_pattern{pattern, name_scope_};
  prior_box_pattern();

  int quantize_prior_box_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    VLOG(4) << "Quantize prior_box op";
    GET_IR_NODE_FROM_SUBGRAPH(prior_box_op, prior_box_op, prior_box_pattern);

    // skip if should not be quantized
773
    if (!platform::HasOpINT8DataType(prior_box_op->Op())) {
774 775 776
      LogQuantizationDisabled(prior_box_op);
      return;
    }
777

778 779
    GET_IR_NODE_FROM_SUBGRAPH(
        prior_box_input, prior_box_input, prior_box_pattern);
780

781
    if (!AreScalesPresentForNodes({prior_box_input})) {
782 783
      MarkAndLogCannotQuantizeOp(prior_box_op,
                                 "No scale available for the operator");
784 785
      return;
    }
786

787 788 789
    bool is_input_unsigned{false};
    auto input_scale =
        GetScaleValueForNode(prior_box_input, &is_input_unsigned);
790 791 792 793 794
    QuantizeInput(g,
                  prior_box_op,
                  prior_box_input,
                  "Input",
                  input_scale,
795 796 797 798 799 800 801
                  is_input_unsigned);

    ++quantize_prior_box_count;
  };

  gpd(graph, handler);
  AddStatis(quantize_prior_box_count);
802
  LogQuantizedOpsCounter("prior_box", quantize_prior_box_count);
803 804
}

805 806 807
void CPUQuantizePass::QuantizeImmutable(Graph* graph,
                                        const std::string& immutable_type,
                                        const std::string& input_name) const {
808 809
  GraphPatternDetector gpd;
  auto pattern = gpd.mutable_pattern();
810 811
  patterns::Immutable immutable_pattern{pattern, name_scope_};
  immutable_pattern(immutable_type, input_name);
812

813
  int quantize_immutable_count = 0;
814 815
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
816 817
    VLOG(4) << "Quantize " + immutable_type + " op";
    GET_IR_NODE_FROM_SUBGRAPH(immutable_op, immutable_op, immutable_pattern);
818 819

    // skip if should not be quantized
820 821
    if (!platform::HasOpINT8DataType(immutable_op->Op())) {
      LogQuantizationDisabled(immutable_op);
822 823
      return;
    }
824 825 826
    GET_IR_NODE_FROM_SUBGRAPH(prev_op, prev_op, immutable_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(immutable_in, immutable_in, immutable_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(immutable_out, immutable_out, immutable_pattern);
827

828
    // skip if prev op and next op is not quantized
829 830
    if (!IsOpDequantized(prev_op) && !IsOpQuantized(immutable_out)) {
      MarkAndLogCannotQuantizeOp(immutable_op,
831
                                 "No other quantizable operators nearby");
832 833 834
      return;
    }

835 836 837 838 839 840 841
    // skip if the dtype of immutable_in is not float32
    auto dtype = immutable_in->Var()->GetDataType();
    if (dtype != proto::VarType::FP32) {
      MarkAndLogCannotQuantizeOp(immutable_op, "The input dtype is not float.");
      return;
    }

842 843
    if (!AreScalesPresentForNodes({immutable_out})) {
      MarkAndLogCannotQuantizeOp(immutable_op,
844
                                 "No scale available for the operator");
845
      return;
846
    }
847

848
    bool is_input_unsigned{false};
849 850 851 852 853 854 855 856
    auto input_scale = GetScaleValueForNode(immutable_out, &is_input_unsigned);

    QuantizeInput(g,
                  immutable_op,
                  immutable_in,
                  input_name,
                  input_scale,
                  is_input_unsigned);
857

858 859
    bool is_output_unsigned{false};
    auto output_scale =
860
        GetScaleValueForNode(immutable_out, &is_output_unsigned);
P
Paulina Gacek 已提交
861 862 863 864 865 866 867 868 869 870 871
    if (immutable_type == "split") {  // ops with multiple outputs
      DequantizeOutputs(
          g, immutable_op, "Out", output_scale, is_output_unsigned);
    } else {
      DequantizeOutput(g,
                       immutable_op,
                       immutable_out,
                       "Out",
                       output_scale,
                       is_output_unsigned);
    }
872
    ++quantize_immutable_count;
873 874 875
  };

  gpd(graph, handler);
876 877
  AddStatis(quantize_immutable_count);
  LogQuantizedOpsCounter(immutable_type, quantize_immutable_count);
Z
Zuza 已提交
878 879
}

880
void CPUQuantizePass::QuantizeMatmul(Graph* graph, bool with_residual) const {
881 882
  GraphPatternDetector gpd;
  auto pattern = gpd.mutable_pattern();
883
  patterns::FusedMatmul matmul_pattern{pattern, name_scope_};
884
  matmul_pattern(with_residual);
885 886 887 888 889 890 891 892

  int quantize_matmul_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    VLOG(4) << "Quantize matmul op";
    GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, matmul_pattern);

    // skip if should not be quantized
893
    if (!platform::HasOpINT8DataType(matmul_op->Op())) {
894
      LogQuantizationDisabled(matmul_op);
895 896 897 898 899 900 901
      return;
    }

    GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, matmul_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, matmul_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, matmul_pattern);

902 903 904 905 906 907 908 909 910
    auto has_output_scale = AreScalesPresentForNodes({matmul_out});
    if (with_residual && !has_output_scale) {
      MarkAndLogCannotQuantizeOp(
          matmul_op,
          "Matmul op with ResidualData input cannot be quantized "
          "without output scale.");
      return;
    }

911
    if (!AreScalesPresentForNodes({matmul_in_x, matmul_in_y})) {
912 913
      MarkAndLogCannotQuantizeOp(matmul_op,
                                 "No scale available for the operator");
914
      return;
915
    }
916

917 918 919
    bool is_x_unsigned{false}, is_y_unsigned{false};
    auto input_x_scale = GetScaleValueForNode(matmul_in_x, &is_x_unsigned);
    auto input_y_scale = GetScaleValueForNode(matmul_in_y, &is_y_unsigned);
920 921
    PADDLE_ENFORCE_EQ(is_x_unsigned,
                      is_y_unsigned,
922 923 924 925
                      platform::errors::InvalidArgument(
                          "Matmul inputs should have the same "
                          "attribute of signed/unsigned, but they "
                          "are different: x(%d), y(%d).",
926 927
                          is_x_unsigned,
                          is_y_unsigned));
928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949

    if (with_residual) {
      GET_IR_NODE_FROM_SUBGRAPH(
          matmul_residual_data, matmul_residual_data, matmul_pattern);
      if (!AreScalesPresentForNodes({matmul_residual_data})) {
        MarkAndLogCannotQuantizeOp(matmul_op,
                                   "No scale available for the operator");
        return;
      }
      bool is_residual_unsigned{false};
      auto residual_scale =
          GetScaleValueForNode(matmul_residual_data, &is_residual_unsigned);

      QuantizeInput(g,
                    matmul_op,
                    matmul_residual_data,
                    "ResidualData",
                    residual_scale,
                    is_residual_unsigned,
                    "Scale_in_eltwise");
    }

950 951 952 953 954 955
    QuantizeInput(g,
                  matmul_op,
                  matmul_in_x,
                  "X",
                  input_x_scale,
                  is_x_unsigned,
956
                  "Scale_x");
957 958 959 960 961 962
    QuantizeInput(g,
                  matmul_op,
                  matmul_in_y,
                  "Y",
                  input_y_scale,
                  is_y_unsigned,
963 964
                  "Scale_y");

965
    // if quantization scale is missing for output tensor, return fp32 data
966
    if (AreScalesPresentForNodes({matmul_out})) {
967 968
      bool is_output_unsigned{false};
      auto output_scale = GetScaleValueForNode(matmul_out, &is_output_unsigned);
969 970 971 972 973 974 975
      DequantizeOutput(g,
                       matmul_op,
                       matmul_out,
                       "Out",
                       output_scale,
                       is_output_unsigned,
                       "Scale_out");
976 977 978
    } else {
      matmul_op->Op()->SetAttr("force_fp32_output", true);
    }
979 980 981 982 983

    ++quantize_matmul_count;
  };
  gpd(graph, handler);
  AddStatis(quantize_matmul_count);
984 985 986
  LogQuantizedOpsCounter("matmul",
                         quantize_matmul_count,
                         (with_residual ? "with residual connection" : ""));
987 988
}

Z
Zuza 已提交
989
void CPUQuantizePass::QuantizeElementwise(
990
    Graph* graph, const std::string& elementwise_type) const {
991 992
  GraphPatternDetector gpd;
  auto pattern = gpd.mutable_pattern();
993
  patterns::ElementwiseOp elementwise_pattern{pattern, name_scope_};
994

995
  elementwise_pattern(elementwise_type);
996

Z
Zuza 已提交
997
  int quantize_elementwise_count = 0;
998 999
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
Z
Zuza 已提交
1000
    VLOG(4) << "Quantize " + elementwise_type + " op";
1001 1002
    GET_IR_NODE_FROM_SUBGRAPH(
        elementwise_op, elementwise_op, elementwise_pattern);
1003 1004

    // skip if should not be quantized
Z
Zuza 已提交
1005 1006
    if (!platform::HasOpINT8DataType(elementwise_op->Op())) {
      LogQuantizationDisabled(elementwise_op);
1007 1008 1009
      return;
    }

1010 1011
    auto x_name = elementwise_op->Op()->Input("X");
    auto y_name = elementwise_op->Op()->Input("Y");
1012
    Node *elementwise_x{nullptr}, *elementwise_y{nullptr};
1013 1014 1015 1016 1017 1018 1019 1020 1021

    for (auto& input : elementwise_op->inputs) {
      if (input->Name() == x_name[0]) elementwise_x = input;
      if (input->Name() == y_name[0]) elementwise_y = input;
    }
    if (!elementwise_x || !elementwise_y) {
      return;
    }

1022 1023
    GET_IR_NODE_FROM_SUBGRAPH(
        elementwise_out, elementwise_out, elementwise_pattern);
1024

1025
    if (!AreScalesPresentForNodes(
Z
Zuza 已提交
1026
            {elementwise_x, elementwise_y, elementwise_out})) {
1027 1028
      MarkAndLogCannotQuantizeOp(elementwise_op,
                                 "No scale available for the operator");
1029 1030 1031 1032
      return;
    }

    bool is_x_unsigned{false}, is_y_unsigned{false};
Z
Zuza 已提交
1033 1034
    auto input_x_scale = GetScaleValueForNode(elementwise_x, &is_x_unsigned);
    auto input_y_scale = GetScaleValueForNode(elementwise_y, &is_y_unsigned);
1035 1036

    if (is_x_unsigned != is_y_unsigned) {
1037 1038
      MarkAndLogCannotQuantizeOp(
          elementwise_op, "Elementwise inputs must be of the same type.");
1039 1040 1041
      return;
    }

1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055
    QuantizeInput(g,
                  elementwise_op,
                  elementwise_x,
                  "X",
                  input_x_scale,
                  is_x_unsigned,
                  "Scale_x");
    QuantizeInput(g,
                  elementwise_op,
                  elementwise_y,
                  "Y",
                  input_y_scale,
                  is_y_unsigned,
                  "Scale_y");
1056

1057 1058
    bool is_output_unsigned{false};
    auto output_scale =
Z
Zuza 已提交
1059
        GetScaleValueForNode(elementwise_out, &is_output_unsigned);
1060

1061 1062 1063 1064 1065 1066 1067
    DequantizeOutput(g,
                     elementwise_op,
                     elementwise_out,
                     "Out",
                     output_scale,
                     is_output_unsigned,
                     "Scale_out");
1068

Z
Zuza 已提交
1069
    ++quantize_elementwise_count;
1070 1071
  };
  gpd(graph, handler);
Z
Zuza 已提交
1072
  AddStatis(quantize_elementwise_count);
1073
  LogQuantizedOpsCounter(elementwise_type, quantize_elementwise_count);
1074 1075
}

1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097
void CPUQuantizePass::QuantizeFusionGru(Graph* graph) const {
  GraphPatternDetector gpd;
  patterns::FusionGru pattern{gpd.mutable_pattern(), name_scope_};
  pattern();

  int quantize_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    VLOG(4) << "Quantize fusion_gru op";
    GET_IR_NODE_FROM_SUBGRAPH(op, op, pattern);

    // skip if should not be quantized
    if (!platform::HasOpINT8DataType(op->Op())) {
      LogQuantizationDisabled(op);
      return;
    }

    GET_IR_NODE_FROM_SUBGRAPH(x, x, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(weight_h, weight_h, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(weight_x, weight_x, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(out, out, pattern);

1098
    if (!AreScalesPresentForNodes({x, weight_x})) {
1099
      MarkAndLogCannotQuantizeOp(op, "No scale available for the operator");
1100 1101 1102 1103 1104 1105 1106 1107 1108
      return;
    }

    bool is_x_unsigned{false};
    auto input_x_scale = GetScaleValueForNode(x, &is_x_unsigned);

    double input_x_shift{128.};
    if (is_x_unsigned) input_x_shift = 0.;

1109 1110 1111 1112 1113 1114 1115 1116 1117
    QuantizeInput(g,
                  op,
                  x,
                  "X",
                  input_x_scale,
                  is_x_unsigned,
                  "Scale_data",
                  input_x_shift,
                  "Shift_data");
1118 1119 1120

    auto weight_scale_tensor = GetScaleTensorForNode(weight_x);
    EigenVectorArrayMap eigen_tensor{weight_scale_tensor.data<double>(),
1121
                                     weight_scale_tensor.numel()};
1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134
    eigen_tensor *= static_cast<double>(S8_MAX);
    std::vector<float> scale_weights{
        weight_scale_tensor.data<double>(),
        weight_scale_tensor.data<double>() + weight_scale_tensor.numel()};

    op->Op()->SetAttr("Scale_weights", scale_weights);
    // return fp32 data
    op->Op()->SetAttr("force_fp32_output", true);

    ++quantize_count;
  };
  gpd(graph, handler);
  AddStatis(quantize_count);
1135
  LogQuantizedOpsCounter("fusion_gru", quantize_count);
1136 1137
}

1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161
void CPUQuantizePass::QuantizeMultiGru(Graph* graph) const {
  GraphPatternDetector gpd;
  patterns::MultiGru pattern{gpd.mutable_pattern(), name_scope_};
  pattern();

  int quantize_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    VLOG(4) << "Quantize multi_gru op";
    GET_IR_NODE_FROM_SUBGRAPH(gru, gru, pattern);

    // skip if should not be quantized
    if (!platform::HasOpINT8DataType(gru->Op())) {
      LogQuantizationDisabled(gru);
      return;
    }

    GET_IR_NODE_FROM_SUBGRAPH(x, x, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(wx, wx, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(h, h, pattern);

    auto wx_names = gru->Op()->Input("WeightX");
    if (!AreScalesPresentForNodes({x}) ||
        !AreScalesPresentForVarNames(wx_names)) {
1162
      MarkAndLogCannotQuantizeOp(gru, "No scale available for the operator");
1163 1164 1165 1166 1167 1168 1169 1170 1171
      return;
    }

    bool is_x_unsigned{false};
    auto input_x_scale = GetScaleValueForNode(x, &is_x_unsigned);

    double input_x_shift{128.};
    if (is_x_unsigned) input_x_shift = 0.;

1172 1173 1174 1175 1176 1177 1178 1179 1180
    QuantizeInput(g,
                  gru,
                  x,
                  "X",
                  input_x_scale,
                  is_x_unsigned,
                  "Scale_data",
                  input_x_shift,
                  "Shift_data");
1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191

    auto* scope = param_scope();
    int wx_size = wx_names.size();
    std::vector<std::string> w_scale_var_names;
    for (int i = 0; i < wx_size; ++i) {
      auto scale_tensor_src = GetScaleTensorByName(wx_names[i]);
      EigenVectorArrayMap eigen_tensor_src{scale_tensor_src.data<double>(),
                                           scale_tensor_src.numel()};

      VarDesc scale_var_desc(patterns::PDNodeName("multi_gru", "w_scale"));

1192
      scale_var_desc.SetShape(phi::vectorize(scale_tensor_src.dims()));
1193 1194 1195 1196 1197 1198
      scale_var_desc.SetDataType(proto::VarType::FP32);
      scale_var_desc.SetLoDLevel(scale_tensor_src.lod().size());
      scale_var_desc.SetPersistable(true);
      auto* w_scale_node = g->CreateVarNode(&scale_var_desc);

      auto* w_scale_tensor_dst =
1199
          scope->Var(w_scale_node->Name())->GetMutable<phi::DenseTensor>();
1200
      w_scale_tensor_dst->Resize(scale_tensor_src.dims());
1201
      auto* dst_data = w_scale_tensor_dst->mutable_data<float>(phi::CPUPlace());
1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217
      EigenVectorArrayMapFloat eigen_tensor_dst{dst_data,
                                                w_scale_tensor_dst->numel()};
      eigen_tensor_dst =
          eigen_tensor_src.cast<float>() * static_cast<float>(S8_MAX);
      w_scale_var_names.push_back(w_scale_node->Name());
      IR_NODE_LINK_TO(w_scale_node, gru);
    }

    gru->Op()->SetInput("Scale_weights", w_scale_var_names);
    // return fp32 data
    gru->Op()->SetAttr("force_fp32_output", true);

    ++quantize_count;
  };
  gpd(graph, handler);
  AddStatis(quantize_count);
1218
  LogQuantizedOpsCounter("multi_gru", quantize_count);
1219 1220
}

1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245
void CPUQuantizePass::QuantizeFusionLSTM(Graph* graph) const {
  GraphPatternDetector gpd;
  patterns::FusionLSTM pattern{gpd.mutable_pattern(), name_scope_};
  pattern();

  int quantize_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    VLOG(4) << "Quantize fusion_lstm op";
    GET_IR_NODE_FROM_SUBGRAPH(op, op, pattern);

    // skip if should not be quantized
    if (!platform::HasOpINT8DataType(op->Op())) {
      LogQuantizationDisabled(op);
      return;
    }

    GET_IR_NODE_FROM_SUBGRAPH(x, x, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(weight_h, weight_h, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(weight_x, weight_x, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(hidden, hidden, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(cell, cell, pattern);

    // Starting from here there maybe issues
    if (!AreScalesPresentForNodes({x, weight_x})) {
1246
      MarkAndLogCannotQuantizeOp(op, "No scale available for the operator");
1247 1248 1249 1250 1251 1252
      return;
    }

    bool is_x_unsigned{false};
    auto input_x_scale = GetScaleValueForNode(x, &is_x_unsigned);

1253 1254 1255 1256
    // In the QAT process scales are prepared for only int8 data type,
    // lstm scales should behave as input is int8 to get correct accuracy
    is_x_unsigned = false;

1257 1258 1259
    double input_x_shift{128.};
    if (is_x_unsigned) input_x_shift = 0.;

1260 1261 1262 1263 1264 1265 1266 1267 1268
    QuantizeInput(g,
                  op,
                  x,
                  "X",
                  input_x_scale,
                  is_x_unsigned,
                  "Scale_data",
                  input_x_shift,
                  "Shift_data");
1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285

    auto weight_scale_tensor = GetScaleTensorForNode(weight_x);
    EigenVectorArrayMap eigen_tensor{weight_scale_tensor.data<double>(),
                                     weight_scale_tensor.numel()};
    eigen_tensor *= static_cast<double>(S8_MAX);
    std::vector<float> scale_weights{
        weight_scale_tensor.data<double>(),
        weight_scale_tensor.data<double>() + weight_scale_tensor.numel()};

    op->Op()->SetAttr("Scale_weights", scale_weights);
    // return fp32 data
    op->Op()->SetAttr("force_fp32_output", true);

    ++quantize_count;
  };
  gpd(graph, handler);
  AddStatis(quantize_count);
1286
  LogQuantizedOpsCounter("fusion_lstm", quantize_count);
1287 1288
}

1289
void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
1290
  VLOG(3) << "Quantizing the graph.";
1291 1292
  PADDLE_ENFORCE_NOT_NULL(
      graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
1293
  FusePassBase::Init(name_scope_, graph);
1294

1295 1296 1297
  PADDLE_ENFORCE_NOT_NULL(
      param_scope(),
      platform::errors::InvalidArgument("Scope cannot be nullptr."));
1298

B
baoachun 已提交
1299
  GetQuantInfo(graph);
1300 1301
  QuantizeConv(graph, "fused_conv2d", false /* with_residual_data */);
  QuantizeConv(graph, "fused_conv2d", true /* with_residual_data */);
Z
zyfncg 已提交
1302 1303
  QuantizeConv(graph, "conv2d", false /* with_residual_data */);
  QuantizeConv(graph, "conv2d", true /* with_residual_data */);
1304
  QuantizePool(graph);
1305
  QuantizeConcat(graph);
1306
  QuantizePriorBox(graph);
1307 1308
  QuantizeFc(graph, false /* with_residual_data */);
  QuantizeFc(graph, true /* with_residual_data */);
1309 1310
  QuantizeMatmul(graph, false /* with_residual_data */);
  QuantizeMatmul(graph, true /* with_residual_data */);
1311
  QuantizeImmutable(graph, "reshape2", "X");
1312
  QuantizeImmutable(graph, "fused_transpose", "X");
1313 1314 1315
  QuantizeImmutable(graph, "slice", "Input");
  QuantizeImmutable(graph, "nearest_interp", "X");
  QuantizeImmutable(graph, "nearest_interp_v2", "X");
P
Paulina Gacek 已提交
1316
  QuantizeImmutable(graph, "split", "X");
Z
Zuza 已提交
1317 1318
  QuantizeElementwise(graph, "elementwise_add");
  QuantizeElementwise(graph, "elementwise_mul");
1319
  QuantizeElementwise(graph, "elementwise_sub");
1320
  QuantizeFusionGru(graph);
1321
  QuantizeMultiGru(graph);
1322
  QuantizeFusionLSTM(graph);
1323 1324 1325 1326 1327 1328 1329 1330
}

}  // namespace ir
}  // namespace framework
}  // namespace paddle

REGISTER_PASS(cpu_quantize_pass, paddle::framework::ir::CPUQuantizePass)
    .RequirePassAttr("quant_var_scales");