convert_to_mixed_precision.cc 29.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h"

W
Wilber 已提交
17 18
#include <algorithm>
#include <iterator>
19
#include <memory>
20
#include <string>
W
Wilber 已提交
21
#include <unordered_map>
22
#include <unordered_set>
23
#include <utility>
24 25 26

#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/executor.h"
27
#include "paddle/fluid/framework/framework.pb.h"
28 29 30
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
31
#include "paddle/fluid/framework/ir/node.h"
32 33
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
34
#include "paddle/fluid/framework/var_desc.h"
35
#include "paddle/fluid/inference/io.h"
36
#include "paddle/phi/common/bfloat16.h"
37
#include "paddle/phi/common/data_type.h"
38
#include "paddle/phi/common/float16.h"
39
#include "paddle/phi/common/layout.h"
40
#include "paddle/phi/common/place.h"
41 42 43 44 45 46

namespace paddle {
namespace inference {
namespace analysis {

namespace {
47 48
using VarType = framework::proto::VarType;

W
Wilber 已提交
49
bool PhiKernelSupportPrecision(
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
    const std::string& op_type,
    phi::Backend backend,
    phi::DataType data_type,
    phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) {
  auto kernels = phi::KernelFactory::Instance().kernels();
  if (kernels.find(op_type) == kernels.end()) {
    return false;
  }
  phi::KernelKey kernel_key(backend, layout, data_type);
  return phi::KernelFactory::Instance().HasKernel(op_type, kernel_key);
}

bool GpuKernelSupportPrecision(
    const std::string& op_type,
    phi::DataType data_type,
    phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) {
W
Wilber 已提交
66 67 68 69 70 71 72
  auto phi_op_type = phi::TransToPhiKernelName(op_type);
  bool res = PhiKernelSupportPrecision(
      phi_op_type, phi::Backend::GPU, data_type, layout);
  res |= PhiKernelSupportPrecision(
      phi_op_type, phi::Backend::GPUDNN, data_type, layout);

  if (!res) {
73
    auto& all_kernels = framework::OperatorWithKernel::AllOpKernels();
W
Wilber 已提交
74 75 76
    auto it = all_kernels.find(op_type);
    if (it != all_kernels.end()) {
      for (auto& kern_pair : it->second) {
77
        if (platform::is_gpu_place(kern_pair.first.place_) &&
78
            kern_pair.first.data_type_ == VarType::FP16) {
W
Wilber 已提交
79
          res = true;
80
          break;
W
Wilber 已提交
81 82 83 84
        }
      }
    }
  }
85 86 87
  return res;
}

88
class ConvertToMixedPrecisionPass {
89 90
  using BlockID = size_t;

91 92 93 94 95 96 97 98 99
 public:
  explicit ConvertToMixedPrecisionPass(
      const std::string& model_file,
      const std::string& params_file,
      const std::string& mixed_model_file,
      const std::string& mixed_params_file,
      phi::DataType mixed_precision,
      phi::Backend backend,
      bool keep_io_types,
100
      const std::unordered_set<std::string>& black_list)
101 102 103 104 105 106 107 108 109
      : model_file_(model_file),
        params_file_(params_file),
        mixed_model_file_(mixed_model_file),
        mixed_params_file_(mixed_params_file),
        mixed_precision_(mixed_precision),
        backend_(backend),
        keep_io_types_(keep_io_types),
        black_list_(black_list),
        place_(paddle::CPUPlace()),
110
        executor_(place_) {
111 112 113 114
    VLOG(4) << "black_list has ";
    for (auto& name : black_list_) {
      VLOG(4) << " - " << name;
    }
115
  }
116

117 118 119 120
  void Run();

