auto_mixed_precision_pass.cc 27.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/fluid/framework/ir/auto_mixed_precision_pass.h"
16 17 18

#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/operator.h"
19 20 21 22 23 24
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"
25 26 27 28 29 30 31

namespace paddle {
namespace framework {
namespace ir {

namespace {

32
using VarType = AutoMixedPrecisionPass::VarType;
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73

bool PhiKernelSupportPrecision(
    const std::string& op_type,
    phi::Backend backend,
    phi::DataType data_type,
    phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) {
  const auto& kernels = phi::KernelFactory::Instance().kernels();
  if (kernels.count(op_type) == 0) {
    return false;
  }
  phi::KernelKey kernel_key(backend, layout, data_type);
  return phi::KernelFactory::Instance().HasKernel(op_type, kernel_key);
}

bool GpuKernelSupportPrecision(
    const std::string& op_type,
    phi::DataType precision,
    phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) {
  auto phi_op_type = phi::TransToPhiKernelName(op_type);
  bool support = PhiKernelSupportPrecision(
      phi_op_type, phi::Backend::GPU, precision, layout);
  support |= PhiKernelSupportPrecision(
      phi_op_type, phi::Backend::GPUDNN, precision, layout);

  if (!support) {
    const auto& all_kernels = framework::OperatorWithKernel::AllOpKernels();
    auto it = all_kernels.find(op_type);
    if (it != all_kernels.end()) {
      for (const auto& kern_pair : it->second) {
        if (platform::is_gpu_place(kern_pair.first.place_) &&
            kern_pair.first.data_type_ ==
                framework::TransToProtoVarType(precision)) {
          support = true;
          break;
        }
      }
    }
  }
  return support;
}

74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
inline bool VarNodeHasDtype(Node* var_node) {
  auto type = var_node->Var()->GetType();
  return (type == VarType::SELECTED_ROWS) || (type == VarType::LOD_TENSOR) ||
         (type == VarType::LOD_TENSOR_ARRAY) || (type == VarType::STRINGS) ||
         (type == VarType::VOCAB);
}

inline bool IsFloatType(VarType::Type type) {
  return (type == VarType::FP64) || (type == VarType::FP32);
}

inline bool IsHalfType(VarType::Type type) {
  return (type == VarType::FP16) || (type == VarType::BF16);
}

};  // namespace

91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
void DoInsertCastOp(Graph* graph,
                    Node* var_node,
                    Node* op_node,
                    VarType::Type from_type,
                    VarType::Type to_type,
                    framework::BlockDesc* block_desc,
                    int* suffix,
                    std::unordered_map<Node*, Node*>* cache) {
  if (from_type == to_type) return;

  auto update_cast_desc = [&](framework::OpDesc& desc,
                              const std::string& x_name,
                              const std::string& out_name,
                              const int in_dtype,
                              const int out_dtype) {
    desc.SetType("cast");
    desc.SetInput("X", {x_name});
    desc.SetOutput("Out", {out_name});
    desc.SetAttr("in_dtype", in_dtype);
    desc.SetAttr("out_dtype", out_dtype);
    desc.SetAttr("use_mkldnn", false);
    desc.SetAttr("with_quant_attr", false);
    desc.Flush();
  };

  if (cache->count(var_node) == 0) {
    // insert cast op between var_node and op_node
    std::string cast_input_name = var_node->Var()->Name();
    std::string cast_output_name =
        var_node->Var()->Name() + "_cast.tmp_" + std::to_string((*suffix)++);
    framework::OpDesc cast_op_desc(block_desc);
    update_cast_desc(cast_op_desc,
                     cast_input_name,
                     cast_output_name,
                     static_cast<int>(from_type),
                     static_cast<int>(to_type));
    auto* cast_op_node = graph->CreateOpNode(&cast_op_desc);
    auto* cast_output_vardesc = block_desc->Var(cast_output_name);
    cast_output_vardesc->SetPersistable(false);
    cast_output_vardesc->SetDataType(to_type);
    cast_output_vardesc->SetShape(var_node->Var()->GetShape());
    auto* cast_output_node = graph->CreateVarNode(cast_output_vardesc);
    IR_NODE_LINK_TO(cast_op_node, cast_output_node);
    (*cache)[var_node] = cast_output_node;
  }
  op_node->Op()->Rename(var_node->Name(), cache->at(var_node)->Name());
  IR_NODE_LINK_TO(var_node, cache->at(var_node)->inputs[0]);
  IR_NODE_LINK_TO(cache->at(var_node), op_node);

  IR_NODE_UNLINK(var_node, op_node);
}

143 144 145 146 147 148 149 150 151 152 153 154 155 156
bool OpSupportPrecision(const std::string& op_type,
                        phi::Backend backend,
                        phi::DataType precision,
                        const std::unordered_set<std::string>& black_list) {
  bool support = false;
  if (black_list.count(op_type) == 0) {
    if (backend == phi::Backend::GPU) {
      support = GpuKernelSupportPrecision(op_type, precision);
    } else {
      PADDLE_THROW(paddle::platform::errors::InvalidArgument(
          "Now, only support backend of GPU."));
    }
  }
  return support;
157 158 159 160 161
}

// The set of ops that support fp16 calculation and are considered
// numerically-dangerous, slower and whose effects may also be observed in
// downstream ops.
162
void AutoMixedPrecisionPass::SetDefaultBlacklist() const {
163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
  black_list_.insert({
      // numerically-dangerous
      "acos",
      "asin",
      "cosh",
      "tan",
      "exp",
      "expm1",
      "square",
      "log",
      "log2",
      "log10",
      "log1p",
      "logsumexp",
      "mean",
      "rsqrt",
      "sum",
      "cos_sim",
      "softmax_with_cross_entropy",
      "sigmoid_cross_entropy_with_logits",
      "c_softmax_with_cross_entropy",
      "cross_entropy",
      "cross_entropy2",
      // slower than fp32
      "conv2d_transpose",
      // default fp32 can avoid return inf when the sum value large than 65504
      "reduce_sum",
  });
}

193 194 195 196 197 198 199 200 201 202
void AutoMixedPrecisionPass::Init(Graph* graph) const {
  bool enable_gpu_mixed = Get<bool>("enable_gpu_mixed");
  if (enable_gpu_mixed) {
    backend_ = phi::Backend::GPU;
  }

  skip_pass_ = !enable_gpu_mixed;

  low_precision_ = static_cast<phi::DataType>(Get<int>("mixed_precision_mode"));

203 204
  black_list_ = Get<std::unordered_set<std::string>>("mixed_black_list");
  SetDefaultBlacklist();
205 206 207 208 209 210 211 212 213
  VLOG(4) << "black_list has ";
  for (const auto& name : black_list_) {
    VLOG(4) << " - " << name;
  }

  keep_io_types_ = true;
  if (Has("keep_io_types")) {
    keep_io_types_ = Get<bool>("keep_io_types");
  }
214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236

  auto graph_size = graph->SubGraphsSize();
  VLOG(4) << "graph size: " << graph_size;
  subgraphes_.resize(graph_size);
  all_op_nodes_.resize(graph_size);

  for (size_t i = 0; i < graph_size; i++) {
    subgraphes_[i] = graph->GetSubGraph(i);
    all_op_nodes_[i] = TopologySortOperations(*subgraphes_[i]);
    VLOG(4) << "subgraph " << i << " has " << all_op_nodes_[i].size()
            << "op nodes";
    for (auto* var_node : subgraphes_[i]->Nodes()) {
      if (!var_node->IsVar()) continue;

      auto var_name = var_node->Var()->Name();
      if (real_vars_.count(var_name) == 0) {
        real_vars_[var_name] = var_node;
        VLOG(4) << var_name << " is in graph " << i;
      }
    }
  }
}

237 238 239 240 241 242 243 244 245 246
void AutoMixedPrecisionPass::ApplyImpl(Graph* graph) const {
  PADDLE_ENFORCE_NOT_NULL(graph,
                          platform::errors::PreconditionNotMet(
                              "During the auto_mixed_precision_pass, the graph "
                              "should not be nullptr."));
  PADDLE_ENFORCE_EQ(graph->IsMainGraph(),
                    true,
                    platform::errors::PreconditionNotMet(
                        "During the auto_mixed_precision_pass, the graph "
                        "should be main graph."));
247

248
  FusePassBase::Init("auto_mixed_precision", graph);
249 250 251

  Init(graph);
  VLOG(4) << "Init done";
252 253 254 255 256 257

  if (skip_pass_) {
    VLOG(3) << "Skip auto_mixed_precision_pass.";
    return;
  }

258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275
  SetOpUniqueType();
  VLOG(4) << "SetOpUniqueType done";
  GetOpPrecision();
  VLOG(4) << "GetOpPrecision done";
  UpdateOpPrecision();
  VLOG(4) << "UpdateOpPrecision done";
  SetVarPrecision();
  VLOG(4) << "SetVarPrecision done";
  ConvertWeightsData();
  VLOG(4) << "ConvertWeightsData done";
  ProcessOpWithDtypeAttr();
  VLOG(4) << "ProcessOpWithDtypeAttr done";
  InsertCastOp();
  VLOG(4) << "InsertCastOp done";
  RestoreOpOriginType();
  VLOG(4) << "RestoreOpOriginType done";
}

276
void AutoMixedPrecisionPass::SetOpUniqueType() const {
277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292
  int suffix = 0;
  for (const auto& nodes : all_op_nodes_) {
    for (auto* op_node : nodes) {
      auto op_type = op_node->Op()->Type();

      if (op_type == "feed" || op_type == "fetch") continue;

      std::string unique_type = op_type + "_" + std::to_string(suffix++);
      op_original_type_[unique_type] = op_type;
      op_node->Op()->SetType(unique_type);
      op_node->Op()->Flush();
      VLOG(4) << "change op type: " << op_type << " ---> " << unique_type;
    }
  }
}

293
void AutoMixedPrecisionPass::RestoreOpOriginType() const {
294 295 296 297 298 299 300 301 302 303 304
  for (const auto& nodes : all_op_nodes_) {
    for (auto* op_node : nodes) {
      auto op_type = op_node->Op()->Type();
      op_node->Op()->SetType(GetOpOriginalType(op_type));
      op_node->Op()->Flush();
      VLOG(4) << "restore op type: " << op_type << " ---> "
              << op_node->Op()->Type();
    }
  }
}

305
inline std::string AutoMixedPrecisionPass::GetOpOriginalType(
306 307 308 309 310 311 312
    const std::string& op_type) const {
  if (op_original_type_.count(op_type)) {
    return op_original_type_.at(op_type);
  }
  return op_type;
}

313
void AutoMixedPrecisionPass::ProcessOpWithDtypeAttr() const {
314 315 316
  for (const auto& nodes : all_op_nodes_) {
    for (auto* op_node : nodes) {
      auto op_type = op_node->Op()->Type();
317
      if (op_run_low_precision_.count(op_type) == 0) continue;
318 319 320 321 322 323

      if (op_node->Op()->HasAttr("dtype")) {
        auto dtype = op_node->Op()->GetAttrIfExists<int>("dtype");
        if (IsFloatType(static_cast<VarType::Type>(dtype))) {
          op_node->Op()->SetAttr(
              "dtype",
324
              static_cast<int>(framework::TransToProtoVarType(low_precision_)));
325 326
          op_node->Op()->Flush();
          VLOG(4) << "process op with dtype attr: " << op_type << " ( " << dtype
327
                  << " --->" << static_cast<int>(low_precision_) << " )";
328 329 330 331 332 333 334
        }
      }
      if (op_node->Op()->HasAttr("out_dtype")) {
        auto out_dtype = op_node->Op()->GetAttrIfExists<int>("out_dtype");
        if (IsFloatType(static_cast<VarType::Type>(out_dtype))) {
          op_node->Op()->SetAttr(
              "out_dtype",
335
              static_cast<int>(framework::TransToProtoVarType(low_precision_)));
336 337
          op_node->Op()->Flush();
          VLOG(4) << "process op with out_dtype attr: " << op_type << " ( "
338
                  << out_dtype << " --->" << static_cast<int>(low_precision_)
339 340 341 342 343 344 345
                  << " )";
        }
      }
    }
  }
}

346
void AutoMixedPrecisionPass::GetOpPrecision() const {
347 348 349
  for (const auto& nodes : all_op_nodes_) {
    for (auto* op_node : nodes) {
      auto op_type = op_node->Op()->Type();
350
      bool support_low_precision = true;
351 352
      if (GetOpOriginalType(op_type) == "feed" ||
          GetOpOriginalType(op_type) == "fetch") {
353
        support_low_precision = !keep_io_types_;
354
      } else {
355 356
        support_low_precision = OpSupportPrecision(
            GetOpOriginalType(op_type), backend_, low_precision_, black_list_);
357 358 359 360
      }

      if (op_node->Op()->HasAttr("dtype")) {
        auto dtype = op_node->Op()->GetAttrIfExists<int>("dtype");
361 362
        support_low_precision = support_low_precision &&
                                IsFloatType(static_cast<VarType::Type>(dtype));
363 364
      } else if (op_node->Op()->HasAttr("out_dtype")) {
        auto out_dtype = op_node->Op()->GetAttrIfExists<int>("out_dtype");
365 366 367
        support_low_precision =
            support_low_precision &&
            IsFloatType(static_cast<VarType::Type>(out_dtype));
368 369
      } else {
        // if op's input var and output var is not dense tensor, the op should
370
        // not run at low precision.
371 372 373 374 375
        for (auto* in_var_node : op_node->inputs) {
          CHECK_EQ(in_var_node->IsVar(), true);
          auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
          if (real_in_var_node->Var()->Persistable()) continue;

376 377 378
          support_low_precision =
              support_low_precision &&
              (real_in_var_node->Var()->GetType() == VarType::LOD_TENSOR);
379 380 381 382 383 384 385
        }

        for (auto* out_var_node : op_node->outputs) {
          CHECK_EQ(out_var_node->IsVar(), true);
          auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()];
          if (real_out_var_node->Var()->Persistable()) continue;

386 387 388
          support_low_precision =
              support_low_precision &&
              (real_out_var_node->Var()->GetType() == VarType::LOD_TENSOR);
389 390 391
        }
      }

392 393 394
      if (support_low_precision) {
        op_run_low_precision_.insert(op_type);
        VLOG(4) << "support precision: " << op_type << " run at low precision";
395
      } else {
396 397
        VLOG(4) << "support precision: " << op_type
                << " not run at low precision";
398 399 400 401 402
      }
    }
  }
}

403 404
void AutoMixedPrecisionPass::UpdateOpPrecision() const {
  std::unordered_set<std::string> vars_should_not_low_precision;
405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426

  // var -> the var's all input op
  std::unordered_map<std::string, std::vector<Node*>> var_input_ops;

  auto GetVarInputOps = [&] {
    for (const auto& nodes : all_op_nodes_) {
      for (auto* op_node : nodes) {
        auto op_type = op_node->Op()->Type();

        if (GetOpOriginalType(op_type) == "fetch") continue;
        if (op_node->Op()->HasAttr("sub_block")) continue;

        for (auto* var_node : op_node->outputs) {
          CHECK_EQ(var_node->IsVar(), true);
          if (var_node->Var()->Persistable()) continue;
          if (!VarNodeHasDtype(var_node)) continue;

          var_input_ops[var_node->Var()->Name()].push_back(op_node);
          VLOG(4) << "var input ops: " << var_node->Var()->Name()
                  << " is output of " << op_type;
        }

427 428 429
        // the select_input op's input var should not convert to low precision.
        // when op's output var is select_input op's input var, the op should
        // not run at low precision.
430 431 432 433 434 435
        if (GetOpOriginalType(op_node->Op()->Type()) == "select_input") {
          for (auto* in_var_node : op_node->inputs) {
            CHECK_EQ(in_var_node->IsVar(), true);
            if (in_var_node->Var()->Persistable()) continue;
            if (!VarNodeHasDtype(in_var_node)) continue;

436
            vars_should_not_low_precision.insert(in_var_node->Var()->Name());
437 438
          }
        }
439 440 441 442 443 444 445 446 447 448 449 450 451 452

        // when op_1 only support cpu kernel. if op_2's intput var is op_1's
        // output var, then op_2 should not run half.
        if (GetOpOriginalType(op_type) != "feed" &&
            !GpuKernelSupportPrecision(GetOpOriginalType(op_type),
                                       phi::DataType::FLOAT32)) {
          for (auto* out_var_node : op_node->outputs) {
            CHECK_EQ(out_var_node->IsVar(), true);
            if (out_var_node->Var()->Persistable()) continue;
            if (!VarNodeHasDtype(out_var_node)) continue;

            vars_should_not_low_precision.insert(out_var_node->Var()->Name());
          }
        }
453 454 455 456 457 458 459 460 461 462
      }
    }
  };
  GetVarInputOps();

  bool precision_updated = false;
  do {
    precision_updated = false;
    for (const auto& nodes : all_op_nodes_) {
      for (auto* op_node : nodes) {
463
        if (op_run_low_precision_.count(op_node->Op()->Type()) == 0) continue;
464

465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483
        for (auto* in_var_node : op_node->inputs) {
          CHECK_EQ(in_var_node->IsVar(), true);
          if (!VarNodeHasDtype(in_var_node)) continue;

          auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
          if (real_in_var_node->Var()->Persistable()) continue;

          if (vars_should_not_low_precision.count(
                  real_in_var_node->Var()->Name())) {
            op_run_low_precision_.erase(op_node->Op()->Type());
            precision_updated = true;
            VLOG(4) << op_node->Op()->Type()
                    << " should not run at low precision.";
            break;
          }
        }

        if (op_run_low_precision_.count(op_node->Op()->Type()) == 0) continue;

484 485 486 487 488 489 490
        for (auto* out_var_node : op_node->outputs) {
          CHECK_EQ(out_var_node->IsVar(), true);
          if (!VarNodeHasDtype(out_var_node)) continue;

          auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()];
          if (real_out_var_node->Var()->Persistable()) continue;

491
          bool not_run_low_precision = false;
492 493
          const auto& input_op_nodes =
              var_input_ops[real_out_var_node->Var()->Name()];
494 495 496
          if (vars_should_not_low_precision.count(
                  real_out_var_node->Var()->Name())) {
            not_run_low_precision = true;
497 498
          } else {
            for (auto* node : input_op_nodes) {
499 500
              if (op_run_low_precision_.count(node->Op()->Type()) == 0) {
                not_run_low_precision = true;
501 502 503 504
                break;
              }
            }
          }
505 506
          if (not_run_low_precision) {
            op_run_low_precision_.erase(op_node->Op()->Type());
507 508
            precision_updated = true;
            VLOG(4) << op_node->Op()->Type()
509
                    << " should not run at low precision.";
510 511 512 513 514 515 516 517 518
            break;
          }
        }
      }
    }
  } while (precision_updated);
}

// special ops, its weights should not be low precision.
519 520
bool AutoMixedPrecisionPass::InputVarsNotConvert(
    Node* op_node, const std::string& var_name) const {
521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559
  auto* op_desc = op_node->Op();
  if (GetOpOriginalType(op_desc->Type()) == "batch_norm") {
    auto vecs = op_desc->Input("Bias");
    if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
      return true;
    }
    vecs = op_desc->Input("Mean");
    if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
      return true;
    }
    vecs = op_desc->Input("Scale");
    if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
      return true;
    }
    vecs = op_desc->Input("Variance");
    if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
      return true;
    }
  } else if (GetOpOriginalType(op_desc->Type()) == "fused_multi_transformer") {
    auto vecs = op_desc->Input("LnScale");
    if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
      return true;
    }
    vecs = op_desc->Input("LnBias");
    if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
      return true;
    }
    vecs = op_desc->Input("FFNLnScale");
    if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
      return true;
    }
    vecs = op_desc->Input("FFNLnBias");
    if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
      return true;
    }
  }
  return false;
}

