cpu_quantize_pass.cc 46.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#include "paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h"

17
#include <sstream>
18 19
#include <utility>
#include <vector>
W
wanghuancoder 已提交
20

B
baoachun 已提交
21
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h"
22
#include "paddle/fluid/platform/mkldnn_helper.h"
23 24 25 26 27 28
#include "paddle/fluid/string/pretty_log.h"

namespace paddle {
namespace framework {
namespace ir {

29
using EigenVectorArrayMap = Eigen::Map<Eigen::Array<double, Eigen::Dynamic, 1>>;
30 31
using EigenVectorArrayMapFloat =
    Eigen::Map<Eigen::Array<float, Eigen::Dynamic, 1>>;
32 33
using string::PrettyLogDetail;

34 35 36 37 38 39 40 41 42
namespace {

void UnlinkNodes(ir::Node* a, ir::Node* b) {
  a->outputs.erase(std::remove(a->outputs.begin(), a->outputs.end(), b),
                   a->outputs.end());
  b->inputs.erase(std::remove(b->inputs.begin(), b->inputs.end(), a),
                  b->inputs.end());
}

43
void MarkAndLogCannotQuantizeOp(Node* op, const char* details = nullptr) {
44 45 46
  std::stringstream msg_ss;
  msg_ss << "Cannot quantize operator " << op->Name()
         << " (type: " << op->Op()->Type() << ", id: " << op->id() << ").";
47
  if (details) msg_ss << " " << details;
48 49
  VLOG(2) << msg_ss.str().c_str();
  op->Op()->SetAttr("mkldnn_data_type", std::string("float32"));
50 51
}

52 53 54 55 56 57
void LogScaleIsMissingForVarName(const std::string& name) {
  VLOG(4) << "Quantization scale for the variable " << name << " is missing.";
}

void LogScaleIsMissingForVarNode(Node* node) {
  LogScaleIsMissingForVarName(node->Name());
58 59
}

60
void LogQuantizationDisabled(Node* op) {
61
  VLOG(2) << "Quantization skipped for operator " << op->Name()
62
          << " (type: " << op->Op()->Type() << ", id: " << op->id()
63
          << "). Attribute mkldnn_data_type != \"int8\".";
64 65
}

66 67
void LogQuantizedOpsCounter(const std::string& type,
                            const int counter,
68 69 70 71 72 73 74
                            const char* details = nullptr) {
  std::stringstream msg_ss;
  msg_ss << "---    quantized " << counter << " " << type << " ops";
  if (details) msg_ss << " " << details;
  PrettyLogDetail(msg_ss.str().c_str());
}

75 76 77 78
}  // namespace

enum { U8_MAX = 255, S8_MAX = 127 };

79 80 81 82 83
void CPUQuantizePass::QuantizeInput(Graph* g,
                                    Node* op,
                                    Node* input,
                                    std::string input_name,
                                    double scale_to_one,
84
                                    bool is_input_unsigned,
85 86
                                    std::string scale_attr_name,
                                    float shift,
87
                                    std::string shift_attr_name) const {
M
Michał Gallus 已提交
88 89 90
  auto inputs = op->Op()->InputNames();
  bool name_found =
      std::find(inputs.begin(), inputs.end(), input_name) != inputs.end();
91 92
  PADDLE_ENFORCE_EQ(name_found,
                    true,
93 94
                    platform::errors::InvalidArgument(
                        "Var(%s) isn't the input of the %s operator.",
95 96
                        input_name,
                        op->Op()->Type()));
97
  unsigned max = is_input_unsigned ? U8_MAX : S8_MAX;
98 99 100 101 102 103 104 105 106 107 108 109 110
  float scale = scale_to_one * max;

  // Create quantize output variable
  VarDesc quantize_out_desc(patterns::PDNodeName("quantize", "out"));
  auto* quantize_out_node = g->CreateVarNode(&quantize_out_desc);

  // create a quantize op node
  OpDesc q_desc;
  q_desc.SetType("quantize");
  q_desc.SetInput("Input", std::vector<std::string>({input->Name()}));
  q_desc.SetOutput("Output",
                   std::vector<std::string>({quantize_out_node->Name()}));
  q_desc.SetAttr("Scale", scale);
111 112
  q_desc.SetAttr("Shift", shift);
  q_desc.SetAttr("is_negative_input", !is_input_unsigned);
113

Z
Zuza 已提交
114 115 116
  // fix to fc format error
  if (op->Op()->Type() == "fc" &&
      op->Op()->GetAttrIfExists<int>("in_num_col_dims") == 2) {
117 118 119
    q_desc.SetAttr(
        "output_format",
        Has("data_layout") ? Get<std::string>("data_layout") : "NCHW");
Z
Zuza 已提交
120
  } else {
121 122 123
    q_desc.SetAttr(
        "output_format",
        Has("data_layout") ? Get<std::string>("data_layout") : "NHWC");
Z
Zuza 已提交
124
  }
125 126 127 128 129 130 131 132 133 134 135 136 137
  auto quantize_op = g->CreateOpNode(&q_desc);  // OpDesc will be copied.

  // update op's input
  op->Op()->SetInput(input_name,
                     std::vector<std::string>({quantize_out_node->Name()}));

  // link quantize op
  UnlinkNodes(input, op);
  IR_NODE_LINK_TO(input, quantize_op);
  IR_NODE_LINK_TO(quantize_op, quantize_out_node);
  IR_NODE_LINK_TO(quantize_out_node, op);

  if (!scale_attr_name.empty()) op->Op()->SetAttr(scale_attr_name, scale);
138
  if (!shift_attr_name.empty()) op->Op()->SetAttr(shift_attr_name, shift);
139 140
}

141 142 143
void CPUQuantizePass::QuantizeInputs(Graph* g,
                                     Node* op,
                                     std::string input_name,
144
                                     bool are_inputs_unsigned,
145 146
                                     std::string scale_attr_name,
                                     float shift,
147
                                     std::string shift_attr_name) const {
148
  auto inputs = op->inputs;
149 150 151 152 153 154 155 156
  auto var_names = op->Op()->Inputs().at(input_name);
  std::vector<std::string> unique_var_names;
  for (unsigned i = 0; i < var_names.size(); i++)
    if (std::find(unique_var_names.begin(),
                  unique_var_names.end(),
                  var_names[i]) == unique_var_names.end())
      unique_var_names.push_back(var_names[i]);

157
  auto output = op->outputs[0];
158 159
  PADDLE_ENFORCE_GE(inputs.size(),
                    1,
160 161
                    platform::errors::InvalidArgument(
                        "OP(%s)'s inputs(%d) must be equal or greater than 1.",
162 163 164 165
                        op->Name(),
                        inputs.size()));
  PADDLE_ENFORCE_EQ(op->outputs.size(),
                    1,
166
                    platform::errors::InvalidArgument(
167 168
                        "OP(%s)'s outputs(%d) must be equal to 1.",
                        op->Name(),
169
                        op->outputs.size()));
170 171 172 173 174 175 176

  // create a quantize op desc prototype
  OpDesc q_desc;
  q_desc.SetType("quantize");
  std::vector<Node*> quantize_out_nodes(inputs.size());
  std::vector<std::string> quantize_out_node_names(inputs.size());

177
  double scale_out = GetScaleValueForNode(output);
178
  unsigned max = are_inputs_unsigned ? U8_MAX : S8_MAX;
179
  float scale = scale_out * max;
180

181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
  for (size_t var_id = 0; var_id < unique_var_names.size(); var_id++) {
    auto index = -1;
    for (size_t it = 0; it < inputs.size(); it++) {
      if (inputs[it]->Name() == unique_var_names[var_id]) index = it;
    }

    if (index == -1) {
      PADDLE_ENFORCE_NE(index,
                        -1,
                        platform::errors::InvalidArgument(
                            "Var(%s) isn't the input of the %s operator.",
                            unique_var_names[var_id],
                            op->Op()->Type()));
    }

    auto* input = inputs.at(index);

198
    VarDesc quantize_out_desc(patterns::PDNodeName("quantize", "out"));
199 200
    quantize_out_nodes[var_id] = g->CreateVarNode(&quantize_out_desc);
    quantize_out_node_names[var_id] = quantize_out_nodes[var_id]->Name();
201 202

    q_desc.SetAttr("Scale", scale);
203
    q_desc.SetAttr("Shift", shift);
204 205 206
    q_desc.SetInput("Input", std::vector<std::string>({input->Name()}));
    q_desc.SetOutput(
        "Output", std::vector<std::string>({quantize_out_node_names[var_id]}));
207
    q_desc.SetAttr("is_negative_input", !are_inputs_unsigned);
208 209 210
    auto quantize_op = g->CreateOpNode(&q_desc);  // OpDesc will be copied.

    // link quantize op
211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226
    UnlinkNodes(input, op);
    IR_NODE_LINK_TO(input, quantize_op);
    IR_NODE_LINK_TO(quantize_op, quantize_out_nodes[var_id]);
    IR_NODE_LINK_TO(quantize_out_nodes[var_id], op);
  }

  // If any inputs were duplicated, now you have to enter them in the correct
  // order.
  for (size_t i = unique_var_names.size(); i < var_names.size(); i++) {
    auto index = std::find(
        unique_var_names.begin(), unique_var_names.end(), var_names[i]);
    if (index != unique_var_names.end()) {
      auto id = std::distance(unique_var_names.begin(), index);
      quantize_out_node_names[i] = quantize_out_nodes[id]->Name();
      IR_NODE_LINK_TO(quantize_out_nodes[id], op);
    }
227 228 229 230 231 232
  }

  // update op's input
  op->Op()->SetInput(input_name, quantize_out_node_names);

  if (!scale_attr_name.empty()) op->Op()->SetAttr(scale_attr_name, scale);
233
  if (!shift_attr_name.empty()) op->Op()->SetAttr(shift_attr_name, shift);
234 235
}

236 237 238
void CPUQuantizePass::DequantizeOutput(Graph* g,
                                       Node* op,
                                       Node* output,
239
                                       std::string output_name,
240 241
                                       double scale_to_one,
                                       bool is_unsigned,
242
                                       std::string scale_attr_name) const {
M
Michał Gallus 已提交
243 244 245
  auto outputs = op->Op()->OutputNames();
  bool name_found =
      std::find(outputs.begin(), outputs.end(), output_name) != outputs.end();
246 247
  PADDLE_ENFORCE_EQ(name_found,
                    true,
M
Michał Gallus 已提交
248
                    platform::errors::InvalidArgument(
249
                        "Var(%s) isn't the output of the %s operator.",
250 251
                        output_name,
                        op->Op()->Type()));
252 253 254 255 256 257 258 259 260 261 262 263 264 265
  unsigned max = is_unsigned ? U8_MAX : S8_MAX;
  float scale = scale_to_one * max;

  // Create dequantize input variable
  VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in"));
  auto* dequantize_in_node = g->CreateVarNode(&dequantize_in_desc);

  // create a dequantize op node for output.
  OpDesc deq_desc;
  deq_desc.SetType("dequantize");
  deq_desc.SetInput("Input",
                    std::vector<std::string>({dequantize_in_node->Name()}));
  deq_desc.SetOutput("Output", std::vector<std::string>({output->Name()}));
  deq_desc.SetAttr("Scale", scale);
266
  deq_desc.SetAttr("is_negative_input", !is_unsigned);
267 268 269 270 271 272 273 274 275 276 277 278 279 280 281
  auto dequantize_op = g->CreateOpNode(&deq_desc);  // OpDesc will be copied.

  // update op's output
  op->Op()->SetOutput(output_name,
                      std::vector<std::string>({dequantize_in_node->Name()}));

  // link dequantize op
  UnlinkNodes(op, output);
  IR_NODE_LINK_TO(op, dequantize_in_node);
  IR_NODE_LINK_TO(dequantize_in_node, dequantize_op);
  IR_NODE_LINK_TO(dequantize_op, output);

  if (!scale_attr_name.empty()) op->Op()->SetAttr(scale_attr_name, scale);
}

P
Paulina Gacek 已提交
282 283 284 285 286 287 288
void CPUQuantizePass::DequantizeOutputs(Graph* g,
                                        Node* op,
                                        std::string output_name,
                                        double scale_to_one,
                                        bool is_unsigned,
                                        std::string scale_attr_name) const {
  auto outputs = op->outputs;
289 290
  auto var_names = op->Op()->Outputs().at(output_name);

P
Paulina Gacek 已提交
291 292 293 294 295 296 297
  PADDLE_ENFORCE_GE(outputs.size(),
                    1,
                    platform::errors::InvalidArgument(
                        "OP(%s)'s outputs(%d) must be equal or greater than 1.",
                        op->Name(),
                        outputs.size()));

298 299
  std::vector<std::string> dequantize_in_node_names(outputs.size());
  std::vector<Node*> dequantize_in_nodes(outputs.size());
P
Paulina Gacek 已提交
300 301 302 303

  unsigned max = is_unsigned ? U8_MAX : S8_MAX;
  float scale = scale_to_one * max;

304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320
  for (size_t var_id = 0; var_id < var_names.size(); var_id++) {
    auto index = -1;
    for (size_t it = 0; it < outputs.size(); it++) {
      if (outputs[it]->Name() == var_names[var_id]) index = it;
    }

    if (index == -1) {
      PADDLE_ENFORCE_NE(index,
                        -1,
                        platform::errors::InvalidArgument(
                            "Var(%s) isn't the input of the %s operator.",
                            var_names[var_id],
                            op->Op()->Type()));
    }

    auto* output = outputs.at(index);

P
Paulina Gacek 已提交
321 322
    // Create dequantize input variable
    VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in"));
323 324
    dequantize_in_nodes[var_id] = g->CreateVarNode(&dequantize_in_desc);
    dequantize_in_node_names[var_id] = dequantize_in_nodes[var_id]->Name();
P
Paulina Gacek 已提交
325 326 327 328

    // create a dequantize op node for output.
    OpDesc deq_desc;
    deq_desc.SetType("dequantize");
329 330 331
    deq_desc.SetInput(
        "Input", std::vector<std::string>({dequantize_in_node_names[var_id]}));
    deq_desc.SetOutput("Output", std::vector<std::string>({output->Name()}));
P
Paulina Gacek 已提交
332 333 334 335 336
    deq_desc.SetAttr("Scale", scale);
    deq_desc.SetAttr("is_negative_input", !is_unsigned);
    auto dequantize_op = g->CreateOpNode(&deq_desc);  // OpDesc will be copied.

    // link dequantize op
337 338 339 340
    UnlinkNodes(op, output);
    IR_NODE_LINK_TO(op, dequantize_in_nodes[var_id]);
    IR_NODE_LINK_TO(dequantize_in_nodes[var_id], dequantize_op);
    IR_NODE_LINK_TO(dequantize_op, output);
P
Paulina Gacek 已提交
341 342 343
  }

  // update op's output
344
  op->Op()->SetOutput(output_name, dequantize_in_node_names);
P
Paulina Gacek 已提交
345 346 347
  if (!scale_attr_name.empty()) op->Op()->SetAttr(scale_attr_name, scale);
}

348 349 350
bool CPUQuantizePass::AreScalesPresentForVarNames(
    std::vector<std::string> names) const {
  bool present = true;
B
baoachun 已提交
351 352 353 354 355 356 357 358 359 360 361 362 363 364
  if (var_quant_scales_->empty()) {
    auto& scales = Get<VarQuantScale>("quant_var_scales");
    for (auto name : names) {
      if (scales.find(name) == scales.end()) {
        present = false;
        LogScaleIsMissingForVarName(name);
      }
    }
  } else {
    for (auto name : names) {
      if (var_quant_scales_->find(name) == var_quant_scales_->end()) {
        present = false;
        LogScaleIsMissingForVarName(name);
      }
365 366 367 368 369
    }
  }
  return present;
}

370
bool CPUQuantizePass::AreScalesPresentForNodes(
371
    std::initializer_list<Node*> nodes) const {
372
  bool present = true;
B
baoachun 已提交
373 374 375 376 377 378 379 380 381 382 383 384 385 386
  if (var_quant_scales_->empty()) {
    auto& scales = Get<VarQuantScale>("quant_var_scales");
    for (auto node : nodes) {
      if (scales.count(node->Name()) == 0) {
        present = false;
        LogScaleIsMissingForVarNode(node);
      }
    }
  } else {
    for (auto node : nodes) {
      if (var_quant_scales_->count(node->Name()) == 0) {
        present = false;
        LogScaleIsMissingForVarNode(node);
      }
387 388 389 390 391
    }
  }
  return present;
}

392
std::pair<bool, phi::DenseTensor> CPUQuantizePass::GetScaleDataByName(
393
    const std::string& name) const {
B
baoachun 已提交
394 395 396 397 398
  if (var_quant_scales_->empty()) {
    auto& scales = Get<VarQuantScale>("quant_var_scales");
    return scales.at(name);
  }
  return var_quant_scales_->at(name);
399 400
}

401
std::pair<bool, phi::DenseTensor> CPUQuantizePass::GetScaleDataForNode(
402
    const Node* node) const {
403 404 405
  return GetScaleDataByName(node->Name());
}

406 407
phi::DenseTensor CPUQuantizePass::GetScaleTensorByName(
    const std::string& name) const {
408
  return GetScaleDataByName(name).second;
409 410
}

411 412
phi::DenseTensor CPUQuantizePass::GetScaleTensorForNode(
    const Node* node) const {
413 414 415 416 417 418 419 420 421 422
  return GetScaleDataForNode(node).second;
}

double CPUQuantizePass::GetScaleValueForNode(const Node* node,
                                             bool* is_unsigned) const {
  auto scale_data = GetScaleDataForNode(node);
  if (is_unsigned != nullptr) *is_unsigned = scale_data.first;
  return scale_data.second.data<double>()[0];
}

423 424
bool CPUQuantizePass::IsOpDequantized(const Node* node) const {
  return node->Op()->Type() == "dequantize" ||
425
         platform::HasOpINT8DataType(node->Op());
426 427 428
}

bool CPUQuantizePass::IsOpQuantized(const Node* node) const {
429 430 431 432 433 434
  // return true only if all of outputs are ops and their are either quantize or
  // have int8 data type
  return all_of(node->outputs.begin(), node->outputs.end(), [](Node* output) {
    return (output->IsOp() && (output->Op()->Type() == "quantize" ||
                               platform::HasOpINT8DataType(output->Op())));
  });
435 436
}

B
baoachun 已提交
437
void CPUQuantizePass::GetQuantInfo(Graph* graph) const {
438 439
  GetInfoFromTheFirstOp(
      graph, "has_quant_info", "var_quant_scales", var_quant_scales_);
B
baoachun 已提交
440 441
}

442
void CPUQuantizePass::QuantizeConv(Graph* graph,
443
                                   const std::string& conv_type,
444 445 446 447
                                   bool with_residual_data) const {
  GraphPatternDetector gpd;
  auto pattern = gpd.mutable_pattern();
  patterns::ConvResidual conv_pattern{pattern, name_scope_};
448
  conv_pattern(conv_type, with_residual_data);
449 450 451 452 453 454 455 456

  int quantize_conv_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    VLOG(4) << "Quantize conv2d op";
    GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);

