From 0972d6ac78e8e7696c256da8dc961b1d7ed8fe93 Mon Sep 17 00:00:00 2001 From: Yuanle Liu Date: Thu, 27 Oct 2022 20:56:40 +0800 Subject: [PATCH] [Paddle Inference] improve convert_to_mixed_precision (#47333) --- .../passes/convert_to_mixed_precision.cc | 440 +++++++++--------- .../passes/convert_to_mixed_precision.h | 4 +- 2 files changed, 223 insertions(+), 221 deletions(-) diff --git a/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc b/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc index 219ce6d17e..9d0e6ecf49 100644 --- a/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc +++ b/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc @@ -42,13 +42,13 @@ #include "paddle/phi/common/place.h" #include "paddle/phi/core/tensor_meta.h" -using namespace paddle::framework; // NOLINT - namespace paddle { namespace inference { namespace analysis { namespace { +using VarType = framework::proto::VarType; + bool PhiKernelSupportPrecision( const std::string& op_type, phi::Backend backend, @@ -73,13 +73,14 @@ bool GpuKernelSupportPrecision( phi_op_type, phi::Backend::GPUDNN, data_type, layout); if (!res) { - auto& all_kernels = OperatorWithKernel::AllOpKernels(); + auto& all_kernels = framework::OperatorWithKernel::AllOpKernels(); auto it = all_kernels.find(op_type); if (it != all_kernels.end()) { for (auto& kern_pair : it->second) { if (platform::is_gpu_place(kern_pair.first.place_) && - kern_pair.first.data_type_ == framework::proto::VarType::FP16) { + kern_pair.first.data_type_ == VarType::FP16) { res = true; + break; } } } @@ -88,6 +89,8 @@ bool GpuKernelSupportPrecision( } class ConvertToMixedPrecisionPass { + using BlockID = size_t; + public: explicit ConvertToMixedPrecisionPass( const std::string& model_file, @@ -97,7 +100,7 @@ class ConvertToMixedPrecisionPass { phi::DataType mixed_precision, phi::Backend backend, bool keep_io_types, - std::unordered_set black_list) + const std::unordered_set& black_list) : model_file_(model_file), params_file_(params_file), mixed_model_file_(mixed_model_file), @@ -107,45 +110,40 @@ class ConvertToMixedPrecisionPass { keep_io_types_(keep_io_types), black_list_(black_list), place_(paddle::CPUPlace()), - executor_(place_) { - black_list_.insert("assign"); - black_list_.insert("fill_constant"); - black_list_.insert("assign_value"); - black_list_.insert("eye"); - black_list_.insert("fill_any_like"); - black_list_.insert("fill_constant_batch_size_like"); - } + executor_(place_) {} + void Run(); private: void LoadAndPrepare(); - inline bool NodeVarHasDtype(framework::ir::Node* node); + inline bool VarNodeHasDtype(framework::ir::Node* node); void ConvertAllFp64ToFp32(framework::ir::Graph* graph); void FixCastAttr(framework::ir::Graph* graph); void SaveMixedModel(); - void ConvertTensorDtype(int block_idx); + void ConvertTensorDtype(BlockID block_idx); void ProcessInputNode(bool support_precision, - ir::Node* in_node, - ir::Node* op_node, + framework::ir::Node* in_node, + framework::ir::Node* op_node, int* suffix, framework::BlockDesc* block_desc, - framework::proto::VarType::Type to_type, - int block_idx); + VarType::Type to_type, + BlockID block_idx); - void ProcessOutputNode(int block_idx, - ir::Node* var_node, - framework::proto::VarType::Type to_type); - inline bool IsFloatVarType(framework::proto::VarType::Type type); + void ProcessOutputNode(BlockID block_idx, + framework::ir::Node* var_node, + VarType::Type to_type); + inline bool IsFloatVarType(VarType::Type type); - bool OutShouldNotConvert(ir::Node* var_node); + bool OutShouldNotConvert(framework::ir::Node* var_node); // Just process special cases for weights conversion. - bool WeightsShouldNotConvert(ir::Node* var_node); + bool WeightsShouldNotConvert(framework::ir::Node* var_node); // To support multi block, we need to consider a lot of special cases. // Return Node* which first appers in block. - framework::ir::Node* GetRealNode(int block_idx, framework::ir::Node* node); + framework::ir::Node* GetRealVarNode(BlockID block_idx, + framework::ir::Node* node); void FindVarsInMultiBlock(); - inline bool VarIsMultiPrecisionOpsOut(int block_idx, + inline bool VarIsMultiPrecisionOpsOut(BlockID block_idx, framework::ir::Node* op_node); private: @@ -167,11 +165,10 @@ class ConvertToMixedPrecisionPass { framework::Scope scope_; std::unordered_map cast_map_; - std::unordered_map> - vars_in_multi_block_map_; - std::vector>> - vars_appear_multi_in_one_block_; + std::unordered_map> + vars_in_multi_block_with_pair_; + std::unordered_map> + vars_in_multi_block_with_ops_; int suffix_{0}; std::unique_ptr program_desc_{nullptr}; @@ -179,91 +176,84 @@ class ConvertToMixedPrecisionPass { std::vector graphes_; }; -framework::ir::Node* ConvertToMixedPrecisionPass::GetRealNode( - int block_idx, framework::ir::Node* node) { - if (vars_in_multi_block_map_.count(node->Name())) { - int var_origin_block_id = vars_in_multi_block_map_.at(node->Name()).second; - if (block_idx != var_origin_block_id) { - auto graph = graphes_[var_origin_block_id]; - for (auto nd : graph->Nodes()) { - if (nd->Name() == node->Name()) { - return nd; +framework::ir::Node* ConvertToMixedPrecisionPass::GetRealVarNode( + BlockID block_idx, framework::ir::Node* var_node) { + CHECK_EQ(var_node->IsVar(), true); + + if (vars_in_multi_block_with_pair_.count(var_node->Name())) { + auto origin_blockId = + vars_in_multi_block_with_pair_.at(var_node->Name()).second; + if (block_idx != origin_blockId) { + auto* graph = graphes_[origin_blockId]; + for (auto* node : graph->Nodes()) { + if (node->Name() == var_node->Name()) { + return node; } } } } - return node; + return var_node; } -inline bool ConvertToMixedPrecisionPass::NodeVarHasDtype( - framework::ir::Node* node) { - if (node->IsVar() && - (node->Var()->GetType() == - paddle::framework::proto::VarType::SELECTED_ROWS || - node->Var()->GetType() == - paddle::framework::proto::VarType::LOD_TENSOR || - node->Var()->GetType() == - paddle::framework::proto::VarType::LOD_TENSOR_ARRAY || - node->Var()->GetType() == paddle::framework::proto::VarType::STRINGS || - node->Var()->GetType() == paddle::framework::proto::VarType::VOCAB)) { - return true; - } - - return false; +inline bool ConvertToMixedPrecisionPass::VarNodeHasDtype( + framework::ir::Node* var_node) { + CHECK_EQ(var_node->IsVar(), true); + auto type = var_node->Var()->GetType(); + return (type == VarType::SELECTED_ROWS) || (type == VarType::LOD_TENSOR) || + (type == VarType::LOD_TENSOR_ARRAY) || (type == VarType::STRINGS) || + (type == VarType::VOCAB); } // op1(fp32) -> var1, op2(fp16) -> var1 // if and only if op1 and op2 both support fp16, we convert op1 and op2's // precision. inline bool ConvertToMixedPrecisionPass::VarIsMultiPrecisionOpsOut( - int block_idx, framework::ir::Node* op_node) { + BlockID block_idx, framework::ir::Node* op_node) { CHECK_EQ(op_node->IsOp(), true); - bool ret{false}; - - for (auto* out : op_node->outputs) { - auto* real_node = GetRealNode(block_idx, out); - if (!real_node->Var()->Persistable() && - vars_appear_multi_in_one_block_[block_idx].count(out->Name())) { - for (auto op_type : - vars_appear_multi_in_one_block_[block_idx].at(out->Name())) { - if (OpSupportPrecision( + + for (auto* var_node : op_node->outputs) { + if (!var_node->IsVar()) continue; + auto* real_var_node = GetRealVarNode(block_idx, var_node); + if (!real_var_node->Var()->Persistable() && + vars_in_multi_block_with_ops_.count(var_node->Name())) { + for (const auto& op_type : + vars_in_multi_block_with_ops_.at(var_node->Name())) { + if (!OpSupportPrecision( op_type, backend_, mixed_precision_, black_list_)) { - ret = true; - VLOG(2) << out->Name() + VLOG(2) << var_node->Name() << " is multi precision op's out, so we skip convert to fp16"; - break; + return true; } } } - if (ret) break; } - return ret; + return false; } void ConvertToMixedPrecisionPass::ProcessInputNode( bool support_precision, - ir::Node* in_node, - ir::Node* op_node, + framework::ir::Node* in_node, + framework::ir::Node* op_node, int* suffix, framework::BlockDesc* block_desc, - framework::proto::VarType::Type to_type, - int block_idx) { - auto* real_node = GetRealNode(block_idx, in_node); - if (!NodeVarHasDtype(real_node)) return; - auto graph = graphes_[block_idx]; + VarType::Type to_type, + BlockID block_idx) { + if (!in_node->IsVar()) return; + auto* real_node = GetRealVarNode(block_idx, in_node); + if (!VarNodeHasDtype(real_node)) return; + auto* graph = graphes_[block_idx]; bool is_main_block = block_idx == 0; auto* in_var = real_node->Var(); auto in_var_type = in_var->GetDataType(); auto prev_type = in_var_type; - bool is_in_multi_block = vars_in_multi_block_map_.count(in_var->Name()); + bool is_in_multi_block = vars_in_multi_block_with_pair_.count(in_var->Name()); if (!is_main_block && is_in_multi_block) { - in_var_type = vars_in_multi_block_map_.at(in_var->Name()).first; + in_var_type = vars_in_multi_block_with_pair_.at(in_var->Name()).first; } if (support_precision) { - if (in_var->Persistable() && - in_var_type == framework::proto::VarType::FP32) { + if (in_var->Persistable() && in_var_type == VarType::FP32) { if (WeightsShouldNotConvert(in_node)) return; in_var->SetDataType(to_type); in_var_type = to_type; @@ -300,14 +290,13 @@ void ConvertToMixedPrecisionPass::ProcessInputNode( } void ConvertToMixedPrecisionPass::ProcessOutputNode( - int block_idx, - ir::Node* var_node, - framework::proto::VarType::Type to_type) { - auto* real_node = GetRealNode(block_idx, var_node); - if (!NodeVarHasDtype(real_node)) return; + BlockID block_idx, framework::ir::Node* var_node, VarType::Type to_type) { + if (!var_node->IsVar()) return; + auto* real_node = GetRealVarNode(block_idx, var_node); + if (!VarNodeHasDtype(real_node)) return; auto* out_var = real_node->Var(); auto prev_type = out_var->GetDataType(); - if (out_var->GetDataType() == framework::proto::VarType::FP32) { + if (out_var->GetDataType() == VarType::FP32) { if (OutShouldNotConvert(var_node)) return; out_var->SetDataType(to_type); } @@ -316,7 +305,8 @@ void ConvertToMixedPrecisionPass::ProcessOutputNode( } // Just process special cases. -bool ConvertToMixedPrecisionPass::OutShouldNotConvert(ir::Node* var_node) { +bool ConvertToMixedPrecisionPass::OutShouldNotConvert( + framework::ir::Node* var_node) { auto op_node = var_node->inputs[0]; auto* op_desc = op_node->Op(); @@ -343,7 +333,8 @@ bool ConvertToMixedPrecisionPass::OutShouldNotConvert(ir::Node* var_node) { return false; } -bool ConvertToMixedPrecisionPass::WeightsShouldNotConvert(ir::Node* var_node) { +bool ConvertToMixedPrecisionPass::WeightsShouldNotConvert( + framework::ir::Node* var_node) { auto op_nodes = var_node->outputs; for (auto* op_node : op_nodes) { auto* op_desc = op_node->Op(); @@ -391,13 +382,10 @@ bool ConvertToMixedPrecisionPass::WeightsShouldNotConvert(ir::Node* var_node) { return false; } -inline bool ConvertToMixedPrecisionPass::IsFloatVarType( - framework::proto::VarType::Type type) { - if (type == framework::proto::VarType::FP16 || - type == framework::proto::VarType::FP32 || - type == framework::proto::VarType::BF16) - return true; - return false; + +inline bool ConvertToMixedPrecisionPass::IsFloatVarType(VarType::Type type) { + return (type == VarType::FP16) || (type == VarType::FP32) || + (type == VarType::BF16); } void ConvertToMixedPrecisionPass::LoadAndPrepare() { @@ -405,6 +393,10 @@ void ConvertToMixedPrecisionPass::LoadAndPrepare() { inference::Load(&executor_, &scope_, model_file_, params_file_); main_graph_ = std::unique_ptr( new framework::ir::Graph(*program_desc_)); + for (size_t i = 0; i < main_graph_->SubGraphsSize(); ++i) { + auto* graph = main_graph_->GetSubGraph(i); + graphes_.push_back(graph); + } // Remove all control var IrInferCleanGraphPass pass; @@ -412,41 +404,45 @@ void ConvertToMixedPrecisionPass::LoadAndPrepare() { arg.SetMainGraphNotOwned(main_graph_.get()); pass.Run(&arg); - vars_appear_multi_in_one_block_.resize(program_desc_->Size()); FindVarsInMultiBlock(); } void ConvertToMixedPrecisionPass::FindVarsInMultiBlock() { - std::vector> block_var_names_set(program_desc_->Size()); - for (size_t i = 0; i < program_desc_->Size(); ++i) { - for (auto op : program_desc_->Block(i).AllOps()) { - auto in_names = op->InputArgumentNames(); - block_var_names_set[i].insert(in_names.begin(), in_names.end()); - auto out_names = op->OutputArgumentNames(); + std::unordered_set all_var_names_set; + std::vector> block_var_names_set( + program_desc_->Size()); + for (BlockID idx = 0; idx < program_desc_->Size(); ++idx) { + for (auto* op : program_desc_->Block(idx).AllOps()) { + const auto& in_names = op->InputArgumentNames(); + block_var_names_set[idx].insert(in_names.begin(), in_names.end()); + const auto& out_names = op->OutputArgumentNames(); + block_var_names_set[idx].insert(out_names.begin(), out_names.end()); + if (op->HasAttr("sub_block") == false) { - for (auto& n : out_names) { - if (block_var_names_set[i].count(n)) { - vars_appear_multi_in_one_block_[i][n].push_back(op->Type()); + for (const auto& name : out_names) { + if (all_var_names_set.count(name)) { + vars_in_multi_block_with_ops_[name].push_back(op->Type()); } } } - block_var_names_set[i].insert(out_names.begin(), out_names.end()); + all_var_names_set.insert(block_var_names_set[idx].begin(), + block_var_names_set[idx].end()); } } - for (size_t i = 0; i < program_desc_->Size() - 1; ++i) { - for (size_t j = i + 1; j < program_desc_->Size(); ++j) { - std::set vars_in_multi_block; - std::set_intersection( - block_var_names_set[i].begin(), - block_var_names_set[i].end(), - block_var_names_set[j].begin(), - block_var_names_set[j].end(), - std::inserter(vars_in_multi_block, vars_in_multi_block.begin())); - - for (auto name : vars_in_multi_block) { - vars_in_multi_block_map_.emplace( - name, std::make_pair(framework::proto::VarType::FP32, i)); + CHECK_GT(program_desc_->Size(), 0U); + for (BlockID idx = 0; idx < program_desc_->Size() - 1; ++idx) { + for (BlockID jdx = idx + 1; jdx < program_desc_->Size(); ++jdx) { + std::vector vars_in_multi_block; + std::set_intersection(block_var_names_set[idx].begin(), + block_var_names_set[idx].end(), + block_var_names_set[jdx].begin(), + block_var_names_set[jdx].end(), + std::back_inserter(vars_in_multi_block)); + + for (const auto& name : vars_in_multi_block) { + vars_in_multi_block_with_pair_.emplace( + name, std::make_pair(VarType::FP32, idx)); } } } @@ -462,41 +458,34 @@ void ConvertToMixedPrecisionPass::ConvertAllFp64ToFp32( if (op_type == "fill_constant") { if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) == - static_cast(framework::proto::VarType::FP64)) - op_node->Op()->SetAttr( - "dtype", static_cast(framework::proto::VarType::FP32)); + static_cast(VarType::FP64)) + op_node->Op()->SetAttr("dtype", static_cast(VarType::FP32)); } else if (op_type == "assign_value") { if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) == - static_cast(framework::proto::VarType::FP64)) - op_node->Op()->SetAttr( - "dtype", static_cast(framework::proto::VarType::FP32)); + static_cast(VarType::FP64)) + op_node->Op()->SetAttr("dtype", static_cast(VarType::FP32)); } else if (op_type == "eye") { if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) == - static_cast(framework::proto::VarType::FP64)) - op_node->Op()->SetAttr( - "dtype", static_cast(framework::proto::VarType::FP32)); + static_cast(VarType::FP64)) + op_node->Op()->SetAttr("dtype", static_cast(VarType::FP32)); } else if (op_type == "fill_any_like") { if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) == - static_cast(framework::proto::VarType::FP64)) - op_node->Op()->SetAttr( - "dtype", static_cast(framework::proto::VarType::FP32)); + static_cast(VarType::FP64)) + op_node->Op()->SetAttr("dtype", static_cast(VarType::FP32)); } else if (op_type == "cast") { if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("in_dtype")) == - static_cast(framework::proto::VarType::FP64)) - op_node->Op()->SetAttr( - "in_dtype", static_cast(framework::proto::VarType::FP32)); + static_cast(VarType::FP64)) + op_node->Op()->SetAttr("in_dtype", static_cast(VarType::FP32)); if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("out_dtype")) == - static_cast(framework::proto::VarType::FP64)) - op_node->Op()->SetAttr( - "out_dtype", static_cast(framework::proto::VarType::FP32)); + static_cast(VarType::FP64)) + op_node->Op()->SetAttr("out_dtype", static_cast(VarType::FP32)); } auto inputs = op_node->inputs; for (auto* in_node : inputs) { auto* in_var = in_node->Var(); - if (!in_var->Persistable() && - in_var->GetDataType() == framework::proto::VarType::FP64) { - in_var->SetDataType(framework::proto::VarType::FP32); + if (!in_var->Persistable() && in_var->GetDataType() == VarType::FP64) { + in_var->SetDataType(VarType::FP32); } } } @@ -505,9 +494,8 @@ void ConvertToMixedPrecisionPass::ConvertAllFp64ToFp32( void ConvertToMixedPrecisionPass::Run() { LoadAndPrepare(); - for (size_t i = 0; i < main_graph_->SubGraphsSize(); ++i) { - auto graph = main_graph_->GetSubGraph(i); - graphes_.push_back(graph); + for (size_t i = 0; i < graphes_.size(); ++i) { + auto* graph = graphes_[i]; VLOG(2) << " -------- handle subgraph " << i << ", has " << graph->Nodes().size() << " nodes --------"; @@ -518,19 +506,19 @@ void ConvertToMixedPrecisionPass::Run() { // A trick PatchForStrangeOp(); - CHECK_EQ(ir::VarDescIsConsistency(*graph), true); + CHECK_EQ(framework::ir::VarDescIsConsistency(*graph), true); } SaveMixedModel(); } -void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) { - auto graph = graphes_[block_idx]; - framework::proto::VarType::Type to_type; +void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) { + auto* graph = graphes_[block_idx]; + VarType::Type to_type; if (mixed_precision_ == phi::DataType::FLOAT16) { - to_type = framework::proto::VarType::FP16; + to_type = VarType::FP16; } else if (mixed_precision_ == phi::DataType::BFLOAT16) { - to_type = framework::proto::VarType::BF16; + to_type = VarType::BF16; } else { PADDLE_THROW(paddle::platform::errors::InvalidArgument( "mixed_precision currently not supported dtype %d, we now only " @@ -551,8 +539,7 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) { // 1. set input dtype. if (op_type == "feed") { auto feed_var = op_node->outputs[0]->Var(); - if (!keep_io_types_ && - feed_var->GetDataType() == framework::proto::VarType::FP32) { + if (!keep_io_types_ && feed_var->GetDataType() == VarType::FP32) { feed_var->SetDataType(to_type); } } else if (op_type == "fetch") { @@ -568,15 +555,17 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) { // same name. std::unordered_map in_name_to_node; for (auto* in : op_node->inputs) { - auto* real_node = GetRealNode(block_idx, in); - if (NodeVarHasDtype(real_node)) { + if (!in->IsVar()) continue; + auto* real_node = GetRealVarNode(block_idx, in); + if (VarNodeHasDtype(real_node)) { in_name_to_node[in->Name()] = in; } } - for (auto out : op_node->outputs) { - auto* real_node = GetRealNode(block_idx, out); - if (NodeVarHasDtype(real_node)) { + for (auto* out : op_node->outputs) { + if (!out->IsVar()) continue; + auto* real_node = GetRealVarNode(block_idx, out); + if (VarNodeHasDtype(real_node)) { if (in_name_to_node.count(out->Name())) real_node->Var()->SetDataType( in_name_to_node[out->Name()]->Var()->GetDataType()); @@ -591,32 +580,46 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) { // - add cast op if the input dtype is not fp16/bf16. // - set output dtype. // - // If a var(op's out var) appears multiple times in a block, we should not + // If a var(op's out var) appears multiple times in graph, we should not // convert to fp16. else if (black_list_.count(op_type) == 0 && // NOLINT !VarIsMultiPrecisionOpsOut(block_idx, op_node)) { bool support_precision = OpSupportPrecision(op_type, backend_, mixed_precision_, black_list_); - // if op not has float input, we will not choose the low precision kernel. + // If the op has no input and output of float type, we will not choose the + // low precision kernel. { - bool has_float_input{false}; - for (auto in_node : op_node->inputs) { - auto* real_node = GetRealNode(block_idx, in_node); - if (real_node->Var()->GetDataType() == proto::VarType::FP16 || - real_node->Var()->GetDataType() == proto::VarType::FP32 || - real_node->Var()->GetDataType() == proto::VarType::FP64 || - real_node->Var()->GetDataType() == proto::VarType::BF16) { - has_float_input = true; + bool has_float_input_and_output{false}; + for (auto* in_node : op_node->inputs) { + if (!in_node->IsVar()) continue; + auto* real_node = GetRealVarNode(block_idx, in_node); + if (real_node->Var()->GetDataType() == VarType::FP16 || + real_node->Var()->GetDataType() == VarType::FP32 || + real_node->Var()->GetDataType() == VarType::FP64 || + real_node->Var()->GetDataType() == VarType::BF16) { + has_float_input_and_output = true; break; } } - if (!has_float_input) { + for (auto* out_node : op_node->outputs) { + if (!out_node->IsVar()) continue; + auto* real_node = GetRealVarNode(block_idx, out_node); + if (real_node->Var()->GetDataType() == VarType::FP16 || + real_node->Var()->GetDataType() == VarType::FP32 || + real_node->Var()->GetDataType() == VarType::FP64 || + real_node->Var()->GetDataType() == VarType::BF16) { + has_float_input_and_output = true; + break; + } + } + if (!has_float_input_and_output) { support_precision = false; - VLOG(2) << " op doesn't has float input, just skip."; + VLOG(2) << " op doesn't has float input and output, just skip."; } } - VLOG(2) << " support low precision " << support_precision; + VLOG(2) << "op type: " << op_type + << " support low precision: " << support_precision; if (support_precision) { VLOG(2) << " process input nodes:"; @@ -626,8 +629,8 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) { // Just for paddle's terriable case: op's input and output has the same // name. std::unordered_map names_map; - for (auto out_node : op_node->outputs) { - for (auto in_node : op_node->inputs) { + for (auto* out_node : op_node->outputs) { + for (auto* in_node : op_node->inputs) { if (out_node->Name() == in_node->Name()) { names_map[out_node->Name()] = in_node->Name(); } @@ -655,7 +658,7 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) { op_node, &suffix_, block_desc, - framework::proto::VarType::FP32, + VarType::FP32, block_idx); } } @@ -665,21 +668,19 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) { // - add cast op if the input dtype is not fp32. else { // NOLINT VLOG(3) << "not to run fp16 op_type: " << op_type; - auto ins = op_node->inputs; - for (auto* in_node : ins) { + for (auto* in_node : op_node->inputs) { auto* in_var = in_node->Var(); if (in_var->GetDataType() == to_type) { AddCastOp(graph, in_node, op_node, to_type, - framework::proto::VarType::FP32, + VarType::FP32, &suffix_, block_desc, &cast_map_); VLOG(3) << "-- " << in_node->Name() << "(" << to_type << ") to " - << cast_map_[in_node]->Name() << "(" - << framework::proto::VarType::FP32 << ")"; + << cast_map_[in_node]->Name() << "(" << VarType::FP32 << ")"; } } } @@ -688,31 +689,30 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) { // 4. if output_op's dtype is not compatible to output dtype, then just // insert cast. for (auto* node : output_nodes) { - ir::Node* fetch_op{nullptr}; + framework::ir::Node* fetch_op{nullptr}; for (auto* op_node : node->outputs) { if (op_node->IsOp() && op_node->Op()->Type() == "fetch") { fetch_op = op_node; } } CHECK_NOTNULL(fetch_op); - auto var = node->Var(); + auto* var = node->Var(); if (keep_io_types_ && var->GetDataType() == to_type) { // fp16/bf16 -> fp32. AddCastOp(graph, node, fetch_op, to_type, - framework::proto::VarType::FP32, + VarType::FP32, &suffix_, block_desc, &cast_map_); - } else if (!keep_io_types_ && - var->GetDataType() == framework::proto::VarType::FP32) { + } else if (!keep_io_types_ && var->GetDataType() == VarType::FP32) { // fp32 -> fp16/bf16 AddCastOp(graph, node, fetch_op, - framework::proto::VarType::FP32, + VarType::FP32, to_type, &suffix_, block_desc, @@ -720,13 +720,15 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) { } } - for (auto node : graph->Nodes()) { - auto* real_node = GetRealNode(block_idx, node); - if (!NodeVarHasDtype(real_node)) continue; + for (auto* node : graph->Nodes()) { + if (!node->IsVar()) continue; + auto* real_node = GetRealVarNode(block_idx, node); + if (!VarNodeHasDtype(real_node)) continue; - if (vars_in_multi_block_map_.count(real_node->Name()) && - vars_in_multi_block_map_.at(real_node->Name()).second == block_idx) { - vars_in_multi_block_map_.at(real_node->Name()).first = + if (vars_in_multi_block_with_pair_.count(real_node->Name()) && + vars_in_multi_block_with_pair_.at(real_node->Name()).second == + block_idx) { + vars_in_multi_block_with_pair_.at(real_node->Name()).first = real_node->Var()->GetDataType(); } } @@ -757,17 +759,15 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() { framework::ProgramDesc mixed_program_desc; framework::ir::GraphToProgram(*main_graph_, &mixed_program_desc); - paddle::CPUPlace place; auto parameters = scope_.LocalVarNames(); std::sort(parameters.begin(), parameters.end()); std::unordered_set weights_should_be_fp32; for (auto* node : main_graph_->Nodes()) { - if (!(node->IsVar())) continue; - if (NodeVarHasDtype(node)) { + if (!node->IsVar()) continue; + if (VarNodeHasDtype(node)) { if (node->Var()->Persistable() && - node->Var()->GetDataType() == - paddle::framework::proto::VarType::FP32) { + node->Var()->GetDataType() == VarType::FP32) { VLOG(2) << "weights keep to fp32: " << node->Name(); weights_should_be_fp32.insert(node->Name()); } @@ -777,26 +777,27 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() { #define CONVERT_TENSOR_DTYPE(DTYPE, dtype) \ mixed_tensor.set_type(DTYPE); \ auto* mixed_data = mixed_tensor.mutable_data(platform::CPUPlace()); \ - for (int i = 0; i < t->numel(); i++) { \ - mixed_data[i] = static_cast(data[i]); \ + for (int64_t i = 0; i < origin_tensor->numel(); i++) { \ + mixed_data[i] = static_cast(origin_data[i]); \ } \ - t->clear(); \ - paddle::framework::TensorCopySync(mixed_tensor, place, t) + origin_tensor->clear(); \ + paddle::framework::TensorCopySync( \ + mixed_tensor, platform::CPUPlace(), origin_tensor) for (const auto& param_name : parameters) { + if (weights_should_be_fp32.count(param_name)) continue; auto* var = scope_.FindLocalVar(param_name); if (var->IsType()) { - auto* t = var->GetMutable(); - if (t->dtype() != phi::DataType::FLOAT32) continue; + auto* origin_tensor = var->GetMutable(); + if (origin_tensor->dtype() != phi::DataType::FLOAT32) continue; phi::DenseTensor mixed_tensor; - mixed_tensor.Resize(t->dims()); - auto* data = t->mutable_data(platform::CPUPlace()); - if (mixed_precision_ == phi::DataType::FLOAT16 && - !weights_should_be_fp32.count(param_name)) { + mixed_tensor.Resize(origin_tensor->dims()); + auto* origin_data = + origin_tensor->mutable_data(platform::CPUPlace()); + if (mixed_precision_ == phi::DataType::FLOAT16) { CONVERT_TENSOR_DTYPE(paddle::experimental::DataType::FLOAT16, phi::dtype::float16); - } else if (mixed_precision_ == phi::DataType::BFLOAT16 && - !weights_should_be_fp32.count(param_name)) { + } else if (mixed_precision_ == phi::DataType::BFLOAT16) { CONVERT_TENSOR_DTYPE(paddle::experimental::DataType::BFLOAT16, phi::dtype::bfloat16); } @@ -851,8 +852,8 @@ void AddCastOp( framework::ir::Graph* graph, framework::ir::Node* node, framework::ir::Node* next_op, - framework::proto::VarType::Type from_type, - framework::proto::VarType::Type to_type, + VarType::Type from_type, + VarType::Type to_type, int* suffix, framework::BlockDesc* block_desc, std::unordered_map* map) { @@ -913,14 +914,15 @@ bool OpSupportPrecision(const std::string& op_type, return support_precision; } -void ConvertToMixedPrecision(const std::string& model_file, - const std::string& params_file, - const std::string& mixed_model_file, - const std::string& mixed_params_file, - phi::DataType mixed_precision, - phi::Backend backend, - bool keep_io_types, - std::unordered_set black_list) { +void ConvertToMixedPrecision( + const std::string& model_file, + const std::string& params_file, + const std::string& mixed_model_file, + const std::string& mixed_params_file, + phi::DataType mixed_precision, + phi::Backend backend, + bool keep_io_types, + const std::unordered_set& black_list) { ConvertToMixedPrecisionPass pass(model_file, params_file, mixed_model_file, diff --git a/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h b/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h index 3b763a4420..583512408c 100644 --- a/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h +++ b/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h @@ -51,8 +51,8 @@ void ConvertToMixedPrecision(const std::string& model_file, const std::string& mixed_params_file, phi::DataType mixed_precision, phi::Backend backend, - bool keep_io_types = true, - std::unordered_set black_list = {}); + bool keep_io_types, + const std::unordered_set& black_list); } // namespace analysis } // namespace inference -- GitLab