 private:
  void LoadAndPrepare();
121
  inline bool VarNodeHasDtype(framework::ir::Node* node);
122 123 124
  void ConvertAllFp64ToFp32(framework::ir::Graph* graph);
  void FixCastAttr(framework::ir::Graph* graph);
  void SaveMixedModel();
125
  void ConvertTensorDtype(BlockID block_idx);
126
  void ProcessInputNode(bool support_precision,
127 128
                        framework::ir::Node* in_node,
                        framework::ir::Node* op_node,
129 130
                        int* suffix,
                        framework::BlockDesc* block_desc,
131 132
                        VarType::Type to_type,
                        BlockID block_idx);
133

134 135 136 137
  void ProcessOutputNode(BlockID block_idx,
                         framework::ir::Node* var_node,
                         VarType::Type to_type);
  inline bool IsFloatVarType(VarType::Type type);
138

139
  bool OutShouldNotConvert(framework::ir::Node* var_node);
140
  // Just process special cases for weights conversion.
141
  bool WeightsShouldNotConvert(framework::ir::Node* var_node);
142 143

  // Return Node* which first appers in block.
144
  framework::ir::Node* GetRealVarNode(framework::ir::Node* node);
145

146 147
  // Fallback to fp32 dtype when encounter circle (Not a DAG graph).
  void ProcessCircleCases();
148 149 150 151 152 153 154 155 156 157 158 159 160 161

 private:
  std::string model_file_;
  std::string params_file_;
  std::string mixed_model_file_;
  std::string mixed_params_file_;
  phi::DataType mixed_precision_;
  phi::Backend backend_;
  bool keep_io_types_;
  std::unordered_set<std::string> black_list_;
  paddle::CPUPlace place_;
  framework::Executor executor_;
  framework::Scope scope_;

162
  std::unordered_map<std::string, framework::ir::Node*> name2node_;
163 164 165
  std::unordered_map<framework::ir::Node*, framework::ir::Node*> cast_map_;
  int suffix_{0};

166 167
  std::set<std::string> var_names_in_circles_;

168 169 170 171 172
  std::unique_ptr<framework::ProgramDesc> program_desc_{nullptr};
  std::unique_ptr<framework::ir::Graph> main_graph_{nullptr};
  std::vector<framework::ir::Graph*> graphes_;
};

173
framework::ir::Node* ConvertToMixedPrecisionPass::GetRealVarNode(
174
    framework::ir::Node* var_node) {
175
  CHECK_EQ(var_node->IsVar(), true);
176
  if (name2node_.count(var_node->Name())) return name2node_[var_node->Name()];
177
  return var_node;
178 179
}

180 181 182 183 184 185 186
inline bool ConvertToMixedPrecisionPass::VarNodeHasDtype(
    framework::ir::Node* var_node) {
  CHECK_EQ(var_node->IsVar(), true);
  auto type = var_node->Var()->GetType();
  return (type == VarType::SELECTED_ROWS) || (type == VarType::LOD_TENSOR) ||
         (type == VarType::LOD_TENSOR_ARRAY) || (type == VarType::STRINGS) ||
         (type == VarType::VOCAB);
187 188 189 190
}

void ConvertToMixedPrecisionPass::ProcessInputNode(
    bool support_precision,
191 192
    framework::ir::Node* in_node,
    framework::ir::Node* op_node,
193 194
    int* suffix,
    framework::BlockDesc* block_desc,
195 196 197
    VarType::Type to_type,
    BlockID block_idx) {
  if (!in_node->IsVar()) return;
198
  auto* real_node = GetRealVarNode(in_node);
199 200
  if (!VarNodeHasDtype(real_node)) return;
  auto* graph = graphes_[block_idx];
201 202 203 204 205
  auto* in_var = real_node->Var();
  auto in_var_type = in_var->GetDataType();
  auto prev_type = in_var_type;

  if (support_precision) {
206
    if (in_var->Persistable() && in_var_type == VarType::FP32) {
207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242
      if (WeightsShouldNotConvert(in_node)) return;
      in_var->SetDataType(to_type);
      in_var_type = to_type;
      VLOG(3) << "   in_node name " << in_var->Name() << " from " << prev_type
              << " to " << to_type;
    } else if (!in_var->Persistable() && IsFloatVarType(in_var_type) &&
               in_var_type != to_type) {
      AddCastOp(graph,
                in_node,
                op_node,
                in_var_type,
                to_type,
                suffix,
                block_desc,
                &cast_map_);
      VLOG(3) << "   in_node name " << in_var->Name() << "(" << prev_type
              << ") to " << cast_map_[in_node]->Name() << "(" << to_type << ")";
    }
  } else {
    if (!in_var->Persistable() && IsFloatVarType(in_var_type) &&
        in_var_type != to_type) {
      AddCastOp(graph,
                in_node,
                op_node,
                in_var_type,
                to_type,
                suffix,
                block_desc,
                &cast_map_);
      VLOG(3) << "   in_node name " << in_var->Name() << "(" << prev_type
              << ") to " << cast_map_[in_node]->Name() << "(" << to_type << ")";
    }
  }
}

void ConvertToMixedPrecisionPass::ProcessOutputNode(
243 244
    BlockID block_idx, framework::ir::Node* var_node, VarType::Type to_type) {
  if (!var_node->IsVar()) return;
245
  auto* real_node = GetRealVarNode(var_node);
246
  if (!VarNodeHasDtype(real_node)) return;
247 248
  auto* out_var = real_node->Var();
  auto prev_type = out_var->GetDataType();
249
  if (out_var->GetDataType() == VarType::FP32) {
250 251 252 253 254 255 256
    if (OutShouldNotConvert(var_node)) return;
    out_var->SetDataType(to_type);
  }
  VLOG(3) << "   out_node name " << var_node->Name() << " from dtype "
          << prev_type << " to " << out_var->GetDataType();
}

257
// Just process special cases.
258 259
bool ConvertToMixedPrecisionPass::OutShouldNotConvert(
    framework::ir::Node* var_node) {
260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284
  auto op_node = var_node->inputs[0];
  auto* op_desc = op_node->Op();

  // batch_norm's input and output (variance and mean) are the same.
  if (op_desc->Type() == "batch_norm") {
    auto vecs = op_desc->Output("MeanOut");
    if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
      return true;
    }
    vecs = op_desc->Output("VarianceOut");
    if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
      return true;
    }
    vecs = op_desc->Output("SavedMean");
    if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
      return true;
    }
    vecs = op_desc->Output("SavedVariance");
    if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
      return true;
    }
  }