    // skip if should not be quantized
457
    if (!platform::HasOpINT8DataType(conv_op->Op())) {
458 459 460
      LogQuantizationDisabled(conv_op);
      return;
    }
461 462 463 464 465

    GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);

466
    auto has_output_scale = AreScalesPresentForNodes({conv_output});
W
Wojciech Uss 已提交
467
    if (with_residual_data && !has_output_scale) {
468 469 470 471
      MarkAndLogCannotQuantizeOp(
          conv_op,
          "Conv op with ResidualData input cannot be quantized "
          "without output scale.");
W
Wojciech Uss 已提交
472 473 474
      return;
    }

475
    if (with_residual_data) {
476 477
      GET_IR_NODE_FROM_SUBGRAPH(
          conv_residual_data, conv_residual_data, conv_pattern);
478
      if (!AreScalesPresentForNodes(
479
              {conv_input, conv_filter, conv_residual_data})) {
480 481
        MarkAndLogCannotQuantizeOp(conv_op,
                                   "No scale available for the operator");
482
        return;
483
      }
484 485 486 487 488

      bool is_residual_unsigned{false};
      auto residual_scale =
          GetScaleValueForNode(conv_residual_data, &is_residual_unsigned);

489 490 491 492 493 494 495
      QuantizeInput(g,
                    conv_op,
                    conv_residual_data,
                    "ResidualData",
                    residual_scale,
                    is_residual_unsigned,
                    "Scale_in_eltwise");
496
    } else {
497
      if (!AreScalesPresentForNodes({conv_input, conv_filter})) {
498 499
        MarkAndLogCannotQuantizeOp(conv_op,
                                   "No scale available for the operator");
500
        return;
501
      }
502 503
    }

504 505
    bool is_input_unsigned{false};
    auto input_scale = GetScaleValueForNode(conv_input, &is_input_unsigned);
506 507 508 509 510 511 512
    QuantizeInput(g,
                  conv_op,
                  conv_input,
                  "Input",
                  input_scale,
                  is_input_unsigned,
                  "Scale_in");
513

514
    auto filter_scale_tensor = GetScaleTensorForNode(conv_filter);
515
    EigenVectorArrayMap eigen_tensor{filter_scale_tensor.data<double>(),
516
                                     filter_scale_tensor.numel()};
517 518 519 520 521 522 523 524 525 526

