From d373f4ff7bf4b4b439e7c4facf3d28186698bf91 Mon Sep 17 00:00:00 2001 From: Wilber Date: Thu, 21 Jul 2022 14:15:55 +0800 Subject: [PATCH] fix some convert error found in tipc. (#44457) * fix some error found in tipc. * update --- .../passes/convert_to_mixed_precision.cc | 190 ++++++++++++++++-- 1 file changed, 177 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc b/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc index 48d2cefe4a..ca8ab8aa71 100644 --- a/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc +++ b/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h" +#include #include #include "paddle/fluid/framework/block_desc.h" @@ -22,6 +23,7 @@ #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/inference/io.h" @@ -89,6 +91,31 @@ bool OutShouldNotConvert(ir::Node* var_node) { return false; } +// Get weight names which appear in multiple block (block 0 and block n). +std::unordered_set GetMultiBlockPersistableNames( + framework::ProgramDesc* program_desc) { + std::unordered_set special_weights; + size_t block_size = program_desc->Size(); + + std::unordered_set block_0_weights; + for (auto var : program_desc->Block(0).AllVars()) { + if (var->Persistable()) block_0_weights.insert(var->Name()); + } + + for (size_t i = 1; i < block_size; ++i) { + // std::cout << program_desc->MutableBlock(i)->Proto()->DebugString() << + // std::endl;; + auto all_ops = program_desc->Block(i).AllOps(); + for (auto op : all_ops) { + for (auto name : op->InputArgumentNames()) { + if (block_0_weights.count(name)) special_weights.insert(name); + } + } + } + + return special_weights; +} + // Just process special cases for weights conversion. bool WeightsShouldNotConvert(ir::Node* var_node) { auto op_nodes = var_node->outputs; @@ -116,19 +143,139 @@ bool WeightsShouldNotConvert(ir::Node* var_node) { } } + // If cur_op's next is condition_flow op, then cur op should be fp32. Note, we + // now only convert to mixed in block 0. + for (auto* op_node : op_nodes) { + for (auto var : op_node->outputs) { + for (auto next_op : var->outputs) { + if (next_op->Op()->HasAttr("sub_block")) { + return true; + } + } + } + } + return false; } inline bool IsFloatVarType(framework::proto::VarType::Type type) { if (type == framework::proto::VarType::FP16 || type == framework::proto::VarType::FP32 || - type == framework::proto::VarType::BF16 || - type == framework::proto::VarType::FP64) + type == framework::proto::VarType::BF16) return true; return false; } -void ConvertTensorDtype(framework::ir::Graph* graph, +void ConvertAllFp64ToFp32(framework::ir::Graph* graph) { + auto op_nodes = framework::ir::TopologySortOperations(*graph); + for (auto* op_node : op_nodes) { + if (!op_node->IsOp()) continue; + auto op_type = op_node->Op()->Type(); + if (op_type == "feed" || op_type == "fetch") continue; + + if (op_type == "fill_constant") { + if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) == + static_cast(framework::proto::VarType::FP64)) + op_node->Op()->SetAttr( + "dtype", static_cast(framework::proto::VarType::FP32)); + } else if (op_type == "assign_value") { + if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) == + static_cast(framework::proto::VarType::FP64)) + op_node->Op()->SetAttr( + "dtype", static_cast(framework::proto::VarType::FP32)); + } else if (op_type == "eye") { + if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) == + static_cast(framework::proto::VarType::FP64)) + op_node->Op()->SetAttr( + "dtype", static_cast(framework::proto::VarType::FP32)); + } else if (op_type == "fill_any_like") { + if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) == + static_cast(framework::proto::VarType::FP64)) + op_node->Op()->SetAttr( + "dtype", static_cast(framework::proto::VarType::FP32)); + } else if (op_type == "cast") { + if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("in_dtype")) == + static_cast(framework::proto::VarType::FP64)) + op_node->Op()->SetAttr( + "in_dtype", static_cast(framework::proto::VarType::FP32)); + if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("out_dtype")) == + static_cast(framework::proto::VarType::FP64)) + op_node->Op()->SetAttr( + "out_dtype", static_cast(framework::proto::VarType::FP32)); + } + + auto inputs = op_node->inputs; + for (auto* in_node : inputs) { + if (in_node->IsCtrlVar()) continue; + auto* in_var = in_node->Var(); + if (!in_var->Persistable() && + in_var->GetDataType() == framework::proto::VarType::FP64) { + in_var->SetDataType(framework::proto::VarType::FP32); + } + } + } +} + +// Handle special ops which contains dtype attribute. e.g., fill_constant, +// assign_value. +void HandleSpecialOps(framework::OpDesc* op_desc) { + if (op_desc->Type() == "fill_constant") { + if (PADDLE_GET_CONST(int, op_desc->GetAttr("dtype")) == + static_cast(framework::proto::VarType::FP32)) + op_desc->SetAttr("dtype", + static_cast(framework::proto::VarType::FP16)); + } else if (op_desc->Type() == "assign_value") { + if (PADDLE_GET_CONST(int, op_desc->GetAttr("dtype")) == + static_cast(framework::proto::VarType::FP32)) + op_desc->SetAttr("dtype", + static_cast(framework::proto::VarType::FP16)); + } else if (op_desc->Type() == "eye") { + if (PADDLE_GET_CONST(int, op_desc->GetAttr("dtype")) == + static_cast(framework::proto::VarType::FP32)) + op_desc->SetAttr("dtype", + static_cast(framework::proto::VarType::FP16)); + } else if (op_desc->Type() == "fill_any_like") { + if (PADDLE_GET_CONST(int, op_desc->GetAttr("dtype")) == + static_cast(framework::proto::VarType::FP32)) + op_desc->SetAttr("dtype", + static_cast(framework::proto::VarType::FP16)); + } +} + +// We modify op's input output precision, and we need to fix cast op in_dtype +// and out_dtype attribute. +void FixCastAttr(framework::ir::Graph* graph) { + auto op_nodes = framework::ir::TopologySortOperations(*graph); + for (auto* op_node : op_nodes) { + if (!op_node->IsOp()) continue; + auto op_type = op_node->Op()->Type(); + if (op_type != "cast") continue; + + auto input = op_node->inputs[0]; + auto output = op_node->outputs[0]; + op_node->Op()->SetAttr("in_dtype", + static_cast(input->Var()->GetDataType())); + op_node->Op()->SetAttr("out_dtype", + static_cast(output->Var()->GetDataType())); + } +} + +// If op's output var is condition flow op's input, then the op must be fp32 +// precision. +bool NextOpIncludesConditionFlowOp(framework::ir::Node* cur_op_node) { + auto cur_op_outs = cur_op_node->outputs; + for (auto out_var : cur_op_outs) { + for (auto next_op_node : out_var->outputs) { + if (next_op_node->Op()->HasAttr("sub_block")) { + return true; + } + } + } + return false; +} + +void ConvertTensorDtype(framework::ProgramDesc* program_desc, + framework::ir::Graph* graph, const std::unordered_set& blacklist, bool keep_io_types, phi::Backend backend, @@ -145,13 +292,14 @@ void ConvertTensorDtype(framework::ir::Graph* graph, static_cast(tensor_dtype))); } + auto weight_name_in_multi_block = GetMultiBlockPersistableNames(program_desc); int num_low_precision = 0; int suffix = 0; framework::BlockDesc* block_desc{nullptr}; std::vector output_nodes; std::unordered_map cast_map; - - for (auto* op_node : framework::ir::TopologySortOperations(*graph)) { + auto op_nodes = framework::ir::TopologySortOperations(*graph); + for (auto* op_node : op_nodes) { if (!op_node->IsOp()) continue; auto op_type = op_node->Op()->Type(); auto phi_op_type = phi::TransToPhiKernelName(op_type); @@ -167,18 +315,29 @@ void ConvertTensorDtype(framework::ir::Graph* graph, auto* fetch_var = op_node->inputs[0]; output_nodes.push_back(fetch_var); continue; + } else if (op_type == "cast") { + continue; } // 2. if op support fp16/bf16 and not in blacklist. // - cast weight to fp16/bf16. // - add cast op if the input dtype is not fp16/bf16. // - set output dtype. - else if (blacklist.count(phi_op_type) == 0) { // NOLINT + else if (blacklist.count(phi_op_type) == 0 && // NOLINT + !NextOpIncludesConditionFlowOp(op_node)) { bool support_precision = OpSupportPrecision(phi_op_type, backend, tensor_dtype, blacklist); - VLOG(2) << "phi_op_type " << phi_op_type << " support low precision " - << support_precision; + VLOG(2) << "op_type " << op_type << ", phi_op_type " << phi_op_type + << " support low precision " << support_precision << ", " + << reinterpret_cast(op_node->Op()->Block()); + + for (auto in_node : op_node->inputs) { + if (weight_name_in_multi_block.count(in_node->Name())) + support_precision = false; + } + if (support_precision) { + HandleSpecialOps(op_node->Op()); ++num_low_precision; auto inputs = op_node->inputs; for (auto* in_node : inputs) { @@ -232,9 +391,8 @@ void ConvertTensorDtype(framework::ir::Graph* graph, // 3. check op not support fp16/bf16 or in blacklist. // - add cast op if the input dtype is not fp32. else { // NOLINT - // trt pass should explicitle add cast op is input is bf16/tf32, etc. - if (op_node->Name() == "tensorrt_engine") continue; - for (auto* in_node : op_node->inputs) { + auto ins = op_node->inputs; + for (auto* in_node : ins) { if (in_node->IsCtrlVar()) continue; auto* in_var = in_node->Var(); if (in_var->GetDataType() == to_type) { @@ -366,8 +524,14 @@ void ConvertToMixedPrecision(const std::string& model_file, auto graph = std::unique_ptr( new framework::ir::Graph(*program_desc)); - ConvertTensorDtype( - graph.get(), black_list, keep_io_types, backend, mixed_precision); + ConvertAllFp64ToFp32(graph.get()); + ConvertTensorDtype(program_desc.get(), + graph.get(), + black_list, + keep_io_types, + backend, + mixed_precision); + FixCastAttr(graph.get()); framework::ProgramDesc mixed_program_desc; framework::ir::GraphToProgram(*graph, &mixed_program_desc); -- GitLab