  return false;
}
285

286 287
bool ConvertToMixedPrecisionPass::WeightsShouldNotConvert(
    framework::ir::Node* var_node) {
288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309
  auto op_nodes = var_node->outputs;
  for (auto* op_node : op_nodes) {
    auto* op_desc = op_node->Op();
    // batch_norm op's bias, mean, scale and variance just be float32, so we can
    // not convert the dtype.
    if (op_desc->Type() == "batch_norm") {
      auto vecs = op_desc->Input("Bias");
      if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
        return true;
      }
      vecs = op_desc->Input("Mean");
      if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
        return true;
      }
      vecs = op_desc->Input("Scale");
      if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
        return true;
      }
      vecs = op_desc->Input("Variance");
      if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
        return true;
      }
310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329
    } else if (op_desc->Type() == "fused_multi_transformer") {
      auto vecs = op_desc->Input("LnScale");
      if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
        return true;
      }

      vecs = op_desc->Input("LnBias");
      if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
        return true;
      }

      vecs = op_desc->Input("FFNLnScale");
      if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
        return true;
      }

      vecs = op_desc->Input("FFNLnBias");
      if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
        return true;
      }
330 331 332 333 334
    }
  }

  return false;
}
335 336 337 338

inline bool ConvertToMixedPrecisionPass::IsFloatVarType(VarType::Type type) {
  return (type == VarType::FP16) || (type == VarType::FP32) ||
         (type == VarType::BF16);
W
Wilber 已提交
339
}
340

341 342 343 344 345
void ConvertToMixedPrecisionPass::LoadAndPrepare() {
  program_desc_ =
      inference::Load(&executor_, &scope_, model_file_, params_file_);
  main_graph_ = std::unique_ptr<framework::ir::Graph>(
      new framework::ir::Graph(*program_desc_));
346

347 348 349
  for (size_t i = 0; i < main_graph_->SubGraphsSize(); ++i) {
    auto* graph = main_graph_->GetSubGraph(i);
    graphes_.push_back(graph);
350 351 352 353 354 355 356

    for (auto* node : graph->Nodes()) {
      if (!node->IsVar()) continue;
      if (!name2node_.count(node->Name())) {
        name2node_[node->Name()] = node;
      }
    }
357
  }
358

359
  ProcessCircleCases();
360 361
}

