auto_mixed_precision_pass.cc 34.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/fluid/framework/ir/auto_mixed_precision_pass.h"
16 17 18

#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/operator.h"
19 20 21 22 23 24
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"
25 26 27
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/backends/device_manager.h"
#endif
28 29 30 31 32 33 34

namespace paddle {
namespace framework {
namespace ir {

namespace {

35
using VarType = AutoMixedPrecisionPass::VarType;
36 37 38 39 40 41 42 43 44 45 46 47 48 49

bool PhiKernelSupportPrecision(
    const std::string& op_type,
    phi::Backend backend,
    phi::DataType data_type,
    phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) {
  const auto& kernels = phi::KernelFactory::Instance().kernels();
  if (kernels.count(op_type) == 0) {
    return false;
  }
  phi::KernelKey kernel_key(backend, layout, data_type);
  return phi::KernelFactory::Instance().HasKernel(op_type, kernel_key);
}

50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
static phi::Backend ConvertPlaceToBackend(const phi::Place& place) {
  switch (place.GetType()) {
    case phi::AllocationType::CPU:
      return phi::Backend::CPU;
    case phi::AllocationType::GPU:
      return phi::Backend::GPU;
    case phi::AllocationType::XPU:
      return phi::Backend::XPU;
    default:
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Cannot convert place(%d).", static_cast<int>(place.GetType())));
  }
  return phi::Backend::UNDEFINED;
}

65
bool KernelSupportPrecision(
66
    const std::string& op_type,
67
    phi::Backend backend,
68 69 70 71
    phi::DataType precision,
    phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) {
  auto phi_op_type = phi::TransToPhiKernelName(op_type);

72 73 74 75 76 77
  bool support =
      PhiKernelSupportPrecision(phi_op_type, backend, precision, layout);
  if (backend == phi::Backend::GPU) {
    support |= PhiKernelSupportPrecision(
        phi_op_type, phi::Backend::GPUDNN, precision, layout);
  }
78 79 80 81 82
  if (!support) {
    const auto& all_kernels = framework::OperatorWithKernel::AllOpKernels();
    auto it = all_kernels.find(op_type);
    if (it != all_kernels.end()) {
      for (const auto& kern_pair : it->second) {
83
        if (ConvertPlaceToBackend(kern_pair.first.place_) == backend &&
84 85 86 87 88 89 90 91 92 93 94
            kern_pair.first.data_type_ ==
                framework::TransToProtoVarType(precision)) {
          support = true;
          break;
        }
      }
    }
  }
  return support;
}

95 96 97 98 99 100 101
inline bool VarNodeHasDtype(Node* var_node) {
  auto type = var_node->Var()->GetType();
  return (type == VarType::SELECTED_ROWS) || (type == VarType::LOD_TENSOR) ||
         (type == VarType::LOD_TENSOR_ARRAY) || (type == VarType::STRINGS) ||
         (type == VarType::VOCAB);
}

102
inline bool IsFP32AndFP64(VarType::Type type) {
103 104 105
  return (type == VarType::FP64) || (type == VarType::FP32);
}

106
inline bool IsFP16AndBFP16(VarType::Type type) {
107 108 109 110 111
  return (type == VarType::FP16) || (type == VarType::BF16);
}

};  // namespace

112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
void DoInsertCastOp(Graph* graph,
                    Node* var_node,
                    Node* op_node,
                    VarType::Type from_type,
                    VarType::Type to_type,
                    framework::BlockDesc* block_desc,
                    int* suffix,
                    std::unordered_map<Node*, Node*>* cache) {
  if (from_type == to_type) return;

  auto update_cast_desc = [&](framework::OpDesc& desc,
                              const std::string& x_name,
                              const std::string& out_name,
                              const int in_dtype,
                              const int out_dtype) {
    desc.SetType("cast");
    desc.SetInput("X", {x_name});
    desc.SetOutput("Out", {out_name});
    desc.SetAttr("in_dtype", in_dtype);
    desc.SetAttr("out_dtype", out_dtype);
    desc.SetAttr("use_mkldnn", false);
    desc.SetAttr("with_quant_attr", false);
    desc.Flush();
  };

  if (cache->count(var_node) == 0) {
    // insert cast op between var_node and op_node
    std::string cast_input_name = var_node->Var()->Name();
140 141 142
    std::string cast_output_name = var_node->Var()->Name() +
                                   "_cast_auto_mixed.tmp_" +
                                   std::to_string((*suffix)++);
143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
    framework::OpDesc cast_op_desc(block_desc);
    update_cast_desc(cast_op_desc,
                     cast_input_name,
                     cast_output_name,
                     static_cast<int>(from_type),
                     static_cast<int>(to_type));
    auto* cast_op_node = graph->CreateOpNode(&cast_op_desc);
    auto* cast_output_vardesc = block_desc->Var(cast_output_name);
    cast_output_vardesc->SetPersistable(false);
    cast_output_vardesc->SetDataType(to_type);
    cast_output_vardesc->SetShape(var_node->Var()->GetShape());
    auto* cast_output_node = graph->CreateVarNode(cast_output_vardesc);
    IR_NODE_LINK_TO(cast_op_node, cast_output_node);
    (*cache)[var_node] = cast_output_node;
  }
  op_node->Op()->Rename(var_node->Name(), cache->at(var_node)->Name());
  IR_NODE_LINK_TO(var_node, cache->at(var_node)->inputs[0]);
  IR_NODE_LINK_TO(cache->at(var_node), op_node);

  IR_NODE_UNLINK(var_node, op_node);
}

165 166 167 168
bool OpSupportPrecision(const std::string& op_type,
                        phi::Backend backend,
                        phi::DataType precision,
                        const std::unordered_set<std::string>& black_list) {
169 170
  return black_list.count(op_type) == 0 &&
         KernelSupportPrecision(op_type, backend, precision);
171 172 173 174 175
}

// The set of ops that support fp16 calculation and are considered
// numerically-dangerous, slower and whose effects may also be observed in
// downstream ops.
176
// ref to python/paddle/fluid/contrib/mixed_precision/fp16_lists.py
177
void AutoMixedPrecisionPass::SetDefaultBlacklist() const {
178 179 180 181 182 183 184 185 186 187 188 189 190
  black_list_.insert({
      // numerically-dangerous
      "exp",
      "square",
      "log",
      "mean",
      "sum",
      "cos_sim",
      "softmax_with_cross_entropy",
      "sigmoid_cross_entropy_with_logits",
      "c_softmax_with_cross_entropy",
      "cross_entropy",
      "cross_entropy2",
191
#ifndef PADDLE_WITH_XPU
192 193
      // slower than fp32
      "conv2d_transpose",
194
#endif
195 196 197 198 199
      // default fp32 can avoid return inf when the sum value large than 65504
      "reduce_sum",
  });
}

200
void AutoMixedPrecisionPass::Init(Graph* graph) const {
201
  if (Has("enable_gpu_mixed") && Get<bool>("enable_gpu_mixed")) {
202
    backend_ = phi::Backend::GPU;
203 204 205 206 207
  } else if (Has("enable_xpu_mixed") && Get<bool>("enable_xpu_mixed")) {
    backend_ = phi::Backend::XPU;
  } else if (Has("enable_custom_device_mixed") &&
             Get<bool>("enable_custom_device_mixed")) {
    // transform Backend::CUSTOM to actual backend.
208 209 210 211 212 213 214 215 216 217 218 219
// Here, we only consider one custom backend.
#ifdef PADDLE_WITH_CUSTOM_DEVICE
    auto device_type = phi::DeviceManager::GetAllCustomDeviceTypes()[0];
    backend_ = static_cast<phi::Backend>(
        static_cast<size_t>(phi::Backend::NUM_BACKENDS) +
        phi::CustomRegisteredDeviceMap::Instance()
            .GetOrRegisterGlobalDeviceTypeId(device_type));
#else
    PADDLE_THROW(paddle::platform::errors::Unavailable(
        "Paddle is not compiled with CustomDevice. "
        "Cannot enable custom_device_mixed."));
#endif
220 221
  }

222 223 224 225 226 227 228 229 230
  if (Has("mixed_precision_mode")) {
    low_precision_ =
        static_cast<phi::DataType>(Get<int>("mixed_precision_mode"));
  }

  skip_pass_ = (backend_ == phi::Backend::UNDEFINED) ||
               (low_precision_ == phi::DataType::UNDEFINED);

  if (skip_pass_) return;
231

232 233
  black_list_ = Get<std::unordered_set<std::string>>("mixed_black_list");
  SetDefaultBlacklist();
234 235 236 237 238
  VLOG(4) << "black_list has ";
  for (const auto& name : black_list_) {
    VLOG(4) << " - " << name;
  }

239 240
  if (Has("enable_low_precision_io")) {
    enable_low_precision_io_ = Get<bool>("enable_low_precision_io");
241
  }
242 243 244 245 246 247 248 249 250 251

  auto graph_size = graph->SubGraphsSize();
  VLOG(4) << "graph size: " << graph_size;
  subgraphes_.resize(graph_size);
  all_op_nodes_.resize(graph_size);

  for (size_t i = 0; i < graph_size; i++) {
    subgraphes_[i] = graph->GetSubGraph(i);
    all_op_nodes_[i] = TopologySortOperations(*subgraphes_[i]);
    VLOG(4) << "subgraph " << i << " has " << all_op_nodes_[i].size()
252
            << " op nodes";
253 254 255 256 257 258 259 260 261 262 263 264
    for (auto* var_node : subgraphes_[i]->Nodes()) {
      if (!var_node->IsVar()) continue;

      auto var_name = var_node->Var()->Name();
      if (real_vars_.count(var_name) == 0) {
        real_vars_[var_name] = var_node;
        VLOG(4) << var_name << " is in graph " << i;
      }
    }
  }
}

265 266 267 268 269 270 271 272 273 274
void AutoMixedPrecisionPass::ApplyImpl(Graph* graph) const {
  PADDLE_ENFORCE_NOT_NULL(graph,
                          platform::errors::PreconditionNotMet(
                              "During the auto_mixed_precision_pass, the graph "
                              "should not be nullptr."));
  PADDLE_ENFORCE_EQ(graph->IsMainGraph(),
                    true,
                    platform::errors::PreconditionNotMet(
                        "During the auto_mixed_precision_pass, the graph "
                        "should be main graph."));
275

276
  FusePassBase::Init("auto_mixed_precision", graph);
277 278 279

  Init(graph);
  VLOG(4) << "Init done";
280 281 282 283 284 285

  if (skip_pass_) {
    VLOG(3) << "Skip auto_mixed_precision_pass.";
    return;
  }

286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301
  SetOpUniqueType();
  VLOG(4) << "SetOpUniqueType done";
  GetOpPrecision();
  VLOG(4) << "GetOpPrecision done";
  UpdateOpPrecision();
  VLOG(4) << "UpdateOpPrecision done";
  SetVarPrecision();
  VLOG(4) << "SetVarPrecision done";
  ConvertWeightsData();
  VLOG(4) << "ConvertWeightsData done";
  ProcessOpWithDtypeAttr();
  VLOG(4) << "ProcessOpWithDtypeAttr done";
  InsertCastOp();
  VLOG(4) << "InsertCastOp done";
  RestoreOpOriginType();
  VLOG(4) << "RestoreOpOriginType done";
302
  LOG(INFO) << "The number of ops run at low precision ["
303 304
            << op_run_low_precision_.size() << "/"
            << op_original_type_.size() + 2 << "]";
305 306
}

307
void AutoMixedPrecisionPass::SetOpUniqueType() const {
308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323
  int suffix = 0;
  for (const auto& nodes : all_op_nodes_) {
    for (auto* op_node : nodes) {
      auto op_type = op_node->Op()->Type();

      if (op_type == "feed" || op_type == "fetch") continue;

      std::string unique_type = op_type + "_" + std::to_string(suffix++);
      op_original_type_[unique_type] = op_type;
      op_node->Op()->SetType(unique_type);
      op_node->Op()->Flush();
      VLOG(4) << "change op type: " << op_type << " ---> " << unique_type;
    }
  }
}

324
void AutoMixedPrecisionPass::RestoreOpOriginType() const {
325 326 327 328 329 330 331 332 333 334 335
  for (const auto& nodes : all_op_nodes_) {
    for (auto* op_node : nodes) {
      auto op_type = op_node->Op()->Type();
      op_node->Op()->SetType(GetOpOriginalType(op_type));
      op_node->Op()->Flush();
      VLOG(4) << "restore op type: " << op_type << " ---> "
              << op_node->Op()->Type();
    }
  }
}

336
inline std::string AutoMixedPrecisionPass::GetOpOriginalType(
337 338 339 340 341 342 343
    const std::string& op_type) const {
  if (op_original_type_.count(op_type)) {
    return op_original_type_.at(op_type);
  }
  return op_type;
}

344
void AutoMixedPrecisionPass::ProcessOpWithDtypeAttr() const {
345 346 347
  for (const auto& nodes : all_op_nodes_) {
    for (auto* op_node : nodes) {
      auto op_type = op_node->Op()->Type();
348 349 350 351 352 353 354 355 356 357 358 359 360 361 362

      if (op_node->Op()->HasAttr("in_dtype")) {
        auto* var_node = op_node->inputs[0];
        auto* real_var_node = real_vars_[var_node->Var()->Name()];
        if (IsFP16AndBFP16(real_var_node->Var()->GetDataType())) {
          op_node->Op()->SetAttr(
              "in_dtype",
              static_cast<int>(framework::TransToProtoVarType(low_precision_)));
          op_node->Op()->Flush();
          VLOG(4) << "process op with in_dtype attr: " << op_type << " ( "
                  << static_cast<int>(real_var_node->Var()->GetDataType())
                  << " --->" << static_cast<int>(low_precision_) << " )";
        }
      }

363
      if (op_run_low_precision_.count(op_type) == 0) continue;
364 365 366

      if (op_node->Op()->HasAttr("dtype")) {
        auto dtype = op_node->Op()->GetAttrIfExists<int>("dtype");
367
        if (IsFP32AndFP64(static_cast<VarType::Type>(dtype))) {
368 369
          op_node->Op()->SetAttr(
              "dtype",
370
              static_cast<int>(framework::TransToProtoVarType(low_precision_)));
371 372
          op_node->Op()->Flush();
          VLOG(4) << "process op with dtype attr: " << op_type << " ( " << dtype
373
                  << " --->" << static_cast<int>(low_precision_) << " )";
374
        }
375
      } else if (op_node->Op()->HasAttr("out_dtype")) {
376
        auto out_dtype = op_node->Op()->GetAttrIfExists<int>("out_dtype");
377
        if (IsFP32AndFP64(static_cast<VarType::Type>(out_dtype))) {
378 379
          op_node->Op()->SetAttr(
              "out_dtype",
380
              static_cast<int>(framework::TransToProtoVarType(low_precision_)));
381 382
          op_node->Op()->Flush();
          VLOG(4) << "process op with out_dtype attr: " << op_type << " ( "
383
                  << out_dtype << " --->" << static_cast<int>(low_precision_)
384 385 386 387 388 389 390
                  << " )";
        }
      }
    }
  }
}

391
void AutoMixedPrecisionPass::GetOpPrecision() const {
392 393 394
  for (const auto& nodes : all_op_nodes_) {
    for (auto* op_node : nodes) {
      auto op_type = op_node->Op()->Type();
395
      bool support_low_precision = true;
396 397
      if (GetOpOriginalType(op_type) == "feed" ||
          GetOpOriginalType(op_type) == "fetch") {
398 399 400 401 402 403 404
        support_low_precision = enable_low_precision_io_;
      } else if (GetOpOriginalType(op_type) == "tensorrt_engine") {
        auto enable_fp16 = op_node->Op()->GetAttrIfExists<bool>("enable_fp16");
        auto enable_int8 = op_node->Op()->GetAttrIfExists<bool>("enable_int8");
        auto low_precision_io =
            op_node->Op()->GetAttrIfExists<bool>("enable_low_precision_io");
        support_low_precision = enable_fp16 && !enable_int8 && low_precision_io;
405
      } else {
406 407
        support_low_precision = OpSupportPrecision(
            GetOpOriginalType(op_type), backend_, low_precision_, black_list_);
408

409 410
        if (op_node->Op()->HasAttr("dtype")) {
          auto dtype = op_node->Op()->GetAttrIfExists<int>("dtype");
411 412
          support_low_precision =
              support_low_precision &&
413 414 415
              IsFP32AndFP64(static_cast<VarType::Type>(dtype));
        } else if (op_node->Op()->HasAttr("out_dtype")) {
          auto out_dtype = op_node->Op()->GetAttrIfExists<int>("out_dtype");
416 417
          support_low_precision =
              support_low_precision &&
418 419
              (IsFP32AndFP64(static_cast<VarType::Type>(out_dtype)) ||
               out_dtype == -1);
420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438
        }

        // If scale op's "scale" and "bias" attr value exceed the range of fp16
        // and bf16, it cannot run at low precision.
        if (GetOpOriginalType(op_node->Op()->Type()) == "scale") {
          auto scale = op_node->Op()->GetAttrIfExists<float>("scale");
          auto bias = op_node->Op()->GetAttrIfExists<float>("bias");
          if (low_precision_ == phi::DataType::FLOAT16) {
            support_low_precision =
                support_low_precision &&
                phi::dtype::isfinite(static_cast<phi::dtype::float16>(scale)) &&
                phi::dtype::isfinite(static_cast<phi::dtype::float16>(bias));
          } else if (low_precision_ == phi::DataType::BFLOAT16) {
            support_low_precision =
                support_low_precision &&
                phi::dtype::isfinite(
                    static_cast<phi::dtype::bfloat16>(scale)) &&
                phi::dtype::isfinite(static_cast<phi::dtype::bfloat16>(bias));
          }
439 440
        }

441 442 443 444 445 446
        // if op's input var and output var is not dense tensor, the op should
        // not run at low precision.
        for (auto* in_var_node : op_node->inputs) {
          CHECK_EQ(in_var_node->IsVar(), true);
          auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
          if (real_in_var_node->Var()->Persistable()) continue;
447

448 449 450 451 452 453 454 455 456 457 458 459 460
          support_low_precision =
              support_low_precision &&
              (real_in_var_node->Var()->GetType() == VarType::LOD_TENSOR);
        }
        for (auto* out_var_node : op_node->outputs) {
          CHECK_EQ(out_var_node->IsVar(), true);
          auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()];
          if (real_out_var_node->Var()->Persistable()) continue;

          support_low_precision =
              support_low_precision &&
              (real_out_var_node->Var()->GetType() == VarType::LOD_TENSOR);
        }
461 462
      }

463 464 465
      if (support_low_precision) {
        op_run_low_precision_.insert(op_type);
        VLOG(4) << "support precision: " << op_type << " run at low precision";
466
      } else {
467 468
        VLOG(4) << "support precision: " << op_type
                << " not run at low precision";
469 470 471 472 473
      }
    }
  }
}

474 475
void AutoMixedPrecisionPass::UpdateOpPrecision() const {
  std::unordered_set<std::string> vars_should_not_low_precision;
476 477 478 479 480 481 482 483 484 485

  // var -> the var's all input op
  std::unordered_map<std::string, std::vector<Node*>> var_input_ops;

  auto GetVarInputOps = [&] {
    for (const auto& nodes : all_op_nodes_) {
      for (auto* op_node : nodes) {
        auto op_type = op_node->Op()->Type();

        if (GetOpOriginalType(op_type) == "fetch") continue;
486 487 488
        if (op_node->Op()->HasAttr("sub_block") &&
            GetOpOriginalType(op_type) != "tensorrt_engine")
          continue;
489 490 491 492 493 494 495 496 497 498 499

        for (auto* var_node : op_node->outputs) {
          CHECK_EQ(var_node->IsVar(), true);
          if (var_node->Var()->Persistable()) continue;
          if (!VarNodeHasDtype(var_node)) continue;

          var_input_ops[var_node->Var()->Name()].push_back(op_node);
          VLOG(4) << "var input ops: " << var_node->Var()->Name()
                  << " is output of " << op_type;
        }

500 501 502
        // the select_input op's input var should not convert to low precision.
        // when op's output var is select_input op's input var, the op should
        // not run at low precision.
503 504 505 506 507 508
        if (GetOpOriginalType(op_node->Op()->Type()) == "select_input") {
          for (auto* in_var_node : op_node->inputs) {
            CHECK_EQ(in_var_node->IsVar(), true);
            if (in_var_node->Var()->Persistable()) continue;
            if (!VarNodeHasDtype(in_var_node)) continue;

509
            vars_should_not_low_precision.insert(in_var_node->Var()->Name());
510 511
          }
        }
512 513

        // when op_1 only support cpu kernel. if op_2's intput var is op_1's
514
        // output var, then op_2 should not run at low precision.
515
        if (GetOpOriginalType(op_type) != "feed" &&
Y
Yuanle Liu 已提交
516
            GetOpOriginalType(op_type) != "tensorrt_engine" &&
517 518
            !KernelSupportPrecision(
                GetOpOriginalType(op_type), backend_, phi::DataType::FLOAT32)) {
519 520 521 522 523 524 525 526
          for (auto* out_var_node : op_node->outputs) {
            CHECK_EQ(out_var_node->IsVar(), true);
            if (out_var_node->Var()->Persistable()) continue;
            if (!VarNodeHasDtype(out_var_node)) continue;

            vars_should_not_low_precision.insert(out_var_node->Var()->Name());
          }
        }
527 528 529 530 531 532 533 534 535 536
      }
    }
  };
  GetVarInputOps();

  bool precision_updated = false;
  do {
    precision_updated = false;
    for (const auto& nodes : all_op_nodes_) {
      for (auto* op_node : nodes) {
537
        if (op_run_low_precision_.count(op_node->Op()->Type()) == 0) continue;
538

539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557
        for (auto* in_var_node : op_node->inputs) {
          CHECK_EQ(in_var_node->IsVar(), true);
          if (!VarNodeHasDtype(in_var_node)) continue;

          auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
          if (real_in_var_node->Var()->Persistable()) continue;

          if (vars_should_not_low_precision.count(
                  real_in_var_node->Var()->Name())) {
            op_run_low_precision_.erase(op_node->Op()->Type());
            precision_updated = true;
            VLOG(4) << op_node->Op()->Type()
                    << " should not run at low precision.";
            break;
          }
        }

        if (op_run_low_precision_.count(op_node->Op()->Type()) == 0) continue;

558 559 560 561 562 563 564
        for (auto* out_var_node : op_node->outputs) {
          CHECK_EQ(out_var_node->IsVar(), true);
          if (!VarNodeHasDtype(out_var_node)) continue;

          auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()];
          if (real_out_var_node->Var()->Persistable()) continue;

565
          bool not_run_low_precision = false;
566 567
          const auto& input_op_nodes =
              var_input_ops[real_out_var_node->Var()->Name()];
568 569 570
          if (vars_should_not_low_precision.count(
                  real_out_var_node->Var()->Name())) {
            not_run_low_precision = true;
571 572
          } else {
            for (auto* node : input_op_nodes) {
573 574
              if (op_run_low_precision_.count(node->Op()->Type()) == 0) {
                not_run_low_precision = true;
575 576 577 578
                break;
              }
            }
          }
579 580
          if (not_run_low_precision) {
            op_run_low_precision_.erase(op_node->Op()->Type());
581 582
            precision_updated = true;
            VLOG(4) << op_node->Op()->Type()
583
                    << " should not run at low precision.";
584 585 586 587 588 589 590 591 592
            break;
          }
        }
      }
    }
  } while (precision_updated);
}

// special ops, its weights should not be low precision.
593 594
bool AutoMixedPrecisionPass::InputVarsNotConvert(
    Node* op_node, const std::string& var_name) const {
595
  auto* op_desc = op_node->Op();
596 597 598 599 600 601
  if (GetOpOriginalType(op_desc->Type()) == "tensorrt_engine") {
    auto vecs = op_desc->Input("Xs");
    if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
      return true;
    }
  } else if (GetOpOriginalType(op_desc->Type()) == "batch_norm") {
602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617
    auto vecs = op_desc->Input("Bias");
    if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
      return true;
    }
    vecs = op_desc->Input("Mean");
    if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
      return true;
    }
    vecs = op_desc->Input("Scale");
    if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
      return true;
    }
    vecs = op_desc->Input("Variance");
    if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
      return true;
    }