560 561
bool AutoMixedPrecisionPass::OutputVarsNotConvert(
    Node* op_node, const std::string& var_name) const {
562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584
  auto* op_desc = op_node->Op();
  // batch_norm's input and output (variance and mean) are the same.
  if (GetOpOriginalType(op_desc->Type()) == "batch_norm") {
    auto vecs = op_desc->Output("MeanOut");
    if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
      return true;
    }
    vecs = op_desc->Output("VarianceOut");
    if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
      return true;
    }
    vecs = op_desc->Output("SavedMean");
    if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
      return true;
    }
    vecs = op_desc->Output("SavedVariance");
    if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
      return true;
    }
  }
  return false;
}

585
void AutoMixedPrecisionPass::SetVarPrecision() const {
586 587
  for (const auto& nodes : all_op_nodes_) {
    for (auto* op_node : nodes) {
588 589 590 591 592
      if (op_run_low_precision_.count(op_node->Op()->Type()) == 0) {
        continue;
      }

      if (GetOpOriginalType(op_node->Op()->Type()) != "feed") {
593 594 595 596 597 598 599 600 601 602 603 604
        for (auto* in_var_node : op_node->inputs) {
          CHECK_EQ(in_var_node->IsVar(), true);

          auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
          auto in_var_name = real_in_var_node->Var()->Name();

          if (!IsFloatType(real_in_var_node->Var()->GetDataType())) continue;
          if (!VarNodeHasDtype(real_in_var_node)) continue;
          if (InputVarsNotConvert(op_node, in_var_name)) continue;

          if (real_in_var_node->Var()->Persistable()) {
            real_in_var_node->Var()->SetDataType(
605 606
                framework::TransToProtoVarType(low_precision_));
            vars_convert_to_low_precision_.insert(in_var_name);
607 608
          }
        }
609
      }
610

611
      if (GetOpOriginalType(op_node->Op()->Type()) != "fetch") {
612 613 614 615 616 617 618 619 620 621 622
        for (auto* out_var_node : op_node->outputs) {
          CHECK_EQ(out_var_node->IsVar(), true);

          auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()];
          auto out_var_name = real_out_var_node->Var()->Name();

          if (!IsFloatType(real_out_var_node->Var()->GetDataType())) continue;
          if (!VarNodeHasDtype(real_out_var_node)) continue;
          if (OutputVarsNotConvert(op_node, out_var_name)) continue;

          real_out_var_node->Var()->SetDataType(
623
              framework::TransToProtoVarType(low_precision_));
624
          if (real_out_var_node->Var()->Persistable()) {
625
            vars_convert_to_low_precision_.insert(out_var_name);
626 627 628 629 630 631 632 633 634 635 636 637 638 639
          }
        }
      }
    }
  }

  // This code used to precess vars with the same name. Vars with the same
  // name should have the same data type.
  for (auto* subgraph : subgraphes_) {
    for (auto* var_node : subgraph->Nodes()) {
      if (!var_node->IsVar() || !var_node->Var()->Persistable()) continue;
      if (!VarNodeHasDtype(var_node)) continue;

      auto var_name = var_node->Var()->Name();
640
      if (vars_convert_to_low_precision_.count(var_name)) {
641
        var_node->Var()->SetDataType(
642
            framework::TransToProtoVarType(low_precision_));
643 644 645 646 647
      }
    }
  }
}