362 363 364 365
// Find var names which in circles.
void ConvertToMixedPrecisionPass::ProcessCircleCases() {
  std::vector<std::string> vars_in_circles;
  for (size_t idx = 0; idx < program_desc_->Size(); ++idx) {
366
    for (auto* op : program_desc_->Block(idx).AllOps()) {
367 368 369
      // TODO(inference): batch_norm has circle, but we need to fuse it in conv
      // op.
      if (op->Type() == "batch_norm") continue;
370 371
      const auto& in_names = op->InputArgumentNames();
      const auto& out_names = op->OutputArgumentNames();
372 373 374 375 376 377 378
      std::set<std::string> in_names_set(in_names.begin(), in_names.end());
      std::set<std::string> out_names_set(out_names.begin(), out_names.end());
      std::set_intersection(in_names_set.begin(),
                            in_names_set.end(),
                            out_names_set.begin(),
                            out_names_set.end(),
                            std::back_inserter(vars_in_circles));
W
Wilber 已提交
379
    }
380 381
  }

382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420
  for (auto& name : vars_in_circles) {
    var_names_in_circles_.insert(name);
  }
  for (auto& name : var_names_in_circles_) {
    LOG(INFO) << name
              << " in circles, so we will skip process those vars and ops.";
  }
}

inline void ProcessConstantOpAttr(framework::ir::Node* op_node,
                                  VarType::Type from_type,
                                  VarType::Type to_type) {
  if (!op_node->IsOp()) return;
  auto op_type = op_node->Op()->Type();
  if (op_type == "feed" || op_type == "fetch") return;

  if (op_type == "fill_constant") {
    if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) ==
        static_cast<int>(from_type))
      op_node->Op()->SetAttr("dtype", static_cast<int>(to_type));
  } else if (op_type == "assign_value") {
    if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) ==
        static_cast<int>(from_type))
      op_node->Op()->SetAttr("dtype", static_cast<int>(to_type));
  } else if (op_type == "eye") {
    if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) ==
        static_cast<int>(from_type))
      op_node->Op()->SetAttr("dtype", static_cast<int>(to_type));
  } else if (op_type == "fill_any_like") {
    if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) ==
        static_cast<int>(from_type))
      op_node->Op()->SetAttr("dtype", static_cast<int>(to_type));
  } else if (op_type == "cast") {
    if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("in_dtype")) ==
        static_cast<int>(from_type))
      op_node->Op()->SetAttr("in_dtype", static_cast<int>(to_type));
    if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("out_dtype")) ==
        static_cast<int>(from_type))
      op_node->Op()->SetAttr("out_dtype", static_cast<int>(to_type));
W
Wilber 已提交
421 422
  }
}
W
Wilber 已提交
423

424 425
void ConvertToMixedPrecisionPass::ConvertAllFp64ToFp32(
    framework::ir::Graph* graph) {
426 427 428 429
  auto op_nodes = framework::ir::TopologySortOperations(*graph);
  for (auto* op_node : op_nodes) {
    if (!op_node->IsOp()) continue;
    auto op_type = op_node->Op()->Type();
430
    ProcessConstantOpAttr(op_node, VarType::FP64, VarType::FP32);
431 432 433
    auto inputs = op_node->inputs;
    for (auto* in_node : inputs) {
      auto* in_var = in_node->Var();
434 435
      if (!in_var->Persistable() && in_var->GetDataType() == VarType::FP64) {
        in_var->SetDataType(VarType::FP32);
436 437 438 439 440
      }
    }
  }
}

441 442
void ConvertToMixedPrecisionPass::Run() {
  LoadAndPrepare();
443

444 445
  for (size_t i = 0; i < graphes_.size(); ++i) {
    auto* graph = graphes_[i];
446 447
    VLOG(2) << " --------  handle subgraph " << i << ", has "
            << graph->Nodes().size() << " nodes --------";
448

449 450 451
    ConvertAllFp64ToFp32(graph);
    ConvertTensorDtype(i);
    FixCastAttr(graph);
452

453
    CHECK_EQ(framework::ir::VarDescIsConsistency(*graph), true);
W
Wilber 已提交
454 455
  }

456
  SaveMixedModel();
457 458
}