618 619 620 621 622 623 624 625 626
  } else if (GetOpOriginalType(op_desc->Type()) == "instance_norm") {
    auto vecs = op_desc->Input("Bias");
    if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
      return true;
    }
    vecs = op_desc->Input("Scale");
    if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
      return true;
    }
627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643
  } else if (GetOpOriginalType(op_desc->Type()) == "fused_multi_transformer") {
    auto vecs = op_desc->Input("LnScale");
    if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
      return true;
    }
    vecs = op_desc->Input("LnBias");
    if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
      return true;
    }
    vecs = op_desc->Input("FFNLnScale");
    if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
      return true;
    }
    vecs = op_desc->Input("FFNLnBias");
    if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
      return true;
    }
644 645 646 647 648 649 650 651 652 653
  } else if (GetOpOriginalType(op_desc->Type()) ==
             "fused_bias_dropout_residual_layer_norm") {
    auto vecs = op_desc->Input("LnScale");
    if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
      return true;
    }
    vecs = op_desc->Input("LnBias");
    if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
      return true;
    }
654
  }
655 656 657 658 659 660 661 662 663 664 665

  if (backend_ == phi::Backend::XPU) {
    if (GetOpOriginalType(op_desc->Type()) == "layer_norm") {
      auto vecs = op_desc->Input("Bias");
      if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
        return true;
      }
      vecs = op_desc->Input("Scale");
      if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
        return true;
      }
