cpu_quantize_squash_pass_tester.cc 17.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h"
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
#include <gtest/gtest.h>
#include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/platform/place.h"

namespace paddle {
namespace framework {
namespace ir {

void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
           const std::vector<std::string>& inputs,
           const std::vector<std::string>& outputs, bool use_mkldnn,
           float scale = 0) {
  auto* op = prog->MutableBlock(0)->AppendOp();
  op->SetType(type);
  op->SetAttr("use_mkldnn", use_mkldnn);
  op->SetAttr("name", name);
  if (type == "conv2d") {
33
    op->SetAttr("Scale_out", scale);
34 35 36 37 38 39 40 41 42 43 44 45
    op->SetInput("Input", {inputs[0]});
    if (inputs.size() > 1) op->SetInput("Filter", {inputs[1]});
    if (inputs.size() > 2) op->SetInput("Bias", {inputs[2]});
    op->SetOutput("Output", {outputs[0]});
  } else if (type == "quantize") {
    op->SetInput("Input", {inputs[0]});
    op->SetOutput("Output", {outputs[0]});
    op->SetAttr("Scale", scale);
  } else if (type == "dequantize") {
    op->SetInput("Input", {inputs[0]});
    op->SetOutput("Output", {outputs[0]});
    op->SetAttr("Scale", scale);
46 47 48 49 50 51 52
  } else if (type == "requantize") {
    op->SetInput("Input", {inputs[0]});
    op->SetOutput("Output", {outputs[0]});
    op->SetAttr("Scale_out", scale);
  } else if (type == "concat") {
    op->SetInput("X", inputs);
    op->SetOutput("Out", outputs);
53 54 55 56 57 58 59 60 61
  } else if (type == "fc") {
    op->SetInput("Input", {inputs[0]});
    PADDLE_ENFORCE_EQ(inputs.size(), 2UL,
                      platform::errors::InvalidArgument(
                          "The fc inputs should contain input and weights, but "
                          "now the size of inputs is %d",
                          inputs.size()));
    op->SetInput("W", {inputs[1]});
    op->SetOutput("Out", outputs);
62 63 64 65
  }
}

// (a,w1,b1)->Conv1->d
66 67
// d->Dequant(scale1)->e
// e->Quant(scale2)->f
68
// (f,w2,b2)->Conv2->i
69 70
ProgramDesc BuildConvRequantProgramDesc(bool use_mkldnn, float scale_out,
                                        float scale1, float scale2) {
71 72 73 74 75 76 77 78 79
  ProgramDesc prog;
  for (auto& v : std::initializer_list<std::string>(
           {"a", "w1", "b1", "d", "e", "f", "w2", "b2", "i"})) {
    auto* var = prog.MutableBlock(0)->Var(v);
    if (v.find("w") == 0 || v.find("b") == 0) {
      var->SetPersistable(true);
    }
  }

80 81
  SetOp(&prog, "conv2d", "Conv1", {"a", "w1", "b1"}, {"d"}, use_mkldnn,
        scale_out);
82 83
  SetOp(&prog, "dequantize", "Dequant", {"d"}, {"e"}, use_mkldnn, scale1);
  SetOp(&prog, "quantize", "Quant", {"e"}, {"f"}, use_mkldnn, scale2);
84 85
  SetOp(&prog, "conv2d", "Conv2", {"f", "w2", "b2"}, {"i"}, use_mkldnn,
        scale_out);
86 87 88 89 90
  return prog;
}

static const std::initializer_list<std::string> variable_names{
    "a", "b", "c", "d", "e", "f", "g", "h"};
91

92
// a->Conv1->b
93 94
// b->Dequant(scale1)->c
// c->Quant1(scale2)->d and d->Conv2->e
95
// c->Conv3->f
96 97 98 99
// c->Quant2(scale3)->g and g->Conv4->h
ProgramDesc BuildConvMultiOutputProgramDesc(bool use_mkldnn, float scale_out,
                                            float scale1, float scale2,
                                            float scale3) {
100 101 102 103 104
  ProgramDesc prog;
  for (auto& v : variable_names) {
    prog.MutableBlock(0)->Var(v);
  }

105
  SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, scale_out);
106 107 108
  SetOp(&prog, "dequantize", "Dequant", {"b"}, {"c"}, use_mkldnn, scale1);

  SetOp(&prog, "quantize", "Quant1", {"c"}, {"d"}, use_mkldnn, scale2);
109
  SetOp(&prog, "conv2d", "Conv2", {"d"}, {"e"}, use_mkldnn, scale_out);
110

111
  SetOp(&prog, "conv2d", "Conv3", {"c"}, {"f"}, use_mkldnn, scale_out);
112 113