459 460 461
void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) {
  auto* graph = graphes_[block_idx];
  VarType::Type to_type;
462
  if (mixed_precision_ == phi::DataType::FLOAT16) {
463
    to_type = VarType::FP16;
464
  } else if (mixed_precision_ == phi::DataType::BFLOAT16) {
465
    to_type = VarType::BF16;
466 467
  } else {
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
W
Wilber 已提交
468
        "mixed_precision currently not supported dtype %d, we now only "
469
        "support fp16 and bf16.",
470
        static_cast<int>(mixed_precision_)));
471 472
  }

473 474
  auto op_nodes = framework::ir::TopologySortOperations(*graph);
  auto* block_desc = op_nodes[0]->Op()->Block();
475 476
  int num_low_precision = 0;
  std::vector<framework::ir::Node*> output_nodes;
477

478
  for (auto* op_node : op_nodes) {
479 480
    if (!op_node->IsOp()) continue;
    auto op_type = op_node->Op()->Type();
W
Wilber 已提交
481 482
    VLOG(3) << "-------------------- op_type " << op_type << ", phi_type "
            << phi::TransToPhiKernelName(op_type);
483 484 485
    // 1. set input dtype.
    if (op_type == "feed") {
      auto feed_var = op_node->outputs[0]->Var();
486
      if (!keep_io_types_ && feed_var->GetDataType() == VarType::FP32) {
487 488 489 490 491 492
        feed_var->SetDataType(to_type);
      }
    } else if (op_type == "fetch") {
      auto* fetch_var = op_node->inputs[0];
      output_nodes.push_back(fetch_var);
      continue;
493 494
    } else if (op_type == "cast") {
      continue;
495 496
    }

497 498
    // We can not add cast operator before ops who have sub_block, as in
    // sub_block we may get a var which may be transformer by cast op.
W
Wilber 已提交
499 500 501 502
    else if (op_node->Op()->HasAttr("sub_block")) {  // NOLINT
      continue;
    }

503 504 505 506
    // 2. if op support fp16/bf16 and not in blacklist.
    //      - cast weight to fp16/bf16.
    //      - add cast op if the input dtype is not fp16/bf16.
    //      - set output dtype.
507
    else if (black_list_.count(op_type) == 0) {  // NOLINT
508
      bool support_precision =
509
          OpSupportPrecision(op_type, backend_, mixed_precision_, black_list_);
510

511 512 513 514 515 516 517 518 519 520 521
      // If op's output in circle, we should not convert to fp16.
      for (auto* out_node : op_node->outputs) {
        if (var_names_in_circles_.count(out_node->Name())) {
          support_precision = false;
          VLOG(2) << " op's output " << out_node->Name()
                  << " is in circle, we can not support this case, just skip.";
          break;
        }
      }

      // If the op has no input or output of float type, we will not choose the
522
      // low precision kernel.
523 524
      if (support_precision) {
        bool has_float_in_out{false};
525 526
        for (auto* in_node : op_node->inputs) {
          if (!in_node->IsVar()) continue;
527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544
          if (in_node->Var()->GetType() != VarType::LOD_TENSOR) {
            support_precision = false;
            VLOG(2) << " op has tensor array input[" << in_node->Name()
                    << "], just skip.";
            break;
          }
          auto* real_node = GetRealVarNode(in_node);
          if (real_node->Var()->GetDataType() == VarType::FP16 ||
              real_node->Var()->GetDataType() == VarType::FP32 ||
              real_node->Var()->GetDataType() == VarType::FP64 ||
              real_node->Var()->GetDataType() == VarType::BF16) {
            has_float_in_out = true;
            break;
          }
        }
        for (auto* out_node : op_node->outputs) {
          if (!out_node->IsVar()) continue;
          auto* real_node = GetRealVarNode(out_node);
545 546 547 548
          if (real_node->Var()->GetDataType() == VarType::FP16 ||
              real_node->Var()->GetDataType() == VarType::FP32 ||
              real_node->Var()->GetDataType() == VarType::FP64 ||
              real_node->Var()->GetDataType() == VarType::BF16) {
549
            has_float_in_out = true;
550 551 552
            break;
          }
        }
553

554
        if (!has_float_in_out) {
W
Wilber 已提交
555
          support_precision = false;
556
          VLOG(2) << " op doesn't has float input and output, just skip.";
W
Wilber 已提交
557 558
        }
      }
559

560 561
      VLOG(2) << "op type: " << op_type
              << " support low precision: " << support_precision;
W
Wilber 已提交
562

563
      if (support_precision) {
564
        ProcessConstantOpAttr(op_node, VarType::FP32, to_type);
565
        VLOG(2) << " process input nodes:";
566 567 568
        ++num_low_precision;
        auto inputs = op_node->inputs;
        for (auto* in_node : inputs) {
569 570
          ProcessInputNode(
              true, in_node, op_node, &suffix_, block_desc, to_type, block_idx);
571
        }
572

573
        VLOG(2) << " process output nodes:";
574 575
        auto outputs = op_node->outputs;
        for (auto* out_node : outputs) {
576
          ProcessOutputNode(block_idx, out_node, to_type);
577 578 579 580
        }
      } else {
        auto inputs = op_node->inputs;
        for (auto* in_node : inputs) {
W
Wilber 已提交
581 582 583
          ProcessInputNode(false,
                           in_node,
                           op_node,
584
                           &suffix_,
W
Wilber 已提交
585
                           block_desc,
586
                           VarType::FP32,
587
                           block_idx);
588 589 590 591 592 593 594
        }
      }
    }

    // 3. check op not support fp16/bf16 or in blacklist.
    //      - add cast op if the input dtype is not fp32.
    else {  // NOLINT
595 596 597 598
      VLOG(3) << "not to run fp16 op_type: " << op_type << ", node input size "
              << op_node->inputs.size();
      auto in_nodes = op_node->inputs;
      for (auto* in_node : in_nodes) {
599 600 601 602 603 604
        auto* in_var = in_node->Var();
        if (in_var->GetDataType() == to_type) {
          AddCastOp(graph,
                    in_node,
                    op_node,
                    to_type,
605
                    VarType::FP32,
606
                    &suffix_,
607
                    block_desc,
608 609
                    &cast_map_);
          VLOG(3) << "-- " << in_node->Name() << "(" << to_type << ") to "
610
                  << cast_map_[in_node]->Name() << "(" << VarType::FP32 << ")";
611 612 613 614 615
        }
      }
    }
  }