666 667 668 669 670 671 672 673 674
    } else if (GetOpOriginalType(op_desc->Type()) == "instance_norm") {
      auto vecs = op_desc->Input("Bias");
      if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
        return true;
      }
      vecs = op_desc->Input("Scale");
      if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
        return true;
      }
675 676 677
    }
  }

678 679 680
  return false;
}

681 682
bool AutoMixedPrecisionPass::OutputVarsNotConvert(
    Node* op_node, const std::string& var_name) const {
683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702
  auto* op_desc = op_node->Op();
  // batch_norm's input and output (variance and mean) are the same.
  if (GetOpOriginalType(op_desc->Type()) == "batch_norm") {
    auto vecs = op_desc->Output("MeanOut");
    if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
      return true;
    }
    vecs = op_desc->Output("VarianceOut");
    if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
      return true;
    }
    vecs = op_desc->Output("SavedMean");
    if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
      return true;
    }
    vecs = op_desc->Output("SavedVariance");
    if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
      return true;
    }
  }
703 704 705 706 707 708 709 710 711 712 713 714 715 716

  if (backend_ == phi::Backend::XPU) {
    if (GetOpOriginalType(op_desc->Type()) == "layer_norm") {
      auto vecs = op_desc->Output("Mean");
      if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
        return true;
      }
      vecs = op_desc->Output("Variance");
      if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
        return true;
      }
    }
  }