    // If the scale value of a weight is already multiplied by S8_MAX, it does
    // not need to be multiplied again
    if (std::find(change_weight_->begin(),
                  change_weight_->end(),
                  conv_filter->Name()) == change_weight_->end()) {
      eigen_tensor *= static_cast<double>(S8_MAX);
      change_weight_->push_back(conv_filter->Name());
    }

527 528 529 530 531 532
    std::vector<float> filter_scale{
        filter_scale_tensor.data<double>(),
        filter_scale_tensor.data<double>() + filter_scale_tensor.numel()};

    conv_op->Op()->SetAttr("Scale_weights", filter_scale);

533
    // if quantization scale is missing for output tensor, return fp32 data
W
Wojciech Uss 已提交
534
    if (has_output_scale) {
535 536 537
      bool is_output_unsigned{false};
      auto output_scale =
          GetScaleValueForNode(conv_output, &is_output_unsigned);
538 539 540 541 542 543 544
      DequantizeOutput(g,
                       conv_op,
                       conv_output,
                       "Output",
                       output_scale,
                       is_output_unsigned,
                       "Scale_out");
545 546 547
    } else {
      conv_op->Op()->SetAttr("force_fp32_output", true);
    }
548

549
    // change threshold in bounded ReLu
550 551
    if (conv_op->Op()->GetAttrIfExists<std::string>("fuse_activation") ==
        "relu6") {
552
      float scale_out =
R
Ruibiao Chen 已提交
553
          PADDLE_GET_CONST(float, conv_op->Op()->GetAttr("Scale_out"));
554
      float threshold =
R
Ruibiao Chen 已提交
555
          PADDLE_GET_CONST(float, conv_op->Op()->GetAttr("fuse_alpha"));
556
      conv_op->Op()->SetAttr("fuse_alpha", scale_out * threshold);
557 558
    }

559 560 561 562 563 564
    ++quantize_conv_count;
  };