W
Wilber 已提交
616 617
  // 4. if output_op's dtype is not compatible to output dtype, then just
  // insert cast.
618
  for (auto* node : output_nodes) {
619
    framework::ir::Node* fetch_op{nullptr};
620 621 622 623 624 625
    for (auto* op_node : node->outputs) {
      if (op_node->IsOp() && op_node->Op()->Type() == "fetch") {
        fetch_op = op_node;
      }
    }
    CHECK_NOTNULL(fetch_op);
626
    auto* var = node->Var();
627
    if (keep_io_types_ && var->GetDataType() == to_type) {
628 629 630
      // fp16/bf16 -> fp32.
      AddCastOp(graph,
                node,
631
                fetch_op,
632
                to_type,
633
                VarType::FP32,
634
                &suffix_,
635
                block_desc,
636
                &cast_map_);
637
    } else if (!keep_io_types_ && var->GetDataType() == VarType::FP32) {
638 639 640
      // fp32 -> fp16/bf16
      AddCastOp(graph,
                node,
641
                fetch_op,
642
                VarType::FP32,
643
                to_type,
644
                &suffix_,
645
                block_desc,
646
                &cast_map_);
647 648 649 650
    }
  }

  if (num_low_precision)
651 652
    LOG(INFO) << "---  detected " << num_low_precision
              << " low precision ops in " << block_idx << " subgraph";
653 654
}