717 718 719
  return false;
}

720
void AutoMixedPrecisionPass::SetVarPrecision() const {
721 722
  for (const auto& nodes : all_op_nodes_) {
    for (auto* op_node : nodes) {
723 724 725 726 727
      if (op_run_low_precision_.count(op_node->Op()->Type()) == 0) {
        continue;
      }

      if (GetOpOriginalType(op_node->Op()->Type()) != "feed") {
728 729 730 731 732 733
        for (auto* in_var_node : op_node->inputs) {
          CHECK_EQ(in_var_node->IsVar(), true);

          auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
          auto in_var_name = real_in_var_node->Var()->Name();

734
          if (!IsFP32AndFP64(real_in_var_node->Var()->GetDataType())) continue;
735 736 737 738 739
          if (!VarNodeHasDtype(real_in_var_node)) continue;
          if (InputVarsNotConvert(op_node, in_var_name)) continue;

          if (real_in_var_node->Var()->Persistable()) {
            real_in_var_node->Var()->SetDataType(
740 741
                framework::TransToProtoVarType(low_precision_));
            vars_convert_to_low_precision_.insert(in_var_name);
742 743
          }
        }
744
      }
745

746
      if (GetOpOriginalType(op_node->Op()->Type()) != "fetch") {
747 748 749 750 751 752
        for (auto* out_var_node : op_node->outputs) {
          CHECK_EQ(out_var_node->IsVar(), true);

          auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()];
          auto out_var_name = real_out_var_node->Var()->Name();

753
          if (!IsFP32AndFP64(real_out_var_node->Var()->GetDataType())) continue;
754 755 756 757
          if (!VarNodeHasDtype(real_out_var_node)) continue;
          if (OutputVarsNotConvert(op_node, out_var_name)) continue;

          real_out_var_node->Var()->SetDataType(
758
              framework::TransToProtoVarType(low_precision_));
759
          if (real_out_var_node->Var()->Persistable()) {
760
            vars_convert_to_low_precision_.insert(out_var_name);
761 762 763 764 765 766 767 768 769 770 771 772 773 774
          }
        }
      }
    }
  }

  // This code used to precess vars with the same name. Vars with the same
  // name should have the same data type.
  for (auto* subgraph : subgraphes_) {
    for (auto* var_node : subgraph->Nodes()) {
      if (!var_node->IsVar() || !var_node->Var()->Persistable()) continue;
      if (!VarNodeHasDtype(var_node)) continue;

      auto var_name = var_node->Var()->Name();
775
      if (vars_convert_to_low_precision_.count(var_name)) {
776
        var_node->Var()->SetDataType(
777
            framework::TransToProtoVarType(low_precision_));
778 779 780 781 782
      }
    }
  }
}