  gpd(graph, handler);
  AddStatis(quantize_conv_count);

565
  LogQuantizedOpsCounter(
566
      conv_type,
567
      quantize_conv_count,
568
      ((with_residual_data) ? "with residual connection" : ""));
569 570
}

571
void CPUQuantizePass::QuantizeFc(Graph* graph, bool with_residual_data) const {
M
Michał Gallus 已提交
572 573 574
  GraphPatternDetector gpd;
  auto pattern = gpd.mutable_pattern();
  patterns::FCMKLDNN fc_pattern{pattern, name_scope_};
575
  fc_pattern(with_residual_data);
M
Michał Gallus 已提交
576 577 578 579

  int quantize_fc_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
580 581
    VLOG(4) << "Quantize fc op " << (with_residual_data ? "with" : "without")
            << " residual data";
M
Michał Gallus 已提交
582 583 584
    GET_IR_NODE_FROM_SUBGRAPH(fc, fc, fc_pattern);

    // skip if should not be quantized
585
    if (!platform::HasOpINT8DataType(fc->Op())) {
586 587 588
      LogQuantizationDisabled(fc);
      return;
    }
589

590
    if (!fc->Op()->GetAttrIfExists<bool>("use_mkldnn")) {
591
      MarkAndLogCannotQuantizeOp(fc, "use_mkldnn attribute set to false");
M
Michał Gallus 已提交
592
      return;
593
    }
M
Michał Gallus 已提交
594 595 596 597 598

    GET_IR_NODE_FROM_SUBGRAPH(weights, weights, fc_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(input, input, fc_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(output, output, fc_pattern);

599
    if (!AreScalesPresentForNodes({input, weights})) {
600
      MarkAndLogCannotQuantizeOp(fc, "No scale available for the operator");
601 602
      return;
    }
603

604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623
    if (with_residual_data) {
      GET_IR_NODE_FROM_SUBGRAPH(residual_data, residual_data, fc_pattern);
      if (!AreScalesPresentForNodes({residual_data})) {
        MarkAndLogCannotQuantizeOp(fc, "No scale available for the operator");
        return;
      }

      bool is_residual_unsigned{false};
      auto residual_scale =
          GetScaleValueForNode(residual_data, &is_residual_unsigned);

      QuantizeInput(g,
                    fc,
                    residual_data,
                    "ResidualData",
                    residual_scale,
                    is_residual_unsigned,
                    "Scale_in_eltwise");
    }

624 625
    bool is_input_unsigned{false};
    auto input_scale = GetScaleValueForNode(input, &is_input_unsigned);
626 627
    QuantizeInput(
        g, fc, input, "Input", input_scale, is_input_unsigned, "Scale_in");
M
Michał Gallus 已提交
628

629
    auto weight_scale_tensor = GetScaleTensorForNode(weights);
M
Michał Gallus 已提交
630
    EigenVectorArrayMap eigen_tensor{weight_scale_tensor.data<double>(),
631
                                     weight_scale_tensor.numel()};
M
Michał Gallus 已提交
632 633 634 635 636 637 638
    eigen_tensor *= static_cast<double>(S8_MAX);
    std::vector<float> filter_scale{
        weight_scale_tensor.data<double>(),
        weight_scale_tensor.data<double>() + weight_scale_tensor.numel()};

    fc->Op()->SetAttr("Scale_weights", filter_scale);

639
    // if quantization scale is missing for output tensor, return fp32 data
640
    if (AreScalesPresentForNodes({output})) {
641 642
      bool is_output_unsigned{false};
      auto output_scale = GetScaleValueForNode(output, &is_output_unsigned);
643 644
      DequantizeOutput(
          g, fc, output, "Out", output_scale, is_output_unsigned, "Scale_out");
645 646 647
    } else {
      fc->Op()->SetAttr("force_fp32_output", true);
    }
M
Michał Gallus 已提交
648 649 650 651 652 653

    ++quantize_fc_count;
  };

  gpd(graph, handler);
  AddStatis(quantize_fc_count);
654 655 656
  LogQuantizedOpsCounter("fc",
                         quantize_fc_count,
                         with_residual_data ? "with residual connection" : "");
M
Michał Gallus 已提交
657 658
}

659 660 661 662 663 664 665 666 667 668 669 670 671
void CPUQuantizePass::QuantizePool(Graph* graph) const {
  GraphPatternDetector gpd;
  auto pattern = gpd.mutable_pattern();
  patterns::Pool pool_pattern{pattern, name_scope_};
  pool_pattern();

  int quantize_pool_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    VLOG(4) << "Quantize pool2d op";
    GET_IR_NODE_FROM_SUBGRAPH(pool_op, pool_op, pool_pattern);

    // skip if should not be quantized
672
    if (!platform::HasOpINT8DataType(pool_op->Op())) {
673 674 675
      LogQuantizationDisabled(pool_op);
      return;
    }
676 677 678 679

    GET_IR_NODE_FROM_SUBGRAPH(pool_input, pool_input, pool_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(pool_output, pool_output, pool_pattern);

680
    if (!AreScalesPresentForNodes({pool_input, pool_output})) {
681 682
      MarkAndLogCannotQuantizeOp(pool_op,
                                 "No scale available for the operator");
683 684
      return;
    }
685

686 687
    bool is_input_unsigned{false};
    auto input_scale = GetScaleValueForNode(pool_input, &is_input_unsigned);
688 689
    QuantizeInput(g, pool_op, pool_input, "X", input_scale, is_input_unsigned);

690 691
    bool is_output_unsigned{false};
    auto output_scale = GetScaleValueForNode(pool_output, &is_output_unsigned);
692 693
    DequantizeOutput(
        g, pool_op, pool_output, "Out", output_scale, is_output_unsigned);
694 695 696 697 698 699

    ++quantize_pool_count;
  };

  gpd(graph, handler);
  AddStatis(quantize_pool_count);
700
  LogQuantizedOpsCounter("pool2d", quantize_pool_count);
701 702
}

703 704 705 706 707 708 709 710 711 712 713 714 715
void CPUQuantizePass::QuantizeConcat(Graph* graph) const {
  GraphPatternDetector gpd;
  auto pattern = gpd.mutable_pattern();
  patterns::Concat concat_pattern{pattern, name_scope_};
  concat_pattern();

  int quantize_concat_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    VLOG(4) << "Quantize concat op";
    GET_IR_NODE_FROM_SUBGRAPH(concat_op, concat_op, concat_pattern);

    // skip if should not be quantized
716
    if (!platform::HasOpINT8DataType(concat_op->Op())) {
717 718 719
      LogQuantizationDisabled(concat_op);
      return;
    }
720

721 722 723 724 725 726 727 728 729 730 731 732 733 734
    bool are_all_inputs_unsigned{true};
    // if all inputs were unsigned, then the output was set to unsigned
    // during the scale calculation step
    auto inputs = concat_op->inputs;
    for (size_t i = 0; i < inputs.size(); i++) {
      if (AreScalesPresentForVarNames({inputs[i]->Name()})) {
        auto scale_data = GetScaleDataByName(inputs[i]->Name());
        if (scale_data.first == false) {
          are_all_inputs_unsigned = false;
          break;
        }
      }
    }

735 736
    GET_IR_NODE_FROM_SUBGRAPH(concat_out, concat_out, concat_pattern);

737
    if (!AreScalesPresentForNodes({concat_out})) {
738 739
      MarkAndLogCannotQuantizeOp(concat_op,
                                 "No scale available for the operator");
740 741
      return;
    }
742

743
    auto output_scale = GetScaleValueForNode(concat_out);
744

745
    QuantizeInputs(g, concat_op, "X", are_all_inputs_unsigned);
746

747 748
    DequantizeOutput(
        g, concat_op, concat_out, "Out", output_scale, are_all_inputs_unsigned);
749 750 751 752 753
    ++quantize_concat_count;
  };

  gpd(graph, handler);
  AddStatis(quantize_concat_count);
754
  LogQuantizedOpsCounter("concat", quantize_concat_count);
755 756
}

