未验证 提交 c5a45cc6 编写于 作者: Y Yuanle Liu 提交者: GitHub

[Paddle Inference] Add float_to_half_pass to support inference with mixed precision (#47993)

上级 54b756e2
......@@ -104,6 +104,7 @@ pass_library(delete_c_identity_op_pass inference)
pass_library(preln_residual_bias_fuse_pass inference)
pass_library(delete_fill_constant_op_pass inference)
pass_library(constant_folding_pass inference)
pass_library(float_to_half_pass inference)
pass_library(conv2d_fusion_layout_transfer_pass inference)
pass_library(simplify_with_basic_ops_pass base)
pass_library(fc_elementwise_layernorm_fuse_pass base)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/float_to_half_pass.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/common/data_type.h"
namespace paddle {
namespace framework {
namespace ir {
namespace {
using VarType = FloatToHalfPass::VarType;
bool PhiKernelSupportPrecision(
const std::string& op_type,
phi::Backend backend,
phi::DataType data_type,
phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) {
const auto& kernels = phi::KernelFactory::Instance().kernels();
if (kernels.count(op_type) == 0) {
return false;
}
phi::KernelKey kernel_key(backend, layout, data_type);
return phi::KernelFactory::Instance().HasKernel(op_type, kernel_key);
}
bool GpuKernelSupportPrecision(
const std::string& op_type,
phi::DataType precision,
phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) {
auto phi_op_type = phi::TransToPhiKernelName(op_type);
bool support = PhiKernelSupportPrecision(
phi_op_type, phi::Backend::GPU, precision, layout);
support |= PhiKernelSupportPrecision(
phi_op_type, phi::Backend::GPUDNN, precision, layout);
if (!support) {
const auto& all_kernels = framework::OperatorWithKernel::AllOpKernels();
auto it = all_kernels.find(op_type);
if (it != all_kernels.end()) {
for (const auto& kern_pair : it->second) {
if (platform::is_gpu_place(kern_pair.first.place_) &&
kern_pair.first.data_type_ ==
framework::TransToProtoVarType(precision)) {
support = true;
break;
}
}
}
}
return support;
}
void DoInsertCastOp(Graph* graph,
Node* var_node,
Node* op_node,
VarType::Type from_type,
VarType::Type to_type,
framework::BlockDesc* block_desc,
int* suffix,
std::unordered_map<Node*, Node*>* cache) {
if (from_type == to_type) return;
auto update_cast_desc = [&](framework::OpDesc& desc,
const std::string& x_name,
const std::string& out_name,
const int in_dtype,
const int out_dtype) {
desc.SetType("cast");
desc.SetInput("X", {x_name});
desc.SetOutput("Out", {out_name});
desc.SetAttr("in_dtype", in_dtype);
desc.SetAttr("out_dtype", out_dtype);
desc.SetAttr("use_mkldnn", false);
desc.SetAttr("with_quant_attr", false);
desc.Flush();
};
if (cache->count(var_node) == 0) {
// insert cast op between var_node and op_node
std::string cast_input_name = var_node->Var()->Name();
std::string cast_output_name =
var_node->Var()->Name() + "_cast.tmp_" + std::to_string((*suffix)++);
framework::OpDesc cast_op_desc(block_desc);
update_cast_desc(cast_op_desc,
cast_input_name,
cast_output_name,
static_cast<int>(from_type),
static_cast<int>(to_type));
auto* cast_op_node = graph->CreateOpNode(&cast_op_desc);
auto* cast_output_vardesc = block_desc->Var(cast_output_name);
cast_output_vardesc->SetPersistable(false);
cast_output_vardesc->SetDataType(to_type);
cast_output_vardesc->SetShape(var_node->Var()->GetShape());
auto* cast_output_node = graph->CreateVarNode(cast_output_vardesc);
IR_NODE_LINK_TO(cast_op_node, cast_output_node);
(*cache)[var_node] = cast_output_node;
}
op_node->Op()->Rename(var_node->Name(), cache->at(var_node)->Name());
IR_NODE_LINK_TO(var_node, cache->at(var_node)->inputs[0]);
IR_NODE_LINK_TO(cache->at(var_node), op_node);
IR_NODE_UNLINK(var_node, op_node);
}
inline bool VarNodeHasDtype(Node* var_node) {
auto type = var_node->Var()->GetType();
return (type == VarType::SELECTED_ROWS) || (type == VarType::LOD_TENSOR) ||
(type == VarType::LOD_TENSOR_ARRAY) || (type == VarType::STRINGS) ||
(type == VarType::VOCAB);
}
inline bool IsFloatType(VarType::Type type) {
return (type == VarType::FP64) || (type == VarType::FP32);
}
inline bool IsHalfType(VarType::Type type) {
return (type == VarType::FP16) || (type == VarType::BF16);
}
}; // namespace
// The set of ops that support fp16 calculation and are considered
// numerically-dangerous, slower and whose effects may also be observed in
// downstream ops.
void FloatToHalfPass::SetDefaultBlacklist() const {
black_list_.insert({
// numerically-dangerous
"acos",
"asin",
"cosh",
"tan",
"exp",
"expm1",
"square",
"log",
"log2",
"log10",
"log1p",
"logsumexp",
"mean",
"rsqrt",
"sum",
"cos_sim",
"softmax",
"softmax_with_cross_entropy",
"sigmoid_cross_entropy_with_logits",
"c_softmax_with_cross_entropy",
"cross_entropy",
"cross_entropy2",
// slower than fp32
"conv2d_transpose",
// default fp32 can avoid return inf when the sum value large than 65504
"reduce_sum",
});
}
void FloatToHalfPass::Init(Graph* graph) const {
keep_io_types_ = true;
half_precision_ =
static_cast<phi::DataType>(Get<int>("mixed_precision_mode"));
black_list_ = Get<std::unordered_set<std::string>>("mixed_black_list");
SetDefaultBlacklist();
auto graph_size = graph->SubGraphsSize();
VLOG(4) << "graph size: " << graph_size;
subgraphes_.resize(graph_size);
all_op_nodes_.resize(graph_size);
for (size_t i = 0; i < graph_size; i++) {
subgraphes_[i] = graph->GetSubGraph(i);
all_op_nodes_[i] = TopologySortOperations(*subgraphes_[i]);
VLOG(4) << "subgraph " << i << " has " << all_op_nodes_[i].size()
<< "op nodes";
for (auto* var_node : subgraphes_[i]->Nodes()) {
if (!var_node->IsVar()) continue;
auto var_name = var_node->Var()->Name();
if (real_vars_.count(var_name) == 0) {
real_vars_[var_name] = var_node;
VLOG(4) << var_name << " is in graph " << i;
}
}
}
}
void FloatToHalfPass::ApplyImpl(Graph* graph) const {
auto enable_gpu_half = Get<bool>("enable_gpu_half");
if (!enable_gpu_half) return;
PADDLE_ENFORCE_NOT_NULL(
graph,
platform::errors::PreconditionNotMet(
"During the float to half pass, the graph should not be nullptr."));
PADDLE_ENFORCE_EQ(
graph->IsMainGraph(),
true,
platform::errors::PreconditionNotMet(
"During the float to half pass, the graph should be main graph."));
FusePassBase::Init("float_to_half", graph);
Init(graph);
VLOG(4) << "Init done";
SetOpUniqueType();
VLOG(4) << "SetOpUniqueType done";
GetOpPrecision();
VLOG(4) << "GetOpPrecision done";
UpdateOpPrecision();
VLOG(4) << "UpdateOpPrecision done";
SetVarPrecision();
VLOG(4) << "SetVarPrecision done";
ConvertWeightsData();
VLOG(4) << "ConvertWeightsData done";
ProcessOpWithDtypeAttr();
VLOG(4) << "ProcessOpWithDtypeAttr done";
InsertCastOp();
VLOG(4) << "InsertCastOp done";
RestoreOpOriginType();
VLOG(4) << "RestoreOpOriginType done";
}
bool FloatToHalfPass::OpSupportPrecision(const std::string& op_type,
phi::DataType precision,
phi::Backend backend) const {
bool support = false;
if (black_list_.count(op_type) == 0) {
if (backend == phi::Backend::GPU) {
support = GpuKernelSupportPrecision(op_type, precision);
}
}
return support;
}
void FloatToHalfPass::SetOpUniqueType() const {
int suffix = 0;
for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) {
auto op_type = op_node->Op()->Type();
if (op_type == "feed" || op_type == "fetch") continue;
std::string unique_type = op_type + "_" + std::to_string(suffix++);
op_original_type_[unique_type] = op_type;
op_node->Op()->SetType(unique_type);
op_node->Op()->Flush();
VLOG(4) << "change op type: " << op_type << " ---> " << unique_type;
}
}
}
void FloatToHalfPass::RestoreOpOriginType() const {
for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) {
auto op_type = op_node->Op()->Type();
op_node->Op()->SetType(GetOpOriginalType(op_type));
op_node->Op()->Flush();
VLOG(4) << "restore op type: " << op_type << " ---> "
<< op_node->Op()->Type();
}
}
}
inline std::string FloatToHalfPass::GetOpOriginalType(
const std::string& op_type) const {
if (op_original_type_.count(op_type)) {
return op_original_type_.at(op_type);
}
return op_type;
}
void FloatToHalfPass::ProcessOpWithDtypeAttr() const {
for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) {
auto op_type = op_node->Op()->Type();
if (op_run_half_.count(op_type) == 0) continue;
if (op_node->Op()->HasAttr("dtype")) {
auto dtype = op_node->Op()->GetAttrIfExists<int>("dtype");
if (IsFloatType(static_cast<VarType::Type>(dtype))) {
op_node->Op()->SetAttr(
"dtype",
static_cast<int>(
framework::TransToProtoVarType(half_precision_)));
op_node->Op()->Flush();
VLOG(4) << "process op with dtype attr: " << op_type << " ( " << dtype
<< " --->" << static_cast<int>(half_precision_) << " )";
}
}
if (op_node->Op()->HasAttr("out_dtype")) {
auto out_dtype = op_node->Op()->GetAttrIfExists<int>("out_dtype");
if (IsFloatType(static_cast<VarType::Type>(out_dtype))) {
op_node->Op()->SetAttr(
"out_dtype",
static_cast<int>(
framework::TransToProtoVarType(half_precision_)));
op_node->Op()->Flush();
VLOG(4) << "process op with out_dtype attr: " << op_type << " ( "
<< out_dtype << " --->" << static_cast<int>(half_precision_)
<< " )";
}
}
}
}
}
void FloatToHalfPass::GetOpPrecision() const {
for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) {
auto op_type = op_node->Op()->Type();
bool support_half = true;
if (GetOpOriginalType(op_type) == "feed" ||
GetOpOriginalType(op_type) == "fetch") {
support_half = !keep_io_types_;
} else {
support_half =
OpSupportPrecision(GetOpOriginalType(op_type), half_precision_);
}
if (op_node->Op()->HasAttr("dtype")) {
auto dtype = op_node->Op()->GetAttrIfExists<int>("dtype");
support_half =
support_half && IsFloatType(static_cast<VarType::Type>(dtype));
} else if (op_node->Op()->HasAttr("out_dtype")) {
auto out_dtype = op_node->Op()->GetAttrIfExists<int>("out_dtype");
support_half =
support_half && IsFloatType(static_cast<VarType::Type>(out_dtype));
} else {
// if op's input var and output var is not dense tensor, the op should
// not run half.
for (auto* in_var_node : op_node->inputs) {
CHECK_EQ(in_var_node->IsVar(), true);
auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
if (real_in_var_node->Var()->Persistable()) continue;
support_half = support_half && (real_in_var_node->Var()->GetType() ==
VarType::LOD_TENSOR);
}
for (auto* out_var_node : op_node->outputs) {
CHECK_EQ(out_var_node->IsVar(), true);
auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()];
if (real_out_var_node->Var()->Persistable()) continue;
support_half = support_half && (real_out_var_node->Var()->GetType() ==
VarType::LOD_TENSOR);
}
}
if (support_half) {
op_run_half_.insert(op_type);
VLOG(4) << "support precision: " << op_type << " run at half";
} else {
VLOG(4) << "support precision: " << op_type << " not run at half";
}
}
}
}
void FloatToHalfPass::UpdateOpPrecision() const {
std::unordered_set<std::string> vars_should_not_half;
// var -> the var's all input op
std::unordered_map<std::string, std::vector<Node*>> var_input_ops;
auto GetVarInputOps = [&] {
for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) {
auto op_type = op_node->Op()->Type();
if (GetOpOriginalType(op_type) == "fetch") continue;
if (op_node->Op()->HasAttr("sub_block")) continue;
for (auto* var_node : op_node->outputs) {
CHECK_EQ(var_node->IsVar(), true);
if (var_node->Var()->Persistable()) continue;
if (!VarNodeHasDtype(var_node)) continue;
var_input_ops[var_node->Var()->Name()].push_back(op_node);
VLOG(4) << "var input ops: " << var_node->Var()->Name()
<< " is output of " << op_type;
}
// the select_input op's input var should not convert to half. when
// op's output var is select_input op's input var, the op should not run
// half.
if (GetOpOriginalType(op_node->Op()->Type()) == "select_input") {
for (auto* in_var_node : op_node->inputs) {
CHECK_EQ(in_var_node->IsVar(), true);
if (in_var_node->Var()->Persistable()) continue;
if (!VarNodeHasDtype(in_var_node)) continue;
vars_should_not_half.insert(in_var_node->Var()->Name());
}
}
// when op_1 only support cpu kernel. if op_2's intput var is op_1's
// output var, then op_2 should not run half.
if (GetOpOriginalType(op_type) != "feed" &&
!GpuKernelSupportPrecision(GetOpOriginalType(op_type),
phi::DataType::FLOAT32)) {
for (auto* out_var_node : op_node->outputs) {
CHECK_EQ(out_var_node->IsVar(), true);
if (out_var_node->Var()->Persistable()) continue;
if (!VarNodeHasDtype(out_var_node)) continue;
vars_should_not_half.insert(out_var_node->Var()->Name());
}
}
}
}
};
GetVarInputOps();
bool precision_updated = false;
do {
precision_updated = false;
for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) {
if (op_run_half_.count(op_node->Op()->Type()) == 0) continue;
for (auto* in_var_node : op_node->inputs) {
CHECK_EQ(in_var_node->IsVar(), true);
if (!VarNodeHasDtype(in_var_node)) continue;
auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
if (real_in_var_node->Var()->Persistable()) continue;
if (vars_should_not_half.count(real_in_var_node->Var()->Name())) {
op_run_half_.erase(op_node->Op()->Type());
precision_updated = true;
VLOG(4) << op_node->Op()->Type()
<< " should not support half precision.";
break;
}
}
if (op_run_half_.count(op_node->Op()->Type()) == 0) continue;
for (auto* out_var_node : op_node->outputs) {
CHECK_EQ(out_var_node->IsVar(), true);
if (!VarNodeHasDtype(out_var_node)) continue;
auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()];
if (real_out_var_node->Var()->Persistable()) continue;
bool not_run_half = false;
const auto& input_op_nodes =
var_input_ops[real_out_var_node->Var()->Name()];
if (vars_should_not_half.count(real_out_var_node->Var()->Name())) {
not_run_half = true;
} else {
for (auto* node : input_op_nodes) {
if (op_run_half_.count(node->Op()->Type()) == 0) {
not_run_half = true;
break;
}
}
}
if (not_run_half) {
op_run_half_.erase(op_node->Op()->Type());
precision_updated = true;
VLOG(4) << op_node->Op()->Type()
<< " should not support half precision.";
break;
}
}
}
}
} while (precision_updated);
}
// special ops, its weights should not be low precision.
bool FloatToHalfPass::InputVarsNotConvert(Node* op_node,
const std::string& var_name) const {
auto* op_desc = op_node->Op();
if (GetOpOriginalType(op_desc->Type()) == "batch_norm") {
auto vecs = op_desc->Input("Bias");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Input("Mean");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Input("Scale");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Input("Variance");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
} else if (GetOpOriginalType(op_desc->Type()) == "fused_multi_transformer") {
auto vecs = op_desc->Input("LnScale");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Input("LnBias");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Input("FFNLnScale");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Input("FFNLnBias");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
}
return false;
}
bool FloatToHalfPass::OutputVarsNotConvert(Node* op_node,
const std::string& var_name) const {
auto* op_desc = op_node->Op();
// batch_norm's input and output (variance and mean) are the same.
if (GetOpOriginalType(op_desc->Type()) == "batch_norm") {
auto vecs = op_desc->Output("MeanOut");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Output("VarianceOut");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Output("SavedMean");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Output("SavedVariance");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
}
return false;
}
void FloatToHalfPass::SetVarPrecision() const {
for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) {
if (op_run_half_.count(op_node->Op()->Type())) {
for (auto* in_var_node : op_node->inputs) {
CHECK_EQ(in_var_node->IsVar(), true);
auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
auto in_var_name = real_in_var_node->Var()->Name();
if (!IsFloatType(real_in_var_node->Var()->GetDataType())) continue;
if (!VarNodeHasDtype(real_in_var_node)) continue;
if (InputVarsNotConvert(op_node, in_var_name)) continue;
if (real_in_var_node->Var()->Persistable()) {
real_in_var_node->Var()->SetDataType(
framework::TransToProtoVarType(half_precision_));
vars_convert_to_half_.insert(in_var_name);
}
}
for (auto* out_var_node : op_node->outputs) {
CHECK_EQ(out_var_node->IsVar(), true);
auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()];
auto out_var_name = real_out_var_node->Var()->Name();
if (!IsFloatType(real_out_var_node->Var()->GetDataType())) continue;
if (!VarNodeHasDtype(real_out_var_node)) continue;
if (OutputVarsNotConvert(op_node, out_var_name)) continue;
real_out_var_node->Var()->SetDataType(
framework::TransToProtoVarType(half_precision_));
if (real_out_var_node->Var()->Persistable()) {
vars_convert_to_half_.insert(out_var_name);
}
}
}
}
}
// This code used to precess vars with the same name. Vars with the same
// name should have the same data type.
for (auto* subgraph : subgraphes_) {
for (auto* var_node : subgraph->Nodes()) {
if (!var_node->IsVar() || !var_node->Var()->Persistable()) continue;
if (!VarNodeHasDtype(var_node)) continue;
auto var_name = var_node->Var()->Name();
if (vars_convert_to_half_.count(var_name)) {
var_node->Var()->SetDataType(
framework::TransToProtoVarType(half_precision_));
}
}
}
}
void FloatToHalfPass::ConvertWeightsData() const {
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope,
platform::errors::PreconditionNotMet(
"During the float to half pass, the scope should not be null."));
auto var_names = scope->LocalVarNames();
for (const auto& var_name : var_names) {
if (vars_convert_to_half_.count(var_name)) {
VLOG(4) << var_name << "'s data type was convert to half";
#define CONVERT_TENSOR_DTYPE(DTYPE, dtype) \
half_tensor.set_type(DTYPE); \
auto* half_data = half_tensor.mutable_data<dtype>(platform::CPUPlace()); \
for (int64_t i = 0; i < origin_tensor->numel(); i++) { \
half_data[i] = static_cast<dtype>(origin_data[i]); \
} \
origin_tensor->clear(); \
paddle::framework::TensorCopySync( \
half_tensor, platform::CPUPlace(), origin_tensor)
auto* var = scope->FindLocalVar(var_name);
if (var->IsType<phi::DenseTensor>()) {
auto* origin_tensor = var->GetMutable<phi::DenseTensor>();
phi::DenseTensor half_tensor;
half_tensor.Resize(origin_tensor->dims());
auto* origin_data =
origin_tensor->mutable_data<float>(platform::CPUPlace());
if (half_precision_ == phi::DataType::FLOAT16) {
CONVERT_TENSOR_DTYPE(paddle::experimental::DataType::FLOAT16,
phi::dtype::float16);
} else if (half_precision_ == phi::DataType::BFLOAT16) {
CONVERT_TENSOR_DTYPE(paddle::experimental::DataType::BFLOAT16,
phi::dtype::bfloat16);
}
}
}
#undef CONVERT_TENSOR_DTYPE
}
}
void FloatToHalfPass::InsertCastOp() const {
int suffix = 0;
std::unordered_map<Node*, Node*> cache;
for (size_t i = 0; i < all_op_nodes_.size(); i++) {
auto* block_desc = all_op_nodes_[i][0]->Op()->Block();
CHECK_NOTNULL(block_desc);
for (auto* op_node : all_op_nodes_[i]) {
auto op_type = op_node->Op()->Type();
if (GetOpOriginalType(op_type) == "feed") continue;
if (op_node->Op()->HasAttr("sub_block")) continue;
VLOG(4) << "process op: " << op_type
<< " run half: " << op_run_half_.count(op_type);
auto inputs = op_node->inputs;
for (auto* in_var_node : inputs) {
if (!in_var_node->IsVar()) continue;
if (!VarNodeHasDtype(in_var_node)) continue;
if (in_var_node->Var()->Persistable()) continue;
auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
auto in_var_type = real_in_var_node->Var()->GetDataType();
VLOG(4) << "process var: " << real_in_var_node->Var()->Name()
<< " with type " << in_var_type;
if (IsFloatType(in_var_type) && op_run_half_.count(op_type)) {
DoInsertCastOp(subgraphes_[i],
in_var_node,
op_node,
in_var_type,
framework::TransToProtoVarType(half_precision_),
block_desc,
&suffix,
&cache);
} else if (IsHalfType(in_var_type) &&
op_run_half_.count(op_type) == 0) {
DoInsertCastOp(subgraphes_[i],
in_var_node,
op_node,
in_var_type,
VarType::FP32,
block_desc,
&suffix,
&cache);
}
}
// Special op.
// fused_multi_transformer's input(CacheKV) and output(CacheKVOut) vars
// have same name.
if (GetOpOriginalType(op_type) == "fused_multi_transformer") {
auto cache_kv_inputs = op_node->Op()->Input("CacheKV");
auto cache_kv_outputs = op_node->Op()->Output("CacheKVOut");
CHECK_EQ(cache_kv_inputs.size(), cache_kv_outputs.size());
for (size_t i = 0; i < cache_kv_inputs.size(); ++i) {
op_node->Op()->RenameOutput(cache_kv_outputs[i], cache_kv_inputs[i]);
}
}
}
}
VLOG(4) << "insert number of cast op: " << cache.size();
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(float_to_half_pass, paddle::framework::ir::FloatToHalfPass);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_map>
#include <unordered_set>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/phi/common/backend.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/place.h"
namespace paddle {
namespace framework {
namespace ir {
class FloatToHalfPass : public FusePassBase {
public:
using VarType = framework::proto::VarType;
public:
FloatToHalfPass() = default;
~FloatToHalfPass() = default;
protected:
void ApplyImpl(Graph* graph) const override;
private:
void Init(Graph* graph) const;
void SetDefaultBlacklist() const;
bool OpSupportPrecision(const std::string& op_type,
phi::DataType precision,
phi::Backend backend = phi::Backend::GPU) const;
void SetOpUniqueType() const;
void RestoreOpOriginType() const;
inline std::string GetOpOriginalType(const std::string& op_type) const;
void GetOpPrecision() const;
void UpdateOpPrecision() const;
void InsertCastOp() const;
void ProcessOpWithDtypeAttr() const;
bool InputVarsNotConvert(Node* op_node, const std::string& var_name) const;
bool OutputVarsNotConvert(Node* op_node, const std::string& var_name) const;
void SetVarPrecision() const;
void ConvertWeightsData() const;
private:
mutable bool keep_io_types_;
// float16 or bfloat16 now
mutable phi::DataType half_precision_;
mutable std::unordered_set<std::string> black_list_;
// subgraph id -> pointer to subgraph
mutable std::vector<Graph*> subgraphes_;
// var name -> real var node
mutable std::unordered_map<std::string, Node*> real_vars_;
// subgraph id -> all op nodes in subgraph
mutable std::vector<std::vector<Node*>> all_op_nodes_;
// op's unique type -> the op's origin type
mutable std::unordered_map<std::string, std::string> op_original_type_;
// op's unique type -> whether the op run at half precision
mutable std::unordered_set<std::string> op_run_half_;
mutable std::unordered_set<std::string> vars_convert_to_half_;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -365,6 +365,8 @@ struct Argument {
DECL_ARGUMENT_FIELD(mixed_black_list,
MixedBlackList,
std::unordered_set<std::string>);
DECL_ARGUMENT_FIELD(enable_gpu_half, EnableGPUHalf, bool);
DECL_ARGUMENT_FIELD(mixed_precision_mode, MixedPrecisionMode, int);
private:
std::unordered_set<std::string> valid_fields_;
......
......@@ -86,10 +86,14 @@ void IRPassManager::CreatePasses(Argument *argument,
argument->tensorrt_tuned_dynamic_shape();
pass->Set("with_dynamic_shape", new bool(with_dynamic_shape));
// mixed precision related
pass->Set("model_precision", new int(argument->model_precision()));
pass->Set(
"mixed_black_list",
new std::unordered_set<std::string>(argument->mixed_black_list()));
pass->Set("enable_gpu_half", new bool(argument->enable_gpu_half()));
pass->Set("mixed_precision_mode",
new int(argument->mixed_precision_mode()));
if (pass_name == "graph_viz_pass") {
std::string optim_cache_dir = argument->optim_cache_dir();
......
......@@ -85,16 +85,29 @@ void AnalysisConfig::SetModel(const std::string &prog_file_path,
Update();
}
void AnalysisConfig::EnableUseGpu(uint64_t memory_pool_init_size_mb,
int device_id) {
int device_id,
Precision precision_mode) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
use_gpu_ = true;
memory_pool_init_size_mb_ = memory_pool_init_size_mb;
FLAGS_initial_gpu_memory_in_mb = memory_pool_init_size_mb_;
gpu_device_id_ = device_id;
mixed_precision_mode_ = precision_mode;
if (precision_mode == Precision::kFloat32) {
// default
} else if (precision_mode == Precision::kHalf ||
precision_mode == Precision::kBf16) {
enable_gpu_half_ = true;
} else {
LOG(ERROR)
<< "The Paddle-GPU inference currently only supports "
"float32/float16/bfloat16 precision. Please check the parameters "
"you specified in EnableUseGpu or enable_use_gpu function.";
}
#else
LOG(ERROR) << "Please compile with gpu to EnableGpu()";
use_gpu_ = false;
LOG(ERROR) << "Please use PaddlePaddle with GPU version.";
#endif
Update();
......@@ -381,8 +394,10 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
CP_MEMBER(gpu_device_id_);
CP_MEMBER(memory_pool_init_size_mb_);
// Mixed related.
// Mixed precision related.
CP_MEMBER(mixed_black_list_);
CP_MEMBER(enable_gpu_half_);
CP_MEMBER(mixed_precision_mode_);
CP_MEMBER(enable_memory_optim_);
// TensorRT related.
......@@ -996,6 +1011,7 @@ std::string AnalysisConfig::SerializeInfoCache() {
ss << params_file_;
ss << use_gpu_;
ss << enable_gpu_half_;
ss << use_external_stream_;
ss << exec_stream_;
ss << use_fc_padding_;
......@@ -1212,6 +1228,7 @@ std::string AnalysisConfig::Summary() {
os.InsertRow({"use_gpu", use_gpu_ ? "true" : "false"});
if (use_gpu_) {
os.InsertRow({"gpu_device_id", std::to_string(gpu_device_id_)});
os.InsertRow({"enable_gpu_half_", std::to_string(enable_gpu_half_)});
os.InsertRow({"memory_pool_init_size",
std::to_string(memory_pool_init_size_mb_) + "MB"});
os.InsertRow(
......@@ -1407,7 +1424,7 @@ bool AnalysisConfig::trt_allow_build_at_runtime() const {
return trt_allow_build_at_runtime_;
}
void AnalysisConfig::Exp_SetBlackListOpsForMixedModel(
void AnalysisConfig::Exp_DisableMixedInferOps(
const std::unordered_set<std::string> &black_list) {
mixed_black_list_ = black_list;
}
......
......@@ -1257,12 +1257,26 @@ void AnalysisPredictor::PrepareArgument() {
}
}
}
if (config_.ir_debug_) {
pass_builder->TurnOnDebug();
}
if (!config_.ir_optim()) {
argument_.SetEnableIrOptim(false);
LOG(INFO) << "ir_optim is turned off, no IR pass will be executed";
if (config_.enable_gpu_half_) {
argument_.SetEnableIrOptim(true);
pass_builder->ClearPasses();
pass_builder->AppendPass("float_to_half_pass");
LOG(INFO)
<< "This model run in Paddle-GPU mixed precision mode with no ir "
"optimization.";
} else {
LOG(INFO) << "ir_optim is turned off, no IR pass will be executed.";
}
} else {
if (config_.ir_debug_) {
pass_builder->TurnOnDebug();
}
if (config_.enable_gpu_half_) {
LOG(INFO) << "This model run in Paddle-GPU mixed precision mode.";
}
}
argument_.SetDisableLogs(config_.glog_info_disabled());
argument_.SetIrAnalysisPasses(pass_builder->AllPasses());
......@@ -1272,6 +1286,9 @@ void AnalysisPredictor::PrepareArgument() {
// mixed precison.
argument_.SetModelPrecision(static_cast<int>(model_precision_));
argument_.SetMixedBlackList(config_.mixed_black_list_);
argument_.SetEnableGPUHalf(config_.enable_gpu_half_);
argument_.SetMixedPrecisionMode(static_cast<int>(
paddle::ConvertPrecision(config_.mixed_precision_mode_)));
}
// NOTE All the members in AnalysisConfig should be copied to Argument.
......
......@@ -247,8 +247,12 @@ struct PD_INFER_DECL AnalysisConfig {
///
/// \param memory_pool_init_size_mb initial size of the GPU memory pool in MB.
/// \param device_id device_id the GPU card to use (default is 0).
/// \param precision the precision used in Paddle-GPU inference.
///
void EnableUseGpu(uint64_t memory_pool_init_size_mb, int device_id = 0);
void EnableUseGpu(uint64_t memory_pool_init_size_mb,
int device_id = 0,
Precision precision_mode = Precision::kFloat32);
///
/// \brief Turn off GPU.
///
......@@ -1005,7 +1009,7 @@ struct PD_INFER_DECL AnalysisConfig {
/// interface is in the experimental stage and may change in the future. Note
/// that the blacklist must be the same as the model conversion blacklist.
///
void Exp_SetBlackListOpsForMixedModel(
void Exp_DisableMixedInferOps(
const std::unordered_set<std::string>& black_list);
void SetApplyOptim(bool value) { apply_optim_ = value; }
......@@ -1024,13 +1028,15 @@ struct PD_INFER_DECL AnalysisConfig {
mutable std::string prog_file_;
mutable std::string params_file_;
// Mixed precision.
// Mixed precision related.
Precision mixed_precision_mode_{Precision::kFloat32};
std::unordered_set<std::string> mixed_black_list_;
// GPU related.
bool use_gpu_{false};
int gpu_device_id_{0};
uint64_t memory_pool_init_size_mb_{100}; // initial size is 100MB.
bool enable_gpu_half_{false};
bool thread_local_stream_{false};
bool use_cudnn_{false};
......
......@@ -246,9 +246,10 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"conv_elementwise_add_fuse_pass", //
#endif //
"transpose_flatten_concat_fuse_pass", //
"constant_folding_pass",
"constant_folding_pass", //
// following pass should be located in the last, since it will
// work on all fused ops.
"float_to_half_pass", //
"runtime_context_cache_pass"
});
......
......@@ -416,6 +416,9 @@ download_result(${ERNIE_INSTALL_DIR} "Ernie_result.txt.tar.gz"
if(WITH_GPU)
inference_analysis_api_test(test_analyzer_ernie ${ERNIE_INSTALL_DIR}
analyzer_ernie_tester.cc)
inference_analysis_api_test(gpu_ernie_half_test ${ERNIE_INSTALL_DIR}
gpu_ernie_half_test.cc)
set_tests_properties(gpu_ernie_half_test PROPERTIES TIMEOUT 40)
endif()
inference_analysis_api_int8_test(test_analyzer_ernie_int8 ${ERNIE_INSTALL_DIR}
analyzer_ernie_int8_tester.cc)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/tests/api/tester_helper.h"
namespace paddle {
namespace inference {
using paddle::PaddleTensor;
template <typename T>
void GetValueFromStream(std::stringstream *ss, T *t) {
(*ss) >> (*t);
}
template <>
void GetValueFromStream<std::string>(std::stringstream *ss, std::string *t) {
*t = ss->str();
}
// Split string to vector
template <typename T>
void Split(const std::string &line, char sep, std::vector<T> *v) {
std::stringstream ss;
T t;
for (auto c : line) {
if (c != sep) {
ss << c;
} else {
GetValueFromStream<T>(&ss, &t);
v->push_back(std::move(t));
ss.str({});
ss.clear();
}
}
if (!ss.str().empty()) {
GetValueFromStream<T>(&ss, &t);
v->push_back(std::move(t));
ss.str({});
ss.clear();
}
}
// Parse tensor from string
template <typename T>
bool ParseTensor(const std::string &field, paddle::PaddleTensor *tensor) {
std::vector<std::string> data;
Split(field, ':', &data);
if (data.size() < 2) return false;
std::string shape_str = data[0];
std::vector<int> shape;
Split(shape_str, ' ', &shape);
std::string mat_str = data[1];
std::vector<T> mat;
Split(mat_str, ' ', &mat);
tensor->shape = shape;
auto size =
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()) *
sizeof(T);
tensor->data.Resize(size);
std::copy(mat.begin(), mat.end(), static_cast<T *>(tensor->data.data()));
tensor->dtype = GetPaddleDType<T>();
return true;
}
// Parse input tensors from string
bool ParseLine(const std::string &line,
std::vector<paddle::PaddleTensor> *tensors) {
std::vector<std::string> fields;
Split(line, ';', &fields);
tensors->clear();
tensors->reserve(4);
int i = 0;
auto input_name = FLAGS_ernie_large ? "eval_placeholder_" : "placeholder_";
for (; i < 3; i++) {
paddle::PaddleTensor temp;
ParseTensor<int64_t>(fields[i], &temp);
temp.name = input_name + std::to_string(i);
tensors->push_back(temp);
}
// input_mask
paddle::PaddleTensor input_mask;
ParseTensor<float>(fields[i], &input_mask);
input_mask.name = input_name + std::to_string(i);
tensors->push_back(input_mask);
return true;
}
bool LoadInputData(std::vector<std::vector<paddle::PaddleTensor>> *inputs,
int batch_size = 1) {
if (FLAGS_infer_data.empty()) {
LOG(ERROR) << "please set input data path";
return false;
}
std::ifstream fin(FLAGS_infer_data);
std::string line;
int sample = 0;
// The unit-test dataset only have 10 samples, each sample have 5 feeds.
while (std::getline(fin, line)) {
std::vector<paddle::PaddleTensor> feed_data;
ParseLine(line, &feed_data);
inputs->push_back(std::move(feed_data));
sample++;
if (!FLAGS_test_all_data && sample == batch_size) break;
}
LOG(INFO) << "number of samples: " << sample;
return true;
}
// Compare results
TEST(Ernie_gpu_fp16_no_ir, compare_results) {
AnalysisConfig config;
config.SetModel(FLAGS_infer_model);
config.EnableUseGpu(512, 0, paddle_infer::PrecisionType::kHalf);
config.SwitchIrOptim(false);
auto predictor = CreatePaddlePredictor(config);
std::vector<std::vector<PaddleTensor>> input_slots_all;
LoadInputData(&input_slots_all);
std::ifstream fin(FLAGS_refer_result);
std::string line;
std::vector<float> ref;
while (std::getline(fin, line)) {
Split(line, ' ', &ref);
}
std::vector<PaddleTensor> outputs;
for (size_t i = 0; i < input_slots_all.size(); i++) {
outputs.clear();
predictor->Run(input_slots_all[i], &outputs);
auto output = outputs.front();
size_t outputs_size = 1;
for (auto dim : output.shape) {
outputs_size *= dim;
}
float *result = reinterpret_cast<float *>(output.data.data());
for (size_t j = 0; j < outputs_size; ++j) {
EXPECT_NEAR(ref[i * outputs_size + j], result[j], 5e-2);
}
}
}
// Compare results
TEST(Ernie_gpu_fp16_with_ir, compare_results) {
AnalysisConfig config;
config.SetModel(FLAGS_infer_model);
config.EnableUseGpu(512, 0, paddle_infer::PrecisionType::kHalf);
config.SwitchIrOptim(true);
// The fc_fuse_pass has diff, which will be repaired later.
config.pass_builder()->DeletePass("fc_fuse_pass");
// There is a problem with the model itself, which has nothing to do with
// constant_folding_pass.
config.pass_builder()->DeletePass("constant_folding_pass");
auto predictor = CreatePaddlePredictor(config);
std::vector<std::vector<PaddleTensor>> input_slots_all;
LoadInputData(&input_slots_all);
std::ifstream fin(FLAGS_refer_result);
std::string line;
std::vector<float> ref;
while (std::getline(fin, line)) {
Split(line, ' ', &ref);
}
std::vector<PaddleTensor> outputs;
for (size_t i = 0; i < input_slots_all.size(); i++) {
outputs.clear();
predictor->Run(input_slots_all[i], &outputs);
auto output = outputs.front();
size_t outputs_size = 1;
for (auto dim : output.shape) {
outputs_size *= dim;
}
float *result = reinterpret_cast<float *>(output.data.data());
for (size_t j = 0; j < outputs_size; ++j) {
EXPECT_NEAR(ref[i * outputs_size + j], result[j], 5e-2);
}
}
}
// Compare results
TEST(Ernie_gpu_bf16_no_ir, compare_results) {
AnalysisConfig config;
config.SetModel(FLAGS_infer_model);
config.EnableUseGpu(512, 0, paddle_infer::PrecisionType::kBf16);
config.SwitchIrOptim(false);
auto predictor = CreatePaddlePredictor(config);
std::vector<std::vector<PaddleTensor>> input_slots_all;
LoadInputData(&input_slots_all);
std::ifstream fin(FLAGS_refer_result);
std::string line;
std::vector<float> ref;
while (std::getline(fin, line)) {
Split(line, ' ', &ref);
}
std::vector<PaddleTensor> outputs;
for (size_t i = 0; i < input_slots_all.size(); i++) {
outputs.clear();
predictor->Run(input_slots_all[i], &outputs);
auto output = outputs.front();
size_t outputs_size = 1;
for (auto dim : output.shape) {
outputs_size *= dim;
}
float *result = reinterpret_cast<float *>(output.data.data());
for (size_t j = 0; j < outputs_size; ++j) {
EXPECT_NEAR(ref[i * outputs_size + j], result[j], 7e-2);
}
}
}
// Compare results
TEST(Ernie_gpu_bf16_with_ir, compare_results) {
AnalysisConfig config;
config.SetModel(FLAGS_infer_model);
config.EnableUseGpu(512, 0, paddle_infer::PrecisionType::kBf16);
config.SwitchIrOptim(true);
// The fc_fuse_pass has diff, which will be repaired later.
config.pass_builder()->DeletePass("fc_fuse_pass");
// There is a problem with the model itself, which has nothing to do with
// constant_folding_pass.
config.pass_builder()->DeletePass("constant_folding_pass");
auto predictor = CreatePaddlePredictor(config);
std::vector<std::vector<PaddleTensor>> input_slots_all;
LoadInputData(&input_slots_all);
std::ifstream fin(FLAGS_refer_result);
std::string line;
std::vector<float> ref;
while (std::getline(fin, line)) {
Split(line, ' ', &ref);
}
std::vector<PaddleTensor> outputs;
for (size_t i = 0; i < input_slots_all.size(); i++) {
outputs.clear();
predictor->Run(input_slots_all[i], &outputs);
auto output = outputs.front();
size_t outputs_size = 1;
for (auto dim : output.shape) {
outputs_size *= dim;
}
float *result = reinterpret_cast<float *>(output.data.data());
for (size_t j = 0; j < outputs_size; ++j) {
EXPECT_NEAR(ref[i * outputs_size + j], result[j], 7e-2);
}
}
}
} // namespace inference
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -12,15 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cuda_runtime.h>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <cstring>
#include <numeric>
#include "gflags/gflags.h"
#include "paddle/fluid/inference/tests/api/trt_test_helper.h"
#include "paddle/fluid/inference/tests/api/tester_helper.h"
namespace paddle_infer {
......
......@@ -644,7 +644,8 @@ void BindAnalysisConfig(py::module *m) {
.def("enable_use_gpu",
&AnalysisConfig::EnableUseGpu,
py::arg("memory_pool_init_size_mb"),
py::arg("device_id") = 0)
py::arg("device_id") = 0,
py::arg("precision_mode") = AnalysisConfig::Precision::kFloat32)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
.def("set_exec_stream",
[](AnalysisConfig &self, phi::CUDAStream &stream) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册