783
void AutoMixedPrecisionPass::ConvertWeightsData() const {
784
  auto* scope = param_scope();
785 786 787 788
  PADDLE_ENFORCE_NOT_NULL(scope,
                          platform::errors::PreconditionNotMet(
                              "During the auto_mixed_precision_pass, the scope "
                              "should not be null."));
789 790 791

  auto var_names = scope->LocalVarNames();
  for (const auto& var_name : var_names) {
792
    if (vars_convert_to_low_precision_.count(var_name)) {
793
      VLOG(4) << var_name << "'s data type was convert to low precision";
794 795

      auto* var = scope->FindLocalVar(var_name);
796 797 798 799
      CHECK_EQ(var->IsType<phi::DenseTensor>(), true);

      auto* origin_tensor = var->GetMutable<phi::DenseTensor>();

800 801 802
      phi::DenseTensor low_precision_tensor;
      low_precision_tensor.Resize(origin_tensor->dims());
      low_precision_tensor.set_type(low_precision_);
803

804 805 806 807
      if (low_precision_ == phi::DataType::FLOAT16) {
        auto* low_precision_data =
            low_precision_tensor.mutable_data<phi::dtype::float16>(
                phi::CPUPlace{});
808 809 810
        for (int64_t i = 0; i < origin_tensor->numel(); i++) {
          if (origin_tensor->dtype() == phi::DataType::FLOAT64) {
            auto* origin_data = origin_tensor->data<double>();
811 812
            low_precision_data[i] =
                static_cast<phi::dtype::float16>(origin_data[i]);
813 814
          } else if (origin_tensor->dtype() == phi::DataType::FLOAT32) {
            auto* origin_data = origin_tensor->data<float>();
815 816
            low_precision_data[i] =
                static_cast<phi::dtype::float16>(origin_data[i]);
817 818
          }
        }
819
      } else if (low_precision_ == phi::DataType::BFLOAT16) {
820
        auto* low_precision_data =
821 822
            low_precision_tensor.mutable_data<phi::dtype::bfloat16>(
                phi::CPUPlace{});
823 824 825
        for (int64_t i = 0; i < origin_tensor->numel(); i++) {
          if (origin_tensor->dtype() == phi::DataType::FLOAT64) {
            auto* origin_data = origin_tensor->data<double>();
826 827
            low_precision_data[i] =
                static_cast<phi::dtype::bfloat16>(origin_data[i]);
828 829
          } else if (origin_tensor->dtype() == phi::DataType::FLOAT32) {
            auto* origin_data = origin_tensor->data<float>();
830 831
            low_precision_data[i] =
                static_cast<phi::dtype::bfloat16>(origin_data[i]);
832
          }
833 834
        }
      }
835 836
      origin_tensor->clear();
      paddle::framework::TensorCopySync(
837
          low_precision_tensor, phi::CPUPlace{}, origin_tensor);
838 839 840 841
    }
  }
}