757 758 759 760 761 762 763 764 765 766 767 768 769
void CPUQuantizePass::QuantizePriorBox(Graph* graph) const {
  GraphPatternDetector gpd;
  auto pattern = gpd.mutable_pattern();
  patterns::PriorBox prior_box_pattern{pattern, name_scope_};
  prior_box_pattern();

  int quantize_prior_box_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    VLOG(4) << "Quantize prior_box op";
    GET_IR_NODE_FROM_SUBGRAPH(prior_box_op, prior_box_op, prior_box_pattern);

    // skip if should not be quantized
770
    if (!platform::HasOpINT8DataType(prior_box_op->Op())) {
771 772 773
      LogQuantizationDisabled(prior_box_op);
      return;
    }
774

775 776
    GET_IR_NODE_FROM_SUBGRAPH(
        prior_box_input, prior_box_input, prior_box_pattern);
777

778
    if (!AreScalesPresentForNodes({prior_box_input})) {
779 780
      MarkAndLogCannotQuantizeOp(prior_box_op,
                                 "No scale available for the operator");
781 782
      return;
    }
783

784 785 786
    bool is_input_unsigned{false};
    auto input_scale =
        GetScaleValueForNode(prior_box_input, &is_input_unsigned);
787 788 789 790 791
    QuantizeInput(g,
                  prior_box_op,
                  prior_box_input,
                  "Input",
                  input_scale,
792 793 794 795 796 797 798
                  is_input_unsigned);

    ++quantize_prior_box_count;
  };

  gpd(graph, handler);
  AddStatis(quantize_prior_box_count);
799
  LogQuantizedOpsCounter("prior_box", quantize_prior_box_count);
800 801
}

