diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 18ffa2661da48e5c10a8e462925cc37114232c28..4e09b4922b0ba38b174bda535cc4bc7b51360927 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -104,6 +104,7 @@ pass_library(delete_c_identity_op_pass inference) pass_library(preln_residual_bias_fuse_pass inference) pass_library(delete_fill_constant_op_pass inference) pass_library(constant_folding_pass inference) +pass_library(float_to_half_pass inference) pass_library(conv2d_fusion_layout_transfer_pass inference) pass_library(simplify_with_basic_ops_pass base) pass_library(fc_elementwise_layernorm_fuse_pass base) diff --git a/paddle/fluid/framework/ir/float_to_half_pass.cc b/paddle/fluid/framework/ir/float_to_half_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..ec94728fb3c64175bdf44eab7ed2683f1ba4ce75 --- /dev/null +++ b/paddle/fluid/framework/ir/float_to_half_pass.cc @@ -0,0 +1,725 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/float_to_half_pass.h" + +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/phi/common/data_type.h" + +namespace paddle { +namespace framework { +namespace ir { + +namespace { + +using VarType = FloatToHalfPass::VarType; + +bool PhiKernelSupportPrecision( + const std::string& op_type, + phi::Backend backend, + phi::DataType data_type, + phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) { + const auto& kernels = phi::KernelFactory::Instance().kernels(); + if (kernels.count(op_type) == 0) { + return false; + } + phi::KernelKey kernel_key(backend, layout, data_type); + return phi::KernelFactory::Instance().HasKernel(op_type, kernel_key); +} + +bool GpuKernelSupportPrecision( + const std::string& op_type, + phi::DataType precision, + phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) { + auto phi_op_type = phi::TransToPhiKernelName(op_type); + bool support = PhiKernelSupportPrecision( + phi_op_type, phi::Backend::GPU, precision, layout); + support |= PhiKernelSupportPrecision( + phi_op_type, phi::Backend::GPUDNN, precision, layout); + + if (!support) { + const auto& all_kernels = framework::OperatorWithKernel::AllOpKernels(); + auto it = all_kernels.find(op_type); + if (it != all_kernels.end()) { + for (const auto& kern_pair : it->second) { + if (platform::is_gpu_place(kern_pair.first.place_) && + kern_pair.first.data_type_ == + framework::TransToProtoVarType(precision)) { + support = true; + break; + } + } + } + } + return support; +} + +void DoInsertCastOp(Graph* graph, + Node* var_node, + Node* op_node, + VarType::Type from_type, + VarType::Type to_type, + framework::BlockDesc* block_desc, + int* suffix, + std::unordered_map* cache) { + if (from_type == to_type) return; + + auto update_cast_desc = [&](framework::OpDesc& desc, + const std::string& x_name, + const std::string& out_name, + const int in_dtype, + const int out_dtype) { + desc.SetType("cast"); + desc.SetInput("X", {x_name}); + desc.SetOutput("Out", {out_name}); + desc.SetAttr("in_dtype", in_dtype); + desc.SetAttr("out_dtype", out_dtype); + desc.SetAttr("use_mkldnn", false); + desc.SetAttr("with_quant_attr", false); + desc.Flush(); + }; + + if (cache->count(var_node) == 0) { + // insert cast op between var_node and op_node + std::string cast_input_name = var_node->Var()->Name(); + std::string cast_output_name = + var_node->Var()->Name() + "_cast.tmp_" + std::to_string((*suffix)++); + framework::OpDesc cast_op_desc(block_desc); + update_cast_desc(cast_op_desc, + cast_input_name, + cast_output_name, + static_cast(from_type), + static_cast(to_type)); + auto* cast_op_node = graph->CreateOpNode(&cast_op_desc); + auto* cast_output_vardesc = block_desc->Var(cast_output_name); + cast_output_vardesc->SetPersistable(false); + cast_output_vardesc->SetDataType(to_type); + cast_output_vardesc->SetShape(var_node->Var()->GetShape()); + auto* cast_output_node = graph->CreateVarNode(cast_output_vardesc); + IR_NODE_LINK_TO(cast_op_node, cast_output_node); + (*cache)[var_node] = cast_output_node; + } + op_node->Op()->Rename(var_node->Name(), cache->at(var_node)->Name()); + IR_NODE_LINK_TO(var_node, cache->at(var_node)->inputs[0]); + IR_NODE_LINK_TO(cache->at(var_node), op_node); + + IR_NODE_UNLINK(var_node, op_node); +} + +inline bool VarNodeHasDtype(Node* var_node) { + auto type = var_node->Var()->GetType(); + return (type == VarType::SELECTED_ROWS) || (type == VarType::LOD_TENSOR) || + (type == VarType::LOD_TENSOR_ARRAY) || (type == VarType::STRINGS) || + (type == VarType::VOCAB); +} + +inline bool IsFloatType(VarType::Type type) { + return (type == VarType::FP64) || (type == VarType::FP32); +} + +inline bool IsHalfType(VarType::Type type) { + return (type == VarType::FP16) || (type == VarType::BF16); +} + +}; // namespace + +// The set of ops that support fp16 calculation and are considered +// numerically-dangerous, slower and whose effects may also be observed in +// downstream ops. +void FloatToHalfPass::SetDefaultBlacklist() const { + black_list_.insert({ + // numerically-dangerous + "acos", + "asin", + "cosh", + "tan", + "exp", + "expm1", + "square", + "log", + "log2", + "log10", + "log1p", + "logsumexp", + "mean", + "rsqrt", + "sum", + "cos_sim", + "softmax", + "softmax_with_cross_entropy", + "sigmoid_cross_entropy_with_logits", + "c_softmax_with_cross_entropy", + "cross_entropy", + "cross_entropy2", + // slower than fp32 + "conv2d_transpose", + // default fp32 can avoid return inf when the sum value large than 65504 + "reduce_sum", + }); +} + +void FloatToHalfPass::Init(Graph* graph) const { + keep_io_types_ = true; + half_precision_ = + static_cast(Get("mixed_precision_mode")); + black_list_ = Get>("mixed_black_list"); + SetDefaultBlacklist(); + + auto graph_size = graph->SubGraphsSize(); + VLOG(4) << "graph size: " << graph_size; + subgraphes_.resize(graph_size); + all_op_nodes_.resize(graph_size); + + for (size_t i = 0; i < graph_size; i++) { + subgraphes_[i] = graph->GetSubGraph(i); + all_op_nodes_[i] = TopologySortOperations(*subgraphes_[i]); + VLOG(4) << "subgraph " << i << " has " << all_op_nodes_[i].size() + << "op nodes"; + for (auto* var_node : subgraphes_[i]->Nodes()) { + if (!var_node->IsVar()) continue; + + auto var_name = var_node->Var()->Name(); + if (real_vars_.count(var_name) == 0) { + real_vars_[var_name] = var_node; + VLOG(4) << var_name << " is in graph " << i; + } + } + } +} + +void FloatToHalfPass::ApplyImpl(Graph* graph) const { + auto enable_gpu_half = Get("enable_gpu_half"); + if (!enable_gpu_half) return; + + PADDLE_ENFORCE_NOT_NULL( + graph, + platform::errors::PreconditionNotMet( + "During the float to half pass, the graph should not be nullptr.")); + PADDLE_ENFORCE_EQ( + graph->IsMainGraph(), + true, + platform::errors::PreconditionNotMet( + "During the float to half pass, the graph should be main graph.")); + + FusePassBase::Init("float_to_half", graph); + + Init(graph); + VLOG(4) << "Init done"; + SetOpUniqueType(); + VLOG(4) << "SetOpUniqueType done"; + GetOpPrecision(); + VLOG(4) << "GetOpPrecision done"; + UpdateOpPrecision(); + VLOG(4) << "UpdateOpPrecision done"; + SetVarPrecision(); + VLOG(4) << "SetVarPrecision done"; + ConvertWeightsData(); + VLOG(4) << "ConvertWeightsData done"; + ProcessOpWithDtypeAttr(); + VLOG(4) << "ProcessOpWithDtypeAttr done"; + InsertCastOp(); + VLOG(4) << "InsertCastOp done"; + RestoreOpOriginType(); + VLOG(4) << "RestoreOpOriginType done"; +} + +bool FloatToHalfPass::OpSupportPrecision(const std::string& op_type, + phi::DataType precision, + phi::Backend backend) const { + bool support = false; + if (black_list_.count(op_type) == 0) { + if (backend == phi::Backend::GPU) { + support = GpuKernelSupportPrecision(op_type, precision); + } + } + return support; +} + +void FloatToHalfPass::SetOpUniqueType() const { + int suffix = 0; + for (const auto& nodes : all_op_nodes_) { + for (auto* op_node : nodes) { + auto op_type = op_node->Op()->Type(); + + if (op_type == "feed" || op_type == "fetch") continue; + + std::string unique_type = op_type + "_" + std::to_string(suffix++); + op_original_type_[unique_type] = op_type; + op_node->Op()->SetType(unique_type); + op_node->Op()->Flush(); + VLOG(4) << "change op type: " << op_type << " ---> " << unique_type; + } + } +} + +void FloatToHalfPass::RestoreOpOriginType() const { + for (const auto& nodes : all_op_nodes_) { + for (auto* op_node : nodes) { + auto op_type = op_node->Op()->Type(); + op_node->Op()->SetType(GetOpOriginalType(op_type)); + op_node->Op()->Flush(); + VLOG(4) << "restore op type: " << op_type << " ---> " + << op_node->Op()->Type(); + } + } +} + +inline std::string FloatToHalfPass::GetOpOriginalType( + const std::string& op_type) const { + if (op_original_type_.count(op_type)) { + return op_original_type_.at(op_type); + } + return op_type; +} + +void FloatToHalfPass::ProcessOpWithDtypeAttr() const { + for (const auto& nodes : all_op_nodes_) { + for (auto* op_node : nodes) { + auto op_type = op_node->Op()->Type(); + if (op_run_half_.count(op_type) == 0) continue; + + if (op_node->Op()->HasAttr("dtype")) { + auto dtype = op_node->Op()->GetAttrIfExists("dtype"); + if (IsFloatType(static_cast(dtype))) { + op_node->Op()->SetAttr( + "dtype", + static_cast( + framework::TransToProtoVarType(half_precision_))); + op_node->Op()->Flush(); + VLOG(4) << "process op with dtype attr: " << op_type << " ( " << dtype + << " --->" << static_cast(half_precision_) << " )"; + } + } + if (op_node->Op()->HasAttr("out_dtype")) { + auto out_dtype = op_node->Op()->GetAttrIfExists("out_dtype"); + if (IsFloatType(static_cast(out_dtype))) { + op_node->Op()->SetAttr( + "out_dtype", + static_cast( + framework::TransToProtoVarType(half_precision_))); + op_node->Op()->Flush(); + VLOG(4) << "process op with out_dtype attr: " << op_type << " ( " + << out_dtype << " --->" << static_cast(half_precision_) + << " )"; + } + } + } + } +} + +void FloatToHalfPass::GetOpPrecision() const { + for (const auto& nodes : all_op_nodes_) { + for (auto* op_node : nodes) { + auto op_type = op_node->Op()->Type(); + bool support_half = true; + if (GetOpOriginalType(op_type) == "feed" || + GetOpOriginalType(op_type) == "fetch") { + support_half = !keep_io_types_; + } else { + support_half = + OpSupportPrecision(GetOpOriginalType(op_type), half_precision_); + } + + if (op_node->Op()->HasAttr("dtype")) { + auto dtype = op_node->Op()->GetAttrIfExists("dtype"); + support_half = + support_half && IsFloatType(static_cast(dtype)); + } else if (op_node->Op()->HasAttr("out_dtype")) { + auto out_dtype = op_node->Op()->GetAttrIfExists("out_dtype"); + support_half = + support_half && IsFloatType(static_cast(out_dtype)); + } else { + // if op's input var and output var is not dense tensor, the op should + // not run half. + for (auto* in_var_node : op_node->inputs) { + CHECK_EQ(in_var_node->IsVar(), true); + auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()]; + if (real_in_var_node->Var()->Persistable()) continue; + + support_half = support_half && (real_in_var_node->Var()->GetType() == + VarType::LOD_TENSOR); + } + + for (auto* out_var_node : op_node->outputs) { + CHECK_EQ(out_var_node->IsVar(), true); + auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()]; + if (real_out_var_node->Var()->Persistable()) continue; + + support_half = support_half && (real_out_var_node->Var()->GetType() == + VarType::LOD_TENSOR); + } + } + + if (support_half) { + op_run_half_.insert(op_type); + VLOG(4) << "support precision: " << op_type << " run at half"; + } else { + VLOG(4) << "support precision: " << op_type << " not run at half"; + } + } + } +} + +void FloatToHalfPass::UpdateOpPrecision() const { + std::unordered_set vars_should_not_half; + + // var -> the var's all input op + std::unordered_map> var_input_ops; + + auto GetVarInputOps = [&] { + for (const auto& nodes : all_op_nodes_) { + for (auto* op_node : nodes) { + auto op_type = op_node->Op()->Type(); + + if (GetOpOriginalType(op_type) == "fetch") continue; + if (op_node->Op()->HasAttr("sub_block")) continue; + + for (auto* var_node : op_node->outputs) { + CHECK_EQ(var_node->IsVar(), true); + if (var_node->Var()->Persistable()) continue; + if (!VarNodeHasDtype(var_node)) continue; + + var_input_ops[var_node->Var()->Name()].push_back(op_node); + VLOG(4) << "var input ops: " << var_node->Var()->Name() + << " is output of " << op_type; + } + + // the select_input op's input var should not convert to half. when + // op's output var is select_input op's input var, the op should not run + // half. + if (GetOpOriginalType(op_node->Op()->Type()) == "select_input") { + for (auto* in_var_node : op_node->inputs) { + CHECK_EQ(in_var_node->IsVar(), true); + if (in_var_node->Var()->Persistable()) continue; + if (!VarNodeHasDtype(in_var_node)) continue; + + vars_should_not_half.insert(in_var_node->Var()->Name()); + } + } + + // when op_1 only support cpu kernel. if op_2's intput var is op_1's + // output var, then op_2 should not run half. + if (GetOpOriginalType(op_type) != "feed" && + !GpuKernelSupportPrecision(GetOpOriginalType(op_type), + phi::DataType::FLOAT32)) { + for (auto* out_var_node : op_node->outputs) { + CHECK_EQ(out_var_node->IsVar(), true); + if (out_var_node->Var()->Persistable()) continue; + if (!VarNodeHasDtype(out_var_node)) continue; + + vars_should_not_half.insert(out_var_node->Var()->Name()); + } + } + } + } + }; + GetVarInputOps(); + + bool precision_updated = false; + do { + precision_updated = false; + for (const auto& nodes : all_op_nodes_) { + for (auto* op_node : nodes) { + if (op_run_half_.count(op_node->Op()->Type()) == 0) continue; + + for (auto* in_var_node : op_node->inputs) { + CHECK_EQ(in_var_node->IsVar(), true); + if (!VarNodeHasDtype(in_var_node)) continue; + + auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()]; + if (real_in_var_node->Var()->Persistable()) continue; + + if (vars_should_not_half.count(real_in_var_node->Var()->Name())) { + op_run_half_.erase(op_node->Op()->Type()); + precision_updated = true; + VLOG(4) << op_node->Op()->Type() + << " should not support half precision."; + break; + } + } + + if (op_run_half_.count(op_node->Op()->Type()) == 0) continue; + + for (auto* out_var_node : op_node->outputs) { + CHECK_EQ(out_var_node->IsVar(), true); + if (!VarNodeHasDtype(out_var_node)) continue; + + auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()]; + if (real_out_var_node->Var()->Persistable()) continue; + + bool not_run_half = false; + const auto& input_op_nodes = + var_input_ops[real_out_var_node->Var()->Name()]; + if (vars_should_not_half.count(real_out_var_node->Var()->Name())) { + not_run_half = true; + } else { + for (auto* node : input_op_nodes) { + if (op_run_half_.count(node->Op()->Type()) == 0) { + not_run_half = true; + break; + } + } + } + if (not_run_half) { + op_run_half_.erase(op_node->Op()->Type()); + precision_updated = true; + VLOG(4) << op_node->Op()->Type() + << " should not support half precision."; + break; + } + } + } + } + } while (precision_updated); +} + +// special ops, its weights should not be low precision. +bool FloatToHalfPass::InputVarsNotConvert(Node* op_node, + const std::string& var_name) const { + auto* op_desc = op_node->Op(); + if (GetOpOriginalType(op_desc->Type()) == "batch_norm") { + auto vecs = op_desc->Input("Bias"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + vecs = op_desc->Input("Mean"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + vecs = op_desc->Input("Scale"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + vecs = op_desc->Input("Variance"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + } else if (GetOpOriginalType(op_desc->Type()) == "fused_multi_transformer") { + auto vecs = op_desc->Input("LnScale"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + vecs = op_desc->Input("LnBias"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + vecs = op_desc->Input("FFNLnScale"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + vecs = op_desc->Input("FFNLnBias"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + } + return false; +} + +bool FloatToHalfPass::OutputVarsNotConvert(Node* op_node, + const std::string& var_name) const { + auto* op_desc = op_node->Op(); + // batch_norm's input and output (variance and mean) are the same. + if (GetOpOriginalType(op_desc->Type()) == "batch_norm") { + auto vecs = op_desc->Output("MeanOut"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + vecs = op_desc->Output("VarianceOut"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + vecs = op_desc->Output("SavedMean"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + vecs = op_desc->Output("SavedVariance"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + } + return false; +} + +void FloatToHalfPass::SetVarPrecision() const { + for (const auto& nodes : all_op_nodes_) { + for (auto* op_node : nodes) { + if (op_run_half_.count(op_node->Op()->Type())) { + for (auto* in_var_node : op_node->inputs) { + CHECK_EQ(in_var_node->IsVar(), true); + + auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()]; + auto in_var_name = real_in_var_node->Var()->Name(); + + if (!IsFloatType(real_in_var_node->Var()->GetDataType())) continue; + if (!VarNodeHasDtype(real_in_var_node)) continue; + if (InputVarsNotConvert(op_node, in_var_name)) continue; + + if (real_in_var_node->Var()->Persistable()) { + real_in_var_node->Var()->SetDataType( + framework::TransToProtoVarType(half_precision_)); + vars_convert_to_half_.insert(in_var_name); + } + } + + for (auto* out_var_node : op_node->outputs) { + CHECK_EQ(out_var_node->IsVar(), true); + + auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()]; + auto out_var_name = real_out_var_node->Var()->Name(); + + if (!IsFloatType(real_out_var_node->Var()->GetDataType())) continue; + if (!VarNodeHasDtype(real_out_var_node)) continue; + if (OutputVarsNotConvert(op_node, out_var_name)) continue; + + real_out_var_node->Var()->SetDataType( + framework::TransToProtoVarType(half_precision_)); + if (real_out_var_node->Var()->Persistable()) { + vars_convert_to_half_.insert(out_var_name); + } + } + } + } + } + + // This code used to precess vars with the same name. Vars with the same + // name should have the same data type. + for (auto* subgraph : subgraphes_) { + for (auto* var_node : subgraph->Nodes()) { + if (!var_node->IsVar() || !var_node->Var()->Persistable()) continue; + if (!VarNodeHasDtype(var_node)) continue; + + auto var_name = var_node->Var()->Name(); + if (vars_convert_to_half_.count(var_name)) { + var_node->Var()->SetDataType( + framework::TransToProtoVarType(half_precision_)); + } + } + } +} + +void FloatToHalfPass::ConvertWeightsData() const { + auto* scope = param_scope(); + PADDLE_ENFORCE_NOT_NULL( + scope, + platform::errors::PreconditionNotMet( + "During the float to half pass, the scope should not be null.")); + + auto var_names = scope->LocalVarNames(); + for (const auto& var_name : var_names) { + if (vars_convert_to_half_.count(var_name)) { + VLOG(4) << var_name << "'s data type was convert to half"; +#define CONVERT_TENSOR_DTYPE(DTYPE, dtype) \ + half_tensor.set_type(DTYPE); \ + auto* half_data = half_tensor.mutable_data(platform::CPUPlace()); \ + for (int64_t i = 0; i < origin_tensor->numel(); i++) { \ + half_data[i] = static_cast(origin_data[i]); \ + } \ + origin_tensor->clear(); \ + paddle::framework::TensorCopySync( \ + half_tensor, platform::CPUPlace(), origin_tensor) + + auto* var = scope->FindLocalVar(var_name); + + if (var->IsType()) { + auto* origin_tensor = var->GetMutable(); + phi::DenseTensor half_tensor; + half_tensor.Resize(origin_tensor->dims()); + auto* origin_data = + origin_tensor->mutable_data(platform::CPUPlace()); + if (half_precision_ == phi::DataType::FLOAT16) { + CONVERT_TENSOR_DTYPE(paddle::experimental::DataType::FLOAT16, + phi::dtype::float16); + } else if (half_precision_ == phi::DataType::BFLOAT16) { + CONVERT_TENSOR_DTYPE(paddle::experimental::DataType::BFLOAT16, + phi::dtype::bfloat16); + } + } + } +#undef CONVERT_TENSOR_DTYPE + } +} + +void FloatToHalfPass::InsertCastOp() const { + int suffix = 0; + std::unordered_map cache; + + for (size_t i = 0; i < all_op_nodes_.size(); i++) { + auto* block_desc = all_op_nodes_[i][0]->Op()->Block(); + CHECK_NOTNULL(block_desc); + for (auto* op_node : all_op_nodes_[i]) { + auto op_type = op_node->Op()->Type(); + + if (GetOpOriginalType(op_type) == "feed") continue; + if (op_node->Op()->HasAttr("sub_block")) continue; + + VLOG(4) << "process op: " << op_type + << " run half: " << op_run_half_.count(op_type); + + auto inputs = op_node->inputs; + for (auto* in_var_node : inputs) { + if (!in_var_node->IsVar()) continue; + if (!VarNodeHasDtype(in_var_node)) continue; + if (in_var_node->Var()->Persistable()) continue; + + auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()]; + + auto in_var_type = real_in_var_node->Var()->GetDataType(); + + VLOG(4) << "process var: " << real_in_var_node->Var()->Name() + << " with type " << in_var_type; + + if (IsFloatType(in_var_type) && op_run_half_.count(op_type)) { + DoInsertCastOp(subgraphes_[i], + in_var_node, + op_node, + in_var_type, + framework::TransToProtoVarType(half_precision_), + block_desc, + &suffix, + &cache); + } else if (IsHalfType(in_var_type) && + op_run_half_.count(op_type) == 0) { + DoInsertCastOp(subgraphes_[i], + in_var_node, + op_node, + in_var_type, + VarType::FP32, + block_desc, + &suffix, + &cache); + } + } + + // Special op. + // fused_multi_transformer's input(CacheKV) and output(CacheKVOut) vars + // have same name. + if (GetOpOriginalType(op_type) == "fused_multi_transformer") { + auto cache_kv_inputs = op_node->Op()->Input("CacheKV"); + auto cache_kv_outputs = op_node->Op()->Output("CacheKVOut"); + CHECK_EQ(cache_kv_inputs.size(), cache_kv_outputs.size()); + for (size_t i = 0; i < cache_kv_inputs.size(); ++i) { + op_node->Op()->RenameOutput(cache_kv_outputs[i], cache_kv_inputs[i]); + } + } + } + } + VLOG(4) << "insert number of cast op: " << cache.size(); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(float_to_half_pass, paddle::framework::ir::FloatToHalfPass); diff --git a/paddle/fluid/framework/ir/float_to_half_pass.h b/paddle/fluid/framework/ir/float_to_half_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..a274dc9a53c61a1490c96d60ba96e49608fe446b --- /dev/null +++ b/paddle/fluid/framework/ir/float_to_half_pass.h @@ -0,0 +1,98 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/node.h" +#include "paddle/phi/common/backend.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/common/layout.h" +#include "paddle/phi/common/place.h" + +namespace paddle { +namespace framework { +namespace ir { + +class FloatToHalfPass : public FusePassBase { + public: + using VarType = framework::proto::VarType; + + public: + FloatToHalfPass() = default; + ~FloatToHalfPass() = default; + + protected: + void ApplyImpl(Graph* graph) const override; + + private: + void Init(Graph* graph) const; + + void SetDefaultBlacklist() const; + + bool OpSupportPrecision(const std::string& op_type, + phi::DataType precision, + phi::Backend backend = phi::Backend::GPU) const; + + void SetOpUniqueType() const; + + void RestoreOpOriginType() const; + + inline std::string GetOpOriginalType(const std::string& op_type) const; + + void GetOpPrecision() const; + + void UpdateOpPrecision() const; + + void InsertCastOp() const; + + void ProcessOpWithDtypeAttr() const; + + bool InputVarsNotConvert(Node* op_node, const std::string& var_name) const; + + bool OutputVarsNotConvert(Node* op_node, const std::string& var_name) const; + + void SetVarPrecision() const; + + void ConvertWeightsData() const; + + private: + mutable bool keep_io_types_; + // float16 or bfloat16 now + mutable phi::DataType half_precision_; + + mutable std::unordered_set black_list_; + + // subgraph id -> pointer to subgraph + mutable std::vector subgraphes_; + // var name -> real var node + mutable std::unordered_map real_vars_; + // subgraph id -> all op nodes in subgraph + mutable std::vector> all_op_nodes_; + // op's unique type -> the op's origin type + mutable std::unordered_map op_original_type_; + // op's unique type -> whether the op run at half precision + mutable std::unordered_set op_run_half_; + + mutable std::unordered_set vars_convert_to_half_; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index a8d1067c55471507bf783b53e1ab078d6a5d11ff..8750a9afb44e48fe29b4e33eea61ccc06e083bfe 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -365,6 +365,8 @@ struct Argument { DECL_ARGUMENT_FIELD(mixed_black_list, MixedBlackList, std::unordered_set); + DECL_ARGUMENT_FIELD(enable_gpu_half, EnableGPUHalf, bool); + DECL_ARGUMENT_FIELD(mixed_precision_mode, MixedPrecisionMode, int); private: std::unordered_set valid_fields_; diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index b31f28a6a602f969cb8990641cd633ccdd71a0a4..cbcc48a7f68e85e4797716e1838215443e4c1983 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -86,10 +86,14 @@ void IRPassManager::CreatePasses(Argument *argument, argument->tensorrt_tuned_dynamic_shape(); pass->Set("with_dynamic_shape", new bool(with_dynamic_shape)); + // mixed precision related pass->Set("model_precision", new int(argument->model_precision())); pass->Set( "mixed_black_list", new std::unordered_set(argument->mixed_black_list())); + pass->Set("enable_gpu_half", new bool(argument->enable_gpu_half())); + pass->Set("mixed_precision_mode", + new int(argument->mixed_precision_mode())); if (pass_name == "graph_viz_pass") { std::string optim_cache_dir = argument->optim_cache_dir(); diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 7720fab31e29ee3636d7fa0c5fe1a7ee31f1f2cb..c5e648dffc0bfc8b0ab939ca897e69b4a2883c47 100755 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -85,16 +85,29 @@ void AnalysisConfig::SetModel(const std::string &prog_file_path, Update(); } + void AnalysisConfig::EnableUseGpu(uint64_t memory_pool_init_size_mb, - int device_id) { + int device_id, + Precision precision_mode) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) use_gpu_ = true; memory_pool_init_size_mb_ = memory_pool_init_size_mb; FLAGS_initial_gpu_memory_in_mb = memory_pool_init_size_mb_; gpu_device_id_ = device_id; + mixed_precision_mode_ = precision_mode; + if (precision_mode == Precision::kFloat32) { + // default + } else if (precision_mode == Precision::kHalf || + precision_mode == Precision::kBf16) { + enable_gpu_half_ = true; + } else { + LOG(ERROR) + << "The Paddle-GPU inference currently only supports " + "float32/float16/bfloat16 precision. Please check the parameters " + "you specified in EnableUseGpu or enable_use_gpu function."; + } #else - LOG(ERROR) << "Please compile with gpu to EnableGpu()"; - use_gpu_ = false; + LOG(ERROR) << "Please use PaddlePaddle with GPU version."; #endif Update(); @@ -381,8 +394,10 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { CP_MEMBER(gpu_device_id_); CP_MEMBER(memory_pool_init_size_mb_); - // Mixed related. + // Mixed precision related. CP_MEMBER(mixed_black_list_); + CP_MEMBER(enable_gpu_half_); + CP_MEMBER(mixed_precision_mode_); CP_MEMBER(enable_memory_optim_); // TensorRT related. @@ -996,6 +1011,7 @@ std::string AnalysisConfig::SerializeInfoCache() { ss << params_file_; ss << use_gpu_; + ss << enable_gpu_half_; ss << use_external_stream_; ss << exec_stream_; ss << use_fc_padding_; @@ -1212,6 +1228,7 @@ std::string AnalysisConfig::Summary() { os.InsertRow({"use_gpu", use_gpu_ ? "true" : "false"}); if (use_gpu_) { os.InsertRow({"gpu_device_id", std::to_string(gpu_device_id_)}); + os.InsertRow({"enable_gpu_half_", std::to_string(enable_gpu_half_)}); os.InsertRow({"memory_pool_init_size", std::to_string(memory_pool_init_size_mb_) + "MB"}); os.InsertRow( @@ -1407,7 +1424,7 @@ bool AnalysisConfig::trt_allow_build_at_runtime() const { return trt_allow_build_at_runtime_; } -void AnalysisConfig::Exp_SetBlackListOpsForMixedModel( +void AnalysisConfig::Exp_DisableMixedInferOps( const std::unordered_set &black_list) { mixed_black_list_ = black_list; } diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index da09ecff079bd33e6b47bdd36360c6e3596f4e90..35005a2f676716fac6991e82d37a855a56d62968 100755 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1257,12 +1257,26 @@ void AnalysisPredictor::PrepareArgument() { } } } - if (config_.ir_debug_) { - pass_builder->TurnOnDebug(); - } + if (!config_.ir_optim()) { argument_.SetEnableIrOptim(false); - LOG(INFO) << "ir_optim is turned off, no IR pass will be executed"; + if (config_.enable_gpu_half_) { + argument_.SetEnableIrOptim(true); + pass_builder->ClearPasses(); + pass_builder->AppendPass("float_to_half_pass"); + LOG(INFO) + << "This model run in Paddle-GPU mixed precision mode with no ir " + "optimization."; + } else { + LOG(INFO) << "ir_optim is turned off, no IR pass will be executed."; + } + } else { + if (config_.ir_debug_) { + pass_builder->TurnOnDebug(); + } + if (config_.enable_gpu_half_) { + LOG(INFO) << "This model run in Paddle-GPU mixed precision mode."; + } } argument_.SetDisableLogs(config_.glog_info_disabled()); argument_.SetIrAnalysisPasses(pass_builder->AllPasses()); @@ -1272,6 +1286,9 @@ void AnalysisPredictor::PrepareArgument() { // mixed precison. argument_.SetModelPrecision(static_cast(model_precision_)); argument_.SetMixedBlackList(config_.mixed_black_list_); + argument_.SetEnableGPUHalf(config_.enable_gpu_half_); + argument_.SetMixedPrecisionMode(static_cast( + paddle::ConvertPrecision(config_.mixed_precision_mode_))); } // NOTE All the members in AnalysisConfig should be copied to Argument. diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index a8f645680a962c4fb850ba6f2deceac45753c629..f8ddcbdaa8f39a4346e3e5c8fdd13627a0e7d055 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -247,8 +247,12 @@ struct PD_INFER_DECL AnalysisConfig { /// /// \param memory_pool_init_size_mb initial size of the GPU memory pool in MB. /// \param device_id device_id the GPU card to use (default is 0). + /// \param precision the precision used in Paddle-GPU inference. /// - void EnableUseGpu(uint64_t memory_pool_init_size_mb, int device_id = 0); + void EnableUseGpu(uint64_t memory_pool_init_size_mb, + int device_id = 0, + Precision precision_mode = Precision::kFloat32); + /// /// \brief Turn off GPU. /// @@ -1005,7 +1009,7 @@ struct PD_INFER_DECL AnalysisConfig { /// interface is in the experimental stage and may change in the future. Note /// that the blacklist must be the same as the model conversion blacklist. /// - void Exp_SetBlackListOpsForMixedModel( + void Exp_DisableMixedInferOps( const std::unordered_set& black_list); void SetApplyOptim(bool value) { apply_optim_ = value; } @@ -1024,13 +1028,15 @@ struct PD_INFER_DECL AnalysisConfig { mutable std::string prog_file_; mutable std::string params_file_; - // Mixed precision. + // Mixed precision related. + Precision mixed_precision_mode_{Precision::kFloat32}; std::unordered_set mixed_black_list_; // GPU related. bool use_gpu_{false}; int gpu_device_id_{0}; uint64_t memory_pool_init_size_mb_{100}; // initial size is 100MB. + bool enable_gpu_half_{false}; bool thread_local_stream_{false}; bool use_cudnn_{false}; diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index ce55ba81200c62cf743e0b44a45f92adab40c030..4ac91231121d13f76c0f9f32c5a2913964ee5b34 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -246,9 +246,10 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { "conv_elementwise_add_fuse_pass", // #endif // "transpose_flatten_concat_fuse_pass", // - "constant_folding_pass", + "constant_folding_pass", // // following pass should be located in the last, since it will // work on all fused ops. + "float_to_half_pass", // "runtime_context_cache_pass" }); diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index 7a281ddfcdf6a000f28db9d76fe0e2cc5f6693ed..7398b9c2c01361ada531505e71563fad42172ab4 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -416,6 +416,9 @@ download_result(${ERNIE_INSTALL_DIR} "Ernie_result.txt.tar.gz" if(WITH_GPU) inference_analysis_api_test(test_analyzer_ernie ${ERNIE_INSTALL_DIR} analyzer_ernie_tester.cc) + inference_analysis_api_test(gpu_ernie_half_test ${ERNIE_INSTALL_DIR} + gpu_ernie_half_test.cc) + set_tests_properties(gpu_ernie_half_test PROPERTIES TIMEOUT 40) endif() inference_analysis_api_int8_test(test_analyzer_ernie_int8 ${ERNIE_INSTALL_DIR} analyzer_ernie_int8_tester.cc) diff --git a/paddle/fluid/inference/tests/api/gpu_ernie_half_test.cc b/paddle/fluid/inference/tests/api/gpu_ernie_half_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..6354ee47a18f689a41a3c28d542cb8861c2fc1a0 --- /dev/null +++ b/paddle/fluid/inference/tests/api/gpu_ernie_half_test.cc @@ -0,0 +1,294 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/api/paddle_inference_api.h" +#include "paddle/fluid/inference/tests/api/tester_helper.h" + +namespace paddle { +namespace inference { + +using paddle::PaddleTensor; + +template +void GetValueFromStream(std::stringstream *ss, T *t) { + (*ss) >> (*t); +} + +template <> +void GetValueFromStream(std::stringstream *ss, std::string *t) { + *t = ss->str(); +} + +// Split string to vector +template +void Split(const std::string &line, char sep, std::vector *v) { + std::stringstream ss; + T t; + for (auto c : line) { + if (c != sep) { + ss << c; + } else { + GetValueFromStream(&ss, &t); + v->push_back(std::move(t)); + ss.str({}); + ss.clear(); + } + } + + if (!ss.str().empty()) { + GetValueFromStream(&ss, &t); + v->push_back(std::move(t)); + ss.str({}); + ss.clear(); + } +} + +// Parse tensor from string +template +bool ParseTensor(const std::string &field, paddle::PaddleTensor *tensor) { + std::vector data; + Split(field, ':', &data); + if (data.size() < 2) return false; + + std::string shape_str = data[0]; + + std::vector shape; + Split(shape_str, ' ', &shape); + + std::string mat_str = data[1]; + + std::vector mat; + Split(mat_str, ' ', &mat); + + tensor->shape = shape; + auto size = + std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()) * + sizeof(T); + tensor->data.Resize(size); + std::copy(mat.begin(), mat.end(), static_cast(tensor->data.data())); + tensor->dtype = GetPaddleDType(); + + return true; +} + +// Parse input tensors from string +bool ParseLine(const std::string &line, + std::vector *tensors) { + std::vector fields; + Split(line, ';', &fields); + + tensors->clear(); + tensors->reserve(4); + + int i = 0; + auto input_name = FLAGS_ernie_large ? "eval_placeholder_" : "placeholder_"; + for (; i < 3; i++) { + paddle::PaddleTensor temp; + ParseTensor(fields[i], &temp); + temp.name = input_name + std::to_string(i); + tensors->push_back(temp); + } + + // input_mask + paddle::PaddleTensor input_mask; + ParseTensor(fields[i], &input_mask); + input_mask.name = input_name + std::to_string(i); + tensors->push_back(input_mask); + + return true; +} + +bool LoadInputData(std::vector> *inputs, + int batch_size = 1) { + if (FLAGS_infer_data.empty()) { + LOG(ERROR) << "please set input data path"; + return false; + } + + std::ifstream fin(FLAGS_infer_data); + std::string line; + int sample = 0; + + // The unit-test dataset only have 10 samples, each sample have 5 feeds. + while (std::getline(fin, line)) { + std::vector feed_data; + ParseLine(line, &feed_data); + inputs->push_back(std::move(feed_data)); + sample++; + if (!FLAGS_test_all_data && sample == batch_size) break; + } + LOG(INFO) << "number of samples: " << sample; + return true; +} + +// Compare results +TEST(Ernie_gpu_fp16_no_ir, compare_results) { + AnalysisConfig config; + config.SetModel(FLAGS_infer_model); + config.EnableUseGpu(512, 0, paddle_infer::PrecisionType::kHalf); + config.SwitchIrOptim(false); + + auto predictor = CreatePaddlePredictor(config); + + std::vector> input_slots_all; + LoadInputData(&input_slots_all); + + std::ifstream fin(FLAGS_refer_result); + std::string line; + std::vector ref; + + while (std::getline(fin, line)) { + Split(line, ' ', &ref); + } + + std::vector outputs; + for (size_t i = 0; i < input_slots_all.size(); i++) { + outputs.clear(); + predictor->Run(input_slots_all[i], &outputs); + + auto output = outputs.front(); + size_t outputs_size = 1; + for (auto dim : output.shape) { + outputs_size *= dim; + } + float *result = reinterpret_cast(output.data.data()); + for (size_t j = 0; j < outputs_size; ++j) { + EXPECT_NEAR(ref[i * outputs_size + j], result[j], 5e-2); + } + } +} + +// Compare results +TEST(Ernie_gpu_fp16_with_ir, compare_results) { + AnalysisConfig config; + config.SetModel(FLAGS_infer_model); + config.EnableUseGpu(512, 0, paddle_infer::PrecisionType::kHalf); + config.SwitchIrOptim(true); + // The fc_fuse_pass has diff, which will be repaired later. + config.pass_builder()->DeletePass("fc_fuse_pass"); + // There is a problem with the model itself, which has nothing to do with + // constant_folding_pass. + config.pass_builder()->DeletePass("constant_folding_pass"); + + auto predictor = CreatePaddlePredictor(config); + + std::vector> input_slots_all; + LoadInputData(&input_slots_all); + + std::ifstream fin(FLAGS_refer_result); + std::string line; + std::vector ref; + + while (std::getline(fin, line)) { + Split(line, ' ', &ref); + } + + std::vector outputs; + for (size_t i = 0; i < input_slots_all.size(); i++) { + outputs.clear(); + predictor->Run(input_slots_all[i], &outputs); + + auto output = outputs.front(); + size_t outputs_size = 1; + for (auto dim : output.shape) { + outputs_size *= dim; + } + float *result = reinterpret_cast(output.data.data()); + for (size_t j = 0; j < outputs_size; ++j) { + EXPECT_NEAR(ref[i * outputs_size + j], result[j], 5e-2); + } + } +} + +// Compare results +TEST(Ernie_gpu_bf16_no_ir, compare_results) { + AnalysisConfig config; + config.SetModel(FLAGS_infer_model); + config.EnableUseGpu(512, 0, paddle_infer::PrecisionType::kBf16); + config.SwitchIrOptim(false); + + auto predictor = CreatePaddlePredictor(config); + + std::vector> input_slots_all; + LoadInputData(&input_slots_all); + + std::ifstream fin(FLAGS_refer_result); + std::string line; + std::vector ref; + + while (std::getline(fin, line)) { + Split(line, ' ', &ref); + } + + std::vector outputs; + for (size_t i = 0; i < input_slots_all.size(); i++) { + outputs.clear(); + predictor->Run(input_slots_all[i], &outputs); + + auto output = outputs.front(); + size_t outputs_size = 1; + for (auto dim : output.shape) { + outputs_size *= dim; + } + float *result = reinterpret_cast(output.data.data()); + for (size_t j = 0; j < outputs_size; ++j) { + EXPECT_NEAR(ref[i * outputs_size + j], result[j], 7e-2); + } + } +} + +// Compare results +TEST(Ernie_gpu_bf16_with_ir, compare_results) { + AnalysisConfig config; + config.SetModel(FLAGS_infer_model); + config.EnableUseGpu(512, 0, paddle_infer::PrecisionType::kBf16); + config.SwitchIrOptim(true); + // The fc_fuse_pass has diff, which will be repaired later. + config.pass_builder()->DeletePass("fc_fuse_pass"); + // There is a problem with the model itself, which has nothing to do with + // constant_folding_pass. + config.pass_builder()->DeletePass("constant_folding_pass"); + + auto predictor = CreatePaddlePredictor(config); + + std::vector> input_slots_all; + LoadInputData(&input_slots_all); + + std::ifstream fin(FLAGS_refer_result); + std::string line; + std::vector ref; + + while (std::getline(fin, line)) { + Split(line, ' ', &ref); + } + + std::vector outputs; + for (size_t i = 0; i < input_slots_all.size(); i++) { + outputs.clear(); + predictor->Run(input_slots_all[i], &outputs); + + auto output = outputs.front(); + size_t outputs_size = 1; + for (auto dim : output.shape) { + outputs_size *= dim; + } + float *result = reinterpret_cast(output.data.data()); + for (size_t j = 0; j < outputs_size; ++j) { + EXPECT_NEAR(ref[i * outputs_size + j], result[j], 7e-2); + } + } +} + +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tests/api/paddle_infer_api_test.cc b/paddle/fluid/inference/tests/api/paddle_infer_api_test.cc index 8cff649b97092aea6850a33f3535e468b3ca8886..9029cefc9a424f201055d8f824c3ba1c42ee45e4 100644 --- a/paddle/fluid/inference/tests/api/paddle_infer_api_test.cc +++ b/paddle/fluid/inference/tests/api/paddle_infer_api_test.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,15 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include #include #include -#include -#include - #include "gflags/gflags.h" -#include "paddle/fluid/inference/tests/api/trt_test_helper.h" +#include "paddle/fluid/inference/tests/api/tester_helper.h" namespace paddle_infer { diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index 1524c1f29d67b8188727eea30ef727498d18de15..9a791e4f2e36243931216b409e03d83de8e26865 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -644,7 +644,8 @@ void BindAnalysisConfig(py::module *m) { .def("enable_use_gpu", &AnalysisConfig::EnableUseGpu, py::arg("memory_pool_init_size_mb"), - py::arg("device_id") = 0) + py::arg("device_id") = 0, + py::arg("precision_mode") = AnalysisConfig::Precision::kFloat32) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) .def("set_exec_stream", [](AnalysisConfig &self, phi::CUDAStream &stream) {