  SetOp(&prog, "quantize", "Quant2", {"c"}, {"g"}, use_mkldnn, scale3);
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
  SetOp(&prog, "conv2d", "Conv4", {"g"}, {"h"}, use_mkldnn, scale_out);

  return prog;
}

//  a->Conv1->b->Requant(scale1)->c
//  d->Conv2->e->Requant(scale2)->f
//  {c,f}->Concat
ProgramDesc BuildConvsRequantConcatProgramDesc(bool use_mkldnn, float scale_out,
                                               float scale1, float scale2) {
  ProgramDesc prog;
  for (auto& v : variable_names) {
    prog.MutableBlock(0)->Var(v);
  }

  SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, scale_out);
  SetOp(&prog, "requantize", "Requant1", {"b"}, {"c"}, use_mkldnn, scale1);

  SetOp(&prog, "conv2d", "Conv2", {"d"}, {"e"}, use_mkldnn, scale_out);
  SetOp(&prog, "requantize", "Requant2", {"e"}, {"f"}, use_mkldnn, scale2);

  SetOp(&prog, "concat", "Concat", {"c"}, {"f"}, use_mkldnn);

  return prog;
}

// a->Concat->b
// b->Dequant(scale1)->c
// c->Quant(scale2)->d
// d->Conv->e
ProgramDesc BuildConcatDequantQuantProgramDesc(bool use_mkldnn, float scale_out,
                                               float scale1, float scale2) {
  ProgramDesc prog;
  for (auto& v : variable_names) {
    prog.MutableBlock(0)->Var(v);
  }
150

151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
  SetOp(&prog, "concat", "Concat", {"a"}, {"b"}, use_mkldnn);
  SetOp(&prog, "dequantize", "Dequant", {"b"}, {"c"}, use_mkldnn, scale1);
  SetOp(&prog, "quantize", "Quant", {"c"}, {"d"}, use_mkldnn, scale2);
  SetOp(&prog, "conv2d", "Conv2", {"d"}, {"e"}, use_mkldnn, scale_out);
  return prog;
}

// a->Conv1->b
// b->Requant1(Scale1)->c
// b->Requant2(Scale2)->d
ProgramDesc BuildConvMultiRequantProgramDesc(bool use_mkldnn, float scale_out,
                                             float scale1, float scale2) {
  ProgramDesc prog;
  for (auto& v : variable_names) {
    prog.MutableBlock(0)->Var(v);
  }
  SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, scale_out);
  SetOp(&prog, "requantize", "Requant1", {"b"}, {"c"}, use_mkldnn, scale1);
  SetOp(&prog, "requantize", "Requant2", {"b"}, {"d"}, use_mkldnn, scale2);
170 171 172
  return prog;
}

173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
// a->Conv1->b
// b->Dequant1(Scale1)->c
// c->Concat
ProgramDesc BuildConvDequantConcatProgramDesc(bool use_mkldnn, float scale_out,
                                              float scale) {
  ProgramDesc prog;
  for (auto& v : variable_names) {
    prog.MutableBlock(0)->Var(v);
  }
  SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, scale_out);
  SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, use_mkldnn, scale);
  SetOp(&prog, "concat", "Concat1", {"c"}, {"d"}, use_mkldnn);
  return prog;
}

188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217
// a->fc->b
// b->Dequant1->c
// c->Concat1->d
ProgramDesc BuildFcDequantConcatProgramDesc(bool use_mkldnn, float scale_out,
                                            float scale) {
  ProgramDesc prog;
  for (auto& v : variable_names) {
    prog.MutableBlock(0)->Var(v);
  }
  SetOp(&prog, "fc", "Fc1", {"a", "w1"}, {"b"}, use_mkldnn, scale_out);
  SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, use_mkldnn, scale);
  SetOp(&prog, "concat", "Concat1", {"c"}, {"d"}, use_mkldnn);
  return prog;
}

// a->fc->b
// b->Dequant1->c
// b->concat->d
ProgramDesc BuildFcDequantFcProgramDesc(bool use_mkldnn, float scale_out,
                                        float scale) {
  ProgramDesc prog;
  for (auto& v : variable_names) {
    prog.MutableBlock(0)->Var(v);
  }
  SetOp(&prog, "fc", "Fc1", {"a", "w1"}, {"b"}, use_mkldnn, scale_out);
  SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, use_mkldnn, scale);
  SetOp(&prog, "concat", "Concat1", {"b"}, {"d"}, use_mkldnn);
  return prog;
}