802 803 804
void CPUQuantizePass::QuantizeImmutable(Graph* graph,
                                        const std::string& immutable_type,
                                        const std::string& input_name) const {
805 806
  GraphPatternDetector gpd;
  auto pattern = gpd.mutable_pattern();
807 808
  patterns::Immutable immutable_pattern{pattern, name_scope_};
  immutable_pattern(immutable_type, input_name);
809

810
  int quantize_immutable_count = 0;
811 812
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
813 814
    VLOG(4) << "Quantize " + immutable_type + " op";
    GET_IR_NODE_FROM_SUBGRAPH(immutable_op, immutable_op, immutable_pattern);
815 816

    // skip if should not be quantized
817 818
    if (!platform::HasOpINT8DataType(immutable_op->Op())) {
      LogQuantizationDisabled(immutable_op);
819 820
      return;
    }
821 822 823
    GET_IR_NODE_FROM_SUBGRAPH(prev_op, prev_op, immutable_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(immutable_in, immutable_in, immutable_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(immutable_out, immutable_out, immutable_pattern);
824

825
    // skip if prev op and next op is not quantized
826 827
    if (!IsOpDequantized(prev_op) && !IsOpQuantized(immutable_out)) {
      MarkAndLogCannotQuantizeOp(immutable_op,
828
                                 "No other quantizable operators nearby");
829 830 831
      return;
    }

832 833 834 835 836 837 838
    // skip if the dtype of immutable_in is not float32
    auto dtype = immutable_in->Var()->GetDataType();
    if (dtype != proto::VarType::FP32) {
      MarkAndLogCannotQuantizeOp(immutable_op, "The input dtype is not float.");
      return;
    }

839 840
    if (!AreScalesPresentForNodes({immutable_out})) {
      MarkAndLogCannotQuantizeOp(immutable_op,
841
                                 "No scale available for the operator");
842
      return;
843
    }
844

845
    bool is_input_unsigned{false};
846 847 848 849 850 851 852 853
    auto input_scale = GetScaleValueForNode(immutable_out, &is_input_unsigned);

    QuantizeInput(g,
                  immutable_op,
                  immutable_in,
                  input_name,
                  input_scale,
                  is_input_unsigned);
854

855 856
    bool is_output_unsigned{false};
    auto output_scale =
857
        GetScaleValueForNode(immutable_out, &is_output_unsigned);
P
Paulina Gacek 已提交
858 859 860 861 862 863 864 865 866 867 868
    if (immutable_type == "split") {  // ops with multiple outputs
      DequantizeOutputs(
          g, immutable_op, "Out", output_scale, is_output_unsigned);
    } else {
      DequantizeOutput(g,
                       immutable_op,
                       immutable_out,
                       "Out",
                       output_scale,
                       is_output_unsigned);
    }
869
    ++quantize_immutable_count;
870 871 872
  };

  gpd(graph, handler);
873 874
  AddStatis(quantize_immutable_count);
  LogQuantizedOpsCounter(immutable_type, quantize_immutable_count);
Z
Zuza 已提交
875 876
}

877
void CPUQuantizePass::QuantizeMatmul(Graph* graph, bool with_residual) const {
878 879
  GraphPatternDetector gpd;
  auto pattern = gpd.mutable_pattern();
880
  patterns::MatmulWithInputOps matmul_pattern{pattern, name_scope_};
881
  matmul_pattern(with_residual);
882 883 884 885 886 887 888 889

  int quantize_matmul_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    VLOG(4) << "Quantize matmul op";
    GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, matmul_pattern);

    // skip if should not be quantized
890
    if (!platform::HasOpINT8DataType(matmul_op->Op())) {
891
      LogQuantizationDisabled(matmul_op);
892 893 894 895 896 897
      return;
    }
    GET_IR_NODE_FROM_SUBGRAPH(prev_op_x, prev_op_x, matmul_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(prev_op_y, prev_op_y, matmul_pattern);

    // skip if prev ops are not quantized
898
    if (!IsOpDequantized(prev_op_x) && !IsOpDequantized(prev_op_y)) {
899 900
      MarkAndLogCannotQuantizeOp(matmul_op,
                                 "No other quantizable operators nearby");
901 902 903 904 905 906
      return;
    }
    GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, matmul_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, matmul_pattern);
    GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, matmul_pattern);

907 908 909 910 911 912 913 914 915
    auto has_output_scale = AreScalesPresentForNodes({matmul_out});
    if (with_residual && !has_output_scale) {
      MarkAndLogCannotQuantizeOp(
          matmul_op,
          "Matmul op with ResidualData input cannot be quantized "
          "without output scale.");
      return;
    }

916
    if (!AreScalesPresentForNodes({matmul_in_x, matmul_in_y})) {
917 918
      MarkAndLogCannotQuantizeOp(matmul_op,
                                 "No scale available for the operator");
919
      return;
920
    }
921

922 923 924
    bool is_x_unsigned{false}, is_y_unsigned{false};
    auto input_x_scale = GetScaleValueForNode(matmul_in_x, &is_x_unsigned);
    auto input_y_scale = GetScaleValueForNode(matmul_in_y, &is_y_unsigned);
925 926
    PADDLE_ENFORCE_EQ(is_x_unsigned,
                      is_y_unsigned,
927 928 929 930
                      platform::errors::InvalidArgument(
                          "Matmul inputs should have the same "
                          "attribute of signed/unsigned, but they "
                          "are different: x(%d), y(%d).",
931 932
                          is_x_unsigned,
                          is_y_unsigned));
933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954

    if (with_residual) {
      GET_IR_NODE_FROM_SUBGRAPH(
          matmul_residual_data, matmul_residual_data, matmul_pattern);
      if (!AreScalesPresentForNodes({matmul_residual_data})) {
        MarkAndLogCannotQuantizeOp(matmul_op,
                                   "No scale available for the operator");
        return;
      }
      bool is_residual_unsigned{false};
      auto residual_scale =
          GetScaleValueForNode(matmul_residual_data, &is_residual_unsigned);

      QuantizeInput(g,
                    matmul_op,
                    matmul_residual_data,
                    "ResidualData",
                    residual_scale,
                    is_residual_unsigned,
                    "Scale_in_eltwise");
    }

955 956 957 958 959 960
    QuantizeInput(g,
                  matmul_op,
                  matmul_in_x,
                  "X",
                  input_x_scale,
                  is_x_unsigned,
961
                  "Scale_x");
962 963 964 965 966 967
    QuantizeInput(g,
                  matmul_op,
                  matmul_in_y,
                  "Y",
                  input_y_scale,
                  is_y_unsigned,
968 969
                  "Scale_y");

970
    // if quantization scale is missing for output tensor, return fp32 data
971
    if (AreScalesPresentForNodes({matmul_out})) {
972 973
      bool is_output_unsigned{false};
      auto output_scale = GetScaleValueForNode(matmul_out, &is_output_unsigned);
974 975 976 977 978 979 980
      DequantizeOutput(g,
                       matmul_op,
                       matmul_out,
                       "Out",
                       output_scale,
                       is_output_unsigned,
                       "Scale_out");
981 982 983
    } else {
      matmul_op->Op()->SetAttr("force_fp32_output", true);
    }
984 985 986 987 988

    ++quantize_matmul_count;
  };
  gpd(graph, handler);
  AddStatis(quantize_matmul_count);
989 990 991
  LogQuantizedOpsCounter("matmul",
                         quantize_matmul_count,
                         (with_residual ? "with residual connection" : ""));
992 993
}

