未验证 提交 c5a45cc6 编写于 作者: Y Yuanle Liu 提交者: GitHub

[Paddle Inference] Add float_to_half_pass to support inference with mixed precision (#47993)

上级 54b756e2
...@@ -104,6 +104,7 @@ pass_library(delete_c_identity_op_pass inference) ...@@ -104,6 +104,7 @@ pass_library(delete_c_identity_op_pass inference)
pass_library(preln_residual_bias_fuse_pass inference) pass_library(preln_residual_bias_fuse_pass inference)
pass_library(delete_fill_constant_op_pass inference) pass_library(delete_fill_constant_op_pass inference)
pass_library(constant_folding_pass inference) pass_library(constant_folding_pass inference)
pass_library(float_to_half_pass inference)
pass_library(conv2d_fusion_layout_transfer_pass inference) pass_library(conv2d_fusion_layout_transfer_pass inference)
pass_library(simplify_with_basic_ops_pass base) pass_library(simplify_with_basic_ops_pass base)
pass_library(fc_elementwise_layernorm_fuse_pass base) pass_library(fc_elementwise_layernorm_fuse_pass base)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/float_to_half_pass.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/common/data_type.h"
namespace paddle {
namespace framework {
namespace ir {
namespace {
using VarType = FloatToHalfPass::VarType;
bool PhiKernelSupportPrecision(
const std::string& op_type,
phi::Backend backend,
phi::DataType data_type,
phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) {
const auto& kernels = phi::KernelFactory::Instance().kernels();
if (kernels.count(op_type) == 0) {
return false;
}
phi::KernelKey kernel_key(backend, layout, data_type);
return phi::KernelFactory::Instance().HasKernel(op_type, kernel_key);
}
bool GpuKernelSupportPrecision(
const std::string& op_type,
phi::DataType precision,
phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) {
auto phi_op_type = phi::TransToPhiKernelName(op_type);
bool support = PhiKernelSupportPrecision(
phi_op_type, phi::Backend::GPU, precision, layout);
support |= PhiKernelSupportPrecision(
phi_op_type, phi::Backend::GPUDNN, precision, layout);
if (!support) {
const auto& all_kernels = framework::OperatorWithKernel::AllOpKernels();
auto it = all_kernels.find(op_type);
if (it != all_kernels.end()) {
for (const auto& kern_pair : it->second) {
if (platform::is_gpu_place(kern_pair.first.place_) &&
kern_pair.first.data_type_ ==
framework::TransToProtoVarType(precision)) {
support = true;
break;
}
}
}
}
return support;
}
void DoInsertCastOp(Graph* graph,
Node* var_node,
Node* op_node,
VarType::Type from_type,
VarType::Type to_type,
framework::BlockDesc* block_desc,
int* suffix,
std::unordered_map<Node*, Node*>* cache) {
if (from_type == to_type) return;
auto update_cast_desc = [&](framework::OpDesc& desc,
const std::string& x_name,
const std::string& out_name,
const int in_dtype,
const int out_dtype) {
desc.SetType("cast");
desc.SetInput("X", {x_name});
desc.SetOutput("Out", {out_name});
desc.SetAttr("in_dtype", in_dtype);
desc.SetAttr("out_dtype", out_dtype);
desc.SetAttr("use_mkldnn", false);
desc.SetAttr("with_quant_attr", false);
desc.Flush();
};
if (cache->count(var_node) == 0) {
// insert cast op between var_node and op_node
std::string cast_input_name = var_node->Var()->Name();
std::string cast_output_name =
var_node->Var()->Name() + "_cast.tmp_" + std::to_string((*suffix)++);
framework::OpDesc cast_op_desc(block_desc);
update_cast_desc(cast_op_desc,
cast_input_name,
cast_output_name,
static_cast<int>(from_type),
static_cast<int>(to_type));
auto* cast_op_node = graph->CreateOpNode(&cast_op_desc);
auto* cast_output_vardesc = block_desc->Var(cast_output_name);
cast_output_vardesc->SetPersistable(false);
cast_output_vardesc->SetDataType(to_type);
cast_output_vardesc->SetShape(var_node->Var()->GetShape());
auto* cast_output_node = graph->CreateVarNode(cast_output_vardesc);
IR_NODE_LINK_TO(cast_op_node, cast_output_node);
(*cache)[var_node] = cast_output_node;
}
op_node->Op()->Rename(var_node->Name(), cache->at(var_node)->Name());
IR_NODE_LINK_TO(var_node, cache->at(var_node)->inputs[0]);
IR_NODE_LINK_TO(cache->at(var_node), op_node);
IR_NODE_UNLINK(var_node, op_node);
}
inline bool VarNodeHasDtype(Node* var_node) {
auto type = var_node->Var()->GetType();
return (type == VarType::SELECTED_ROWS) || (type == VarType::LOD_TENSOR) ||
(type == VarType::LOD_TENSOR_ARRAY) || (type == VarType::STRINGS) ||
(type == VarType::VOCAB);
}
inline bool IsFloatType(VarType::Type type) {
return (type == VarType::FP64) || (type == VarType::FP32);
}
inline bool IsHalfType(VarType::Type type) {
return (type == VarType::FP16) || (type == VarType::BF16);
}
}; // namespace
// The set of ops that support fp16 calculation and are considered
// numerically-dangerous, slower and whose effects may also be observed in
// downstream ops.
void FloatToHalfPass::SetDefaultBlacklist() const {
black_list_.insert({
// numerically-dangerous
"acos",
"asin",
"cosh",
"tan",
"exp",
"expm1",
"square",
"log",
"log2",
"log10",
"log1p",
"logsumexp",
"mean",
"rsqrt",
"sum",
"cos_sim",
"softmax",
"softmax_with_cross_entropy",
"sigmoid_cross_entropy_with_logits",
"c_softmax_with_cross_entropy",
"cross_entropy",
"cross_entropy2",
// slower than fp32
"conv2d_transpose",
// default fp32 can avoid return inf when the sum value large than 65504
"reduce_sum",
});
}
void FloatToHalfPass::Init(Graph* graph) const {
keep_io_types_ = true;
half_precision_ =
static_cast<phi::DataType>(Get<int>("mixed_precision_mode"));
black_list_ = Get<std::unordered_set<std::string>>("mixed_black_list");
SetDefaultBlacklist();
auto graph_size = graph->SubGraphsSize();
VLOG(4) << "graph size: " << graph_size;
subgraphes_.resize(graph_size);
all_op_nodes_.resize(graph_size);
for (size_t i = 0; i < graph_size; i++) {
subgraphes_[i] = graph->GetSubGraph(i);
all_op_nodes_[i] = TopologySortOperations(*subgraphes_[i]);
VLOG(4) << "subgraph " << i << " has " << all_op_nodes_[i].size()
<< "op nodes";
for (auto* var_node : subgraphes_[i]->Nodes()) {
if (!var_node->IsVar()) continue;
auto var_name = var_node->Var()->Name();
if (real_vars_.count(var_name) == 0) {
real_vars_[var_name] = var_node;
VLOG(4) << var_name << " is in graph " << i;
}
}
}
}
void FloatToHalfPass::ApplyImpl(Graph* graph) const {
auto enable_gpu_half = Get<bool>("enable_gpu_half");
if (!enable_gpu_half) return;
PADDLE_ENFORCE_NOT_NULL(
graph,
platform::errors::PreconditionNotMet(
"During the float to half pass, the graph should not be nullptr."));
PADDLE_ENFORCE_EQ(
graph->IsMainGraph(),
true,
platform::errors::PreconditionNotMet(
"During the float to half pass, the graph should be main graph."));
FusePassBase::Init("float_to_half", graph);
Init(graph);
VLOG(4) << "Init done";
SetOpUniqueType();
VLOG(4) << "SetOpUniqueType done";
GetOpPrecision();
VLOG(4) << "GetOpPrecision done";
UpdateOpPrecision();
VLOG(4) << "UpdateOpPrecision done";
SetVarPrecision();
VLOG(4) << "SetVarPrecision done";
ConvertWeightsData();
VLOG(4) << "ConvertWeightsData done";
ProcessOpWithDtypeAttr();
VLOG(4) << "ProcessOpWithDtypeAttr done";
InsertCastOp();
VLOG(4) << "InsertCastOp done";
RestoreOpOriginType();
VLOG(4) << "RestoreOpOriginType done";
}
bool FloatToHalfPass::OpSupportPrecision(const std::string& op_type,
phi::DataType precision,
phi::Backend backend) const {
bool support = false;
if (black_list_.count(op_type) == 0) {
if (backend == phi::Backend::GPU) {
support = GpuKernelSupportPrecision(op_type, precision);
}
}
return support;
}
void FloatToHalfPass::SetOpUniqueType() const {
int suffix = 0;
for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) {
auto op_type = op_node->Op()->Type();
if (op_type == "feed" || op_type == "fetch") continue;
std::string unique_type = op_type + "_" + std::to_string(suffix++);
op_original_type_[unique_type] = op_type;
op_node->Op()->SetType(unique_type);
op_node->Op()->Flush();
VLOG(4) << "change op type: " << op_type << " ---> " << unique_type;
}
}
}
void FloatToHalfPass::RestoreOpOriginType() const {
for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) {
auto op_type = op_node->Op()->Type();
op_node->Op()->SetType(GetOpOriginalType(op_type));
op_node->Op()->Flush();
VLOG(4) << "restore op type: " << op_type << " ---> "
<< op_node->Op()->Type();
}
}
}
inline std::string FloatToHalfPass::GetOpOriginalType(
const std::string& op_type) const {
if (op_original_type_.count(op_type)) {
return op_original_type_.at(op_type);
}
return op_type;
}
void FloatToHalfPass::ProcessOpWithDtypeAttr() const {
for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) {
auto op_type = op_node->Op()->Type();
if (op_run_half_.count(op_type) == 0) continue;
if (op_node->Op()->HasAttr("dtype")) {
auto dtype = op_node->Op()->GetAttrIfExists<int>("dtype");
if (IsFloatType(static_cast<VarType::Type>(dtype))) {
op_node->Op()->SetAttr(
"dtype",
static_cast<int>(
framework::TransToProtoVarType(half_precision_)));
op_node->Op()->Flush();
VLOG(4) << "process op with dtype attr: " << op_type << " ( " << dtype
<< " --->" << static_cast<int>(half_precision_) << " )";
}
}
if (op_node->Op()->HasAttr("out_dtype")) {
auto out_dtype = op_node->Op()->GetAttrIfExists<int>("out_dtype");
if (IsFloatType(static_cast<VarType::Type>(out_dtype))) {
op_node->Op()->SetAttr(
"out_dtype",
static_cast<int>(
framework::TransToProtoVarType(half_precision_)));
op_node->Op()->Flush();
VLOG(4) << "process op with out_dtype attr: " << op_type << " ( "
<< out_dtype << " --->" << static_cast<int>(half_precision_)
<< " )";
}
}
}
}
}
void FloatToHalfPass::GetOpPrecision() const {
for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) {
auto op_type = op_node->Op()->Type();
bool support_half = true;
if (GetOpOriginalType(op_type) == "feed" ||
GetOpOriginalType(op_type) == "fetch") {
support_half = !keep_io_types_;
} else {
support_half =
OpSupportPrecision(GetOpOriginalType(op_type), half_precision_);
}
if (op_node->Op()->HasAttr("dtype")) {
auto dtype = op_node->Op()->GetAttrIfExists<int>("dtype");
support_half =
support_half && IsFloatType(static_cast<VarType::Type>(dtype));
} else if (op_node->Op()->HasAttr("out_dtype")) {
auto out_dtype = op_node->Op()->GetAttrIfExists<int>("out_dtype");
support_half =
support_half && IsFloatType(static_cast<VarType::Type>(out_dtype));
} else {
// if op's input var and output var is not dense tensor, the op should
// not run half.
for (auto* in_var_node : op_node->inputs) {
CHECK_EQ(in_var_node->IsVar(), true);
auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
if (real_in_var_node->Var()->Persistable()) continue;
support_half = support_half && (real_in_var_node->Var()->GetType() ==
VarType::LOD_TENSOR);
}
for (auto* out_var_node : op_node->outputs) {
CHECK_EQ(out_var_node->IsVar(), true);
auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()];
if (real_out_var_node->Var()->Persistable()) continue;
support_half = support_half && (real_out_var_node->Var()->GetType() ==
VarType::LOD_TENSOR);
}
}
if (support_half) {
op_run_half_.insert(op_type);
VLOG(4) << "support precision: " << op_type << " run at half";
} else {
VLOG(4) << "support precision: " << op_type << " not run at half";
}
}
}
}
void FloatToHalfPass::UpdateOpPrecision() const {
std::unordered_set<std::string> vars_should_not_half;
// var -> the var's all input op
std::unordered_map<std::string, std::vector<Node*>> var_input_ops;
auto GetVarInputOps = [&] {
for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) {
auto op_type = op_node->Op()->Type();
if (GetOpOriginalType(op_type) == "fetch") continue;
if (op_node->Op()->HasAttr("sub_block")) continue;
for (auto* var_node : op_node->outputs) {
CHECK_EQ(var_node->IsVar(), true);
if (var_node->Var()->Persistable()) continue;
if (!VarNodeHasDtype(var_node)) continue;
var_input_ops[var_node->Var()->Name()].push_back(op_node);
VLOG(4) << "var input ops: " << var_node->Var()->Name()
<< " is output of " << op_type;
}
// the select_input op's input var should not convert to half. when
// op's output var is select_input op's input var, the op should not run
// half.
if (GetOpOriginalType(op_node->Op()->Type()) == "select_input") {
for (auto* in_var_node : op_node->inputs) {
CHECK_EQ(in_var_node->IsVar(), true);
if (in_var_node->Var()->Persistable()) continue;
if (!VarNodeHasDtype(in_var_node)) continue;
vars_should_not_half.insert(in_var_node->Var()->Name());
}
}
// when op_1 only support cpu kernel. if op_2's intput var is op_1's
// output var, then op_2 should not run half.
if (GetOpOriginalType(op_type) != "feed" &&
!GpuKernelSupportPrecision(GetOpOriginalType(op_type),
phi::DataType::FLOAT32)) {
for (auto* out_var_node : op_node->outputs) {
CHECK_EQ(out_var_node->IsVar(), true);
if (out_var_node->Var()->Persistable()) continue;
if (!VarNodeHasDtype(out_var_node)) continue;
vars_should_not_half.insert(out_var_node->Var()->Name());
}
}
}
}
};
GetVarInputOps();
bool precision_updated = false;
do {
precision_updated = false;
for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) {
if (op_run_half_.count(op_node->Op()->Type()) == 0) continue;
for (auto* in_var_node : op_node->inputs) {
CHECK_EQ(in_var_node->IsVar(), true);
if (!VarNodeHasDtype(in_var_node)) continue;
auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
if (real_in_var_node->Var()->Persistable()) continue;
if (vars_should_not_half.count(real_in_var_node->Var()->Name())) {
op_run_half_.erase(op_node->Op()->Type());
precision_updated = true;
VLOG(4) << op_node->Op()->Type()
<< " should not support half precision.";
break;
}
}
if (op_run_half_.count(op_node->Op()->Type()) == 0) continue;
for (auto* out_var_node : op_node->outputs) {
CHECK_EQ(out_var_node->IsVar(), true);
if (!VarNodeHasDtype(out_var_node)) continue;
auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()];
if (real_out_var_node->Var()->Persistable()) continue;
bool not_run_half = false;
const auto& input_op_nodes =
var_input_ops[real_out_var_node->Var()->Name()];
if (vars_should_not_half.count(real_out_var_node->Var()->Name())) {
not_run_half = true;
} else {
for (auto* node : input_op_nodes) {
if (op_run_half_.count(node->Op()->Type()) == 0) {
not_run_half = true;
break;
}
}
}
if (not_run_half) {
op_run_half_.erase(op_node->Op()->Type());
precision_updated = true;
VLOG(4) << op_node->Op()->Type()
<< " should not support half precision.";
break;
}
}
}
}
} while (precision_updated);
}
// special ops, its weights should not be low precision.
bool FloatToHalfPass::InputVarsNotConvert(Node* op_node,
const std::string& var_name) const {
auto* op_desc = op_node->Op();
if (GetOpOriginalType(op_desc->Type()) == "batch_norm") {
auto vecs = op_desc->Input("Bias");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Input("Mean");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Input("Scale");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Input("Variance");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
} else if (GetOpOriginalType(op_desc->Type()) == "fused_multi_transformer") {
auto vecs = op_desc->Input("LnScale");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Input("LnBias");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Input("FFNLnScale");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Input("FFNLnBias");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
}
return false;
}
bool FloatToHalfPass::OutputVarsNotConvert(Node* op_node,
const std::string& var_name) const {
auto* op_desc = op_node->Op();
// batch_norm's input and output (variance and mean) are the same.
if (GetOpOriginalType(op_desc->Type()) == "batch_norm") {
auto vecs = op_desc->Output("MeanOut");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Output("VarianceOut");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Output("SavedMean");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Output("SavedVariance");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
}
return false;
}
void FloatToHalfPass::SetVarPrecision() const {
for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) {
if (op_run_half_.count(op_node->Op()->Type())) {
for (auto* in_var_node : op_node->inputs) {
CHECK_EQ(in_var_node->IsVar(), true);
auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
auto in_var_name = real_in_var_node->Var()->Name();
if (!IsFloatType(real_in_var_node->Var()->GetDataType())) continue;
if (!VarNodeHasDtype(real_in_var_node)) continue;
if (InputVarsNotConvert(op_node, in_var_name)) continue;
if (real_in_var_node->Var()->Persistable()) {
real_in_var_node->Var()->SetDataType(
framework::TransToProtoVarType(half_precision_));
vars_convert_to_half_.insert(in_var_name);
}
}
for (auto* out_var_node : op_node->outputs) {
CHECK_EQ(out_var_node->IsVar(), true);
auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()];
auto out_var_name = real_out_var_node->Var()->Name();
if (!IsFloatType(real_out_var_node->Var()->GetDataType())) continue;
if (!VarNodeHasDtype(real_out_var_node)) continue;
if (OutputVarsNotConvert(op_node, out_var_name)) continue;
real_out_var_node->Var()->SetDataType(
framework::TransToProtoVarType(half_precision_));
if (real_out_var_node->Var()->Persistable()) {
vars_convert_to_half_.insert(out_var_name);
}
}
}
}
}
// This code used to precess vars with the same name. Vars with the same
// name should have the same data type.
for (auto* subgraph : subgraphes_) {
for (auto* var_node : subgraph->Nodes()) {
if (!var_node->IsVar() || !var_node->Var()->Persistable()) continue;
if (!VarNodeHasDtype(var_node)) continue;
auto var_name = var_node->Var()->Name();
if (vars_convert_to_half_.count(var_name)) {
var_node->Var()->SetDataType(
framework::TransToProtoVarType(half_precision_));
}
}
}
}
void FloatToHalfPass::ConvertWeightsData() const {
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope,
platform::errors::PreconditionNotMet(
"During the float to half pass, the scope should not be null."));
auto var_names = scope->LocalVarNames();
for (const auto& var_name : var_names) {
if (vars_convert_to_half_.count(var_name)) {
VLOG(4) << var_name << "'s data type was convert to half";
#define CONVERT_TENSOR_DTYPE(DTYPE, dtype) \
half_tensor.set_type(DTYPE); \
auto* half_data = half_tensor.mutable_data<dtype>(platform::CPUPlace()); \
for (int64_t i = 0; i < origin_tensor->numel(); i++) { \
half_data[i] = static_cast<dtype>(origin_data[i]); \
} \
origin_tensor->clear(); \
paddle::framework::TensorCopySync( \
half_tensor, platform::CPUPlace(), origin_tensor)
auto* var = scope->FindLocalVar(var_name);
if (var->IsType<phi::DenseTensor>()) {
auto* origin_tensor = var->GetMutable<phi::DenseTensor>();
phi::DenseTensor half_tensor;
half_tensor.Resize(origin_tensor->dims());
auto* origin_data =
origin_tensor->mutable_data<float>(platform::CPUPlace());
if (half_precision_ == phi::DataType::FLOAT16) {
CONVERT_TENSOR_DTYPE(paddle::experimental::DataType::FLOAT16,
phi::dtype::float16);
} else if (half_precision_ == phi::DataType::BFLOAT16) {
CONVERT_TENSOR_DTYPE(paddle::experimental::DataType::BFLOAT16,
phi::dtype::bfloat16);
}
}
}
#undef CONVERT_TENSOR_DTYPE
}
}
void FloatToHalfPass::InsertCastOp() const {
int suffix = 0;
std::unordered_map<Node*, Node*> cache;
for (size_t i = 0; i < all_op_nodes_.size(); i++) {
auto* block_desc = all_op_nodes_[i][0]->Op()->Block();
CHECK_NOTNULL(block_desc);
for (auto* op_node : all_op_nodes_[i]) {
auto op_type = op_node->Op()->Type();
if (GetOpOriginalType(op_type) == "feed") continue;
if (op_node->Op()->HasAttr("sub_block")) continue;
VLOG(4) << "process op: " << op_type
<< " run half: " << op_run_half_.count(op_type);
auto inputs = op_node->inputs;
for (auto* in_var_node : inputs) {
if (!in_var_node->IsVar()) continue;
if (!VarNodeHasDtype(in_var_node)) continue;
if (in_var_node->Var()->Persistable()) continue;
auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
auto in_var_type = real_in_var_node->Var()->GetDataType();
VLOG(4) << "process var: " << real_in_var_node->Var()->Name()
<< " with type " << in_var_type;
if (IsFloatType(in_var_type) && op_run_half_.count(op_type)) {
DoInsertCastOp(subgraphes_[i],
in_var_node,
op_node,
in_var_type,
framework::TransToProtoVarType(half_precision_),
block_desc,
&suffix,
&cache);
} else if (IsHalfType(in_var_type) &&
op_run_half_.count(op_type) == 0) {
DoInsertCastOp(subgraphes_[i],
in_var_node,
op_node,
in_var_type,
VarType::FP32,
block_desc,
&suffix,
&cache);
}
}
// Special op.
// fused_multi_transformer's input(CacheKV) and output(CacheKVOut) vars
// have same name.
if (GetOpOriginalType(op_type) == "fused_multi_transformer") {
auto cache_kv_inputs = op_node->Op()->Input("CacheKV");
auto cache_kv_outputs = op_node->Op()->Output("CacheKVOut");
CHECK_EQ(cache_kv_inputs.size(), cache_kv_outputs.size());
for (size_t i = 0; i < cache_kv_inputs.size(); ++i) {
op_node->Op()->RenameOutput(cache_kv_outputs[i], cache_kv_inputs[i]);
}
}
}
}
VLOG(4) << "insert number of cast op: " << cache.size();
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(float_to_half_pass, paddle::framework::ir::FloatToHalfPass);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_map>
#include <unordered_set>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/phi/common/backend.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/place.h"
namespace paddle {
namespace framework {
namespace ir {
class FloatToHalfPass : public FusePassBase {
public:
using VarType = framework::proto::VarType;
public:
FloatToHalfPass() = default;
~FloatToHalfPass() = default;
protected:
void ApplyImpl(Graph* graph) const override;
private:
void Init(Graph* graph) const;
void SetDefaultBlacklist() const;
bool OpSupportPrecision(const std::string& op_type,
phi::DataType precision,
phi::Backend backend = phi::Backend::GPU) const;
void SetOpUniqueType() const;
void RestoreOpOriginType() const;
inline std::string GetOpOriginalType(const std::string& op_type) const;
void GetOpPrecision() const;
void UpdateOpPrecision() const;
void InsertCastOp() const;
void ProcessOpWithDtypeAttr() const;
bool InputVarsNotConvert(Node* op_node, const std::string& var_name) const;
bool OutputVarsNotConvert(Node* op_node, const std::string& var_name) const;
void SetVarPrecision() const;
void ConvertWeightsData() const;
private:
mutable bool keep_io_types_;
// float16 or bfloat16 now
mutable phi::DataType half_precision_;
mutable std::unordered_set<std::string> black_list_;
// subgraph id -> pointer to subgraph
mutable std::vector<Graph*> subgraphes_;
// var name -> real var node
mutable std::unordered_map<std::string, Node*> real_vars_;
// subgraph id -> all op nodes in subgraph
mutable std::vector<std::vector<Node*>> all_op_nodes_;
// op's unique type -> the op's origin type
mutable std::unordered_map<std::string, std::string> op_original_type_;
// op's unique type -> whether the op run at half precision
mutable std::unordered_set<std::string> op_run_half_;
mutable std::unordered_set<std::string> vars_convert_to_half_;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -365,6 +365,8 @@ struct Argument { ...@@ -365,6 +365,8 @@ struct Argument {
DECL_ARGUMENT_FIELD(mixed_black_list, DECL_ARGUMENT_FIELD(mixed_black_list,
MixedBlackList, MixedBlackList,
std::unordered_set<std::string>); std::unordered_set<std::string>);
DECL_ARGUMENT_FIELD(enable_gpu_half, EnableGPUHalf, bool);
DECL_ARGUMENT_FIELD(mixed_precision_mode, MixedPrecisionMode, int);
private: private:
std::unordered_set<std::string> valid_fields_; std::unordered_set<std::string> valid_fields_;
......
...@@ -86,10 +86,14 @@ void IRPassManager::CreatePasses(Argument *argument, ...@@ -86,10 +86,14 @@ void IRPassManager::CreatePasses(Argument *argument,
argument->tensorrt_tuned_dynamic_shape(); argument->tensorrt_tuned_dynamic_shape();
pass->Set("with_dynamic_shape", new bool(with_dynamic_shape)); pass->Set("with_dynamic_shape", new bool(with_dynamic_shape));
// mixed precision related
pass->Set("model_precision", new int(argument->model_precision())); pass->Set("model_precision", new int(argument->model_precision()));
pass->Set( pass->Set(
"mixed_black_list", "mixed_black_list",
new std::unordered_set<std::string>(argument->mixed_black_list())); new std::unordered_set<std::string>(argument->mixed_black_list()));
pass->Set("enable_gpu_half", new bool(argument->enable_gpu_half()));
pass->Set("mixed_precision_mode",
new int(argument->mixed_precision_mode()));
if (pass_name == "graph_viz_pass") { if (pass_name == "graph_viz_pass") {
std::string optim_cache_dir = argument->optim_cache_dir(); std::string optim_cache_dir = argument->optim_cache_dir();
......
...@@ -85,16 +85,29 @@ void AnalysisConfig::SetModel(const std::string &prog_file_path, ...@@ -85,16 +85,29 @@ void AnalysisConfig::SetModel(const std::string &prog_file_path,
Update(); Update();
} }
void AnalysisConfig::EnableUseGpu(uint64_t memory_pool_init_size_mb, void AnalysisConfig::EnableUseGpu(uint64_t memory_pool_init_size_mb,
int device_id) { int device_id,
Precision precision_mode) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
use_gpu_ = true; use_gpu_ = true;
memory_pool_init_size_mb_ = memory_pool_init_size_mb; memory_pool_init_size_mb_ = memory_pool_init_size_mb;
FLAGS_initial_gpu_memory_in_mb = memory_pool_init_size_mb_; FLAGS_initial_gpu_memory_in_mb = memory_pool_init_size_mb_;
gpu_device_id_ = device_id; gpu_device_id_ = device_id;
mixed_precision_mode_ = precision_mode;
if (precision_mode == Precision::kFloat32) {
// default
} else if (precision_mode == Precision::kHalf ||
precision_mode == Precision::kBf16) {
enable_gpu_half_ = true;
} else {
LOG(ERROR)
<< "The Paddle-GPU inference currently only supports "
"float32/float16/bfloat16 precision. Please check the parameters "
"you specified in EnableUseGpu or enable_use_gpu function.";
}
#else #else
LOG(ERROR) << "Please compile with gpu to EnableGpu()"; LOG(ERROR) << "Please use PaddlePaddle with GPU version.";
use_gpu_ = false;
#endif #endif
Update(); Update();
...@@ -381,8 +394,10 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { ...@@ -381,8 +394,10 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
CP_MEMBER(gpu_device_id_); CP_MEMBER(gpu_device_id_);
CP_MEMBER(memory_pool_init_size_mb_); CP_MEMBER(memory_pool_init_size_mb_);
// Mixed related. // Mixed precision related.
CP_MEMBER(mixed_black_list_); CP_MEMBER(mixed_black_list_);
CP_MEMBER(enable_gpu_half_);
CP_MEMBER(mixed_precision_mode_);
CP_MEMBER(enable_memory_optim_); CP_MEMBER(enable_memory_optim_);
// TensorRT related. // TensorRT related.
...@@ -996,6 +1011,7 @@ std::string AnalysisConfig::SerializeInfoCache() { ...@@ -996,6 +1011,7 @@ std::string AnalysisConfig::SerializeInfoCache() {
ss << params_file_; ss << params_file_;
ss << use_gpu_; ss << use_gpu_;
ss << enable_gpu_half_;
ss << use_external_stream_; ss << use_external_stream_;
ss << exec_stream_; ss << exec_stream_;
ss << use_fc_padding_; ss << use_fc_padding_;
...@@ -1212,6 +1228,7 @@ std::string AnalysisConfig::Summary() { ...@@ -1212,6 +1228,7 @@ std::string AnalysisConfig::Summary() {
os.InsertRow({"use_gpu", use_gpu_ ? "true" : "false"}); os.InsertRow({"use_gpu", use_gpu_ ? "true" : "false"});
if (use_gpu_) { if (use_gpu_) {
os.InsertRow({"gpu_device_id", std::to_string(gpu_device_id_)}); os.InsertRow({"gpu_device_id", std::to_string(gpu_device_id_)});
os.InsertRow({"enable_gpu_half_", std::to_string(enable_gpu_half_)});
os.InsertRow({"memory_pool_init_size", os.InsertRow({"memory_pool_init_size",
std::to_string(memory_pool_init_size_mb_) + "MB"}); std::to_string(memory_pool_init_size_mb_) + "MB"});
os.InsertRow( os.InsertRow(
...@@ -1407,7 +1424,7 @@ bool AnalysisConfig::trt_allow_build_at_runtime() const { ...@@ -1407,7 +1424,7 @@ bool AnalysisConfig::trt_allow_build_at_runtime() const {
return trt_allow_build_at_runtime_; return trt_allow_build_at_runtime_;
} }
void AnalysisConfig::Exp_SetBlackListOpsForMixedModel( void AnalysisConfig::Exp_DisableMixedInferOps(
const std::unordered_set<std::string> &black_list) { const std::unordered_set<std::string> &black_list) {
mixed_black_list_ = black_list; mixed_black_list_ = black_list;
} }
......
...@@ -1257,12 +1257,26 @@ void AnalysisPredictor::PrepareArgument() { ...@@ -1257,12 +1257,26 @@ void AnalysisPredictor::PrepareArgument() {
} }
} }
} }
if (config_.ir_debug_) {
pass_builder->TurnOnDebug();
}
if (!config_.ir_optim()) { if (!config_.ir_optim()) {
argument_.SetEnableIrOptim(false); argument_.SetEnableIrOptim(false);
LOG(INFO) << "ir_optim is turned off, no IR pass will be executed"; if (config_.enable_gpu_half_) {
argument_.SetEnableIrOptim(true);
pass_builder->ClearPasses();
pass_builder->AppendPass("float_to_half_pass");
LOG(INFO)
<< "This model run in Paddle-GPU mixed precision mode with no ir "
"optimization.";
} else {
LOG(INFO) << "ir_optim is turned off, no IR pass will be executed.";
}
} else {
if (config_.ir_debug_) {
pass_builder->TurnOnDebug();
}
if (config_.enable_gpu_half_) {
LOG(INFO) << "This model run in Paddle-GPU mixed precision mode.";
}
} }
argument_.SetDisableLogs(config_.glog_info_disabled()); argument_.SetDisableLogs(config_.glog_info_disabled());
argument_.SetIrAnalysisPasses(pass_builder->AllPasses()); argument_.SetIrAnalysisPasses(pass_builder->AllPasses());
...@@ -1272,6 +1286,9 @@ void AnalysisPredictor::PrepareArgument() { ...@@ -1272,6 +1286,9 @@ void AnalysisPredictor::PrepareArgument() {
// mixed precison. // mixed precison.
argument_.SetModelPrecision(static_cast<int>(model_precision_)); argument_.SetModelPrecision(static_cast<int>(model_precision_));
argument_.SetMixedBlackList(config_.mixed_black_list_); argument_.SetMixedBlackList(config_.mixed_black_list_);
argument_.SetEnableGPUHalf(config_.enable_gpu_half_);
argument_.SetMixedPrecisionMode(static_cast<int>(
paddle::ConvertPrecision(config_.mixed_precision_mode_)));
} }
// NOTE All the members in AnalysisConfig should be copied to Argument. // NOTE All the members in AnalysisConfig should be copied to Argument.
......
...@@ -247,8 +247,12 @@ struct PD_INFER_DECL AnalysisConfig { ...@@ -247,8 +247,12 @@ struct PD_INFER_DECL AnalysisConfig {
/// ///
/// \param memory_pool_init_size_mb initial size of the GPU memory pool in MB. /// \param memory_pool_init_size_mb initial size of the GPU memory pool in MB.
/// \param device_id device_id the GPU card to use (default is 0). /// \param device_id device_id the GPU card to use (default is 0).
/// \param precision the precision used in Paddle-GPU inference.
/// ///
void EnableUseGpu(uint64_t memory_pool_init_size_mb, int device_id = 0); void EnableUseGpu(uint64_t memory_pool_init_size_mb,
int device_id = 0,
Precision precision_mode = Precision::kFloat32);
/// ///
/// \brief Turn off GPU. /// \brief Turn off GPU.
/// ///
...@@ -1005,7 +1009,7 @@ struct PD_INFER_DECL AnalysisConfig { ...@@ -1005,7 +1009,7 @@ struct PD_INFER_DECL AnalysisConfig {
/// interface is in the experimental stage and may change in the future. Note /// interface is in the experimental stage and may change in the future. Note
/// that the blacklist must be the same as the model conversion blacklist. /// that the blacklist must be the same as the model conversion blacklist.
/// ///
void Exp_SetBlackListOpsForMixedModel( void Exp_DisableMixedInferOps(
const std::unordered_set<std::string>& black_list); const std::unordered_set<std::string>& black_list);
void SetApplyOptim(bool value) { apply_optim_ = value; } void SetApplyOptim(bool value) { apply_optim_ = value; }
...@@ -1024,13 +1028,15 @@ struct PD_INFER_DECL AnalysisConfig { ...@@ -1024,13 +1028,15 @@ struct PD_INFER_DECL AnalysisConfig {
mutable std::string prog_file_; mutable std::string prog_file_;
mutable std::string params_file_; mutable std::string params_file_;
// Mixed precision. // Mixed precision related.
Precision mixed_precision_mode_{Precision::kFloat32};
std::unordered_set<std::string> mixed_black_list_; std::unordered_set<std::string> mixed_black_list_;
// GPU related. // GPU related.
bool use_gpu_{false}; bool use_gpu_{false};
int gpu_device_id_{0}; int gpu_device_id_{0};
uint64_t memory_pool_init_size_mb_{100}; // initial size is 100MB. uint64_t memory_pool_init_size_mb_{100}; // initial size is 100MB.
bool enable_gpu_half_{false};
bool thread_local_stream_{false}; bool thread_local_stream_{false};
bool use_cudnn_{false}; bool use_cudnn_{false};
......
...@@ -246,9 +246,10 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { ...@@ -246,9 +246,10 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"conv_elementwise_add_fuse_pass", // "conv_elementwise_add_fuse_pass", //
#endif // #endif //
"transpose_flatten_concat_fuse_pass", // "transpose_flatten_concat_fuse_pass", //
"constant_folding_pass", "constant_folding_pass", //
// following pass should be located in the last, since it will // following pass should be located in the last, since it will
// work on all fused ops. // work on all fused ops.
"float_to_half_pass", //
"runtime_context_cache_pass" "runtime_context_cache_pass"
}); });
......
...@@ -416,6 +416,9 @@ download_result(${ERNIE_INSTALL_DIR} "Ernie_result.txt.tar.gz" ...@@ -416,6 +416,9 @@ download_result(${ERNIE_INSTALL_DIR} "Ernie_result.txt.tar.gz"
if(WITH_GPU) if(WITH_GPU)
inference_analysis_api_test(test_analyzer_ernie ${ERNIE_INSTALL_DIR} inference_analysis_api_test(test_analyzer_ernie ${ERNIE_INSTALL_DIR}
analyzer_ernie_tester.cc) analyzer_ernie_tester.cc)
inference_analysis_api_test(gpu_ernie_half_test ${ERNIE_INSTALL_DIR}
gpu_ernie_half_test.cc)
set_tests_properties(gpu_ernie_half_test PROPERTIES TIMEOUT 40)
endif() endif()
inference_analysis_api_int8_test(test_analyzer_ernie_int8 ${ERNIE_INSTALL_DIR} inference_analysis_api_int8_test(test_analyzer_ernie_int8 ${ERNIE_INSTALL_DIR}
analyzer_ernie_int8_tester.cc) analyzer_ernie_int8_tester.cc)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/tests/api/tester_helper.h"
namespace paddle {
namespace inference {
using paddle::PaddleTensor;
template <typename T>
void GetValueFromStream(std::stringstream *ss, T *t) {
(*ss) >> (*t);
}
template <>
void GetValueFromStream<std::string>(std::stringstream *ss, std::string *t) {
*t = ss->str();
}
// Split string to vector
template <typename T>
void Split(const std::string &line, char sep, std::vector<T> *v) {
std::stringstream ss;
T t;
for (auto c : line) {
if (c != sep) {
ss << c;
} else {
GetValueFromStream<T>(&ss, &t);
v->push_back(std::move(t));
ss.str({});
ss.clear();
}
}
if (!ss.str().empty()) {
GetValueFromStream<T>(&ss, &t);
v->push_back(std::move(t));
ss.str({});
ss.clear();
}
}
// Parse tensor from string
template <typename T>
bool ParseTensor(const std::string &field, paddle::PaddleTensor *tensor) {
std::vector<std::string> data;
Split(field, ':', &data);
if (data.size() < 2) return false;
std::string shape_str = data[0];
std::vector<int> shape;
Split(shape_str, ' ', &shape);
std::string mat_str = data[1];
std::vector<T> mat;
Split(mat_str, ' ', &mat);
tensor->shape = shape;
auto size =
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()) *
sizeof(T);
tensor->data.Resize(size);
std::copy(mat.begin(), mat.end(), static_cast<T *>(tensor->data.data()));
tensor->dtype = GetPaddleDType<T>();
return true;
}
// Parse input tensors from string
bool ParseLine(const std::string &line,
std::vector<paddle::PaddleTensor> *tensors) {
std::vector<std::string> fields;
Split(line, ';', &fields);
tensors->clear();
tensors->reserve(4);
int i = 0;
auto input_name = FLAGS_ernie_large ? "eval_placeholder_" : "placeholder_";
for (; i < 3; i++) {
paddle::PaddleTensor temp;
ParseTensor<int64_t>(fields[i], &temp);
temp.name = input_name + std::to_string(i);
tensors->push_back(temp);
}
// input_mask
paddle::PaddleTensor input_mask;
ParseTensor<float>(fields[i], &input_mask);
input_mask.name = input_name + std::to_string(i);
tensors->push_back(input_mask);
return true;
}
bool LoadInputData(std::vector<std::vector<paddle::PaddleTensor>> *inputs,
int batch_size = 1) {
if (FLAGS_infer_data.empty()) {
LOG(ERROR) << "please set input data path";
return false;
}
std::ifstream fin(FLAGS_infer_data);
std::string line;
int sample = 0;
// The unit-test dataset only have 10 samples, each sample have 5 feeds.
while (std::getline(fin, line)) {
std::vector<paddle::PaddleTensor> feed_data;
ParseLine(line, &feed_data);
inputs->push_back(std::move(feed_data));
sample++;
if (!FLAGS_test_all_data && sample == batch_size) break;
}
LOG(INFO) << "number of samples: " << sample;
return true;
}
// Compare results
TEST(Ernie_gpu_fp16_no_ir, compare_results) {
AnalysisConfig config;
config.SetModel(FLAGS_infer_model);
config.EnableUseGpu(512, 0, paddle_infer::PrecisionType::kHalf);
config.SwitchIrOptim(false);
auto predictor = CreatePaddlePredictor(config);
std::vector<std::vector<PaddleTensor>> input_slots_all;
LoadInputData(&input_slots_all);
std::ifstream fin(FLAGS_refer_result);
std::string line;
std::vector<float> ref;
while (std::getline(fin, line)) {
Split(line, ' ', &ref);
}
std::vector<PaddleTensor> outputs;
for (size_t i = 0; i < input_slots_all.size(); i++) {
outputs.clear();
predictor->Run(input_slots_all[i], &outputs);
auto output = outputs.front();
size_t outputs_size = 1;
for (auto dim : output.shape) {
outputs_size *= dim;
}
float *result = reinterpret_cast<float *>(output.data.data());
for (size_t j = 0; j < outputs_size; ++j) {
EXPECT_NEAR(ref[i * outputs_size + j], result[j], 5e-2);
}
}
}
// Compare results
TEST(Ernie_gpu_fp16_with_ir, compare_results) {
AnalysisConfig config;
config.SetModel(FLAGS_infer_model);
config.EnableUseGpu(512, 0, paddle_infer::PrecisionType::kHalf);
config.SwitchIrOptim(true);
// The fc_fuse_pass has diff, which will be repaired later.
config.pass_builder()->DeletePass("fc_fuse_pass");
// There is a problem with the model itself, which has nothing to do with
// constant_folding_pass.
config.pass_builder()->DeletePass("constant_folding_pass");
auto predictor = CreatePaddlePredictor(config);
std::vector<std::vector<PaddleTensor>> input_slots_all;
LoadInputData(&input_slots_all);
std::ifstream fin(FLAGS_refer_result);
std::string line;
std::vector<float> ref;
while (std::getline(fin, line)) {
Split(line, ' ', &ref);
}
std::vector<PaddleTensor> outputs;
for (size_t i = 0; i < input_slots_all.size(); i++) {
outputs.clear();
predictor->Run(input_slots_all[i], &outputs);
auto output = outputs.front();
size_t outputs_size = 1;
for (auto dim : output.shape) {
outputs_size *= dim;
}
float *result = reinterpret_cast<float *>(output.data.data());
for (size_t j = 0; j < outputs_size; ++j) {
EXPECT_NEAR(ref[i * outputs_size + j], result[j], 5e-2);
}
}
}
// Compare results
TEST(Ernie_gpu_bf16_no_ir, compare_results) {
AnalysisConfig config;
config.SetModel(FLAGS_infer_model);
config.EnableUseGpu(512, 0, paddle_infer::PrecisionType::kBf16);
config.SwitchIrOptim(false);
auto predictor = CreatePaddlePredictor(config);
std::vector<std::vector<PaddleTensor>> input_slots_all;
LoadInputData(&input_slots_all);
std::ifstream fin(FLAGS_refer_result);
std::string line;
std::vector<float> ref;
while (std::getline(fin, line)) {
Split(line, ' ', &ref);
}
std::vector<PaddleTensor> outputs;
for (size_t i = 0; i < input_slots_all.size(); i++) {
outputs.clear();
predictor->Run(input_slots_all[i], &outputs);
auto output = outputs.front();
size_t outputs_size = 1;
for (auto dim : output.shape) {
outputs_size *= dim;
}
float *result = reinterpret_cast<float *>(output.data.data());
for (size_t j = 0; j < outputs_size; ++j) {
EXPECT_NEAR(ref[i * outputs_size + j], result[j], 7e-2);
}
}
}
// Compare results
TEST(Ernie_gpu_bf16_with_ir, compare_results) {
AnalysisConfig config;
config.SetModel(FLAGS_infer_model);
config.EnableUseGpu(512, 0, paddle_infer::PrecisionType::kBf16);
config.SwitchIrOptim(true);
// The fc_fuse_pass has diff, which will be repaired later.
config.pass_builder()->DeletePass("fc_fuse_pass");
// There is a problem with the model itself, which has nothing to do with
// constant_folding_pass.
config.pass_builder()->DeletePass("constant_folding_pass");
auto predictor = CreatePaddlePredictor(config);
std::vector<std::vector<PaddleTensor>> input_slots_all;
LoadInputData(&input_slots_all);
std::ifstream fin(FLAGS_refer_result);
std::string line;
std::vector<float> ref;
while (std::getline(fin, line)) {
Split(line, ' ', &ref);
}
std::vector<PaddleTensor> outputs;
for (size_t i = 0; i < input_slots_all.size(); i++) {
outputs.clear();
predictor->Run(input_slots_all[i], &outputs);
auto output = outputs.front();
size_t outputs_size = 1;
for (auto dim : output.shape) {
outputs_size *= dim;
}
float *result = reinterpret_cast<float *>(output.data.data());
for (size_t j = 0; j < outputs_size; ++j) {
EXPECT_NEAR(ref[i * outputs_size + j], result[j], 7e-2);
}
}
}
} // namespace inference
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -12,15 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,15 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <cuda_runtime.h>
#include <glog/logging.h> #include <glog/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <cstring>
#include <numeric>
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "paddle/fluid/inference/tests/api/trt_test_helper.h" #include "paddle/fluid/inference/tests/api/tester_helper.h"
namespace paddle_infer { namespace paddle_infer {
......
...@@ -644,7 +644,8 @@ void BindAnalysisConfig(py::module *m) { ...@@ -644,7 +644,8 @@ void BindAnalysisConfig(py::module *m) {
.def("enable_use_gpu", .def("enable_use_gpu",
&AnalysisConfig::EnableUseGpu, &AnalysisConfig::EnableUseGpu,
py::arg("memory_pool_init_size_mb"), py::arg("memory_pool_init_size_mb"),
py::arg("device_id") = 0) py::arg("device_id") = 0,
py::arg("precision_mode") = AnalysisConfig::Precision::kFloat32)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
.def("set_exec_stream", .def("set_exec_stream",
[](AnalysisConfig &self, phi::CUDAStream &stream) { [](AnalysisConfig &self, phi::CUDAStream &stream) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册