218 219 220 221 222 223 224 225 226 227 228 229 230 231 232
// a->Conv1->b
// b->Dequant1(Scale1)->c
// b->Conv2->d
ProgramDesc BuildConvDequantConvProgramDesc(bool use_mkldnn, float scale_out,
                                            float scale) {
  ProgramDesc prog;
  for (auto& v : variable_names) {
    prog.MutableBlock(0)->Var(v);
  }
  SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, scale_out);
  SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, use_mkldnn, scale);
  SetOp(&prog, "conv2d", "Conv2", {"b"}, {"d"}, use_mkldnn);
  return prog;
}

233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254
// a->concat->b
// b->Quant1(Scale1)->c
// b->Quant2(Scale2)->d
// b->concat->e
// c->fc->f
// d->fc->g
ProgramDesc BuildMultipleQuantizeProgramDesc(bool use_mkldnn, float first_scale,
                                             float second_scale) {
  ProgramDesc prog;
  for (auto& v : variable_names) {
    prog.MutableBlock(0)->Var(v);
  }
  SetOp(&prog, "concat", "Concat1", {"a"}, {"b"}, use_mkldnn);
  SetOp(&prog, "quantize", "Quantize1", {"b"}, {"c"}, use_mkldnn, first_scale);
  SetOp(&prog, "quantize", "Quantize2", {"b"}, {"d"}, use_mkldnn, second_scale);
  SetOp(&prog, "concat", "Concat2", {"b"}, {"e"}, use_mkldnn);
  SetOp(&prog, "fc", "Fc1", {"c", "w1"}, {"f"}, use_mkldnn, first_scale);
  SetOp(&prog, "fc", "Fc2", {"d", "w2"}, {"g"}, use_mkldnn, second_scale);

  return prog;
}

255 256 257 258
void InitTensorHolder(Scope* scope, const paddle::platform::Place& place,
                      const char* var_name) {
  auto x = scope->Var(var_name);
  auto tensor = x->GetMutable<LoDTensor>();
259
  tensor->mutable_data(place, proto::VarType::FP32, 1);
260 261
}

262
void PrepareGraph(std::unique_ptr<ir::Graph>* graph, const ProgramDesc& prog) {
263 264 265 266 267 268 269 270
  auto place = paddle::platform::CPUPlace();
  NaiveExecutor exe{place};
  Scope scope;
  exe.CreateVariables(prog, 0, true, &scope);

  for (auto& v : variable_names) {
    InitTensorHolder(&scope, place, v.c_str());
  }
271 272
  (*graph)->SetNotOwned(kParamScopeAttr, &scope);
}
273

274
void RegisterPass(std::unique_ptr<ir::Graph>* graph) {
275
  auto pass = PassRegistry::Instance().Get("cpu_quantize_squash_pass");
276 277
  graph->reset(pass->Apply(graph->release()));
}
278

279 280 281 282
// check number of nodes
void CountNodeTest(const ProgramDesc& prog, int removed_nodes_num) {
  std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
  PrepareGraph(&graph, prog);
283

284 285
  int original_nodes_num = graph->Nodes().size();
  RegisterPass(&graph);
286 287 288 289 290
  int current_nodes_num = graph->Nodes().size();

  EXPECT_EQ(original_nodes_num - removed_nodes_num, current_nodes_num);
}

291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310
// check op->scale_out
void EqualScaleOutTest(const ProgramDesc& prog, const std::string& name,
                       float scale) {
  std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
  PrepareGraph(&graph, prog);
  RegisterPass(&graph);

  for (auto* node : graph->Nodes()) {
    if (node->IsOp() &&
        boost::get<std::string>(node->Op()->GetAttr("name")) == name) {
      float scale_out = boost::get<float>(node->Op()->GetAttr("Scale_out"));
      EXPECT_EQ(scale_out, scale);
    }
  }
}

// check requant_op scales
void CheckRequantScalesTest(const ProgramDesc& prog, float scale_in,
                            float scale_out) {
  std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
311

312 313 314 315 316 317 318 319 320 321 322 323 324
  PrepareGraph(&graph, prog);
  RegisterPass(&graph);

  for (auto* node : graph->Nodes()) {
    if (node->IsOp() && node->Op()->Type() == "requantize") {
      float op_scale_in = boost::get<float>(node->Op()->GetAttr("Scale_in"));
      EXPECT_EQ(op_scale_in, scale_in);
      float op_scale_out = boost::get<float>(node->Op()->GetAttr("Scale_out"));
      EXPECT_EQ(op_scale_out, scale_out);
    }
  }
}