Z
Zuza 已提交
994
void CPUQuantizePass::QuantizeElementwise(
995
    Graph* graph, const std::string& elementwise_type) const {
996 997
  GraphPatternDetector gpd;
  auto pattern = gpd.mutable_pattern();
998
  patterns::ElementwiseOp elementwise_pattern{pattern, name_scope_};
999

1000
  elementwise_pattern(elementwise_type);
1001

Z
Zuza 已提交
1002
  int quantize_elementwise_count = 0;
1003 1004
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
Z
Zuza 已提交
1005
    VLOG(4) << "Quantize " + elementwise_type + " op";
1006 1007
    GET_IR_NODE_FROM_SUBGRAPH(
        elementwise_op, elementwise_op, elementwise_pattern);
1008 1009

    // skip if should not be quantized
Z
Zuza 已提交
1010 1011
    if (!platform::HasOpINT8DataType(elementwise_op->Op())) {
      LogQuantizationDisabled(elementwise_op);
1012 1013 1014
      return;
    }

1015 1016
    auto x_name = elementwise_op->Op()->Input("X");
    auto y_name = elementwise_op->Op()->Input("Y");
1017
    Node *elementwise_x{nullptr}, *elementwise_y{nullptr};
1018 1019 1020 1021 1022 1023 1024 1025 1026

    for (auto& input : elementwise_op->inputs) {
      if (input->Name() == x_name[0]) elementwise_x = input;
      if (input->Name() == y_name[0]) elementwise_y = input;
    }
    if (!elementwise_x || !elementwise_y) {
      return;
    }

1027 1028
    GET_IR_NODE_FROM_SUBGRAPH(
        elementwise_out, elementwise_out, elementwise_pattern);
1029

1030
    if (!AreScalesPresentForNodes(
Z
Zuza 已提交
1031
            {elementwise_x, elementwise_y, elementwise_out})) {
1032 1033
      MarkAndLogCannotQuantizeOp(elementwise_op,
                                 "No scale available for the operator");
1034 1035 1036 1037
      return;
    }

    bool is_x_unsigned{false}, is_y_unsigned{false};
Z
Zuza 已提交
1038 1039
    auto input_x_scale = GetScaleValueForNode(elementwise_x, &is_x_unsigned);
    auto input_y_scale = GetScaleValueForNode(elementwise_y, &is_y_unsigned);
1040 1041 1042

    // TODO(sfraczek): add support for different signness
    if (is_x_unsigned != is_y_unsigned) {
1043 1044
      MarkAndLogCannotQuantizeOp(
          elementwise_op, "Elementwise inputs must be of the same type.");
1045 1046 1047
      return;
    }

1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061
    QuantizeInput(g,
                  elementwise_op,
                  elementwise_x,
                  "X",
                  input_x_scale,
                  is_x_unsigned,
                  "Scale_x");
    QuantizeInput(g,
                  elementwise_op,
                  elementwise_y,
                  "Y",
                  input_y_scale,
                  is_y_unsigned,
                  "Scale_y");
1062

1063 1064
    bool is_output_unsigned{false};
    auto output_scale =
Z
Zuza 已提交
1065
        GetScaleValueForNode(elementwise_out, &is_output_unsigned);
1066

1067 1068 1069 1070 1071 1072 1073
    DequantizeOutput(g,
                     elementwise_op,
                     elementwise_out,
                     "Out",
                     output_scale,
                     is_output_unsigned,
                     "Scale_out");
1074

Z
Zuza 已提交
1075
    ++quantize_elementwise_count;
1076 1077
  };
  gpd(graph, handler);
Z
Zuza 已提交
1078
  AddStatis(quantize_elementwise_count);
1079
  LogQuantizedOpsCounter(elementwise_type, quantize_elementwise_count);
1080 1081
}