655 656
// We modify op's input output precision, and we need to fix cast op in_dtype
// and out_dtype attribute.
657
// TODO(inference): we need a cast elimination pass.
658 659 660 661 662 663 664 665 666 667 668 669
void ConvertToMixedPrecisionPass::FixCastAttr(framework::ir::Graph* graph) {
  auto op_nodes = framework::ir::TopologySortOperations(*graph);
  for (auto* op_node : op_nodes) {
    if (!op_node->IsOp()) continue;
    auto op_type = op_node->Op()->Type();
    if (op_type != "cast") continue;
    auto input = op_node->inputs[0];
    auto output = op_node->outputs[0];
    op_node->Op()->SetAttr("in_dtype",
                           static_cast<int>(input->Var()->GetDataType()));
    op_node->Op()->SetAttr("out_dtype",
                           static_cast<int>(output->Var()->GetDataType()));
670 671 672
  }
}

673 674 675 676 677 678 679 680 681
void ConvertToMixedPrecisionPass::SaveMixedModel() {
  framework::ProgramDesc mixed_program_desc;
  framework::ir::GraphToProgram(*main_graph_, &mixed_program_desc);

  auto parameters = scope_.LocalVarNames();
  std::sort(parameters.begin(), parameters.end());

  std::unordered_set<std::string> weights_should_be_fp32;
  for (auto* node : main_graph_->Nodes()) {
682 683
    if (!node->IsVar()) continue;
    if (VarNodeHasDtype(node)) {
684
      if (node->Var()->Persistable() &&
685
          node->Var()->GetDataType() == VarType::FP32) {
686 687
        VLOG(2) << "weights keep to fp32: " << node->Name() << ", ptr "
                << reinterpret_cast<void*>(node->Var());
688 689 690 691 692 693 694 695
        weights_should_be_fp32.insert(node->Name());
      }
    }
  }

#define CONVERT_TENSOR_DTYPE(DTYPE, dtype)                                   \
  mixed_tensor.set_type(DTYPE);                                              \
  auto* mixed_data = mixed_tensor.mutable_data<dtype>(platform::CPUPlace()); \
696 697
  for (int64_t i = 0; i < origin_tensor->numel(); i++) {                     \
    mixed_data[i] = static_cast<dtype>(origin_data[i]);                      \
698
  }                                                                          \
699 700 701
  origin_tensor->clear();                                                    \
  paddle::framework::TensorCopySync(                                         \
      mixed_tensor, platform::CPUPlace(), origin_tensor)
702 703

  for (const auto& param_name : parameters) {
704
    if (weights_should_be_fp32.count(param_name)) continue;
705 706
    auto* var = scope_.FindLocalVar(param_name);
    if (var->IsType<phi::DenseTensor>()) {
707 708
      auto* origin_tensor = var->GetMutable<phi::DenseTensor>();
      if (origin_tensor->dtype() != phi::DataType::FLOAT32) continue;
709
      phi::DenseTensor mixed_tensor;
710 711 712 713
      mixed_tensor.Resize(origin_tensor->dims());
      auto* origin_data =
          origin_tensor->mutable_data<float>(platform::CPUPlace());
      if (mixed_precision_ == phi::DataType::FLOAT16) {
714 715
        CONVERT_TENSOR_DTYPE(paddle::experimental::DataType::FLOAT16,
                             phi::dtype::float16);
716
      } else if (mixed_precision_ == phi::DataType::BFLOAT16) {
717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732
        CONVERT_TENSOR_DTYPE(paddle::experimental::DataType::BFLOAT16,
                             phi::dtype::bfloat16);
      }
    }
  }

#undef CONVERT_TENSOR_DTYPE

  auto SerializeParams = [&]() -> std::string {
    std::ostringstream os;
    phi::CPUContext ctx;
    for (const auto& param : parameters) {
      PADDLE_ENFORCE_NOT_NULL(
          scope_.FindVar(param),
          platform::errors::NotFound(
              "Block should already have a '%s' variable", param));
733
      auto* tensor = scope_.FindVar(param)->GetMutable<phi::DenseTensor>();
734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750
      framework::SerializeToStream(os, *tensor, ctx);
    }
    return os.str();
  };

  auto StrToBinary = [](const std::string& path, const std::string& str) {
    std::ofstream file(path.c_str(), std::ios::binary);
    file.write(str.c_str(), str.size());
    file.close();
  };

  StrToBinary(mixed_model_file_,
              mixed_program_desc.Proto()->SerializeAsString());
  StrToBinary(mixed_params_file_, SerializeParams());
}
}  // namespace

