未验证 提交 ddcd1b61 编写于 作者: Y Yuanle Liu 提交者: GitHub

[cherry-pick][Inference] support mixed precision inference (#49077)

* [Release2.4] Revert python link prs (#48573)

* Revert "Fix mac link python (#48017)"

This reverts commit 3fa7a736.

* Revert "[Cherry-pick] Fix python link error (#47811)"

This reverts commit ff642c68.

* Update config.go

* [Paddle Inference] Add float_to_half_pass to support  inference with mixed precision (#47993)

* [Inference] optimize some code and fix some bug (#48780)

* clean ir_pass_manager and fix map_depthwise_conv_to_conv_pass

* fix unitest timeout

* [Paddle Inference] clean unused code  (#48392)

* fix

* update

* update
Co-authored-by: NChen Weihang <chenweihang@baidu.com>
上级 9e2ba9b9
......@@ -148,6 +148,7 @@ pass_library(delete_c_identity_op_pass inference)
pass_library(preln_residual_bias_fuse_pass inference)
pass_library(delete_fill_constant_op_pass inference)
pass_library(constant_folding_pass inference)
pass_library(auto_mixed_precision_pass inference)
pass_library(simplify_with_basic_ops_pass base)
pass_library(fc_elementwise_layernorm_fuse_pass base)
pass_library(skip_layernorm_fuse_pass base)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/auto_mixed_precision_pass.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"
namespace paddle {
namespace framework {
namespace ir {
namespace {
using VarType = AutoMixedPrecisionPass::VarType;
bool PhiKernelSupportPrecision(
const std::string& op_type,
phi::Backend backend,
phi::DataType data_type,
phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) {
const auto& kernels = phi::KernelFactory::Instance().kernels();
if (kernels.count(op_type) == 0) {
return false;
}
phi::KernelKey kernel_key(backend, layout, data_type);
return phi::KernelFactory::Instance().HasKernel(op_type, kernel_key);
}
bool GpuKernelSupportPrecision(
const std::string& op_type,
phi::DataType precision,
phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) {
auto phi_op_type = phi::TransToPhiKernelName(op_type);
bool support = PhiKernelSupportPrecision(
phi_op_type, phi::Backend::GPU, precision, layout);
support |= PhiKernelSupportPrecision(
phi_op_type, phi::Backend::GPUDNN, precision, layout);
if (!support) {
const auto& all_kernels = framework::OperatorWithKernel::AllOpKernels();
auto it = all_kernels.find(op_type);
if (it != all_kernels.end()) {
for (const auto& kern_pair : it->second) {
if (platform::is_gpu_place(kern_pair.first.place_) &&
kern_pair.first.data_type_ ==
framework::TransToProtoVarType(precision)) {
support = true;
break;
}
}
}
}
return support;
}
inline bool VarNodeHasDtype(Node* var_node) {
auto type = var_node->Var()->GetType();
return (type == VarType::SELECTED_ROWS) || (type == VarType::LOD_TENSOR) ||
(type == VarType::LOD_TENSOR_ARRAY) || (type == VarType::STRINGS) ||
(type == VarType::VOCAB);
}
inline bool IsFloatType(VarType::Type type) {
return (type == VarType::FP64) || (type == VarType::FP32);
}
inline bool IsHalfType(VarType::Type type) {
return (type == VarType::FP16) || (type == VarType::BF16);
}
}; // namespace
void DoInsertCastOp(Graph* graph,
Node* var_node,
Node* op_node,
VarType::Type from_type,
VarType::Type to_type,
framework::BlockDesc* block_desc,
int* suffix,
std::unordered_map<Node*, Node*>* cache) {
if (from_type == to_type) return;
auto update_cast_desc = [&](framework::OpDesc& desc,
const std::string& x_name,
const std::string& out_name,
const int in_dtype,
const int out_dtype) {
desc.SetType("cast");
desc.SetInput("X", {x_name});
desc.SetOutput("Out", {out_name});
desc.SetAttr("in_dtype", in_dtype);
desc.SetAttr("out_dtype", out_dtype);
desc.SetAttr("use_mkldnn", false);
desc.SetAttr("with_quant_attr", false);
desc.Flush();
};
if (cache->count(var_node) == 0) {
// insert cast op between var_node and op_node
std::string cast_input_name = var_node->Var()->Name();
std::string cast_output_name =
var_node->Var()->Name() + "_cast.tmp_" + std::to_string((*suffix)++);
framework::OpDesc cast_op_desc(block_desc);
update_cast_desc(cast_op_desc,
cast_input_name,
cast_output_name,
static_cast<int>(from_type),
static_cast<int>(to_type));
auto* cast_op_node = graph->CreateOpNode(&cast_op_desc);
auto* cast_output_vardesc = block_desc->Var(cast_output_name);
cast_output_vardesc->SetPersistable(false);
cast_output_vardesc->SetDataType(to_type);
cast_output_vardesc->SetShape(var_node->Var()->GetShape());
auto* cast_output_node = graph->CreateVarNode(cast_output_vardesc);
IR_NODE_LINK_TO(cast_op_node, cast_output_node);
(*cache)[var_node] = cast_output_node;
}
op_node->Op()->Rename(var_node->Name(), cache->at(var_node)->Name());
IR_NODE_LINK_TO(var_node, cache->at(var_node)->inputs[0]);
IR_NODE_LINK_TO(cache->at(var_node), op_node);
IR_NODE_UNLINK(var_node, op_node);
}
bool OpSupportPrecision(const std::string& op_type,
phi::Backend backend,
phi::DataType precision,
const std::unordered_set<std::string>& black_list) {
bool support = false;
if (black_list.count(op_type) == 0) {
if (backend == phi::Backend::GPU) {
support = GpuKernelSupportPrecision(op_type, precision);
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Now, only support backend of GPU."));
}
}
return support;
}
// The set of ops that support fp16 calculation and are considered
// numerically-dangerous, slower and whose effects may also be observed in
// downstream ops.
void AutoMixedPrecisionPass::SetDefaultBlacklist() const {
black_list_.insert({
// numerically-dangerous
"acos",
"asin",
"cosh",
"tan",
"exp",
"expm1",
"square",
"log",
"log2",
"log10",
"log1p",
"logsumexp",
"mean",
"rsqrt",
"sum",
"cos_sim",
"softmax",
"softmax_with_cross_entropy",
"sigmoid_cross_entropy_with_logits",
"c_softmax_with_cross_entropy",
"cross_entropy",
"cross_entropy2",
// slower than fp32
"conv2d_transpose",
// default fp32 can avoid return inf when the sum value large than 65504
"reduce_sum",
});
}
void AutoMixedPrecisionPass::Init(Graph* graph) const {
bool enable_gpu_mixed = Get<bool>("enable_gpu_mixed");
if (enable_gpu_mixed) {
backend_ = phi::Backend::GPU;
}
skip_pass_ = !enable_gpu_mixed;
low_precision_ = static_cast<phi::DataType>(Get<int>("mixed_precision_mode"));
black_list_ = Get<std::unordered_set<std::string>>("mixed_black_list");
SetDefaultBlacklist();
VLOG(4) << "black_list has ";
for (const auto& name : black_list_) {
VLOG(4) << " - " << name;
}
keep_io_types_ = true;
if (Has("keep_io_types")) {
keep_io_types_ = Get<bool>("keep_io_types");
}
auto graph_size = graph->SubGraphsSize();
VLOG(4) << "graph size: " << graph_size;
subgraphes_.resize(graph_size);
all_op_nodes_.resize(graph_size);
for (size_t i = 0; i < graph_size; i++) {
subgraphes_[i] = graph->GetSubGraph(i);
all_op_nodes_[i] = TopologySortOperations(*subgraphes_[i]);
VLOG(4) << "subgraph " << i << " has " << all_op_nodes_[i].size()
<< "op nodes";
for (auto* var_node : subgraphes_[i]->Nodes()) {
if (!var_node->IsVar()) continue;
auto var_name = var_node->Var()->Name();
if (real_vars_.count(var_name) == 0) {
real_vars_[var_name] = var_node;
VLOG(4) << var_name << " is in graph " << i;
}
}
}
}
void AutoMixedPrecisionPass::ApplyImpl(Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(graph,
platform::errors::PreconditionNotMet(
"During the auto_mixed_precision_pass, the graph "
"should not be nullptr."));
PADDLE_ENFORCE_EQ(graph->IsMainGraph(),
true,
platform::errors::PreconditionNotMet(
"During the auto_mixed_precision_pass, the graph "
"should be main graph."));
FusePassBase::Init("auto_mixed_precision", graph);
Init(graph);
VLOG(4) << "Init done";
if (skip_pass_) {
VLOG(3) << "Skip auto_mixed_precision_pass.";
return;
}
SetOpUniqueType();
VLOG(4) << "SetOpUniqueType done";
GetOpPrecision();
VLOG(4) << "GetOpPrecision done";
UpdateOpPrecision();
VLOG(4) << "UpdateOpPrecision done";
SetVarPrecision();
VLOG(4) << "SetVarPrecision done";
ConvertWeightsData();
VLOG(4) << "ConvertWeightsData done";
ProcessOpWithDtypeAttr();
VLOG(4) << "ProcessOpWithDtypeAttr done";
InsertCastOp();
VLOG(4) << "InsertCastOp done";
RestoreOpOriginType();
VLOG(4) << "RestoreOpOriginType done";
}
void AutoMixedPrecisionPass::SetOpUniqueType() const {
int suffix = 0;
for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) {
auto op_type = op_node->Op()->Type();
if (op_type == "feed" || op_type == "fetch") continue;
std::string unique_type = op_type + "_" + std::to_string(suffix++);
op_original_type_[unique_type] = op_type;
op_node->Op()->SetType(unique_type);
op_node->Op()->Flush();
VLOG(4) << "change op type: " << op_type << " ---> " << unique_type;
}
}
}
void AutoMixedPrecisionPass::RestoreOpOriginType() const {
for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) {
auto op_type = op_node->Op()->Type();
op_node->Op()->SetType(GetOpOriginalType(op_type));
op_node->Op()->Flush();
VLOG(4) << "restore op type: " << op_type << " ---> "
<< op_node->Op()->Type();
}
}
}
inline std::string AutoMixedPrecisionPass::GetOpOriginalType(
const std::string& op_type) const {
if (op_original_type_.count(op_type)) {
return op_original_type_.at(op_type);
}
return op_type;
}
void AutoMixedPrecisionPass::ProcessOpWithDtypeAttr() const {
for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) {
auto op_type = op_node->Op()->Type();
if (op_run_low_precision_.count(op_type) == 0) continue;
if (op_node->Op()->HasAttr("dtype")) {
auto dtype = op_node->Op()->GetAttrIfExists<int>("dtype");
if (IsFloatType(static_cast<VarType::Type>(dtype))) {
op_node->Op()->SetAttr(
"dtype",
static_cast<int>(framework::TransToProtoVarType(low_precision_)));
op_node->Op()->Flush();
VLOG(4) << "process op with dtype attr: " << op_type << " ( " << dtype
<< " --->" << static_cast<int>(low_precision_) << " )";
}
}
if (op_node->Op()->HasAttr("out_dtype")) {
auto out_dtype = op_node->Op()->GetAttrIfExists<int>("out_dtype");
if (IsFloatType(static_cast<VarType::Type>(out_dtype))) {
op_node->Op()->SetAttr(
"out_dtype",
static_cast<int>(framework::TransToProtoVarType(low_precision_)));
op_node->Op()->Flush();
VLOG(4) << "process op with out_dtype attr: " << op_type << " ( "
<< out_dtype << " --->" << static_cast<int>(low_precision_)
<< " )";
}
}
}
}
}
void AutoMixedPrecisionPass::GetOpPrecision() const {
for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) {
auto op_type = op_node->Op()->Type();
bool support_low_precision = true;
if (GetOpOriginalType(op_type) == "feed" ||
GetOpOriginalType(op_type) == "fetch") {
support_low_precision = !keep_io_types_;
} else {
support_low_precision = OpSupportPrecision(
GetOpOriginalType(op_type), backend_, low_precision_, black_list_);
}
if (op_node->Op()->HasAttr("dtype")) {
auto dtype = op_node->Op()->GetAttrIfExists<int>("dtype");
support_low_precision = support_low_precision &&
IsFloatType(static_cast<VarType::Type>(dtype));
} else if (op_node->Op()->HasAttr("out_dtype")) {
auto out_dtype = op_node->Op()->GetAttrIfExists<int>("out_dtype");
support_low_precision =
support_low_precision &&
IsFloatType(static_cast<VarType::Type>(out_dtype));
} else {
// if op's input var and output var is not dense tensor, the op should
// not run at low precision.
for (auto* in_var_node : op_node->inputs) {
CHECK_EQ(in_var_node->IsVar(), true);
auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
if (real_in_var_node->Var()->Persistable()) continue;
support_low_precision =
support_low_precision &&
(real_in_var_node->Var()->GetType() == VarType::LOD_TENSOR);
}
for (auto* out_var_node : op_node->outputs) {
CHECK_EQ(out_var_node->IsVar(), true);
auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()];
if (real_out_var_node->Var()->Persistable()) continue;
support_low_precision =
support_low_precision &&
(real_out_var_node->Var()->GetType() == VarType::LOD_TENSOR);
}
}
if (support_low_precision) {
op_run_low_precision_.insert(op_type);
VLOG(4) << "support precision: " << op_type << " run at low precision";
} else {
VLOG(4) << "support precision: " << op_type
<< " not run at low precision";
}
}
}
}
void AutoMixedPrecisionPass::UpdateOpPrecision() const {
std::unordered_set<std::string> vars_should_not_low_precision;
// var -> the var's all input op
std::unordered_map<std::string, std::vector<Node*>> var_input_ops;
auto GetVarInputOps = [&] {
for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) {
auto op_type = op_node->Op()->Type();
if (GetOpOriginalType(op_type) == "fetch") continue;
if (op_node->Op()->HasAttr("sub_block")) continue;
for (auto* var_node : op_node->outputs) {
CHECK_EQ(var_node->IsVar(), true);
if (var_node->Var()->Persistable()) continue;
if (!VarNodeHasDtype(var_node)) continue;
var_input_ops[var_node->Var()->Name()].push_back(op_node);
VLOG(4) << "var input ops: " << var_node->Var()->Name()
<< " is output of " << op_type;
}
// the select_input op's input var should not convert to low precision.
// when op's output var is select_input op's input var, the op should
// not run at low precision.
if (GetOpOriginalType(op_node->Op()->Type()) == "select_input") {
for (auto* in_var_node : op_node->inputs) {
CHECK_EQ(in_var_node->IsVar(), true);
if (in_var_node->Var()->Persistable()) continue;
if (!VarNodeHasDtype(in_var_node)) continue;
vars_should_not_low_precision.insert(in_var_node->Var()->Name());
}
}
}
}
};
GetVarInputOps();
bool precision_updated = false;
do {
precision_updated = false;
for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) {
if (op_run_low_precision_.count(op_node->Op()->Type()) == 0) continue;
for (auto* out_var_node : op_node->outputs) {
CHECK_EQ(out_var_node->IsVar(), true);
if (!VarNodeHasDtype(out_var_node)) continue;
auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()];
if (real_out_var_node->Var()->Persistable()) continue;
bool not_run_low_precision = false;
const auto& input_op_nodes =
var_input_ops[real_out_var_node->Var()->Name()];
if (vars_should_not_low_precision.count(
real_out_var_node->Var()->Name())) {
not_run_low_precision = true;
} else {
for (auto* node : input_op_nodes) {
if (op_run_low_precision_.count(node->Op()->Type()) == 0) {
not_run_low_precision = true;
break;
}
}
}
if (not_run_low_precision) {
op_run_low_precision_.erase(op_node->Op()->Type());
precision_updated = true;
VLOG(4) << op_node->Op()->Type()
<< " should not run at low precision.";
break;
}
}
}
}
} while (precision_updated);
}
// special ops, its weights should not be low precision.
bool AutoMixedPrecisionPass::InputVarsNotConvert(
Node* op_node, const std::string& var_name) const {
auto* op_desc = op_node->Op();
if (GetOpOriginalType(op_desc->Type()) == "batch_norm") {
auto vecs = op_desc->Input("Bias");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Input("Mean");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Input("Scale");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Input("Variance");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
} else if (GetOpOriginalType(op_desc->Type()) == "fused_multi_transformer") {
auto vecs = op_desc->Input("LnScale");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Input("LnBias");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Input("FFNLnScale");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Input("FFNLnBias");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
}
return false;
}
bool AutoMixedPrecisionPass::OutputVarsNotConvert(
Node* op_node, const std::string& var_name) const {
auto* op_desc = op_node->Op();
// batch_norm's input and output (variance and mean) are the same.
if (GetOpOriginalType(op_desc->Type()) == "batch_norm") {
auto vecs = op_desc->Output("MeanOut");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Output("VarianceOut");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Output("SavedMean");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Output("SavedVariance");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
}
return false;
}
void AutoMixedPrecisionPass::SetVarPrecision() const {
for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) {
if (op_run_low_precision_.count(op_node->Op()->Type()) == 0) {
continue;
}
if (GetOpOriginalType(op_node->Op()->Type()) != "feed") {
for (auto* in_var_node : op_node->inputs) {
CHECK_EQ(in_var_node->IsVar(), true);
auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
auto in_var_name = real_in_var_node->Var()->Name();
if (!IsFloatType(real_in_var_node->Var()->GetDataType())) continue;
if (!VarNodeHasDtype(real_in_var_node)) continue;
if (InputVarsNotConvert(op_node, in_var_name)) continue;
if (real_in_var_node->Var()->Persistable()) {
real_in_var_node->Var()->SetDataType(
framework::TransToProtoVarType(low_precision_));
vars_convert_to_low_precision_.insert(in_var_name);
}
}
}
if (GetOpOriginalType(op_node->Op()->Type()) != "fetch") {
for (auto* out_var_node : op_node->outputs) {
CHECK_EQ(out_var_node->IsVar(), true);
auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()];
auto out_var_name = real_out_var_node->Var()->Name();
if (!IsFloatType(real_out_var_node->Var()->GetDataType())) continue;
if (!VarNodeHasDtype(real_out_var_node)) continue;
if (OutputVarsNotConvert(op_node, out_var_name)) continue;
real_out_var_node->Var()->SetDataType(
framework::TransToProtoVarType(low_precision_));
if (real_out_var_node->Var()->Persistable()) {
vars_convert_to_low_precision_.insert(out_var_name);
}
}
}
}
}
// This code used to precess vars with the same name. Vars with the same
// name should have the same data type.
for (auto* subgraph : subgraphes_) {
for (auto* var_node : subgraph->Nodes()) {
if (!var_node->IsVar() || !var_node->Var()->Persistable()) continue;
if (!VarNodeHasDtype(var_node)) continue;
auto var_name = var_node->Var()->Name();
if (vars_convert_to_low_precision_.count(var_name)) {
var_node->Var()->SetDataType(
framework::TransToProtoVarType(low_precision_));
}
}
}
}
void AutoMixedPrecisionPass::ConvertWeightsData() const {
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(scope,
platform::errors::PreconditionNotMet(
"During the auto_mixed_precision_pass, the scope "
"should not be null."));
auto var_names = scope->LocalVarNames();
for (const auto& var_name : var_names) {
if (vars_convert_to_low_precision_.count(var_name)) {
VLOG(4) << var_name << "'s data type was convert to half";
auto* var = scope->FindLocalVar(var_name);
CHECK_EQ(var->IsType<phi::DenseTensor>(), true);
auto* origin_tensor = var->GetMutable<phi::DenseTensor>();
phi::DenseTensor low_precision_tensor;
low_precision_tensor.Resize(origin_tensor->dims());
low_precision_tensor.set_type(low_precision_);
if (low_precision_ == phi::DataType::FLOAT16) {
auto* low_precision_data =
low_precision_tensor.mutable_data<phi::dtype::float16>(
phi::CPUPlace{});
for (int64_t i = 0; i < origin_tensor->numel(); i++) {
if (origin_tensor->dtype() == phi::DataType::FLOAT64) {
auto* origin_data = origin_tensor->data<double>();
low_precision_data[i] =
static_cast<phi::dtype::float16>(origin_data[i]);
} else if (origin_tensor->dtype() == phi::DataType::FLOAT32) {
auto* origin_data = origin_tensor->data<float>();
low_precision_data[i] =
static_cast<phi::dtype::float16>(origin_data[i]);
}
}
} else if (low_precision_ == phi::DataType::BFLOAT16) {
auto* half_data =
low_precision_tensor.mutable_data<phi::dtype::bfloat16>(
phi::CPUPlace{});
for (int64_t i = 0; i < origin_tensor->numel(); i++) {
if (origin_tensor->dtype() == phi::DataType::FLOAT64) {
auto* origin_data = origin_tensor->data<double>();
half_data[i] = static_cast<phi::dtype::bfloat16>(origin_data[i]);
} else if (origin_tensor->dtype() == phi::DataType::FLOAT32) {
auto* origin_data = origin_tensor->data<float>();
half_data[i] = static_cast<phi::dtype::bfloat16>(origin_data[i]);
}
}
}
origin_tensor->clear();
paddle::framework::TensorCopySync(
low_precision_tensor, phi::CPUPlace{}, origin_tensor);
}
}
}
void AutoMixedPrecisionPass::InsertCastOp() const {
int suffix = 0;
std::unordered_map<Node*, Node*> cache;
for (size_t i = 0; i < all_op_nodes_.size(); i++) {
auto* block_desc = all_op_nodes_[i][0]->Op()->Block();
CHECK_NOTNULL(block_desc);
for (auto* op_node : all_op_nodes_[i]) {
auto op_type = op_node->Op()->Type();
if (GetOpOriginalType(op_type) == "feed") continue;
if (op_node->Op()->HasAttr("sub_block")) continue;
VLOG(4) << "process op: " << op_type
<< " run low precision: " << op_run_low_precision_.count(op_type);
auto inputs = op_node->inputs;
for (auto* in_var_node : inputs) {
if (!in_var_node->IsVar()) continue;
if (!VarNodeHasDtype(in_var_node)) continue;
if (in_var_node->Var()->Persistable()) continue;
auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
auto in_var_type = real_in_var_node->Var()->GetDataType();
VLOG(4) << "process var: " << real_in_var_node->Var()->Name()
<< " with type " << in_var_type;
if (IsFloatType(in_var_type) && op_run_low_precision_.count(op_type)) {
DoInsertCastOp(subgraphes_[i],
in_var_node,
op_node,
in_var_type,
framework::TransToProtoVarType(low_precision_),
block_desc,
&suffix,
&cache);
} else if (IsHalfType(in_var_type) &&
op_run_low_precision_.count(op_type) == 0) {
DoInsertCastOp(subgraphes_[i],
in_var_node,
op_node,
in_var_type,
VarType::FP32,
block_desc,
&suffix,
&cache);
}
}
// Special op.
// fused_multi_transformer's input(CacheKV) and output(CacheKVOut) vars
// have same name.
if (GetOpOriginalType(op_type) == "fused_multi_transformer") {
auto cache_kv_inputs = op_node->Op()->Input("CacheKV");
auto cache_kv_outputs = op_node->Op()->Output("CacheKVOut");
CHECK_EQ(cache_kv_inputs.size(), cache_kv_outputs.size());
for (size_t i = 0; i < cache_kv_inputs.size(); ++i) {
op_node->Op()->RenameOutput(cache_kv_outputs[i], cache_kv_inputs[i]);
}
}
}
}
VLOG(4) << "insert number of cast op: " << cache.size();
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(auto_mixed_precision_pass,
paddle::framework::ir::AutoMixedPrecisionPass);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_map>
#include <unordered_set>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/phi/common/backend.h"
#include "paddle/phi/common/data_type.h"
namespace paddle {
namespace framework {
namespace ir {
class AutoMixedPrecisionPass : public FusePassBase {
public:
using VarType = framework::proto::VarType;
public:
AutoMixedPrecisionPass() = default;
~AutoMixedPrecisionPass() = default;
protected:
void ApplyImpl(Graph* graph) const override;
private:
void Init(Graph* graph) const;
void SetDefaultBlacklist() const;
void SetOpUniqueType() const;
void RestoreOpOriginType() const;
inline std::string GetOpOriginalType(const std::string& op_type) const;
void GetOpPrecision() const;
void UpdateOpPrecision() const;
void InsertCastOp() const;
void ProcessOpWithDtypeAttr() const;
bool InputVarsNotConvert(Node* op_node, const std::string& var_name) const;
bool OutputVarsNotConvert(Node* op_node, const std::string& var_name) const;
void SetVarPrecision() const;
void ConvertWeightsData() const;
private:
mutable bool skip_pass_{false};
mutable bool keep_io_types_{false};
// float16 or bfloat16 now
mutable phi::DataType low_precision_{phi::DataType::FLOAT16};
mutable phi::Backend backend_{phi::Backend::GPU};
mutable std::unordered_set<std::string> black_list_;
// subgraph id -> pointer to subgraph
mutable std::vector<Graph*> subgraphes_;
// var name -> real var node
mutable std::unordered_map<std::string, Node*> real_vars_;
// subgraph id -> all op nodes in subgraph
mutable std::vector<std::vector<Node*>> all_op_nodes_;
// op's unique type -> the op's origin type
mutable std::unordered_map<std::string, std::string> op_original_type_;
// op's unique type -> whether the op run at low precision
mutable std::unordered_set<std::string> op_run_low_precision_;
mutable std::unordered_set<std::string> vars_convert_to_low_precision_;
};
bool OpSupportPrecision(const std::string& op_type,
phi::Backend backend,
phi::DataType precision,
const std::unordered_set<std::string>& black_list);
void DoInsertCastOp(Graph* graph,
Node* var_node,
Node* op_node,
proto::VarType::Type from_type,
proto::VarType::Type to_type,
framework::BlockDesc* block_desc,
int* suffix,
std::unordered_map<Node*, Node*>* cache);
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -29,6 +29,11 @@ void FillConstData(LoDTensor* out_t, T value) {
}
void DeleteFillConstantOpPass::ApplyImpl(ir::Graph* graph) const {
bool with_dynamic_shape = Get<bool>("with_dynamic_shape");
// Not support
if (with_dynamic_shape) {
return;
}
FusePassBase::Init("delete_fill_constant_op_pass", graph);
GraphPatternDetector detector;
auto fill_constant_op =
......
......@@ -75,7 +75,6 @@ Graph::Graph(const ProgramDesc &program,
}
} else {
auto var_nodes = InitFromProgram(program_, start_op_index, end_op_index);
ResolveHazard(var_nodes);
}
}
......@@ -88,7 +87,6 @@ Graph::Graph(const BlockDesc &block,
const int64_t end_op_index)
: main_graph_(main_graph) {
auto var_nodes = InitFromBlock(block, start_op_index, end_op_index);
ResolveHazard(var_nodes);
}
// TODO(levi): delete this interface after when we can convert all
......
......@@ -130,86 +130,6 @@ TEST(GraphTest, Basic) {
ASSERT_EQ(nodes.size(), 5UL);
}
TEST(GraphTest, WriteAfterRead) {
// void Test() {
ProgramDesc prog;
auto *op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum");
op->SetInput("X", {"a"});
op->SetOutput("Out", {"b"});
op->SetAttr("op_role", 1);
op = prog.MutableBlock(0)->AppendOp();
op->SetType("dummy");
op->SetInput("X", {"c"});
op->SetOutput("Out", {"a"});
op->SetAttr("op_role", 1);
prog.MutableBlock(0)->Var("a")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("b")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("c")->SetType(proto::VarType::LOD_TENSOR);
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
ir::Node *control_dep1 = nullptr;
ir::Node *control_dep2 = nullptr;
for (ir::Node *n : g->Nodes()) {
if (n->Name() == "sum") {
ASSERT_EQ(n->outputs[0]->Name(), "b");
ASSERT_TRUE(ir::IsControlDepVar(*n->outputs[1]));
control_dep1 = n->outputs[1];
ASSERT_EQ(n->outputs.size(), 2UL);
}
if (n->Name() == "dummy") {
ASSERT_EQ(n->inputs[0]->Name(), "c");
ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1]));
control_dep2 = n->inputs[1];
ASSERT_EQ(n->inputs.size(), 2UL);
}
}
ASSERT_EQ(control_dep1, control_dep2);
}
TEST(GraphTest, WriteAfterWrite) {
// void Test() {
ProgramDesc prog;
auto *op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum");
op->SetInput("X", {"a"});
op->SetOutput("Out", {"b"});
op->SetAttr("op_role", 1);
op = prog.MutableBlock(0)->AppendOp();
op->SetType("dummy");
op->SetInput("X", {"c"});
op->SetOutput("Out", {"b"});
op->SetAttr("op_role", 1);
prog.MutableBlock(0)->Var("a")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("b")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("c")->SetType(proto::VarType::LOD_TENSOR);
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
ir::Node *control_dep1 = nullptr;
ir::Node *control_dep2 = nullptr;
for (ir::Node *n : g->Nodes()) {
if (n->Name() == "sum") {
ASSERT_EQ(n->outputs[0]->Name(), "b");
ASSERT_TRUE(ir::IsControlDepVar(*n->outputs[1]));
ASSERT_EQ(n->outputs.size(), 2UL);
control_dep1 = n->outputs[1];
}
if (n->Name() == "dummy") {
ASSERT_EQ(n->inputs[0]->Name(), "c");
ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1]));
control_dep2 = n->inputs[1];
ASSERT_EQ(n->inputs.size(), 2UL);
}
}
ASSERT_NE(control_dep1, nullptr);
ASSERT_NE(control_dep2, nullptr);
ASSERT_EQ(control_dep1, control_dep2);
}
TEST(GraphTest, TestException) {
ProgramDesc prog;
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
......@@ -350,12 +270,13 @@ TEST(GraphTest, TestMultiBlock) {
op = prog.MutableBlock(1)->AppendOp();
op->SetType("dummy");
op->SetInput("X", {"c"});
op->SetOutput("Out", {"a"});
op->SetOutput("Out", {"d"});
op->SetAttr("op_role", 1);
prog.MutableBlock(1)->Var("a")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(1)->Var("b")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(1)->Var("c")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(1)->Var("d")->SetType(proto::VarType::LOD_TENSOR);
// Set contents in block_2.
op = prog.MutableBlock(2)->AppendOp();
......@@ -367,12 +288,13 @@ TEST(GraphTest, TestMultiBlock) {
op = prog.MutableBlock(2)->AppendOp();
op->SetType("dummy");
op->SetInput("X", {"c"});
op->SetOutput("Out", {"b"});
op->SetOutput("Out", {"d"});
op->SetAttr("op_role", 1);
prog.MutableBlock(2)->Var("a")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(2)->Var("b")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(2)->Var("c")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(1)->Var("d")->SetType(proto::VarType::LOD_TENSOR);
// Step2: Convert program into graph, 3 blocks corresponding 3 sub_graphs.
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
......@@ -399,45 +321,29 @@ TEST(GraphTest, TestMultiBlock) {
// Check contents in sub_graph_1.
const ir::Graph *g1 = g->GetSubGraph(1);
ir::Node *control_dep1 = nullptr;
ir::Node *control_dep2 = nullptr;
for (ir::Node *n : g1->Nodes()) {
if (n->Name() == "sum") {
ASSERT_EQ(n->outputs[0]->Name(), "b");
ASSERT_TRUE(ir::IsControlDepVar(*n->outputs[1]));
control_dep1 = n->outputs[1];
ASSERT_EQ(n->outputs.size(), 2UL);
ASSERT_EQ(n->outputs.size(), 1UL);
}
if (n->Name() == "dummy") {
ASSERT_EQ(n->inputs[0]->Name(), "c");
ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1]));
control_dep2 = n->inputs[1];
ASSERT_EQ(n->inputs.size(), 2UL);
ASSERT_EQ(n->inputs.size(), 1UL);
}
}
ASSERT_EQ(control_dep1, control_dep2);
// Check contents in sub_graph_2.
const ir::Graph *g2 = g->GetSubGraph(2);
control_dep1 = nullptr;
control_dep2 = nullptr;
for (ir::Node *n : g2->Nodes()) {
if (n->Name() == "sum") {
ASSERT_EQ(n->outputs[0]->Name(), "b");
ASSERT_TRUE(ir::IsControlDepVar(*n->outputs[1]));
ASSERT_EQ(n->outputs.size(), 2UL);
control_dep1 = n->outputs[1];
ASSERT_EQ(n->outputs.size(), 1UL);
}
if (n->Name() == "dummy") {
ASSERT_EQ(n->inputs[0]->Name(), "c");
ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1]));
control_dep2 = n->inputs[1];
ASSERT_EQ(n->inputs.size(), 2UL);
ASSERT_EQ(n->inputs.size(), 1UL);
}
}
ASSERT_NE(control_dep1, nullptr);
ASSERT_NE(control_dep2, nullptr);
ASSERT_EQ(control_dep1, control_dep2);
// Step3: Clone graph.
std::shared_ptr<ir::Graph> clone_g = g->Clone();
......
......@@ -331,8 +331,6 @@ void BatchMergePass::ApplyImpl(ir::Graph* graph) const {
copy_node(node);
}
}
result.ResolveHazard(created);
}
} // namespace ir
......
......@@ -183,5 +183,6 @@ void NaiveExecutor::ResetTrtOps(int num) {
}
#endif
}
} // namespace framework
} // namespace paddle
......@@ -38,8 +38,7 @@ void Analyzer::RunAnalysis(Argument *argument) {
if (!disable_logs) {
string::PrettyLogH1("--- Running analysis [%s]", pass);
}
if (!argument->enable_analysis_optim() && pass == "ir_analysis_pass")
continue;
if (!argument->enable_ir_optim() && pass == "ir_analysis_pass") continue;
auto *ptr = PassRegistry::Global().Retreive(pass);
PADDLE_ENFORCE_NOT_NULL(ptr,
......
......@@ -31,7 +31,7 @@ TEST(Analyzer, analysis_without_tensorrt) {
Argument argument;
argument.SetDisableLogs(false);
argument.SetModelDir(FLAGS_inference_model_dir);
argument.SetEnableAnalysisOptim(false);
argument.SetEnableIrOptim(false);
argument.SetUseGPU(false);
argument.SetAnalysisPasses({"ir_graph_build_pass",
"ir_analysis_pass",
......@@ -44,7 +44,7 @@ TEST(Analyzer, analysis_without_tensorrt) {
TEST(Analyzer, analysis_with_tensorrt) {
Argument argument;
argument.SetDisableLogs(false);
argument.SetEnableAnalysisOptim(false);
argument.SetEnableIrOptim(false);
argument.SetTensorRtMaxBatchSize(3);
argument.SetTensorRtWorkspaceSize(1 << 20);
argument.SetModelDir(FLAGS_inference_model_dir);
......
......@@ -42,8 +42,6 @@ namespace paddle {
namespace inference {
namespace analysis {
using framework::ir::Graph;
#ifdef PADDLE_WITH_MKLDNN
using VarQuantScale =
std::unordered_map<std::string, std::pair<bool, framework::LoDTensor>>;
......@@ -148,7 +146,7 @@ struct Argument {
DECL_ARGUMENT_FIELD(model_params_path, ModelParamsPath, std::string);
DECL_ARGUMENT_FIELD(model_from_memory, ModelFromMemory, bool);
DECL_ARGUMENT_FIELD(optim_cache_dir, OptimCacheDir, std::string);
DECL_ARGUMENT_FIELD(enable_analysis_optim, EnableAnalysisOptim, bool);
DECL_ARGUMENT_FIELD(enable_ir_optim, EnableIrOptim, bool);
// For JITLayer
DECL_ARGUMENT_FIELD(skip_load_params, SkipLoadParams, bool);
......@@ -362,6 +360,8 @@ struct Argument {
DECL_ARGUMENT_FIELD(mixed_black_list,
MixedBlackList,
std::unordered_set<std::string>);
DECL_ARGUMENT_FIELD(enable_gpu_mixed, EnableGPUMixed, bool);
DECL_ARGUMENT_FIELD(mixed_precision_mode, MixedPrecisionMode, int);
private:
std::unordered_set<std::string> valid_fields_;
......
......@@ -153,25 +153,6 @@ T &GetFromScope(const framework::Scope &scope, const std::string &name) {
return *var->GetMutable<T>();
}
static framework::proto::ProgramDesc LoadProgramDesc(
const std::string &model_path) {
std::ifstream fin(model_path, std::ios::in | std::ios::binary);
PADDLE_ENFORCE_EQ(
fin.is_open(),
true,
platform::errors::NotFound(
"Cannot open file %s, please confirm whether the file exists",
model_path));
fin.seekg(0, std::ios::end);
std::string buffer(fin.tellg(), ' ');
fin.seekg(0, std::ios::beg);
fin.read(&buffer[0], buffer.size());
fin.close();
framework::proto::ProgramDesc program_desc;
program_desc.ParseFromString(buffer);
return program_desc;
}
static bool FileExists(const std::string &filepath) {
std::ifstream file(filepath);
bool exists = file.is_open();
......
......@@ -27,6 +27,7 @@
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/analysis/argument.h"
#include "paddle/fluid/string/pretty_log.h"
#include "paddle/phi/core/errors.h"
namespace paddle {
namespace inference {
......@@ -36,15 +37,6 @@ using string::PrettyLogEndl;
using string::Style;
IRPassManager::IRPassManager(Argument *argument) {
ARGUMENT_CHECK_FIELD(argument, main_program);
graph_ = std::unique_ptr<Graph>(new Graph(argument->main_program()));
if (argument->Has("scope")) {
auto *scope_ptr = argument->scope_ptr();
PADDLE_ENFORCE_NOT_NULL(scope_ptr,
platform::errors::PreconditionNotMet(
"The scope ptr should not be nullptr."));
graph_->SetNotOwned(framework::ir::kParamScopeAttr, scope_ptr);
}
disable_logs_ = argument->disable_logs();
ARGUMENT_CHECK_FIELD(argument, ir_analysis_passes);
......@@ -95,10 +87,14 @@ void IRPassManager::CreatePasses(Argument *argument,
argument->tensorrt_tuned_dynamic_shape();
pass->Set("with_dynamic_shape", new bool(with_dynamic_shape));
// mixed precision related
pass->Set("model_precision", new int(argument->model_precision()));
pass->Set(
"mixed_black_list",
new std::unordered_set<std::string>(argument->mixed_black_list()));
pass->Set("enable_gpu_mixed", new bool(argument->enable_gpu_mixed()));
pass->Set("mixed_precision_mode",
new int(argument->mixed_precision_mode()));
if (pass_name == "graph_viz_pass") {
std::string optim_cache_dir = argument->optim_cache_dir();
......@@ -302,42 +298,18 @@ void IRPassManager::CreatePasses(Argument *argument,
}
std::unique_ptr<Graph> IRPassManager::Apply(std::unique_ptr<Graph> graph) {
if (passes_.empty()) {
return graph;
}
PADDLE_ENFORCE_NOT_NULL(
graph.get(),
platform::errors::PreconditionNotMet("Graph cannot be NULL."));
graph.get(), platform::errors::InvalidArgument("Graph cannot be null."));
// Apply all the passes
for (const auto &pass : passes_) {
if (pass->Type() != "graph_viz_pass" && !disable_logs_) {
PrettyLogEndl(Style::H2(), "--- Running IR pass [%s]", pass->Type());
}
// delete_fill_constant_op_pass is not apply under trt dynamic shape
if (pass->Type() == "delete_fill_constant_op_pass") {
bool use_dynamic = pass->Get<bool>("with_dynamic_shape");
if (use_dynamic) continue;
}
graph.reset(pass->Apply(graph.release()));
}
return graph;
}
framework::proto::ProgramDesc IRPassManager::AcquireProgram(
std::unique_ptr<Graph> *graph, ProgramDesc *program) const {
auto pass =
framework::ir::PassRegistry::Instance().Get("graph_to_program_pass");
// Direct using ProgramDesc desc(argument->main_program()) may cause
// incomplete copies of information.
ProgramDesc desc;
desc.CopyFrom(*program->Proto());
pass->SetNotOwned("program", &desc);
auto *the_graph = graph->release();
graph->reset(pass->Apply(the_graph));
return *desc.Proto();
}
} // namespace analysis
} // namespace inference
} // namespace paddle
......@@ -48,15 +48,9 @@ class IRPassManager final {
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph);
framework::proto::ProgramDesc AcquireProgram(std::unique_ptr<Graph> *graph,
ProgramDesc *program) const;
framework::ir::Graph &graph() const { return *graph_; }
private:
void CreatePasses(Argument *argument, const std::vector<std::string> &passes);
std::unique_ptr<Graph> graph_;
std::vector<std::unique_ptr<Pass>> passes_;
bool disable_logs_{false};
};
......
......@@ -94,13 +94,13 @@ void OutputProcess(framework::ir::Graph *graph,
backend,
precision,
blacklist)) {
AddCastOp(graph,
InsertCastOp(graph,
var_node,
next_op,
framework::proto::VarType::FP32,
to_type,
&suffix,
block_desc,
&suffix,
&var_to_cast_op_map);
var_node->Var()->SetDataType(framework::proto::VarType::FP32);
}
......
......@@ -13,7 +13,7 @@ cc_library(
cc_library(
convert_to_mixed_precision
SRCS convert_to_mixed_precision.cc
DEPS analysis_pass ir_graph_build_pass)
DEPS analysis_pass ir_graph_build_pass auto_mixed_precision_pass)
cc_library(
ir_params_sync_among_devices_pass
SRCS ir_params_sync_among_devices_pass.cc
......@@ -30,17 +30,6 @@ cc_library(
inference_op_replace_pass
SRCS inference_op_replace_pass.cc
DEPS analysis_pass graph_to_program_pass)
if(WITH_TESTING)
cc_library(
ir_graph_clean_pass
SRCS ir_graph_clean_pass.cc
DEPS analysis_pass gtest)
else()
cc_library(
ir_graph_clean_pass
SRCS ir_graph_clean_pass.cc
DEPS analysis_pass)
endif()
cc_library(
analysis_passes
......@@ -52,8 +41,7 @@ cc_library(
memory_optim_pass
convert_to_mixed_precision
inference_op_replace_pass
ir_graph_to_program_pass
ir_graph_clean_pass)
ir_graph_to_program_pass)
set(analysis_deps
${analysis_deps} analysis_passes subgraph_detector
......
......@@ -14,82 +14,17 @@
#include "paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h"
#include <algorithm>
#include <iterator>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/auto_mixed_precision_pass.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/inference/analysis/argument.h"
#include "paddle/fluid/inference/analysis/passes/ir_graph_clean_pass.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/tensor_meta.h"
using namespace paddle::framework; // NOLINT
#include "paddle/phi/common/backend.h"
namespace paddle {
namespace inference {
namespace analysis {
namespace {
bool PhiKernelSupportPrecision(
const std::string& op_type,
phi::Backend backend,
phi::DataType data_type,
phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) {
auto kernels = phi::KernelFactory::Instance().kernels();
if (kernels.find(op_type) == kernels.end()) {
return false;
}
phi::KernelKey kernel_key(backend, layout, data_type);
return phi::KernelFactory::Instance().HasKernel(op_type, kernel_key);
}
bool GpuKernelSupportPrecision(
const std::string& op_type,
phi::DataType data_type,
phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) {
auto phi_op_type = phi::TransToPhiKernelName(op_type);
bool res = PhiKernelSupportPrecision(
phi_op_type, phi::Backend::GPU, data_type, layout);
res |= PhiKernelSupportPrecision(
phi_op_type, phi::Backend::GPUDNN, data_type, layout);
if (!res) {
auto& all_kernels = OperatorWithKernel::AllOpKernels();
auto it = all_kernels.find(op_type);
if (it != all_kernels.end()) {
for (auto& kern_pair : it->second) {
if (platform::is_gpu_place(kern_pair.first.place_) &&
kern_pair.first.data_type_ == framework::proto::VarType::FP16) {
res = true;
}
}
}
}
return res;
}
class ConvertToMixedPrecisionPass {
public:
explicit ConvertToMixedPrecisionPass(
ConvertToMixedPrecisionPass::ConvertToMixedPrecisionPass(
const std::string& model_file,
const std::string& params_file,
const std::string& mixed_model_file,
......@@ -97,7 +32,7 @@ class ConvertToMixedPrecisionPass {
phi::DataType mixed_precision,
phi::Backend backend,
bool keep_io_types,
std::unordered_set<std::string> black_list)
const std::unordered_set<std::string>& black_list)
: model_file_(model_file),
params_file_(params_file),
mixed_model_file_(mixed_model_file),
......@@ -105,716 +40,62 @@ class ConvertToMixedPrecisionPass {
mixed_precision_(mixed_precision),
backend_(backend),
keep_io_types_(keep_io_types),
black_list_(black_list),
place_(paddle::CPUPlace()),
executor_(place_) {
black_list_.insert("assign");
black_list_.insert("fill_constant");
black_list_.insert("assign_value");
black_list_.insert("eye");
black_list_.insert("fill_any_like");
black_list_.insert("fill_constant_batch_size_like");
}
void Run();
private:
void LoadAndPrepare();
inline bool NodeVarHasDtype(framework::ir::Node* node);
void ConvertAllFp64ToFp32(framework::ir::Graph* graph);
void FixCastAttr(framework::ir::Graph* graph);
void SaveMixedModel();
void ConvertTensorDtype(int block_idx);
void ProcessInputNode(bool support_precision,
ir::Node* in_node,
ir::Node* op_node,
int* suffix,
framework::BlockDesc* block_desc,
framework::proto::VarType::Type to_type,
int block_idx);
void ProcessOutputNode(int block_idx,
ir::Node* var_node,
framework::proto::VarType::Type to_type);
inline bool IsFloatVarType(framework::proto::VarType::Type type);
bool OutShouldNotConvert(ir::Node* var_node);
// Just process special cases for weights conversion.
bool WeightsShouldNotConvert(ir::Node* var_node);
// To support multi block, we need to consider a lot of special cases.
// Return Node* which first appers in block.
framework::ir::Node* GetRealNode(int block_idx, framework::ir::Node* node);
void FindVarsInMultiBlock();
inline bool VarIsMultiPrecisionOpsOut(int block_idx,
framework::ir::Node* op_node);
private:
// A trick. Patch for strange op, which input name equal to output name, such
// as `fused_multi_transformer`
void PatchForStrangeOp();
private:
std::string model_file_;
std::string params_file_;
std::string mixed_model_file_;
std::string mixed_params_file_;
phi::DataType mixed_precision_;
phi::Backend backend_;
bool keep_io_types_;
std::unordered_set<std::string> black_list_;
paddle::CPUPlace place_;
framework::Executor executor_;
framework::Scope scope_;
std::unordered_map<framework::ir::Node*, framework::ir::Node*> cast_map_;
std::unordered_map<std::string,
std::pair<framework::proto::VarType::Type, int>>
vars_in_multi_block_map_;
std::vector<std::unordered_map<std::string, std::vector<std::string>>>
vars_appear_multi_in_one_block_;
int suffix_{0};
std::unique_ptr<framework::ProgramDesc> program_desc_{nullptr};
std::unique_ptr<framework::ir::Graph> main_graph_{nullptr};
std::vector<framework::ir::Graph*> graphes_;
};
framework::ir::Node* ConvertToMixedPrecisionPass::GetRealNode(
int block_idx, framework::ir::Node* node) {
if (vars_in_multi_block_map_.count(node->Name())) {
int var_origin_block_id = vars_in_multi_block_map_.at(node->Name()).second;
if (block_idx != var_origin_block_id) {
auto graph = graphes_[var_origin_block_id];
for (auto nd : graph->Nodes()) {
if (nd->Name() == node->Name()) {
return nd;
}
}
}
}
return node;
}
inline bool ConvertToMixedPrecisionPass::NodeVarHasDtype(
framework::ir::Node* node) {
if (node->IsVar() &&
(node->Var()->GetType() ==
paddle::framework::proto::VarType::SELECTED_ROWS ||
node->Var()->GetType() ==
paddle::framework::proto::VarType::LOD_TENSOR ||
node->Var()->GetType() ==
paddle::framework::proto::VarType::LOD_TENSOR_ARRAY ||
node->Var()->GetType() == paddle::framework::proto::VarType::STRINGS ||
node->Var()->GetType() == paddle::framework::proto::VarType::VOCAB)) {
return true;
}
return false;
}
// op1(fp32) -> var1, op2(fp16) -> var1
// if and only if op1 and op2 both support fp16, we convert op1 and op2's
// precision.
inline bool ConvertToMixedPrecisionPass::VarIsMultiPrecisionOpsOut(
int block_idx, framework::ir::Node* op_node) {
CHECK_EQ(op_node->IsOp(), true);
bool ret{false};
for (auto* out : op_node->outputs) {
auto* real_node = GetRealNode(block_idx, out);
if (!real_node->Var()->Persistable() &&
vars_appear_multi_in_one_block_[block_idx].count(out->Name())) {
for (auto op_type :
vars_appear_multi_in_one_block_[block_idx].at(out->Name())) {
if (OpSupportPrecision(
op_type, backend_, mixed_precision_, black_list_)) {
ret = true;
VLOG(2) << out->Name()
<< " is multi precision op's out, so we skip convert to fp16";
break;
}
}
}
if (ret) break;
}
return ret;
}
void ConvertToMixedPrecisionPass::ProcessInputNode(
bool support_precision,
ir::Node* in_node,
ir::Node* op_node,
int* suffix,
framework::BlockDesc* block_desc,
framework::proto::VarType::Type to_type,
int block_idx) {
auto* real_node = GetRealNode(block_idx, in_node);
if (!NodeVarHasDtype(real_node)) return;
auto graph = graphes_[block_idx];
bool is_main_block = block_idx == 0;
auto* in_var = real_node->Var();
auto in_var_type = in_var->GetDataType();
auto prev_type = in_var_type;
bool is_in_multi_block = vars_in_multi_block_map_.count(in_var->Name());
if (!is_main_block && is_in_multi_block) {
in_var_type = vars_in_multi_block_map_.at(in_var->Name()).first;
}
if (support_precision) {
if (in_var->Persistable() &&
in_var_type == framework::proto::VarType::FP32) {
if (WeightsShouldNotConvert(in_node)) return;
in_var->SetDataType(to_type);
in_var_type = to_type;
VLOG(3) << " in_node name " << in_var->Name() << " from " << prev_type
<< " to " << to_type;
} else if (!in_var->Persistable() && IsFloatVarType(in_var_type) &&
in_var_type != to_type) {
AddCastOp(graph,
in_node,
op_node,
in_var_type,
to_type,
suffix,
block_desc,
&cast_map_);
VLOG(3) << " in_node name " << in_var->Name() << "(" << prev_type
<< ") to " << cast_map_[in_node]->Name() << "(" << to_type << ")";
}
} else {
if (!in_var->Persistable() && IsFloatVarType(in_var_type) &&
in_var_type != to_type) {
AddCastOp(graph,
in_node,
op_node,
in_var_type,
to_type,
suffix,
block_desc,
&cast_map_);
VLOG(3) << " in_node name " << in_var->Name() << "(" << prev_type
<< ") to " << cast_map_[in_node]->Name() << "(" << to_type << ")";
}
}
}
void ConvertToMixedPrecisionPass::ProcessOutputNode(
int block_idx,
ir::Node* var_node,
framework::proto::VarType::Type to_type) {
auto* real_node = GetRealNode(block_idx, var_node);
if (!NodeVarHasDtype(real_node)) return;
auto* out_var = real_node->Var();
auto prev_type = out_var->GetDataType();
if (out_var->GetDataType() == framework::proto::VarType::FP32) {
if (OutShouldNotConvert(var_node)) return;
out_var->SetDataType(to_type);
}
VLOG(3) << " out_node name " << var_node->Name() << " from dtype "
<< prev_type << " to " << out_var->GetDataType();
}
// Just process special cases.
bool ConvertToMixedPrecisionPass::OutShouldNotConvert(ir::Node* var_node) {
auto op_node = var_node->inputs[0];
auto* op_desc = op_node->Op();
// batch_norm's input and output (variance and mean) are the same.
if (op_desc->Type() == "batch_norm") {
auto vecs = op_desc->Output("MeanOut");
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
}
vecs = op_desc->Output("VarianceOut");
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
}
vecs = op_desc->Output("SavedMean");
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
}
vecs = op_desc->Output("SavedVariance");
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
black_list_(black_list) {
if (mixed_precision_ != phi::DataType::FLOAT16 &&
mixed_precision_ != phi::DataType::BFLOAT16) {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"mixed_precision currently not supported dtype %d, we now only "
"support fp16 and bf16.",
static_cast<int>(mixed_precision_)));
}
if (backend_ != phi::Backend::GPU) {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"mixed_precision currently not supported place %d, we now only "
"support gpu.",
static_cast<int>(backend_)));
}
return false;
}
bool ConvertToMixedPrecisionPass::WeightsShouldNotConvert(ir::Node* var_node) {
auto op_nodes = var_node->outputs;
for (auto* op_node : op_nodes) {
auto* op_desc = op_node->Op();
// batch_norm op's bias, mean, scale and variance just be float32, so we can
// not convert the dtype.
if (op_desc->Type() == "batch_norm") {
auto vecs = op_desc->Input("Bias");
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
}
vecs = op_desc->Input("Mean");
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
}
vecs = op_desc->Input("Scale");
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
}
vecs = op_desc->Input("Variance");
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
}
} else if (op_desc->Type() == "fused_multi_transformer") {
auto vecs = op_desc->Input("LnScale");
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
}
void ConvertToMixedPrecisionPass::LoadModel() {
framework::Executor exe{platform::CPUPlace{}};
vecs = op_desc->Input("LnBias");
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
}
vecs = op_desc->Input("FFNLnScale");
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
}
vecs = op_desc->Input("FFNLnBias");
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
}
}
}
return false;
}
inline bool ConvertToMixedPrecisionPass::IsFloatVarType(
framework::proto::VarType::Type type) {
if (type == framework::proto::VarType::FP16 ||
type == framework::proto::VarType::FP32 ||
type == framework::proto::VarType::BF16)
return true;
return false;
}
void ConvertToMixedPrecisionPass::LoadAndPrepare() {
program_desc_ =
inference::Load(&executor_, &scope_, model_file_, params_file_);
auto program_desc = inference::Load(&exe, &scope_, model_file_, params_file_);
main_graph_ = std::unique_ptr<framework::ir::Graph>(
new framework::ir::Graph(*program_desc_));
// Remove all control var
IrInferCleanGraphPass pass;
Argument arg;
arg.SetMainGraphNotOwned(main_graph_.get());
pass.Run(&arg);
vars_appear_multi_in_one_block_.resize(program_desc_->Size());
FindVarsInMultiBlock();
}
void ConvertToMixedPrecisionPass::FindVarsInMultiBlock() {
std::vector<std::set<std::string>> block_var_names_set(program_desc_->Size());
for (size_t i = 0; i < program_desc_->Size(); ++i) {
for (auto op : program_desc_->Block(i).AllOps()) {
auto in_names = op->InputArgumentNames();
block_var_names_set[i].insert(in_names.begin(), in_names.end());
auto out_names = op->OutputArgumentNames();
if (op->HasAttr("sub_block") == false) {
for (auto& n : out_names) {
if (block_var_names_set[i].count(n)) {
vars_appear_multi_in_one_block_[i][n].push_back(op->Type());
}
}
}
block_var_names_set[i].insert(out_names.begin(), out_names.end());
}
}
for (size_t i = 0; i < program_desc_->Size() - 1; ++i) {
for (size_t j = i + 1; j < program_desc_->Size(); ++j) {
std::set<std::string> vars_in_multi_block;
std::set_intersection(
block_var_names_set[i].begin(),
block_var_names_set[i].end(),
block_var_names_set[j].begin(),
block_var_names_set[j].end(),
std::inserter(vars_in_multi_block, vars_in_multi_block.begin()));
for (auto name : vars_in_multi_block) {
vars_in_multi_block_map_.emplace(
name, std::make_pair(framework::proto::VarType::FP32, i));
}
}
}
}
void ConvertToMixedPrecisionPass::ConvertAllFp64ToFp32(
framework::ir::Graph* graph) {
auto op_nodes = framework::ir::TopologySortOperations(*graph);
for (auto* op_node : op_nodes) {
if (!op_node->IsOp()) continue;
auto op_type = op_node->Op()->Type();
if (op_type == "feed" || op_type == "fetch") continue;
if (op_type == "fill_constant") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) ==
static_cast<int>(framework::proto::VarType::FP64))
op_node->Op()->SetAttr(
"dtype", static_cast<int>(framework::proto::VarType::FP32));
} else if (op_type == "assign_value") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) ==
static_cast<int>(framework::proto::VarType::FP64))
op_node->Op()->SetAttr(
"dtype", static_cast<int>(framework::proto::VarType::FP32));
} else if (op_type == "eye") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) ==
static_cast<int>(framework::proto::VarType::FP64))
op_node->Op()->SetAttr(
"dtype", static_cast<int>(framework::proto::VarType::FP32));
} else if (op_type == "fill_any_like") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) ==
static_cast<int>(framework::proto::VarType::FP64))
op_node->Op()->SetAttr(
"dtype", static_cast<int>(framework::proto::VarType::FP32));
} else if (op_type == "cast") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("in_dtype")) ==
static_cast<int>(framework::proto::VarType::FP64))
op_node->Op()->SetAttr(
"in_dtype", static_cast<int>(framework::proto::VarType::FP32));
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("out_dtype")) ==
static_cast<int>(framework::proto::VarType::FP64))
op_node->Op()->SetAttr(
"out_dtype", static_cast<int>(framework::proto::VarType::FP32));
}
auto inputs = op_node->inputs;
for (auto* in_node : inputs) {
auto* in_var = in_node->Var();
if (!in_var->Persistable() &&
in_var->GetDataType() == framework::proto::VarType::FP64) {
in_var->SetDataType(framework::proto::VarType::FP32);
}
}
}
new framework::ir::Graph(*program_desc));
main_graph_->SetNotOwned(framework::ir::kParamScopeAttr, &scope_);
}
void ConvertToMixedPrecisionPass::Run() {
LoadAndPrepare();
for (size_t i = 0; i < main_graph_->SubGraphsSize(); ++i) {
auto graph = main_graph_->GetSubGraph(i);
graphes_.push_back(graph);
VLOG(2) << " -------- handle subgraph " << i << ", has "
<< graph->Nodes().size() << " nodes --------";
LoadModel();
ConvertAllFp64ToFp32(graph);
ConvertTensorDtype(i);
FixCastAttr(graph);
framework::ir::AutoMixedPrecisionPass pass;
pass.Set("mixed_precision_mode", new int{static_cast<int>(mixed_precision_)});
pass.Set("mixed_black_list",
new std::unordered_set<std::string>{black_list_});
pass.Set("enable_gpu_mixed", new bool{true});
pass.Set("keep_io_types", new bool{keep_io_types_});
// A trick
PatchForStrangeOp();
CHECK_EQ(ir::VarDescIsConsistency(*graph), true);
}
pass.Apply(main_graph_.get());
SaveMixedModel();
}
void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
auto graph = graphes_[block_idx];
framework::proto::VarType::Type to_type;
if (mixed_precision_ == phi::DataType::FLOAT16) {
to_type = framework::proto::VarType::FP16;
} else if (mixed_precision_ == phi::DataType::BFLOAT16) {
to_type = framework::proto::VarType::BF16;
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"mixed_precision currently not supported dtype %d, we now only "
"support fp16 and bf16.",
static_cast<int>(mixed_precision_)));
}
auto op_nodes = framework::ir::TopologySortOperations(*graph);
auto* block_desc = op_nodes[0]->Op()->Block();
int num_low_precision = 0;
std::vector<framework::ir::Node*> output_nodes;
for (auto* op_node : op_nodes) {
if (!op_node->IsOp()) continue;
auto op_type = op_node->Op()->Type();
VLOG(3) << "-------------------- op_type " << op_type << ", phi_type "
<< phi::TransToPhiKernelName(op_type);
// 1. set input dtype.
if (op_type == "feed") {
auto feed_var = op_node->outputs[0]->Var();
if (!keep_io_types_ &&
feed_var->GetDataType() == framework::proto::VarType::FP32) {
feed_var->SetDataType(to_type);
}
} else if (op_type == "fetch") {
auto* fetch_var = op_node->inputs[0];
output_nodes.push_back(fetch_var);
continue;
} else if (op_type == "cast") {
continue;
}
else if (op_node->Op()->HasAttr("sub_block")) { // NOLINT
// sub_block op's output dtype should be same as input dtype, if have the
// same name.
std::unordered_map<std::string, framework::ir::Node*> in_name_to_node;
for (auto* in : op_node->inputs) {
auto* real_node = GetRealNode(block_idx, in);
if (NodeVarHasDtype(real_node)) {
in_name_to_node[in->Name()] = in;
}
}
for (auto out : op_node->outputs) {
auto* real_node = GetRealNode(block_idx, out);
if (NodeVarHasDtype(real_node)) {
if (in_name_to_node.count(out->Name()))
real_node->Var()->SetDataType(
in_name_to_node[out->Name()]->Var()->GetDataType());
}
}
continue;
}
// 2. if op support fp16/bf16 and not in blacklist.
// - cast weight to fp16/bf16.
// - add cast op if the input dtype is not fp16/bf16.
// - set output dtype.
//
// If a var(op's out var) appears multiple times in a block, we should not
// convert to fp16.
else if (black_list_.count(op_type) == 0 && // NOLINT
!VarIsMultiPrecisionOpsOut(block_idx, op_node)) {
bool support_precision =
OpSupportPrecision(op_type, backend_, mixed_precision_, black_list_);
// if op not has float input, we will not choose the low precision kernel.
{
bool has_float_input{false};
for (auto in_node : op_node->inputs) {
auto* real_node = GetRealNode(block_idx, in_node);
if (real_node->Var()->GetDataType() == proto::VarType::FP16 ||
real_node->Var()->GetDataType() == proto::VarType::FP32 ||
real_node->Var()->GetDataType() == proto::VarType::FP64 ||
real_node->Var()->GetDataType() == proto::VarType::BF16) {
has_float_input = true;
break;
}
}
if (!has_float_input) {
support_precision = false;
VLOG(2) << " op doesn't has float input, just skip.";
}
}
VLOG(2) << " support low precision " << support_precision;
if (support_precision) {
VLOG(2) << " process input nodes:";
++num_low_precision;
auto inputs = op_node->inputs;
// Just for paddle's terriable case: op's input and output has the same
// name.
std::unordered_map<std::string, std::string> names_map;
for (auto out_node : op_node->outputs) {
for (auto in_node : op_node->inputs) {
if (out_node->Name() == in_node->Name()) {
names_map[out_node->Name()] = in_node->Name();
}
}
}
// Process inputs.
for (auto* in_node : inputs) {
ProcessInputNode(
true, in_node, op_node, &suffix_, block_desc, to_type, block_idx);
if (names_map.count(in_node->Name()) && cast_map_.count(in_node)) {
names_map[in_node->Name()] = cast_map_[in_node]->Name();
}
}
VLOG(2) << " process output nodes:";
// Process outputs.
for (auto* out_node : op_node->outputs) {
ProcessOutputNode(block_idx, out_node, to_type);
}
} else {
auto inputs = op_node->inputs;
for (auto* in_node : inputs) {
ProcessInputNode(false,
in_node,
op_node,
&suffix_,
block_desc,
framework::proto::VarType::FP32,
block_idx);
}
}
}
// 3. check op not support fp16/bf16 or in blacklist.
// - add cast op if the input dtype is not fp32.
else { // NOLINT
VLOG(3) << "not to run fp16 op_type: " << op_type;
auto ins = op_node->inputs;
for (auto* in_node : ins) {
auto* in_var = in_node->Var();
if (in_var->GetDataType() == to_type) {
AddCastOp(graph,
in_node,
op_node,
to_type,
framework::proto::VarType::FP32,
&suffix_,
block_desc,
&cast_map_);
VLOG(3) << "-- " << in_node->Name() << "(" << to_type << ") to "
<< cast_map_[in_node]->Name() << "("
<< framework::proto::VarType::FP32 << ")";
}
}
}
}
// 4. if output_op's dtype is not compatible to output dtype, then just
// insert cast.
for (auto* node : output_nodes) {
ir::Node* fetch_op{nullptr};
for (auto* op_node : node->outputs) {
if (op_node->IsOp() && op_node->Op()->Type() == "fetch") {
fetch_op = op_node;
}
}
CHECK_NOTNULL(fetch_op);
auto var = node->Var();
if (keep_io_types_ && var->GetDataType() == to_type) {
// fp16/bf16 -> fp32.
AddCastOp(graph,
node,
fetch_op,
to_type,
framework::proto::VarType::FP32,
&suffix_,
block_desc,
&cast_map_);
} else if (!keep_io_types_ &&
var->GetDataType() == framework::proto::VarType::FP32) {
// fp32 -> fp16/bf16
AddCastOp(graph,
node,
fetch_op,
framework::proto::VarType::FP32,
to_type,
&suffix_,
block_desc,
&cast_map_);
}
}
for (auto node : graph->Nodes()) {
auto* real_node = GetRealNode(block_idx, node);
if (!NodeVarHasDtype(real_node)) continue;
if (vars_in_multi_block_map_.count(real_node->Name()) &&
vars_in_multi_block_map_.at(real_node->Name()).second == block_idx) {
vars_in_multi_block_map_.at(real_node->Name()).first =
real_node->Var()->GetDataType();
}
}
if (num_low_precision)
LOG(INFO) << "--- detected " << num_low_precision
<< " low precision ops in " << block_idx << " subgraph";
}
// We modify op's input output precision, and we need to fix cast op in_dtype
// and out_dtype attribute.
void ConvertToMixedPrecisionPass::FixCastAttr(framework::ir::Graph* graph) {
auto op_nodes = framework::ir::TopologySortOperations(*graph);
for (auto* op_node : op_nodes) {
if (!op_node->IsOp()) continue;
auto op_type = op_node->Op()->Type();
if (op_type != "cast") continue;
auto input = op_node->inputs[0];
auto output = op_node->outputs[0];
op_node->Op()->SetAttr("in_dtype",
static_cast<int>(input->Var()->GetDataType()));
op_node->Op()->SetAttr("out_dtype",
static_cast<int>(output->Var()->GetDataType()));
}
}
void ConvertToMixedPrecisionPass::SaveMixedModel() {
framework::ProgramDesc mixed_program_desc;
framework::ir::GraphToProgram(*main_graph_, &mixed_program_desc);
paddle::CPUPlace place;
auto parameters = scope_.LocalVarNames();
std::sort(parameters.begin(), parameters.end());
std::unordered_set<std::string> weights_should_be_fp32;
for (auto* node : main_graph_->Nodes()) {
if (!(node->IsVar())) continue;
if (NodeVarHasDtype(node)) {
if (node->Var()->Persistable() &&
node->Var()->GetDataType() ==
paddle::framework::proto::VarType::FP32) {
VLOG(2) << "weights keep to fp32: " << node->Name();
weights_should_be_fp32.insert(node->Name());
}
}
}
#define CONVERT_TENSOR_DTYPE(DTYPE, dtype) \
mixed_tensor.set_type(DTYPE); \
auto* mixed_data = mixed_tensor.mutable_data<dtype>(platform::CPUPlace()); \
for (int i = 0; i < t->numel(); i++) { \
mixed_data[i] = static_cast<dtype>(data[i]); \
} \
t->clear(); \
paddle::framework::TensorCopySync(mixed_tensor, place, t)
for (const auto& param_name : parameters) {
auto* var = scope_.FindLocalVar(param_name);
if (var->IsType<phi::DenseTensor>()) {
auto* t = var->GetMutable<phi::DenseTensor>();
if (t->dtype() != phi::DataType::FLOAT32) continue;
phi::DenseTensor mixed_tensor;
mixed_tensor.Resize(t->dims());
auto* data = t->mutable_data<float>(platform::CPUPlace());
if (mixed_precision_ == phi::DataType::FLOAT16 &&
!weights_should_be_fp32.count(param_name)) {
CONVERT_TENSOR_DTYPE(paddle::experimental::DataType::FLOAT16,
phi::dtype::float16);
} else if (mixed_precision_ == phi::DataType::BFLOAT16 &&
!weights_should_be_fp32.count(param_name)) {
CONVERT_TENSOR_DTYPE(paddle::experimental::DataType::BFLOAT16,
phi::dtype::bfloat16);
}
}
}
#undef CONVERT_TENSOR_DTYPE
auto SerializeParams = [&]() -> std::string {
std::ostringstream os;
phi::CPUContext ctx;
for (const auto& param : parameters) {
VLOG(3) << "Serialize param: " << param;
PADDLE_ENFORCE_NOT_NULL(
scope_.FindVar(param),
platform::errors::NotFound(
"Block should already have a '%s' variable", param));
auto* tensor = scope_.FindVar(param)->GetMutable<framework::LoDTensor>();
auto* tensor = scope_.FindVar(param)->GetMutable<phi::DenseTensor>();
framework::SerializeToStream(os, *tensor, ctx);
}
return os.str();
......@@ -831,96 +112,42 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() {
StrToBinary(mixed_params_file_, SerializeParams());
}
void ConvertToMixedPrecisionPass::PatchForStrangeOp() {
for (auto* graph : graphes_) {
for (auto op_node : framework::ir::TopologySortOperations(*graph)) {
if (op_node->Name() == "fused_multi_transformer") {
auto cache_kv_inputs = op_node->Op()->Input("CacheKV");
auto cache_kv_outputs = op_node->Op()->Output("CacheKVOut");
CHECK_EQ(cache_kv_inputs.size(), cache_kv_outputs.size());
for (size_t i = 0; i < cache_kv_inputs.size(); ++i) {
op_node->Op()->RenameOutput(cache_kv_outputs[i], cache_kv_inputs[i]);
}
}
}
}
bool OpSupportPrecision(const std::string& op_type,
phi::Backend backend,
phi::DataType precision,
const std::unordered_set<std::string>& black_list) {
return framework::ir::OpSupportPrecision(
op_type, backend, precision, black_list);
}
} // namespace
void AddCastOp(
void InsertCastOp(
framework::ir::Graph* graph,
framework::ir::Node* node,
framework::ir::Node* next_op,
framework::ir::Node* var_node,
framework::ir::Node* op_node,
framework::proto::VarType::Type from_type,
framework::proto::VarType::Type to_type,
int* suffix,
framework::BlockDesc* block_desc,
std::unordered_map<framework::ir::Node*, framework::ir::Node*>* map) {
auto update_cast_desc = [&](framework::OpDesc& desc,
const std::string& x_name,
const std::string& out_name,
const int in_dtype,
const int out_dtype) {
desc.SetType("cast");
desc.SetInput("X", {x_name});
desc.SetOutput("Out", {out_name});
desc.SetAttr("in_dtype", in_dtype);
desc.SetAttr("out_dtype", out_dtype);
desc.SetAttr("use_mkldnn", false);
desc.SetAttr("with_quant_attr", false);
desc.Flush();
};
if (map->count(node) == 0) {
// insert cast op before node.
std::string cast_input_name = node->Var()->Name();
std::string cast_output_name =
node->Var()->Name() + "_cast.tmp_" + std::to_string((*suffix)++);
CHECK_NOTNULL(block_desc);
framework::OpDesc cast_op_desc(block_desc);
update_cast_desc(cast_op_desc,
cast_input_name,
cast_output_name,
static_cast<int>(from_type),
static_cast<int>(to_type));
auto* cast_op_node = graph->CreateOpNode(&cast_op_desc);
auto* cast_output_vardesc = block_desc->Var(cast_output_name);
cast_output_vardesc->SetPersistable(false);
cast_output_vardesc->SetDataType(to_type);
cast_output_vardesc->SetShape(node->Var()->GetShape());
auto* cast_output_node = graph->CreateVarNode(cast_output_vardesc);
IR_NODE_LINK_TO(cast_op_node, cast_output_node);
(*map)[node] = cast_output_node;
}
next_op->Op()->Rename(node->Name(), map->at(node)->Name());
IR_NODE_LINK_TO(node, map->at(node)->inputs[0]);
IR_NODE_LINK_TO(map->at(node), next_op);
}
bool OpSupportPrecision(const std::string& op_type,
phi::Backend backend,
phi::DataType precision,
const std::unordered_set<std::string>& blacklist) {
auto phi_op_type = phi::TransToPhiKernelName(op_type);
bool support_precision = false;
if (blacklist.count(op_type) == 0) {
if (backend == phi::Backend::GPU)
support_precision = GpuKernelSupportPrecision(op_type, precision);
else
support_precision =
PhiKernelSupportPrecision(phi_op_type, backend, precision);
}
return support_precision;
int* suffix,
std::unordered_map<framework::ir::Node*, framework::ir::Node*>* visited) {
framework::ir::DoInsertCastOp(graph,
var_node,
op_node,
from_type,
to_type,
block_desc,
suffix,
visited);
}
void ConvertToMixedPrecision(const std::string& model_file,
void ConvertToMixedPrecision(
const std::string& model_file,
const std::string& params_file,
const std::string& mixed_model_file,
const std::string& mixed_params_file,
phi::DataType mixed_precision,
phi::Backend backend,
bool keep_io_types,
std::unordered_set<std::string> black_list) {
const std::unordered_set<std::string>& black_list) {
ConvertToMixedPrecisionPass pass(model_file,
params_file,
mixed_model_file,
......
......@@ -15,14 +15,12 @@
#pragma once
#include <string>
#include <unordered_map>
#include <unordered_set>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/phi/common/backend.h"
#include "paddle/phi/common/data_type.h"
......@@ -30,20 +28,52 @@ namespace paddle {
namespace inference {
namespace analysis {
class ConvertToMixedPrecisionPass {
public:
explicit ConvertToMixedPrecisionPass(
const std::string& model_file,
const std::string& params_file,
const std::string& mixed_model_file,
const std::string& mixed_params_file,
phi::DataType mixed_precision,
phi::Backend backend,
bool keep_io_types,
const std::unordered_set<std::string>& black_list);
void Run();
private:
void LoadModel();
void SaveMixedModel();
private:
std::string model_file_;
std::string params_file_;
std::string mixed_model_file_;
std::string mixed_params_file_;
phi::DataType mixed_precision_;
phi::Backend backend_;
bool keep_io_types_;
std::unordered_set<std::string> black_list_;
framework::Scope scope_;
std::unique_ptr<framework::ir::Graph> main_graph_{nullptr};
};
bool OpSupportPrecision(const std::string& op_type,
phi::Backend backend,
phi::DataType precision,
const std::unordered_set<std::string>& blacklist);
const std::unordered_set<std::string>& black_list);
void AddCastOp(
void InsertCastOp(
framework::ir::Graph* graph,
framework::ir::Node* node,
framework::ir::Node* next_op,
framework::ir::Node* var_node,
framework::ir::Node* op_node,
framework::proto::VarType::Type from_type,
framework::proto::VarType::Type to_type,
int* suffix,
framework::BlockDesc* block_desc,
std::unordered_map<framework::ir::Node*, framework::ir::Node*>* map);
int* suffix,
std::unordered_map<framework::ir::Node*, framework::ir::Node*>* visited);
void ConvertToMixedPrecision(const std::string& model_file,
const std::string& params_file,
......@@ -51,8 +81,8 @@ void ConvertToMixedPrecision(const std::string& model_file,
const std::string& mixed_params_file,
phi::DataType mixed_precision,
phi::Backend backend,
bool keep_io_types = true,
std::unordered_set<std::string> black_list = {});
bool keep_io_types,
const std::unordered_set<std::string>& black_list);
} // namespace analysis
} // namespace inference
......
......@@ -40,7 +40,7 @@ void InferenceOpReplacePass::RunImpl(Argument* argument) {
}
std::string InferenceOpReplacePass::repr() const {
return "inference-op-replace-pass";
return "inference_op_replace_pass";
}
} // namespace analysis
......
......@@ -105,7 +105,7 @@ void IrAnalysisPass::CollectFusionStatis(Argument* argument) {
framework::ir::kFuseStatisAttr));
}
std::string IrAnalysisPass::repr() const { return "ir-analysis-pass"; }
std::string IrAnalysisPass::repr() const { return "ir_analysis_pass"; }
} // namespace analysis
} // namespace inference
......
......@@ -64,7 +64,8 @@ void IrGraphBuildPass::RunImpl(Argument *argument) {
"set."));
}
auto graph = std::unique_ptr<Graph>(new Graph(argument->main_program()));
auto graph = std::unique_ptr<framework::ir::Graph>(
new framework::ir::Graph(argument->main_program()));
argument->SetMainGraph(graph.release());
auto *scope_ptr = argument->scope_ptr();
PADDLE_ENFORCE_NOT_NULL(scope_ptr,
......@@ -125,7 +126,7 @@ std::unique_ptr<framework::ProgramDesc> IrGraphBuildPass::LoadModel(
}
}
std::string IrGraphBuildPass::repr() const { return "ir-graph-build-pass"; }
std::string IrGraphBuildPass::repr() const { return "ir_graph_build_pass"; }
} // namespace analysis
} // namespace inference
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/passes/ir_graph_clean_pass.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/node.h"
namespace paddle {
namespace inference {
namespace analysis {
void IrInferCleanGraphPass::RunImpl(Argument* argument) {
auto& graph = argument->main_graph();
auto is_valid_node = [](framework::ir::Node* x) {
return x && IsControlDepVar(*x) && x->IsVar() && !x->Var();
};
std::unordered_set<const framework::ir::Node*> invalid_nodes;
int valid_op = 0;
for (auto* node : graph.Nodes()) {
PADDLE_ENFORCE_NOT_NULL(node,
platform::errors::PreconditionNotMet(
"The node should not be nullptr."));
if (is_valid_node(node)) {
invalid_nodes.insert(node);
} else if (node->IsOp()) {
++valid_op;
}
}
GraphSafeRemoveNodes(&graph, invalid_nodes);
}
} // namespace analysis
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_set>
#include "paddle/fluid/inference/analysis/analysis_pass.h"
namespace paddle {
namespace inference {
namespace analysis {
struct Argument;
class IrInferCleanGraphPass : public AnalysisPass {
public:
void RunImpl(Argument *argument) override;
std::string repr() const override { return "ir_graph_clean_pass"; }
};
} // namespace analysis
} // namespace inference
} // namespace paddle
......@@ -31,7 +31,7 @@ void IrGraphToProgramPass::RunImpl(Argument *argument) {
new int(argument->memory_optim_sort_kind()));
}
std::unique_ptr<Graph> graph(argument->main_graph_ptr());
std::unique_ptr<framework::ir::Graph> graph(argument->main_graph_ptr());
// Direct using ProgramDesc desc(argument->main_program()) may cause
// incomplete copies of information.
......
......@@ -28,7 +28,7 @@ class IrGraphToProgramPass : public AnalysisPass {
public:
void RunImpl(Argument *argument) override;
std::string repr() const override { return "ir-graph-to-param-pass"; }
std::string repr() const override { return "ir_graph_to_param_pass"; }
};
} // namespace analysis
......
......@@ -169,7 +169,7 @@ void IrParamsSyncAmongDevicesPass::RunImpl(Argument *argument) {
}
std::string IrParamsSyncAmongDevicesPass::repr() const {
return "ir-params-sync-among-devices-pass";
return "ir_params_sync_among_devices_pass";
}
} // namespace analysis
......
......@@ -295,7 +295,7 @@ void UpdateOpDescsByReuse(
}
}
std::string MemoryOptimizePass::repr() const { return "memory optimize pass"; }
std::string MemoryOptimizePass::repr() const { return "memory_optimize_pass"; }
void MemoryOptimizePass::RunImpl(Argument* argument) {
// Memory optimization.
......
......@@ -18,7 +18,6 @@
#include "paddle/fluid/inference/analysis/passes/inference_op_replace_pass.h"
#include "paddle/fluid/inference/analysis/passes/ir_analysis_pass.h"
#include "paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h"
#include "paddle/fluid/inference/analysis/passes/ir_graph_clean_pass.h"
#include "paddle/fluid/inference/analysis/passes/ir_graph_to_program_pass.h"
#include "paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.h"
#include "paddle/fluid/inference/analysis/passes/memory_optimize_pass.h"
......@@ -34,8 +33,6 @@ PassRegistry::PassRegistry() {
std::unique_ptr<AnalysisPass>(new IrAnalysisPass));
passes_.emplace("ir_graph_build_pass",
std::unique_ptr<AnalysisPass>(new IrGraphBuildPass));
passes_.emplace("ir_graph_clean_pass",
std::unique_ptr<AnalysisPass>(new IrInferCleanGraphPass));
passes_.emplace("memory_optimize_pass",
std::unique_ptr<AnalysisPass>(new MemoryOptimizePass));
passes_.emplace(
......
......@@ -85,15 +85,29 @@ void AnalysisConfig::SetModel(const std::string &prog_file_path,
Update();
}
void AnalysisConfig::EnableUseGpu(uint64_t memory_pool_init_size_mb,
int device_id) {
int device_id,
Precision precision_mode) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
use_gpu_ = true;
memory_pool_init_size_mb_ = memory_pool_init_size_mb;
FLAGS_initial_gpu_memory_in_mb = memory_pool_init_size_mb_;
gpu_device_id_ = device_id;
mixed_precision_mode_ = precision_mode;
if (precision_mode == Precision::kFloat32) {
// default
} else if (precision_mode == Precision::kHalf ||
precision_mode == Precision::kBf16) {
enable_gpu_mixed_ = true;
} else {
LOG(ERROR)
<< "The Paddle-GPU inference currently only supports "
"float32/float16/bfloat16 precision. Please check the parameters "
"you specified in EnableUseGpu or enable_use_gpu function.";
}
#else
LOG(ERROR) << "Please compile with gpu to EnableGpu()";
LOG(ERROR) << "Please use PaddlePaddle with GPU version.";
use_gpu_ = false;
#endif
......@@ -279,7 +293,7 @@ void AnalysisConfig::LoadIpuConfig(const std::string &config_path) {
if (ipu_config_mapper_.find(key) == ipu_config_mapper_.end()) {
PADDLE_THROW(platform::errors::InvalidArgument(
"invalid key {} in IPU config", key));
"invalid key %s in IPU config: ", key));
}
switch (ipu_config_mapper_.at(key)) {
case ipu_config_code::ipu_device_num:
......@@ -315,7 +329,7 @@ void AnalysisConfig::LoadIpuConfig(const std::string &config_path) {
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"invalid key {} in IPU config", key));
"invalid key %s in IPU config", key));
break;
}
}
......@@ -372,8 +386,10 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
CP_MEMBER(gpu_device_id_);
CP_MEMBER(memory_pool_init_size_mb_);
// Mixed related.
// Mixed precision related.
CP_MEMBER(mixed_black_list_);
CP_MEMBER(enable_gpu_mixed_);
CP_MEMBER(mixed_precision_mode_);
CP_MEMBER(enable_memory_optim_);
// TensorRT related.
......@@ -740,13 +756,7 @@ void AnalysisConfig::Update() {
((use_custom_device() ^ pass_builder_->use_custom_device()))) {
if (use_gpu()) {
pass_builder_.reset(new GpuPassStrategy);
if (use_tensorrt_) {
// Append after the Affine_channel_conv_fuse pass.
pass_builder()->InsertPass(3, "tensorrt_subgraph_pass");
}
} else if (use_ipu()) {
VLOG(1) << "IpuPassStrategy has been used for new.";
pass_builder_.reset(new IpuPassStrategy);
} else if (use_xpu()) {
PADDLE_ENFORCE_EQ(
......@@ -946,9 +956,6 @@ void AnalysisConfig::Update() {
"but did not have the option -DWITH_CUSTOM_DEVICE compiled."));
#endif
}
if (ir_debug_) {
pass_builder()->TurnOnDebug();
}
}
std::string AnalysisConfig::SerializeInfoCache() {
......@@ -960,6 +967,7 @@ std::string AnalysisConfig::SerializeInfoCache() {
ss << calibration_file_path_;
ss << use_gpu_;
ss << enable_gpu_mixed_;
ss << use_external_stream_;
ss << exec_stream_;
ss << use_fc_padding_;
......@@ -1167,6 +1175,7 @@ std::string AnalysisConfig::Summary() {
os.InsertRow({"use_gpu", use_gpu_ ? "true" : "false"});
if (use_gpu_) {
os.InsertRow({"gpu_device_id", std::to_string(gpu_device_id_)});
os.InsertRow({"enable_gpu_mixed", std::to_string(enable_gpu_mixed_)});
os.InsertRow({"memory_pool_init_size",
std::to_string(memory_pool_init_size_mb_) + "MB"});
os.InsertRow(
......@@ -1360,7 +1369,7 @@ bool AnalysisConfig::trt_allow_build_at_runtime() {
return trt_allow_build_at_runtime_;
}
void AnalysisConfig::Exp_SetBlackListOpsForMixedModel(
void AnalysisConfig::Exp_DisableMixedPrecisionOps(
const std::unordered_set<std::string> &black_list) {
mixed_black_list_ = black_list;
}
......
......@@ -1065,7 +1065,7 @@ void AnalysisPredictor::PrepareArgument() {
argument_.SetUseGPU(config_.use_gpu());
argument_.SetUseFcPadding(config_.use_fc_padding());
argument_.SetGPUDeviceId(config_.gpu_device_id());
argument_.SetEnableAnalysisOptim(config_.enable_ir_optim_);
argument_.SetEnableIrOptim(config_.enable_ir_optim_);
argument_.SetEnableMemoryOptim(config_.enable_memory_optim());
argument_.SetModelFromMemory(config_.model_from_memory_);
// Analyze inference_program
......@@ -1210,53 +1210,57 @@ void AnalysisPredictor::PrepareArgument() {
}
#endif
auto passes = config_.pass_builder()->AllPasses();
auto *pass_builder = config_.pass_builder();
if (model_precision_ != phi::DataType::FLOAT32) {
LOG(INFO) << "Model is mixed precision type with " << model_precision_
<< ", we will use a new PassStrategy. Note that only the GPU "
"backend is supported for now.";
passes.clear();
pass_builder->ClearPasses();
const auto &deleted_passes = pass_builder->GetAllDeletedPasses();
if (config_.tensorrt_engine_enabled()) {
for (const auto &pass : kTrtLowerPrecisionPasses) {
passes.push_back(pass);
if (deleted_passes.count(pass)) continue;
pass_builder->AppendPass(pass);
}
} else if (config_.use_gpu()) {
for (const auto &pass : kGpuLowerPrecisionPasses) {
passes.push_back(pass);
if (deleted_passes.count(pass)) continue;
pass_builder->AppendPass(pass);
}
}
const auto &deleted_passes = config_.pass_builder()->GetAllDeletedPasses();
for (const auto &it : deleted_passes) {
auto iterator = std::find(passes.begin(), passes.end(), it);
if (iterator != passes.end()) {
passes.erase(iterator);
}
}
if (config_.ir_debug_) {
auto it = std::begin(passes);
while (it != std::end(passes)) {
if (*it != "graph_viz_pass") {
it = passes.insert(it + 1, "graph_viz_pass");
if (!config_.ir_optim()) {
argument_.SetEnableIrOptim(false);
if (config_.enable_gpu_mixed_) {
argument_.SetEnableIrOptim(true);
pass_builder->ClearPasses();
pass_builder->AppendPass("auto_mixed_precision_pass");
LOG(INFO)
<< "This model run in Paddle-GPU mixed precision mode with no ir "
"optimization.";
} else {
++it;
}
LOG(INFO) << "ir_optim is turned off, no IR pass will be executed.";
}
} else {
if (config_.ir_debug_) {
pass_builder->TurnOnDebug();
}
if (config_.enable_gpu_mixed_) {
LOG(INFO) << "This model run in Paddle-GPU mixed precision mode.";
}
if (!config_.ir_optim()) {
passes.clear();
LOG(INFO) << "ir_optim is turned off, no IR pass will be executed";
}
argument_.SetDisableLogs(config_.glog_info_disabled());
argument_.SetIrAnalysisPasses(passes);
argument_.SetAnalysisPasses(config_.pass_builder()->AnalysisPasses());
argument_.SetIrAnalysisPasses(pass_builder->AllPasses());
argument_.SetAnalysisPasses(pass_builder->AnalysisPasses());
argument_.SetScopeNotOwned(scope_.get());
// mixed precison.
argument_.SetModelPrecision(static_cast<int>(model_precision_));
argument_.SetMixedBlackList(config_.mixed_black_list_);
argument_.SetEnableGPUMixed(config_.enable_gpu_mixed_);
argument_.SetMixedPrecisionMode(static_cast<int>(
paddle::ConvertPrecision(config_.mixed_precision_mode_)));
}
// NOTE All the members in AnalysisConfig should be copied to Argument.
......@@ -2107,7 +2111,9 @@ std::unique_ptr<PaddlePredictor> AnalysisPredictor::Clone(void *stream) {
}
x->predictor_stream_ = stream;
x->Init(scope_, inference_program_);
#ifdef PADDLE_WITH_TENSORRT
x->executor_->ResetTrtOps(++AnalysisPredictor::clone_num_);
#endif
return std::unique_ptr<PaddlePredictor>(x);
}
......
......@@ -604,10 +604,8 @@ void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const {
if (predictor_.config_.ir_debug_) builder->TurnOnDebug();
auto passes = builder->AllPasses();
predictor_.argument_.SetIrAnalysisPasses(passes);
predictor_.argument_.SetAnalysisPasses({"ir_graph_clean_pass",
"ir_analysis_pass",
"memory_optimize_pass",
"ir_graph_to_program_pass"});
predictor_.argument_.SetAnalysisPasses(
{"ir_analysis_pass", "memory_optimize_pass", "ir_graph_to_program_pass"});
predictor_.argument_.SetQuantVarScales(scales_);
}
......
......@@ -247,8 +247,12 @@ struct PD_INFER_DECL AnalysisConfig {
///
/// \param memory_pool_init_size_mb initial size of the GPU memory pool in MB.
/// \param device_id device_id the GPU card to use (default is 0).
/// \param precision the precision used in Paddle-GPU inference.
///
void EnableUseGpu(uint64_t memory_pool_init_size_mb, int device_id = 0);
void EnableUseGpu(uint64_t memory_pool_init_size_mb,
int device_id = 0,
Precision precision_mode = Precision::kFloat32);
///
/// \brief Turn off GPU.
///
......@@ -967,7 +971,7 @@ struct PD_INFER_DECL AnalysisConfig {
/// interface is in the experimental stage and may change in the future. Note
/// that the blacklist must be the same as the model conversion blacklist.
///
void Exp_SetBlackListOpsForMixedModel(
void Exp_DisableMixedPrecisionOps(
const std::unordered_set<std::string>& black_list);
void SetApplyOptim(bool value) { apply_optim_ = value; }
......@@ -987,13 +991,15 @@ struct PD_INFER_DECL AnalysisConfig {
mutable std::string params_file_;
mutable std::string calibration_file_path_;
// Mixed precision.
// Mixed precision related.
Precision mixed_precision_mode_{Precision::kFloat32};
std::unordered_set<std::string> mixed_black_list_;
// GPU related.
bool use_gpu_{false};
int gpu_device_id_{0};
uint64_t memory_pool_init_size_mb_{100}; // initial size is 100MB.
bool enable_gpu_mixed_{false};
bool thread_local_stream_{false};
bool use_cudnn_{false};
......
......@@ -227,9 +227,10 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"conv_elementwise_add_fuse_pass", //
#endif //
"transpose_flatten_concat_fuse_pass", //
"constant_folding_pass",
"constant_folding_pass", //
// following pass should be located in the last, since it will
// work on all fused ops.
"auto_mixed_precision_pass", //
"runtime_context_cache_pass"
});
......
......@@ -115,7 +115,6 @@ class PD_INFER_DECL PaddlePassBuilder {
/// \cond Protected
std::vector<std::string> analysis_passes_{
{"ir_graph_build_pass",
"ir_graph_clean_pass",
"ir_analysis_pass",
"ir_params_sync_among_devices_pass",
"adjust_cudnn_workspace_size_pass",
......
......@@ -294,15 +294,6 @@ class TensorRTEngine {
nvinfer1::ICudaEngine* engine() { return infer_engine_.get(); }
nvinfer1::IExecutionContext* context() {
#ifndef PADDLE_WITH_TESTING
PADDLE_ENFORCE_GT(
predictor_id_per_thread,
-1,
platform::errors::InvalidArgument(
"thread local var predictor_id_per_thread must be "
"initialized to >= 0, but now predictor_id_per_thread = %d",
predictor_id_per_thread));
#endif
std::unique_lock<std::mutex> lock(mutex_);
if (infer_context_.find(predictor_id_per_thread) == infer_context_.end()) {
PADDLE_ENFORCE_NOT_NULL(
......@@ -329,15 +320,6 @@ class TensorRTEngine {
int GetProfileIndex() {
if (max_profile_num_ > 1) {
#ifndef PADDLE_WITH_TESTING
PADDLE_ENFORCE_GT(
predictor_id_per_thread,
-1,
platform::errors::InvalidArgument(
"thread local var predictor_id_per_thread must be "
"initialized to >= 0, but now predictor_id_per_thread = %d",
predictor_id_per_thread));
#endif
std::unique_lock<std::mutex> lock(mutex_);
return profile_index_[predictor_id_per_thread];
} else {
......@@ -356,15 +338,6 @@ class TensorRTEngine {
infer_engine_,
platform::errors::InvalidArgument(
"You should build engine first and then set the context."));
#ifndef PADDLE_WITH_TESTING
PADDLE_ENFORCE_GT(
predictor_id_per_thread,
-1,
platform::errors::InvalidArgument(
"thread local var predictor_id_per_thread must be "
"initialized to >= 0, but now predictor_id_per_thread = %d",
predictor_id_per_thread));
#endif
std::unique_lock<std::mutex> lock(mutex_);
infer_context_[predictor_id_per_thread].reset(nullptr);
infer_context_.erase(predictor_id_per_thread);
......
......@@ -416,6 +416,9 @@ download_result(${ERNIE_INSTALL_DIR} "Ernie_result.txt.tar.gz"
if(WITH_GPU)
inference_analysis_api_test(test_analyzer_ernie ${ERNIE_INSTALL_DIR}
analyzer_ernie_tester.cc)
inference_analysis_api_test(gpu_ernie_half_test ${ERNIE_INSTALL_DIR}
gpu_ernie_half_test.cc)
set_tests_properties(gpu_ernie_half_test PROPERTIES TIMEOUT 60)
endif()
inference_analysis_api_int8_test(test_analyzer_ernie_int8 ${ERNIE_INSTALL_DIR}
analyzer_ernie_int8_tester.cc)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/tests/api/tester_helper.h"
namespace paddle {
namespace inference {
using paddle::PaddleTensor;
template <typename T>
void GetValueFromStream(std::stringstream *ss, T *t) {
(*ss) >> (*t);
}
template <>
void GetValueFromStream<std::string>(std::stringstream *ss, std::string *t) {
*t = ss->str();
}
// Split string to vector
template <typename T>
void Split(const std::string &line, char sep, std::vector<T> *v) {
std::stringstream ss;
T t;
for (auto c : line) {
if (c != sep) {
ss << c;
} else {
GetValueFromStream<T>(&ss, &t);
v->push_back(std::move(t));
ss.str({});
ss.clear();
}
}
if (!ss.str().empty()) {
GetValueFromStream<T>(&ss, &t);
v->push_back(std::move(t));
ss.str({});
ss.clear();
}
}
// Parse tensor from string
template <typename T>
bool ParseTensor(const std::string &field, paddle::PaddleTensor *tensor) {
std::vector<std::string> data;
Split(field, ':', &data);
if (data.size() < 2) return false;
std::string shape_str = data[0];
std::vector<int> shape;
Split(shape_str, ' ', &shape);
std::string mat_str = data[1];
std::vector<T> mat;
Split(mat_str, ' ', &mat);
tensor->shape = shape;
auto size =
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()) *
sizeof(T);
tensor->data.Resize(size);
std::copy(mat.begin(), mat.end(), static_cast<T *>(tensor->data.data()));
tensor->dtype = GetPaddleDType<T>();
return true;
}
// Parse input tensors from string
bool ParseLine(const std::string &line,
std::vector<paddle::PaddleTensor> *tensors) {
std::vector<std::string> fields;
Split(line, ';', &fields);
tensors->clear();
tensors->reserve(4);
int i = 0;
auto input_name = FLAGS_ernie_large ? "eval_placeholder_" : "placeholder_";
for (; i < 3; i++) {
paddle::PaddleTensor temp;
ParseTensor<int64_t>(fields[i], &temp);
temp.name = input_name + std::to_string(i);
tensors->push_back(temp);
}
// input_mask
paddle::PaddleTensor input_mask;
ParseTensor<float>(fields[i], &input_mask);
input_mask.name = input_name + std::to_string(i);
tensors->push_back(input_mask);
return true;
}
bool LoadInputData(std::vector<std::vector<paddle::PaddleTensor>> *inputs,
int batch_size = 1) {
if (FLAGS_infer_data.empty()) {
LOG(ERROR) << "please set input data path";
return false;
}
std::ifstream fin(FLAGS_infer_data);
std::string line;
int sample = 0;
// The unit-test dataset only have 10 samples, each sample have 5 feeds.
while (std::getline(fin, line)) {
std::vector<paddle::PaddleTensor> feed_data;
ParseLine(line, &feed_data);
inputs->push_back(std::move(feed_data));
sample++;
if (!FLAGS_test_all_data && sample == batch_size) break;
}
LOG(INFO) << "number of samples: " << sample;
return true;
}
// Compare results
TEST(Ernie_gpu_fp16_no_ir, compare_results) {
AnalysisConfig config;
config.SetModel(FLAGS_infer_model);
config.EnableUseGpu(512, 0, paddle_infer::PrecisionType::kHalf);
config.SwitchIrOptim(false);
auto predictor = CreatePaddlePredictor(config);
std::vector<std::vector<PaddleTensor>> input_slots_all;
LoadInputData(&input_slots_all);
std::ifstream fin(FLAGS_refer_result);
std::string line;
std::vector<float> ref;
while (std::getline(fin, line)) {
Split(line, ' ', &ref);
}
std::vector<PaddleTensor> outputs;
for (size_t i = 0; i < input_slots_all.size(); i++) {
outputs.clear();
predictor->Run(input_slots_all[i], &outputs);
auto output = outputs.front();
size_t outputs_size = 1;
for (auto dim : output.shape) {
outputs_size *= dim;
}
float *result = reinterpret_cast<float *>(output.data.data());
for (size_t j = 0; j < outputs_size; ++j) {
EXPECT_NEAR(ref[i * outputs_size + j], result[j], 5e-2);
}
}
}
// Compare results
TEST(Ernie_gpu_fp16_with_ir, compare_results) {
AnalysisConfig config;
config.SetModel(FLAGS_infer_model);
config.EnableUseGpu(512, 0, paddle_infer::PrecisionType::kHalf);
config.SwitchIrOptim(true);
// The fc_fuse_pass has diff, which will be repaired later.
config.pass_builder()->DeletePass("fc_fuse_pass");
// There is a problem with the model itself, which has nothing to do with
// constant_folding_pass.
config.pass_builder()->DeletePass("constant_folding_pass");
auto predictor = CreatePaddlePredictor(config);
std::vector<std::vector<PaddleTensor>> input_slots_all;
LoadInputData(&input_slots_all);
std::ifstream fin(FLAGS_refer_result);
std::string line;
std::vector<float> ref;
while (std::getline(fin, line)) {
Split(line, ' ', &ref);
}
std::vector<PaddleTensor> outputs;
for (size_t i = 0; i < input_slots_all.size(); i++) {
outputs.clear();
predictor->Run(input_slots_all[i], &outputs);
auto output = outputs.front();
size_t outputs_size = 1;
for (auto dim : output.shape) {
outputs_size *= dim;
}
float *result = reinterpret_cast<float *>(output.data.data());
for (size_t j = 0; j < outputs_size; ++j) {
EXPECT_NEAR(ref[i * outputs_size + j], result[j], 5e-2);
}
}
}
// Compare results
TEST(Ernie_gpu_bf16_no_ir, compare_results) {
AnalysisConfig config;
config.SetModel(FLAGS_infer_model);
config.EnableUseGpu(512, 0, paddle_infer::PrecisionType::kBf16);
config.SwitchIrOptim(false);
auto predictor = CreatePaddlePredictor(config);
std::vector<std::vector<PaddleTensor>> input_slots_all;
LoadInputData(&input_slots_all);
std::ifstream fin(FLAGS_refer_result);
std::string line;
std::vector<float> ref;
while (std::getline(fin, line)) {
Split(line, ' ', &ref);
}
std::vector<PaddleTensor> outputs;
for (size_t i = 0; i < input_slots_all.size(); i++) {
outputs.clear();
predictor->Run(input_slots_all[i], &outputs);
auto output = outputs.front();
size_t outputs_size = 1;
for (auto dim : output.shape) {
outputs_size *= dim;
}
float *result = reinterpret_cast<float *>(output.data.data());
for (size_t j = 0; j < outputs_size; ++j) {
EXPECT_NEAR(ref[i * outputs_size + j], result[j], 7e-2);
}
}
}
// Compare results
TEST(Ernie_gpu_bf16_with_ir, compare_results) {
AnalysisConfig config;
config.SetModel(FLAGS_infer_model);
config.EnableUseGpu(512, 0, paddle_infer::PrecisionType::kBf16);
config.SwitchIrOptim(true);
// The fc_fuse_pass has diff, which will be repaired later.
config.pass_builder()->DeletePass("fc_fuse_pass");
// There is a problem with the model itself, which has nothing to do with
// constant_folding_pass.
config.pass_builder()->DeletePass("constant_folding_pass");
auto predictor = CreatePaddlePredictor(config);
std::vector<std::vector<PaddleTensor>> input_slots_all;
LoadInputData(&input_slots_all);
std::ifstream fin(FLAGS_refer_result);
std::string line;
std::vector<float> ref;
while (std::getline(fin, line)) {
Split(line, ' ', &ref);
}
std::vector<PaddleTensor> outputs;
for (size_t i = 0; i < input_slots_all.size(); i++) {
outputs.clear();
predictor->Run(input_slots_all[i], &outputs);
auto output = outputs.front();
size_t outputs_size = 1;
for (auto dim : output.shape) {
outputs_size *= dim;
}
float *result = reinterpret_cast<float *>(output.data.data());
for (size_t j = 0; j < outputs_size; ++j) {
EXPECT_NEAR(ref[i * outputs_size + j], result[j], 7e-2);
}
}
}
} // namespace inference
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -12,15 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cuda_runtime.h>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <cstring>
#include <numeric>
#include "gflags/gflags.h"
#include "paddle/fluid/inference/tests/api/trt_test_helper.h"
#include "paddle/fluid/inference/tests/api/tester_helper.h"
namespace paddle_infer {
......
......@@ -262,10 +262,6 @@ if(WITH_PYTHON)
list(APPEND OP_FUNCTION_GENERETOR_DEPS cncl_context)
endif()
if(NOT ((NOT WITH_PYTHON) AND ON_INFER))
list(APPEND OP_FUNCTION_GENERETOR_DEPS ${PYTHON_LIBRARIES})
endif()
add_executable(op_function_generator op_function_generator.cc)
target_link_libraries(op_function_generator ${OP_FUNCTION_GENERETOR_DEPS})
add_executable(eager_legacy_op_function_generator
......@@ -605,13 +601,4 @@ if(WITH_PYTHON)
get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES)
target_link_libraries(${SHARD_LIB_NAME} ${os_dependency_modules})
add_dependencies(${SHARD_LIB_NAME} op_function_generator_cmd)
if(APPLE)
string(REGEX REPLACE ".+/(.+)" "\\1" PYTHON_LIBRARY_NAME
${PYTHON_LIBRARIES})
# target_link_libraries(${SHARD_LIB_NAME} "-Wl,-rpath,${PYTHON_LIBRARY_NAME}")
else()
target_link_libraries(${SHARD_LIB_NAME} ${PYTHON_LIBRARIES})
endif()
endif()
......@@ -642,7 +642,8 @@ void BindAnalysisConfig(py::module *m) {
.def("enable_use_gpu",
&AnalysisConfig::EnableUseGpu,
py::arg("memory_pool_init_size_mb"),
py::arg("device_id") = 0)
py::arg("device_id") = 0,
py::arg("precision_mode") = AnalysisConfig::Precision::kFloat32)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
.def("set_exec_stream",
[](AnalysisConfig &self, phi::CUDAStream &stream) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册