From a00aebe1ab38117a5cf4c20a4e0e53a5073009e6 Mon Sep 17 00:00:00 2001 From: Wilber Date: Tue, 15 Nov 2022 13:55:17 +0800 Subject: [PATCH] [convert_to_mixed_precision] fallback to fp32 when encounter circle (#47902) --- .../passes/convert_to_mixed_precision.cc | 343 +++++++----------- 1 file changed, 127 insertions(+), 216 deletions(-) diff --git a/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc b/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc index e9b188d78f..a37cfda021 100644 --- a/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc +++ b/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc @@ -40,7 +40,6 @@ #include "paddle/phi/common/float16.h" #include "paddle/phi/common/layout.h" #include "paddle/phi/common/place.h" -#include "paddle/phi/core/tensor_meta.h" namespace paddle { namespace inference { @@ -111,12 +110,10 @@ class ConvertToMixedPrecisionPass { black_list_(black_list), place_(paddle::CPUPlace()), executor_(place_) { - // black_list_.insert("assign"); - black_list_.insert("fill_constant"); - black_list_.insert("assign_value"); - black_list_.insert("eye"); - black_list_.insert("fill_any_like"); - black_list_.insert("fill_constant_batch_size_like"); + VLOG(4) << "black_list has "; + for (auto& name : black_list_) { + VLOG(4) << " - " << name; + } } void Run(); @@ -145,18 +142,11 @@ class ConvertToMixedPrecisionPass { // Just process special cases for weights conversion. bool WeightsShouldNotConvert(framework::ir::Node* var_node); - // To support multi block, we need to consider a lot of special cases. // Return Node* which first appers in block. - framework::ir::Node* GetRealVarNode(BlockID block_idx, - framework::ir::Node* node); - void FindVarsInMultiBlock(); - inline bool VarIsMultiPrecisionOpsOut(BlockID block_idx, - framework::ir::Node* op_node); + framework::ir::Node* GetRealVarNode(framework::ir::Node* node); - private: - // A trick. Patch for strange op, which input name equal to output name, such - // as `fused_multi_transformer` - void PatchForStrangeOp(); + // Fallback to fp32 dtype when encounter circle (Not a DAG graph). + void ProcessCircleCases(); private: std::string model_file_; @@ -171,35 +161,21 @@ class ConvertToMixedPrecisionPass { framework::Executor executor_; framework::Scope scope_; + std::unordered_map name2node_; std::unordered_map cast_map_; - std::unordered_map> - vars_in_multi_block_with_pair_; - std::unordered_map> - vars_in_multi_block_with_ops_; int suffix_{0}; + std::set var_names_in_circles_; + std::unique_ptr program_desc_{nullptr}; std::unique_ptr main_graph_{nullptr}; std::vector graphes_; }; framework::ir::Node* ConvertToMixedPrecisionPass::GetRealVarNode( - BlockID block_idx, framework::ir::Node* var_node) { + framework::ir::Node* var_node) { CHECK_EQ(var_node->IsVar(), true); - - if (vars_in_multi_block_with_pair_.count(var_node->Name())) { - auto origin_blockId = - vars_in_multi_block_with_pair_.at(var_node->Name()).second; - if (block_idx != origin_blockId) { - auto* graph = graphes_[origin_blockId]; - for (auto* node : graph->Nodes()) { - if (node->Name() == var_node->Name()) { - return node; - } - } - } - } - + if (name2node_.count(var_node->Name())) return name2node_[var_node->Name()]; return var_node; } @@ -212,32 +188,6 @@ inline bool ConvertToMixedPrecisionPass::VarNodeHasDtype( (type == VarType::VOCAB); } -// op1(fp32) -> var1, op2(fp16) -> var1 -// if and only if op1 and op2 both support fp16, we convert op1 and op2's -// precision. -inline bool ConvertToMixedPrecisionPass::VarIsMultiPrecisionOpsOut( - BlockID block_idx, framework::ir::Node* op_node) { - CHECK_EQ(op_node->IsOp(), true); - - for (auto* var_node : op_node->outputs) { - if (!var_node->IsVar()) continue; - auto* real_var_node = GetRealVarNode(block_idx, var_node); - if (!real_var_node->Var()->Persistable() && - vars_in_multi_block_with_ops_.count(var_node->Name())) { - for (const auto& op_type : - vars_in_multi_block_with_ops_.at(var_node->Name())) { - if (!OpSupportPrecision( - op_type, backend_, mixed_precision_, black_list_)) { - VLOG(2) << var_node->Name() - << " is multi precision op's out, so we skip convert to fp16"; - return true; - } - } - } - } - return false; -} - void ConvertToMixedPrecisionPass::ProcessInputNode( bool support_precision, framework::ir::Node* in_node, @@ -247,18 +197,13 @@ void ConvertToMixedPrecisionPass::ProcessInputNode( VarType::Type to_type, BlockID block_idx) { if (!in_node->IsVar()) return; - auto* real_node = GetRealVarNode(block_idx, in_node); + auto* real_node = GetRealVarNode(in_node); if (!VarNodeHasDtype(real_node)) return; auto* graph = graphes_[block_idx]; - bool is_main_block = block_idx == 0; auto* in_var = real_node->Var(); auto in_var_type = in_var->GetDataType(); auto prev_type = in_var_type; - bool is_in_multi_block = vars_in_multi_block_with_pair_.count(in_var->Name()); - if (!is_main_block && is_in_multi_block) { - in_var_type = vars_in_multi_block_with_pair_.at(in_var->Name()).first; - } if (support_precision) { if (in_var->Persistable() && in_var_type == VarType::FP32) { if (WeightsShouldNotConvert(in_node)) return; @@ -299,7 +244,7 @@ void ConvertToMixedPrecisionPass::ProcessInputNode( void ConvertToMixedPrecisionPass::ProcessOutputNode( BlockID block_idx, framework::ir::Node* var_node, VarType::Type to_type) { if (!var_node->IsVar()) return; - auto* real_node = GetRealVarNode(block_idx, var_node); + auto* real_node = GetRealVarNode(var_node); if (!VarNodeHasDtype(real_node)) return; auto* out_var = real_node->Var(); auto prev_type = out_var->GetDataType(); @@ -400,9 +345,17 @@ void ConvertToMixedPrecisionPass::LoadAndPrepare() { inference::Load(&executor_, &scope_, model_file_, params_file_); main_graph_ = std::unique_ptr( new framework::ir::Graph(*program_desc_)); + for (size_t i = 0; i < main_graph_->SubGraphsSize(); ++i) { auto* graph = main_graph_->GetSubGraph(i); graphes_.push_back(graph); + + for (auto* node : graph->Nodes()) { + if (!node->IsVar()) continue; + if (!name2node_.count(node->Name())) { + name2node_[node->Name()] = node; + } + } } // Remove all control var @@ -411,46 +364,68 @@ void ConvertToMixedPrecisionPass::LoadAndPrepare() { arg.SetMainGraphNotOwned(main_graph_.get()); pass.Run(&arg); - FindVarsInMultiBlock(); + ProcessCircleCases(); } -void ConvertToMixedPrecisionPass::FindVarsInMultiBlock() { - std::unordered_set all_var_names_set; - std::vector> block_var_names_set(program_desc_->Size()); - for (BlockID idx = 0; idx < program_desc_->Size(); ++idx) { +// Find var names which in circles. +void ConvertToMixedPrecisionPass::ProcessCircleCases() { + std::vector vars_in_circles; + for (size_t idx = 0; idx < program_desc_->Size(); ++idx) { for (auto* op : program_desc_->Block(idx).AllOps()) { + // TODO(inference): batch_norm has circle, but we need to fuse it in conv + // op. + if (op->Type() == "batch_norm") continue; const auto& in_names = op->InputArgumentNames(); - block_var_names_set[idx].insert(in_names.begin(), in_names.end()); const auto& out_names = op->OutputArgumentNames(); - block_var_names_set[idx].insert(out_names.begin(), out_names.end()); - - if (op->HasAttr("sub_block") == false) { - for (const auto& name : out_names) { - if (all_var_names_set.count(name)) { - vars_in_multi_block_with_ops_[name].push_back(op->Type()); - } - } - } - all_var_names_set.insert(block_var_names_set[idx].begin(), - block_var_names_set[idx].end()); + std::set in_names_set(in_names.begin(), in_names.end()); + std::set out_names_set(out_names.begin(), out_names.end()); + std::set_intersection(in_names_set.begin(), + in_names_set.end(), + out_names_set.begin(), + out_names_set.end(), + std::back_inserter(vars_in_circles)); } } - CHECK_GT(program_desc_->Size(), 0U); - for (BlockID idx = 0; idx < program_desc_->Size() - 1; ++idx) { - for (BlockID jdx = idx + 1; jdx < program_desc_->Size(); ++jdx) { - std::vector vars_in_multi_block; - std::set_intersection(block_var_names_set[idx].begin(), - block_var_names_set[idx].end(), - block_var_names_set[jdx].begin(), - block_var_names_set[jdx].end(), - std::back_inserter(vars_in_multi_block)); - - for (const auto& name : vars_in_multi_block) { - vars_in_multi_block_with_pair_.emplace( - name, std::make_pair(VarType::Type(), idx)); - } - } + for (auto& name : vars_in_circles) { + var_names_in_circles_.insert(name); + } + for (auto& name : var_names_in_circles_) { + LOG(INFO) << name + << " in circles, so we will skip process those vars and ops."; + } +} + +inline void ProcessConstantOpAttr(framework::ir::Node* op_node, + VarType::Type from_type, + VarType::Type to_type) { + if (!op_node->IsOp()) return; + auto op_type = op_node->Op()->Type(); + if (op_type == "feed" || op_type == "fetch") return; + + if (op_type == "fill_constant") { + if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) == + static_cast(from_type)) + op_node->Op()->SetAttr("dtype", static_cast(to_type)); + } else if (op_type == "assign_value") { + if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) == + static_cast(from_type)) + op_node->Op()->SetAttr("dtype", static_cast(to_type)); + } else if (op_type == "eye") { + if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) == + static_cast(from_type)) + op_node->Op()->SetAttr("dtype", static_cast(to_type)); + } else if (op_type == "fill_any_like") { + if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) == + static_cast(from_type)) + op_node->Op()->SetAttr("dtype", static_cast(to_type)); + } else if (op_type == "cast") { + if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("in_dtype")) == + static_cast(from_type)) + op_node->Op()->SetAttr("in_dtype", static_cast(to_type)); + if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("out_dtype")) == + static_cast(from_type)) + op_node->Op()->SetAttr("out_dtype", static_cast(to_type)); } } @@ -460,33 +435,7 @@ void ConvertToMixedPrecisionPass::ConvertAllFp64ToFp32( for (auto* op_node : op_nodes) { if (!op_node->IsOp()) continue; auto op_type = op_node->Op()->Type(); - if (op_type == "feed" || op_type == "fetch") continue; - - if (op_type == "fill_constant") { - if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) == - static_cast(VarType::FP64)) - op_node->Op()->SetAttr("dtype", static_cast(VarType::FP32)); - } else if (op_type == "assign_value") { - if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) == - static_cast(VarType::FP64)) - op_node->Op()->SetAttr("dtype", static_cast(VarType::FP32)); - } else if (op_type == "eye") { - if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) == - static_cast(VarType::FP64)) - op_node->Op()->SetAttr("dtype", static_cast(VarType::FP32)); - } else if (op_type == "fill_any_like") { - if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) == - static_cast(VarType::FP64)) - op_node->Op()->SetAttr("dtype", static_cast(VarType::FP32)); - } else if (op_type == "cast") { - if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("in_dtype")) == - static_cast(VarType::FP64)) - op_node->Op()->SetAttr("in_dtype", static_cast(VarType::FP32)); - if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("out_dtype")) == - static_cast(VarType::FP64)) - op_node->Op()->SetAttr("out_dtype", static_cast(VarType::FP32)); - } - + ProcessConstantOpAttr(op_node, VarType::FP64, VarType::FP32); auto inputs = op_node->inputs; for (auto* in_node : inputs) { auto* in_var = in_node->Var(); @@ -509,9 +458,6 @@ void ConvertToMixedPrecisionPass::Run() { ConvertTensorDtype(i); FixCastAttr(graph); - // A trick - PatchForStrangeOp(); - CHECK_EQ(framework::ir::VarDescIsConsistency(*graph), true); } @@ -556,28 +502,9 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) { continue; } + // We can not add cast operator before ops who have sub_block, as in + // sub_block we may get a var which may be transformer by cast op. else if (op_node->Op()->HasAttr("sub_block")) { // NOLINT - // sub_block op's output dtype should be same as input dtype, if have the - // same name. - std::unordered_map in_name_to_node; - for (auto* in : op_node->inputs) { - if (!in->IsVar()) continue; - auto* real_node = GetRealVarNode(block_idx, in); - if (VarNodeHasDtype(real_node)) { - in_name_to_node[in->Name()] = in; - } - } - - for (auto* out : op_node->outputs) { - if (!out->IsVar()) continue; - auto* real_node = GetRealVarNode(block_idx, out); - if (VarNodeHasDtype(real_node)) { - if (in_name_to_node.count(out->Name())) - real_node->Var()->SetDataType( - in_name_to_node[out->Name()]->Var()->GetDataType()); - } - } - continue; } @@ -585,65 +512,75 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) { // - cast weight to fp16/bf16. // - add cast op if the input dtype is not fp16/bf16. // - set output dtype. - // - // If a var(op's out var) appears multiple times in graph, we should not - // convert to fp16. - else if (black_list_.count(op_type) == 0 && // NOLINT - !VarIsMultiPrecisionOpsOut(block_idx, op_node)) { + else if (black_list_.count(op_type) == 0) { // NOLINT bool support_precision = OpSupportPrecision(op_type, backend_, mixed_precision_, black_list_); - // If the op has no input of float type, we will not choose the + // If op's output in circle, we should not convert to fp16. + for (auto* out_node : op_node->outputs) { + if (var_names_in_circles_.count(out_node->Name())) { + support_precision = false; + VLOG(2) << " op's output " << out_node->Name() + << " is in circle, we can not support this case, just skip."; + break; + } + } + + // If the op has no input or output of float type, we will not choose the // low precision kernel. - { - bool has_float_input{false}; + if (support_precision) { + bool has_float_in_out{false}; for (auto* in_node : op_node->inputs) { if (!in_node->IsVar()) continue; - auto* real_node = GetRealVarNode(block_idx, in_node); + if (in_node->Var()->GetType() != VarType::LOD_TENSOR) { + support_precision = false; + VLOG(2) << " op has tensor array input[" << in_node->Name() + << "], just skip."; + break; + } + auto* real_node = GetRealVarNode(in_node); + if (real_node->Var()->GetDataType() == VarType::FP16 || + real_node->Var()->GetDataType() == VarType::FP32 || + real_node->Var()->GetDataType() == VarType::FP64 || + real_node->Var()->GetDataType() == VarType::BF16) { + has_float_in_out = true; + break; + } + } + for (auto* out_node : op_node->outputs) { + if (!out_node->IsVar()) continue; + auto* real_node = GetRealVarNode(out_node); if (real_node->Var()->GetDataType() == VarType::FP16 || real_node->Var()->GetDataType() == VarType::FP32 || real_node->Var()->GetDataType() == VarType::FP64 || real_node->Var()->GetDataType() == VarType::BF16) { - has_float_input = true; + has_float_in_out = true; break; } } - if (!has_float_input) { + if (!has_float_in_out) { support_precision = false; - VLOG(2) << " op doesn't has float input, just skip."; + VLOG(2) << " op doesn't has float input and output, just skip."; } } + VLOG(2) << "op type: " << op_type << " support low precision: " << support_precision; if (support_precision) { + ProcessConstantOpAttr(op_node, VarType::FP32, to_type); VLOG(2) << " process input nodes:"; ++num_low_precision; auto inputs = op_node->inputs; - - // Just for paddle's terriable case: op's input and output has the same - // name. - std::unordered_map names_map; - for (auto* out_node : op_node->outputs) { - for (auto* in_node : op_node->inputs) { - if (out_node->Name() == in_node->Name()) { - names_map[out_node->Name()] = in_node->Name(); - } - } - } - - // Process inputs. for (auto* in_node : inputs) { ProcessInputNode( true, in_node, op_node, &suffix_, block_desc, to_type, block_idx); - if (names_map.count(in_node->Name()) && cast_map_.count(in_node)) { - names_map[in_node->Name()] = cast_map_[in_node]->Name(); - } } + VLOG(2) << " process output nodes:"; - // Process outputs. - for (auto* out_node : op_node->outputs) { + auto outputs = op_node->outputs; + for (auto* out_node : outputs) { ProcessOutputNode(block_idx, out_node, to_type); } } else { @@ -663,8 +600,10 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) { // 3. check op not support fp16/bf16 or in blacklist. // - add cast op if the input dtype is not fp32. else { // NOLINT - VLOG(3) << "not to run fp16 op_type: " << op_type; - for (auto* in_node : op_node->inputs) { + VLOG(3) << "not to run fp16 op_type: " << op_type << ", node input size " + << op_node->inputs.size(); + auto in_nodes = op_node->inputs; + for (auto* in_node : in_nodes) { auto* in_var = in_node->Var(); if (in_var->GetDataType() == to_type) { AddCastOp(graph, @@ -716,21 +655,6 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) { } } - for (auto* node : graph->Nodes()) { - if (!node->IsVar()) continue; - auto* real_node = GetRealVarNode(block_idx, node); - if (!VarNodeHasDtype(real_node)) continue; - - if (vars_in_multi_block_with_pair_.count(real_node->Name()) && - vars_in_multi_block_with_pair_.at(real_node->Name()).second == - block_idx && - vars_in_multi_block_with_pair_.at(real_node->Name()).first == - VarType::Type()) { - vars_in_multi_block_with_pair_.at(real_node->Name()).first = - real_node->Var()->GetDataType(); - } - } - if (num_low_precision) LOG(INFO) << "--- detected " << num_low_precision << " low precision ops in " << block_idx << " subgraph"; @@ -738,6 +662,7 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) { // We modify op's input output precision, and we need to fix cast op in_dtype // and out_dtype attribute. +// TODO(inference): we need a cast elimination pass. void ConvertToMixedPrecisionPass::FixCastAttr(framework::ir::Graph* graph) { auto op_nodes = framework::ir::TopologySortOperations(*graph); for (auto* op_node : op_nodes) { @@ -766,7 +691,8 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() { if (VarNodeHasDtype(node)) { if (node->Var()->Persistable() && node->Var()->GetDataType() == VarType::FP32) { - VLOG(2) << "weights keep to fp32: " << node->Name(); + VLOG(2) << "weights keep to fp32: " << node->Name() << ", ptr " + << reinterpret_cast(node->Var()); weights_should_be_fp32.insert(node->Name()); } } @@ -808,7 +734,6 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() { std::ostringstream os; phi::CPUContext ctx; for (const auto& param : parameters) { - VLOG(3) << "Serialize param: " << param; PADDLE_ENFORCE_NOT_NULL( scope_.FindVar(param), platform::errors::NotFound( @@ -829,21 +754,6 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() { mixed_program_desc.Proto()->SerializeAsString()); StrToBinary(mixed_params_file_, SerializeParams()); } - -void ConvertToMixedPrecisionPass::PatchForStrangeOp() { - for (auto* graph : graphes_) { - for (auto op_node : framework::ir::TopologySortOperations(*graph)) { - if (op_node->Name() == "fused_multi_transformer") { - auto cache_kv_inputs = op_node->Op()->Input("CacheKV"); - auto cache_kv_outputs = op_node->Op()->Output("CacheKVOut"); - CHECK_EQ(cache_kv_inputs.size(), cache_kv_outputs.size()); - for (size_t i = 0; i < cache_kv_inputs.size(); ++i) { - op_node->Op()->RenameOutput(cache_kv_outputs[i], cache_kv_inputs[i]); - } - } - } - } -} } // namespace void AddCastOp( @@ -893,6 +803,7 @@ void AddCastOp( } next_op->Op()->Rename(node->Name(), map->at(node)->Name()); IR_NODE_LINK_TO(node, map->at(node)->inputs[0]); + IR_NODE_UNLINK(node, next_op); IR_NODE_LINK_TO(map->at(node), next_op); } -- GitLab