325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341
// check requant_op scales
void IsForceFp32OutputTest(const ProgramDesc& prog, std::string op_type,
                           bool target_is_force_fp32_output) {
  std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));

  PrepareGraph(&graph, prog);
  RegisterPass(&graph);

  for (auto* node : graph->Nodes()) {
    if (node->IsOp() && node->Op()->Type() == op_type) {
      bool is_force_fp32_output =
          node->Op()->GetAttrIfExists<bool>("force_fp32_output");
      EXPECT_EQ(is_force_fp32_output, target_is_force_fp32_output);
    }
  }
}

342 343
// From Conv1->d->Dequant->e->Quant->f->Conv2
// To Conv1->d->Conv2
344
TEST(CpuQuantizeSquashPass, equal_scales) {
345
  auto scale_out = 1.0f;
346 347 348 349
  auto scale = 1.2345f;
  auto use_mkldnn = true;
  // Remove 4 nodes: Dequant, Quant, e, f
  auto remove_nodes = 4;
350

351 352 353
  CountNodeTest(
      BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale, scale),
      remove_nodes);
354 355
}

356 357
// From Conv1->d->Dequant->e->Quant->f->Conv2
// First change to Conv1->d->Requant->f->Conv2
358
// Then Conv1->f->Conv2
359 360
TEST(CpuQuantizeSquashPass, unequal_scales) {
  auto scale_out = 1.0f;
361 362 363
  auto scale1 = 1.2345f;
  auto scale2 = 21.0f;
  auto use_mkldnn = true;
364 365
  // Remove 4 nodes: Dequant, Quant, e, d
  auto remove_nodes = 4;
366

367 368 369
  CountNodeTest(
      BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale1, scale2),
      remove_nodes);
370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396

  EqualScaleOutTest(
      BuildConvRequantProgramDesc(use_mkldnn, scale_out, scale1, scale2),
      "Conv1", scale2);
}

//  a->Conv1->b->Requant->c
//  d->Conv2->e->Requant->f
//  {c,f}->Concat
TEST(CpuQuantizeSquashPass, equal_scales_squash_requantize) {
  // Delete both requantize op
  auto scale_out = 1.0f;
  auto scale = 1.2345f;
  auto use_mkldnn = true;
  // Remove 4 nodes: b, Requant1, e, Requant2
  auto remove_nodes = 4;
  CountNodeTest(
      BuildConvsRequantConcatProgramDesc(use_mkldnn, scale_out, scale, scale),
      remove_nodes);

  // check equal scale conv->scale_out and requant->scale_out
  EqualScaleOutTest(
      BuildConvsRequantConcatProgramDesc(use_mkldnn, scale_out, scale, scale),
      "Conv1", scale);
  EqualScaleOutTest(
      BuildConvsRequantConcatProgramDesc(use_mkldnn, scale_out, scale, scale),
      "Conv2", scale);
397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415
}

// from
// a->Conv1->b->Dequant(Scale1)->c
// c->Quant1(Scale1)->d and d->Conv2->e
// c->Quant2(Scale2)->g and g->Conv4->h
// c->Conv3->f
// to
// a->Conv1->b
// b->Conv2->e
// b->Requant(Scale_in = Scale1; Scale_out = Scale2)->g->Conv4->h
// b->Dequant(Scale1)->c->Conv3->f
TEST(CpuQuantizeSquashPass, branch_to_equal_unequal_and_fp32) {
  auto scale_out = 1.0f;
  auto scale = 1.2345f;
  auto scale2 = 21.0f;
  auto use_mkldnn = true;
  // Remove 3 nodes: Quant1, c, Quant2,
  // Insert 1 node: Requant
416
  auto remove_nodes = 2;
417

418 419 420 421 422 423 424
  CountNodeTest(BuildConvMultiOutputProgramDesc(use_mkldnn, scale_out, scale,
                                                scale, scale2),
                remove_nodes);
  CheckRequantScalesTest(BuildConvMultiOutputProgramDesc(use_mkldnn, scale_out,
                                                         scale, scale, scale2),
                         scale, scale2);
}
425

426 427 428 429 430
// a->Concat->b->Dequant->c->Quant->d->Conv->e
// to a->Concat->b->Requant->d->Conv->e
TEST(CpuQuantizeSquashPass,
     unequal_scales_squash_dequantize_quantize_into_requantize) {
  auto scale_out = 1.0f;
431 432 433
  auto scale = 1.2345f;
  auto scale2 = 21.0f;
  auto use_mkldnn = true;
434 435
  // Remove 3 nodes: Dequant1, c, Quant
  // Insert 1 node: Requant
436
  auto remove_nodes = 2;
437

438 439 440 441 442 443 444
  CountNodeTest(
      BuildConcatDequantQuantProgramDesc(use_mkldnn, scale_out, scale, scale2),
      remove_nodes);
  CheckRequantScalesTest(
      BuildConcatDequantQuantProgramDesc(use_mkldnn, scale_out, scale, scale2),
      scale, scale2);
}
445