648
void AutoMixedPrecisionPass::ConvertWeightsData() const {
649
  auto* scope = param_scope();
650 651 652 653
  PADDLE_ENFORCE_NOT_NULL(scope,
                          platform::errors::PreconditionNotMet(
                              "During the auto_mixed_precision_pass, the scope "
                              "should not be null."));
654 655 656

  auto var_names = scope->LocalVarNames();
  for (const auto& var_name : var_names) {
657
    if (vars_convert_to_low_precision_.count(var_name)) {
658 659 660
      VLOG(4) << var_name << "'s data type was convert to half";

      auto* var = scope->FindLocalVar(var_name);
661 662 663 664
      CHECK_EQ(var->IsType<phi::DenseTensor>(), true);

      auto* origin_tensor = var->GetMutable<phi::DenseTensor>();

665 666 667
      phi::DenseTensor low_precision_tensor;
      low_precision_tensor.Resize(origin_tensor->dims());
      low_precision_tensor.set_type(low_precision_);
668

669 670 671 672
      if (low_precision_ == phi::DataType::FLOAT16) {
        auto* low_precision_data =
            low_precision_tensor.mutable_data<phi::dtype::float16>(
                phi::CPUPlace{});
673 674 675
        for (int64_t i = 0; i < origin_tensor->numel(); i++) {
          if (origin_tensor->dtype() == phi::DataType::FLOAT64) {
            auto* origin_data = origin_tensor->data<double>();
676 677
            low_precision_data[i] =
                static_cast<phi::dtype::float16>(origin_data[i]);
678 679
          } else if (origin_tensor->dtype() == phi::DataType::FLOAT32) {
            auto* origin_data = origin_tensor->data<float>();
680 681
            low_precision_data[i] =
                static_cast<phi::dtype::float16>(origin_data[i]);
682 683
          }
        }
684
      } else if (low_precision_ == phi::DataType::BFLOAT16) {
685
        auto* half_data =
686 687
            low_precision_tensor.mutable_data<phi::dtype::bfloat16>(
                phi::CPUPlace{});
688 689 690 691 692 693 694 695
        for (int64_t i = 0; i < origin_tensor->numel(); i++) {
          if (origin_tensor->dtype() == phi::DataType::FLOAT64) {
            auto* origin_data = origin_tensor->data<double>();
            half_data[i] = static_cast<phi::dtype::bfloat16>(origin_data[i]);
          } else if (origin_tensor->dtype() == phi::DataType::FLOAT32) {
            auto* origin_data = origin_tensor->data<float>();
            half_data[i] = static_cast<phi::dtype::bfloat16>(origin_data[i]);
          }
696 697
        }
      }
698 699
      origin_tensor->clear();
      paddle::framework::TensorCopySync(
700
          low_precision_tensor, phi::CPUPlace{}, origin_tensor);
701 702 703 704
    }
  }
}