751 752 753 754
void AddCastOp(
    framework::ir::Graph* graph,
    framework::ir::Node* node,
    framework::ir::Node* next_op,
755 756
    VarType::Type from_type,
    VarType::Type to_type,
757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795
    int* suffix,
    framework::BlockDesc* block_desc,
    std::unordered_map<framework::ir::Node*, framework::ir::Node*>* map) {
  auto update_cast_desc = [&](framework::OpDesc& desc,
                              const std::string& x_name,
                              const std::string& out_name,
                              const int in_dtype,
                              const int out_dtype) {
    desc.SetType("cast");
    desc.SetInput("X", {x_name});
    desc.SetOutput("Out", {out_name});
    desc.SetAttr("in_dtype", in_dtype);
    desc.SetAttr("out_dtype", out_dtype);
    desc.SetAttr("use_mkldnn", false);
    desc.SetAttr("with_quant_attr", false);
    desc.Flush();
  };

  if (map->count(node) == 0) {
    // insert cast op before node.
    std::string cast_input_name = node->Var()->Name();
    std::string cast_output_name =
        node->Var()->Name() + "_cast.tmp_" + std::to_string((*suffix)++);
    CHECK_NOTNULL(block_desc);
    framework::OpDesc cast_op_desc(block_desc);
    update_cast_desc(cast_op_desc,
                     cast_input_name,
                     cast_output_name,
                     static_cast<int>(from_type),
                     static_cast<int>(to_type));
    auto* cast_op_node = graph->CreateOpNode(&cast_op_desc);
    auto* cast_output_vardesc = block_desc->Var(cast_output_name);
    cast_output_vardesc->SetPersistable(false);
    cast_output_vardesc->SetDataType(to_type);
    cast_output_vardesc->SetShape(node->Var()->GetShape());
    auto* cast_output_node = graph->CreateVarNode(cast_output_vardesc);
    IR_NODE_LINK_TO(cast_op_node, cast_output_node);
    (*map)[node] = cast_output_node;
  }
796
  next_op->Op()->Rename(node->Name(), map->at(node)->Name());
797
  IR_NODE_LINK_TO(node, map->at(node)->inputs[0]);
798
  IR_NODE_UNLINK(node, next_op);
799 800 801
  IR_NODE_LINK_TO(map->at(node), next_op);
}

802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817
bool OpSupportPrecision(const std::string& op_type,
                        phi::Backend backend,
                        phi::DataType precision,
                        const std::unordered_set<std::string>& blacklist) {
  auto phi_op_type = phi::TransToPhiKernelName(op_type);
  bool support_precision = false;
  if (blacklist.count(op_type) == 0) {
    if (backend == phi::Backend::GPU)
      support_precision = GpuKernelSupportPrecision(op_type, precision);
    else
      support_precision =
          PhiKernelSupportPrecision(phi_op_type, backend, precision);
  }
  return support_precision;
}

818 819 820 821 822 823 824 825 826
void ConvertToMixedPrecision(
    const std::string& model_file,
    const std::string& params_file,
    const std::string& mixed_model_file,
    const std::string& mixed_params_file,
    phi::DataType mixed_precision,
    phi::Backend backend,
    bool keep_io_types,
    const std::unordered_set<std::string>& black_list) {
827 828 829 830 831 832 833 834 835
  ConvertToMixedPrecisionPass pass(model_file,
                                   params_file,
                                   mixed_model_file,
                                   mixed_params_file,
                                   mixed_precision,
                                   backend,
                                   keep_io_types,
                                   black_list);
  pass.Run();
836 837 838 839 840
}

}  // namespace analysis
}  // namespace inference
}  // namespace paddle