446 447 448 449 450 451 452 453 454 455 456 457 458
// a->Conv1->b
// b->Requant1(Scale1)->c
// b->Requant2(Scale2)->d
TEST(CpuQuantizeSquashPass, more_than_one_conv_out_outputs) {
  auto scale_out = 1.0f;
  auto scale = 1.2345f;
  auto scale2 = 21.0f;
  auto use_mkldnn = true;
  // nothing change
  auto remove_nodes = 0;
  CountNodeTest(
      BuildConvMultiRequantProgramDesc(use_mkldnn, scale_out, scale, scale2),
      remove_nodes);
459 460
}

461 462 463 464 465 466 467 468 469
// a->Conv1->c->Concat
TEST(CpuQuantizeSquashPass, conv_dequant_only_one_output) {
  auto scale_out = 1.0f;
  auto scale = 1.2345f;
  auto use_mkldnn = true;
  // remove 2 nodes: Dequant1, c
  auto remove_nodes = 2;
  CountNodeTest(BuildConvDequantConcatProgramDesc(use_mkldnn, scale_out, scale),
                remove_nodes);
470 471 472
  IsForceFp32OutputTest(
      BuildConvDequantConcatProgramDesc(use_mkldnn, scale_out, scale), "conv2d",
      true);
473 474
}

475
// If there are more than one op after conv->dequantize, do not fuse
476 477 478 479 480 481 482 483
TEST(CpuQuantizeSquashPass, conv_dequant_more_than_one_op_after_conv) {
  auto scale_out = 1.0f;
  auto scale = 1.2345f;
  auto use_mkldnn = true;
  // nothing change
  auto remove_nodes = 0;
  CountNodeTest(BuildConvDequantConvProgramDesc(use_mkldnn, scale_out, scale),
                remove_nodes);
484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516
  IsForceFp32OutputTest(
      BuildConvDequantConvProgramDesc(use_mkldnn, scale_out, scale), "conv2d",
      false);
}

// from
// a->fc->b->Dequant1->c->Concat1->d
// to
// a->fc->c->Concat->d
TEST(CpuQuantizeSquashPass, fc_dequant_only_one_output) {
  auto scale_out = 1.0f;
  auto scale = 1.2345f;
  auto use_mkldnn = true;
  // remove 2 nodes: b, Dequant1
  auto remove_nodes = 2;
  CountNodeTest(BuildFcDequantConcatProgramDesc(use_mkldnn, scale_out, scale),
                remove_nodes);
  IsForceFp32OutputTest(
      BuildFcDequantConcatProgramDesc(use_mkldnn, scale_out, scale), "fc",
      true);
}

// If there are more than one op after fc->dequantize, do not fuse
TEST(CpuQuantizeSquashPass, fc_dequant_more_than_one_op_after_dequant) {
  auto scale_out = 1.0f;
  auto scale = 1.2345f;
  auto use_mkldnn = true;
  // nothing change
  auto remove_nodes = 0;
  CountNodeTest(BuildFcDequantFcProgramDesc(use_mkldnn, scale_out, scale),
                remove_nodes);
  IsForceFp32OutputTest(
      BuildFcDequantFcProgramDesc(use_mkldnn, scale_out, scale), "fc", false);
517 518
}

519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546
// a->Concat1->b
// b->Concat2
// b->Quatize1(Scale)->c
// c->Fc1
// c->Fc2
TEST(CpuQuantizeSquashPass, quatize_with_same_scale) {
  auto first_scale = 1.2345f;
  auto second_scale = 1.2345f;
  auto use_mkldnn = true;
  // remove nodes: Quantize2 + d
  auto remove_nodes = 1 + 1;
  CountNodeTest(
      BuildMultipleQuantizeProgramDesc(use_mkldnn, first_scale, second_scale),
      remove_nodes);
}

// if scales are not the same, do not fuse
TEST(CpuQuantizeSquashPass, quatize_with_different_scale) {
  auto first_scale = 1.2345f;
  auto second_scale = 1.5432f;
  auto use_mkldnn = true;
  // nothing change
  auto remove_nodes = 0;
  CountNodeTest(
      BuildMultipleQuantizeProgramDesc(use_mkldnn, first_scale, second_scale),
      remove_nodes);
}

547 548 549 550 551
}  // namespace ir
}  // namespace framework
}  // namespace paddle

USE_PASS(cpu_quantize_squash_pass);