705
void AutoMixedPrecisionPass::InsertCastOp() const {
706 707 708 709 710 711 712 713 714 715 716 717 718
  int suffix = 0;
  std::unordered_map<Node*, Node*> cache;

  for (size_t i = 0; i < all_op_nodes_.size(); i++) {
    auto* block_desc = all_op_nodes_[i][0]->Op()->Block();
    CHECK_NOTNULL(block_desc);
    for (auto* op_node : all_op_nodes_[i]) {
      auto op_type = op_node->Op()->Type();

      if (GetOpOriginalType(op_type) == "feed") continue;
      if (op_node->Op()->HasAttr("sub_block")) continue;

      VLOG(4) << "process op: " << op_type
719
              << " run low precision: " << op_run_low_precision_.count(op_type);
720 721 722 723 724 725 726 727 728 729 730 731 732 733

      auto inputs = op_node->inputs;
      for (auto* in_var_node : inputs) {
        if (!in_var_node->IsVar()) continue;
        if (!VarNodeHasDtype(in_var_node)) continue;
        if (in_var_node->Var()->Persistable()) continue;

        auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];

        auto in_var_type = real_in_var_node->Var()->GetDataType();

        VLOG(4) << "process var: " << real_in_var_node->Var()->Name()
                << " with type " << in_var_type;

734
        if (IsFloatType(in_var_type) && op_run_low_precision_.count(op_type)) {
735 736 737 738
          DoInsertCastOp(subgraphes_[i],
                         in_var_node,
                         op_node,
                         in_var_type,
739
                         framework::TransToProtoVarType(low_precision_),
740 741 742 743
                         block_desc,
                         &suffix,
                         &cache);
        } else if (IsHalfType(in_var_type) &&
744
                   op_run_low_precision_.count(op_type) == 0) {
745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775
          DoInsertCastOp(subgraphes_[i],
                         in_var_node,
                         op_node,
                         in_var_type,
                         VarType::FP32,
                         block_desc,
                         &suffix,
                         &cache);
        }
      }

      // Special op.
      // fused_multi_transformer's input(CacheKV) and output(CacheKVOut) vars
      // have same name.
      if (GetOpOriginalType(op_type) == "fused_multi_transformer") {
        auto cache_kv_inputs = op_node->Op()->Input("CacheKV");
        auto cache_kv_outputs = op_node->Op()->Output("CacheKVOut");
        CHECK_EQ(cache_kv_inputs.size(), cache_kv_outputs.size());
        for (size_t i = 0; i < cache_kv_inputs.size(); ++i) {
          op_node->Op()->RenameOutput(cache_kv_outputs[i], cache_kv_inputs[i]);
        }
      }
    }
  }
  VLOG(4) << "insert number of cast op: " << cache.size();
}

}  // namespace ir
}  // namespace framework
}  // namespace paddle

776 777
REGISTER_PASS(auto_mixed_precision_pass,
              paddle::framework::ir::AutoMixedPrecisionPass);