1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103
void CPUQuantizePass::QuantizeFusionGru(Graph* graph) const {
  GraphPatternDetector gpd;
  patterns::FusionGru pattern{gpd.mutable_pattern(), name_scope_};
  pattern();

  int quantize_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    VLOG(4) << "Quantize fusion_gru op";
    GET_IR_NODE_FROM_SUBGRAPH(op, op, pattern);

    // skip if should not be quantized
    if (!platform::HasOpINT8DataType(op->Op())) {
      LogQuantizationDisabled(op);
      return;
    }

    GET_IR_NODE_FROM_SUBGRAPH(x, x, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(weight_h, weight_h, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(weight_x, weight_x, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(out, out, pattern);

1104
    if (!AreScalesPresentForNodes({x, weight_x})) {
1105
      MarkAndLogCannotQuantizeOp(op, "No scale available for the operator");
1106 1107 1108 1109 1110 1111 1112 1113 1114
      return;
    }

    bool is_x_unsigned{false};
    auto input_x_scale = GetScaleValueForNode(x, &is_x_unsigned);

    double input_x_shift{128.};
    if (is_x_unsigned) input_x_shift = 0.;

1115 1116 1117 1118 1119 1120 1121 1122 1123
    QuantizeInput(g,
                  op,
                  x,
                  "X",
                  input_x_scale,
                  is_x_unsigned,
                  "Scale_data",
                  input_x_shift,
                  "Shift_data");
1124 1125 1126

    auto weight_scale_tensor = GetScaleTensorForNode(weight_x);
    EigenVectorArrayMap eigen_tensor{weight_scale_tensor.data<double>(),
1127
                                     weight_scale_tensor.numel()};
1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140
    eigen_tensor *= static_cast<double>(S8_MAX);
    std::vector<float> scale_weights{
        weight_scale_tensor.data<double>(),
        weight_scale_tensor.data<double>() + weight_scale_tensor.numel()};

    op->Op()->SetAttr("Scale_weights", scale_weights);
    // return fp32 data
    op->Op()->SetAttr("force_fp32_output", true);

    ++quantize_count;
  };
  gpd(graph, handler);
  AddStatis(quantize_count);
1141
  LogQuantizedOpsCounter("fusion_gru", quantize_count);
1142 1143
}

1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167
void CPUQuantizePass::QuantizeMultiGru(Graph* graph) const {
  GraphPatternDetector gpd;
  patterns::MultiGru pattern{gpd.mutable_pattern(), name_scope_};
  pattern();

  int quantize_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    VLOG(4) << "Quantize multi_gru op";
    GET_IR_NODE_FROM_SUBGRAPH(gru, gru, pattern);

    // skip if should not be quantized
    if (!platform::HasOpINT8DataType(gru->Op())) {
      LogQuantizationDisabled(gru);
      return;
    }

    GET_IR_NODE_FROM_SUBGRAPH(x, x, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(wx, wx, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(h, h, pattern);

    auto wx_names = gru->Op()->Input("WeightX");
    if (!AreScalesPresentForNodes({x}) ||
        !AreScalesPresentForVarNames(wx_names)) {
1168
      MarkAndLogCannotQuantizeOp(gru, "No scale available for the operator");
1169 1170 1171 1172 1173 1174 1175 1176 1177
      return;
    }

    bool is_x_unsigned{false};
    auto input_x_scale = GetScaleValueForNode(x, &is_x_unsigned);

    double input_x_shift{128.};
    if (is_x_unsigned) input_x_shift = 0.;

1178 1179 1180 1181 1182 1183 1184 1185 1186
    QuantizeInput(g,
                  gru,
                  x,
                  "X",
                  input_x_scale,
                  is_x_unsigned,
                  "Scale_data",
                  input_x_shift,
                  "Shift_data");
1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197

    auto* scope = param_scope();
    int wx_size = wx_names.size();
    std::vector<std::string> w_scale_var_names;
    for (int i = 0; i < wx_size; ++i) {
      auto scale_tensor_src = GetScaleTensorByName(wx_names[i]);
      EigenVectorArrayMap eigen_tensor_src{scale_tensor_src.data<double>(),
                                           scale_tensor_src.numel()};

      VarDesc scale_var_desc(patterns::PDNodeName("multi_gru", "w_scale"));

1198
      scale_var_desc.SetShape(phi::vectorize(scale_tensor_src.dims()));
1199 1200 1201 1202 1203 1204
      scale_var_desc.SetDataType(proto::VarType::FP32);
      scale_var_desc.SetLoDLevel(scale_tensor_src.lod().size());
      scale_var_desc.SetPersistable(true);
      auto* w_scale_node = g->CreateVarNode(&scale_var_desc);

      auto* w_scale_tensor_dst =
1205
          scope->Var(w_scale_node->Name())->GetMutable<phi::DenseTensor>();
1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224
      w_scale_tensor_dst->Resize(scale_tensor_src.dims());
      auto* dst_data =
          w_scale_tensor_dst->mutable_data<float>(platform::CPUPlace());
      EigenVectorArrayMapFloat eigen_tensor_dst{dst_data,
                                                w_scale_tensor_dst->numel()};
      eigen_tensor_dst =
          eigen_tensor_src.cast<float>() * static_cast<float>(S8_MAX);
      w_scale_var_names.push_back(w_scale_node->Name());
      IR_NODE_LINK_TO(w_scale_node, gru);
    }

    gru->Op()->SetInput("Scale_weights", w_scale_var_names);
    // return fp32 data
    gru->Op()->SetAttr("force_fp32_output", true);

    ++quantize_count;
  };
  gpd(graph, handler);
  AddStatis(quantize_count);
1225
  LogQuantizedOpsCounter("multi_gru", quantize_count);
1226 1227
}

1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252
void CPUQuantizePass::QuantizeFusionLSTM(Graph* graph) const {
  GraphPatternDetector gpd;
  patterns::FusionLSTM pattern{gpd.mutable_pattern(), name_scope_};
  pattern();

  int quantize_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    VLOG(4) << "Quantize fusion_lstm op";
    GET_IR_NODE_FROM_SUBGRAPH(op, op, pattern);

    // skip if should not be quantized
    if (!platform::HasOpINT8DataType(op->Op())) {
      LogQuantizationDisabled(op);
      return;
    }

    GET_IR_NODE_FROM_SUBGRAPH(x, x, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(weight_h, weight_h, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(weight_x, weight_x, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(hidden, hidden, pattern);
    GET_IR_NODE_FROM_SUBGRAPH(cell, cell, pattern);

    // Starting from here there maybe issues
    if (!AreScalesPresentForNodes({x, weight_x})) {
1253
      MarkAndLogCannotQuantizeOp(op, "No scale available for the operator");
1254 1255 1256 1257 1258 1259 1260 1261 1262
      return;
    }

    bool is_x_unsigned{false};
    auto input_x_scale = GetScaleValueForNode(x, &is_x_unsigned);

    double input_x_shift{128.};
    if (is_x_unsigned) input_x_shift = 0.;

1263 1264 1265 1266 1267 1268 1269 1270 1271
    QuantizeInput(g,
                  op,
                  x,
                  "X",
                  input_x_scale,
                  is_x_unsigned,
                  "Scale_data",
                  input_x_shift,
                  "Shift_data");
1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288

    auto weight_scale_tensor = GetScaleTensorForNode(weight_x);
    EigenVectorArrayMap eigen_tensor{weight_scale_tensor.data<double>(),
                                     weight_scale_tensor.numel()};
    eigen_tensor *= static_cast<double>(S8_MAX);
    std::vector<float> scale_weights{
        weight_scale_tensor.data<double>(),
        weight_scale_tensor.data<double>() + weight_scale_tensor.numel()};

    op->Op()->SetAttr("Scale_weights", scale_weights);
    // return fp32 data
    op->Op()->SetAttr("force_fp32_output", true);

    ++quantize_count;
  };
  gpd(graph, handler);
  AddStatis(quantize_count);
1289
  LogQuantizedOpsCounter("fusion_lstm", quantize_count);
1290 1291
}

1292
void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
1293
  VLOG(3) << "Quantizing the graph.";
1294 1295
  PADDLE_ENFORCE_NOT_NULL(
      graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
1296
  FusePassBase::Init(name_scope_, graph);
1297

1298 1299 1300
  PADDLE_ENFORCE_NOT_NULL(
      param_scope(),
      platform::errors::InvalidArgument("Scope cannot be nullptr."));
1301

B
baoachun 已提交
1302
  GetQuantInfo(graph);
1303 1304 1305 1306
  QuantizeConv(graph, "conv2d", false /* with_residual_data */);
  QuantizeConv(graph, "conv2d", true /* with_residual_data */);
  QuantizeConv(graph, "fused_conv2d", false /* with_residual_data */);
  QuantizeConv(graph, "fused_conv2d", true /* with_residual_data */);
1307
  QuantizePool(graph);
1308
  QuantizeConcat(graph);
1309
  QuantizePriorBox(graph);
1310 1311
  QuantizeFc(graph, false /* with_residual_data */);
  QuantizeFc(graph, true /* with_residual_data */);
1312 1313
  QuantizeMatmul(graph, false /* with_residual_data */);
  QuantizeMatmul(graph, true /* with_residual_data */);
1314 1315 1316 1317 1318
  QuantizeImmutable(graph, "reshape2", "X");
  QuantizeImmutable(graph, "transpose2", "X");
  QuantizeImmutable(graph, "slice", "Input");
  QuantizeImmutable(graph, "nearest_interp", "X");
  QuantizeImmutable(graph, "nearest_interp_v2", "X");
P
Paulina Gacek 已提交
1319
  QuantizeImmutable(graph, "split", "X");
Z
Zuza 已提交
1320 1321
  QuantizeElementwise(graph, "elementwise_add");
  QuantizeElementwise(graph, "elementwise_mul");
1322
  QuantizeElementwise(graph, "elementwise_sub");
1323
  QuantizeFusionGru(graph);
1324
  QuantizeMultiGru(graph);
1325
  QuantizeFusionLSTM(graph);
1326 1327 1328 1329 1330 1331 1332 1333
}

}  // namespace ir
}  // namespace framework
}  // namespace paddle

REGISTER_PASS(cpu_quantize_pass, paddle::framework::ir::CPUQuantizePass)
    .RequirePassAttr("quant_var_scales");