diff --git a/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc b/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc index bc753636d2c1a175f98cd36af3c63bde55558dc3..48d2cefe4a720ea7648d9bbbb6ab09b7ce54d924 100644 --- a/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc +++ b/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc @@ -119,6 +119,15 @@ bool WeightsShouldNotConvert(ir::Node* var_node) { return false; } +inline bool IsFloatVarType(framework::proto::VarType::Type type) { + if (type == framework::proto::VarType::FP16 || + type == framework::proto::VarType::FP32 || + type == framework::proto::VarType::BF16 || + type == framework::proto::VarType::FP64) + return true; + return false; +} + void ConvertTensorDtype(framework::ir::Graph* graph, const std::unordered_set& blacklist, bool keep_io_types, @@ -146,8 +155,6 @@ void ConvertTensorDtype(framework::ir::Graph* graph, if (!op_node->IsOp()) continue; auto op_type = op_node->Op()->Type(); auto phi_op_type = phi::TransToPhiKernelName(op_type); - // LOG(INFO) << "process op " << op_type << ", corresponding phi type is " - // << phi_op_type; // 1. set input dtype. if (op_type == "feed") { block_desc = op_node->Op()->Block(); @@ -175,12 +182,14 @@ void ConvertTensorDtype(framework::ir::Graph* graph, ++num_low_precision; auto inputs = op_node->inputs; for (auto* in_node : inputs) { + if (in_node->IsCtrlVar()) continue; auto* in_var = in_node->Var(); if (in_var->Persistable() && in_var->GetDataType() == framework::proto::VarType::FP32) { if (WeightsShouldNotConvert(in_node)) continue; in_var->SetDataType(to_type); } else if (!in_var->Persistable() && + IsFloatVarType(in_var->GetDataType()) && in_var->GetDataType() != to_type) { AddCastOp(graph, in_node, @@ -193,6 +202,7 @@ void ConvertTensorDtype(framework::ir::Graph* graph, } } for (auto* out_node : op_node->outputs) { + if (out_node->IsCtrlVar()) continue; auto* out_var = out_node->Var(); if (out_var->GetDataType() == framework::proto::VarType::FP32) { if (OutShouldNotConvert(out_node)) continue; @@ -202,8 +212,9 @@ void ConvertTensorDtype(framework::ir::Graph* graph, } else { auto inputs = op_node->inputs; for (auto* in_node : inputs) { + if (in_node->IsCtrlVar()) continue; auto* in_var = in_node->Var(); - if (!in_var->Persistable() && + if (!in_var->Persistable() && IsFloatVarType(in_var->GetDataType()) && in_var->GetDataType() != framework::proto::VarType::FP32) { AddCastOp(graph, in_node, @@ -224,6 +235,7 @@ void ConvertTensorDtype(framework::ir::Graph* graph, // trt pass should explicitle add cast op is input is bf16/tf32, etc. if (op_node->Name() == "tensorrt_engine") continue; for (auto* in_node : op_node->inputs) { + if (in_node->IsCtrlVar()) continue; auto* in_var = in_node->Var(); if (in_var->GetDataType() == to_type) { AddCastOp(graph, @@ -242,6 +254,7 @@ void ConvertTensorDtype(framework::ir::Graph* graph, // 4. if output_op's dtype is not compatible to output dtype, then just insert // cast. for (auto* node : output_nodes) { + if (node->IsCtrlVar()) continue; auto var = node->Var(); if (keep_io_types && var->GetDataType() == to_type) { // fp16/bf16 -> fp32. @@ -381,7 +394,7 @@ void ConvertToMixedPrecision(const std::string& model_file, std::unordered_set weights_should_be_fp32; for (auto* node : graph->Nodes()) { - if (!node->IsVar()) continue; + if (!(node->IsVar() && !node->IsCtrlVar())) continue; if (node->Var()->GetType() == paddle::framework::proto::VarType::SELECTED_ROWS || node->Var()->GetType() ==