842
void AutoMixedPrecisionPass::InsertCastOp() const {
843 844 845 846 847 848 849 850 851 852
  int suffix = 0;
  std::unordered_map<Node*, Node*> cache;

  for (size_t i = 0; i < all_op_nodes_.size(); i++) {
    auto* block_desc = all_op_nodes_[i][0]->Op()->Block();
    CHECK_NOTNULL(block_desc);
    for (auto* op_node : all_op_nodes_[i]) {
      auto op_type = op_node->Op()->Type();

      if (GetOpOriginalType(op_type) == "feed") continue;
853 854 855
      if (op_node->Op()->HasAttr("sub_block") &&
          GetOpOriginalType(op_type) != "tensorrt_engine")
        continue;
856 857

      VLOG(4) << "process op: " << op_type
858
              << " run low precision: " << op_run_low_precision_.count(op_type);
859 860 861 862 863 864 865 866 867 868 869 870 871 872

      auto inputs = op_node->inputs;
      for (auto* in_var_node : inputs) {
        if (!in_var_node->IsVar()) continue;
        if (!VarNodeHasDtype(in_var_node)) continue;
        if (in_var_node->Var()->Persistable()) continue;

        auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];

        auto in_var_type = real_in_var_node->Var()->GetDataType();

        VLOG(4) << "process var: " << real_in_var_node->Var()->Name()
                << " with type " << in_var_type;

873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892
        if (IsFP32AndFP64(in_var_type) &&
            op_run_low_precision_.count(op_type)) {
          auto to_type = framework::TransToProtoVarType(low_precision_);
          auto* prev_op =
              in_var_node->inputs.empty() ? nullptr : in_var_node->inputs[0];
          if (prev_op && GetOpOriginalType(prev_op->Op()->Type()) == "cast") {
            in_var_node->Var()->SetDataType(to_type);
            prev_op->Op()->SetAttr("out_dtype", static_cast<int>(to_type));
            prev_op->Op()->Flush();
          } else {
            DoInsertCastOp(subgraphes_[i],
                           in_var_node,
                           op_node,
                           in_var_type,
                           to_type,
                           block_desc,
                           &suffix,
                           &cache);
          }
        } else if (IsFP16AndBFP16(in_var_type) &&
893
                   op_run_low_precision_.count(op_type) == 0) {
894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910
          auto to_type = VarType::FP32;
          auto* prev_op =
              in_var_node->inputs.empty() ? nullptr : in_var_node->inputs[0];
          if (prev_op && GetOpOriginalType(prev_op->Op()->Type()) == "cast") {
            in_var_node->Var()->SetDataType(to_type);
            prev_op->Op()->SetAttr("out_dtype", static_cast<int>(to_type));
            prev_op->Op()->Flush();
          } else {
            DoInsertCastOp(subgraphes_[i],
                           in_var_node,
                           op_node,
                           in_var_type,
                           to_type,
                           block_desc,
                           &suffix,
                           &cache);
          }
911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933
        }
      }

      // Special op.
      // fused_multi_transformer's input(CacheKV) and output(CacheKVOut) vars
      // have same name.
      if (GetOpOriginalType(op_type) == "fused_multi_transformer") {
        auto cache_kv_inputs = op_node->Op()->Input("CacheKV");
        auto cache_kv_outputs = op_node->Op()->Output("CacheKVOut");
        CHECK_EQ(cache_kv_inputs.size(), cache_kv_outputs.size());
        for (size_t i = 0; i < cache_kv_inputs.size(); ++i) {
          op_node->Op()->RenameOutput(cache_kv_outputs[i], cache_kv_inputs[i]);
        }
      }
    }
  }
  VLOG(4) << "insert number of cast op: " << cache.size();
}

}  // namespace ir
}  // namespace framework
}  // namespace paddle

934 935
REGISTER_PASS(auto_mixed_precision_pass,
              paddle::framework::ir::AutoMixedPrecisionPass);