“e8cc4c9279124cef7d2ba6985356bc76d523543a”上不存在“mobile/test/fpga/test_tensor_quant.cpp”
未验证 提交 ddcd1b61 编写于 作者: Y Yuanle Liu 提交者: GitHub

[cherry-pick][Inference] support mixed precision inference (#49077)

* [Release2.4] Revert python link prs (#48573)

* Revert "Fix mac link python (#48017)"

This reverts commit 3fa7a736.

* Revert "[Cherry-pick] Fix python link error (#47811)"

This reverts commit ff642c68.

* Update config.go

* [Paddle Inference] Add float_to_half_pass to support  inference with mixed precision (#47993)

* [Inference] optimize some code and fix some bug (#48780)

* clean ir_pass_manager and fix map_depthwise_conv_to_conv_pass

* fix unitest timeout

* [Paddle Inference] clean unused code  (#48392)

* fix

* update

* update
Co-authored-by: NChen Weihang <chenweihang@baidu.com>
上级 9e2ba9b9
...@@ -148,6 +148,7 @@ pass_library(delete_c_identity_op_pass inference) ...@@ -148,6 +148,7 @@ pass_library(delete_c_identity_op_pass inference)
pass_library(preln_residual_bias_fuse_pass inference) pass_library(preln_residual_bias_fuse_pass inference)
pass_library(delete_fill_constant_op_pass inference) pass_library(delete_fill_constant_op_pass inference)
pass_library(constant_folding_pass inference) pass_library(constant_folding_pass inference)
pass_library(auto_mixed_precision_pass inference)
pass_library(simplify_with_basic_ops_pass base) pass_library(simplify_with_basic_ops_pass base)
pass_library(fc_elementwise_layernorm_fuse_pass base) pass_library(fc_elementwise_layernorm_fuse_pass base)
pass_library(skip_layernorm_fuse_pass base) pass_library(skip_layernorm_fuse_pass base)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/auto_mixed_precision_pass.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"
namespace paddle {
namespace framework {
namespace ir {
namespace {
using VarType = AutoMixedPrecisionPass::VarType;
bool PhiKernelSupportPrecision(
const std::string& op_type,
phi::Backend backend,
phi::DataType data_type,
phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) {
const auto& kernels = phi::KernelFactory::Instance().kernels();
if (kernels.count(op_type) == 0) {
return false;
}
phi::KernelKey kernel_key(backend, layout, data_type);
return phi::KernelFactory::Instance().HasKernel(op_type, kernel_key);
}
bool GpuKernelSupportPrecision(
const std::string& op_type,
phi::DataType precision,
phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) {
auto phi_op_type = phi::TransToPhiKernelName(op_type);
bool support = PhiKernelSupportPrecision(
phi_op_type, phi::Backend::GPU, precision, layout);
support |= PhiKernelSupportPrecision(
phi_op_type, phi::Backend::GPUDNN, precision, layout);
if (!support) {
const auto& all_kernels = framework::OperatorWithKernel::AllOpKernels();
auto it = all_kernels.find(op_type);
if (it != all_kernels.end()) {
for (const auto& kern_pair : it->second) {
if (platform::is_gpu_place(kern_pair.first.place_) &&
kern_pair.first.data_type_ ==
framework::TransToProtoVarType(precision)) {
support = true;
break;
}
}
}
}
return support;
}
inline bool VarNodeHasDtype(Node* var_node) {
auto type = var_node->Var()->GetType();
return (type == VarType::SELECTED_ROWS) || (type == VarType::LOD_TENSOR) ||
(type == VarType::LOD_TENSOR_ARRAY) || (type == VarType::STRINGS) ||
(type == VarType::VOCAB);
}
inline bool IsFloatType(VarType::Type type) {
return (type == VarType::FP64) || (type == VarType::FP32);
}
inline bool IsHalfType(VarType::Type type) {
return (type == VarType::FP16) || (type == VarType::BF16);
}
}; // namespace
void DoInsertCastOp(Graph* graph,
Node* var_node,
Node* op_node,
VarType::Type from_type,
VarType::Type to_type,
framework::BlockDesc* block_desc,
int* suffix,
std::unordered_map<Node*, Node*>* cache) {
if (from_type == to_type) return;
auto update_cast_desc = [&](framework::OpDesc& desc,
const std::string& x_name,
const std::string& out_name,
const int in_dtype,
const int out_dtype) {
desc.SetType("cast");
desc.SetInput("X", {x_name});
desc.SetOutput("Out", {out_name});
desc.SetAttr("in_dtype", in_dtype);
desc.SetAttr("out_dtype", out_dtype);
desc.SetAttr("use_mkldnn", false);
desc.SetAttr("with_quant_attr", false);
desc.Flush();
};
if (cache->count(var_node) == 0) {
// insert cast op between var_node and op_node
std::string cast_input_name = var_node->Var()->Name();
std::string cast_output_name =
var_node->Var()->Name() + "_cast.tmp_" + std::to_string((*suffix)++);
framework::OpDesc cast_op_desc(block_desc);
update_cast_desc(cast_op_desc,
cast_input_name,
cast_output_name,
static_cast<int>(from_type),
static_cast<int>(to_type));
auto* cast_op_node = graph->CreateOpNode(&cast_op_desc);
auto* cast_output_vardesc = block_desc->Var(cast_output_name);
cast_output_vardesc->SetPersistable(false);
cast_output_vardesc->SetDataType(to_type);
cast_output_vardesc->SetShape(var_node->Var()->GetShape());
auto* cast_output_node = graph->CreateVarNode(cast_output_vardesc);
IR_NODE_LINK_TO(cast_op_node, cast_output_node);
(*cache)[var_node] = cast_output_node;
}
op_node->Op()->Rename(var_node->Name(), cache->at(var_node)->Name());
IR_NODE_LINK_TO(var_node, cache->at(var_node)->inputs[0]);
IR_NODE_LINK_TO(cache->at(var_node), op_node);
IR_NODE_UNLINK(var_node, op_node);
}
bool OpSupportPrecision(const std::string& op_type,
phi::Backend backend,
phi::DataType precision,
const std::unordered_set<std::string>& black_list) {
bool support = false;
if (black_list.count(op_type) == 0) {
if (backend == phi::Backend::GPU) {
support = GpuKernelSupportPrecision(op_type, precision);
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Now, only support backend of GPU."));
}
}
return support;
}
// The set of ops that support fp16 calculation and are considered
// numerically-dangerous, slower and whose effects may also be observed in
// downstream ops.
void AutoMixedPrecisionPass::SetDefaultBlacklist() const {
black_list_.insert({
// numerically-dangerous
"acos",
"asin",
"cosh",
"tan",
"exp",
"expm1",
"square",
"log",
"log2",
"log10",
"log1p",
"logsumexp",
"mean",
"rsqrt",
"sum",
"cos_sim",
"softmax",
"softmax_with_cross_entropy",
"sigmoid_cross_entropy_with_logits",
"c_softmax_with_cross_entropy",
"cross_entropy",
"cross_entropy2",
// slower than fp32
"conv2d_transpose",
// default fp32 can avoid return inf when the sum value large than 65504
"reduce_sum",
});
}
void AutoMixedPrecisionPass::Init(Graph* graph) const {
bool enable_gpu_mixed = Get<bool>("enable_gpu_mixed");
if (enable_gpu_mixed) {
backend_ = phi::Backend::GPU;
}
skip_pass_ = !enable_gpu_mixed;
low_precision_ = static_cast<phi::DataType>(Get<int>("mixed_precision_mode"));
black_list_ = Get<std::unordered_set<std::string>>("mixed_black_list");
SetDefaultBlacklist();
VLOG(4) << "black_list has ";
for (const auto& name : black_list_) {
VLOG(4) << " - " << name;
}
keep_io_types_ = true;
if (Has("keep_io_types")) {
keep_io_types_ = Get<bool>("keep_io_types");
}
auto graph_size = graph->SubGraphsSize();
VLOG(4) << "graph size: " << graph_size;
subgraphes_.resize(graph_size);
all_op_nodes_.resize(graph_size);
for (size_t i = 0; i < graph_size; i++) {
subgraphes_[i] = graph->GetSubGraph(i);
all_op_nodes_[i] = TopologySortOperations(*subgraphes_[i]);
VLOG(4) << "subgraph " << i << " has " << all_op_nodes_[i].size()
<< "op nodes";
for (auto* var_node : subgraphes_[i]->Nodes()) {
if (!var_node->IsVar()) continue;
auto var_name = var_node->Var()->Name();
if (real_vars_.count(var_name) == 0) {
real_vars_[var_name] = var_node;
VLOG(4) << var_name << " is in graph " << i;
}
}
}
}
void AutoMixedPrecisionPass::ApplyImpl(Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(graph,
platform::errors::PreconditionNotMet(
"During the auto_mixed_precision_pass, the graph "
"should not be nullptr."));
PADDLE_ENFORCE_EQ(graph->IsMainGraph(),
true,
platform::errors::PreconditionNotMet(
"During the auto_mixed_precision_pass, the graph "
"should be main graph."));
FusePassBase::Init("auto_mixed_precision", graph);
Init(graph);
VLOG(4) << "Init done";
if (skip_pass_) {
VLOG(3) << "Skip auto_mixed_precision_pass.";
return;
}
SetOpUniqueType();
VLOG(4) << "SetOpUniqueType done";
GetOpPrecision();
VLOG(4) << "GetOpPrecision done";
UpdateOpPrecision();
VLOG(4) << "UpdateOpPrecision done";
SetVarPrecision();
VLOG(4) << "SetVarPrecision done";
ConvertWeightsData();
VLOG(4) << "ConvertWeightsData done";
ProcessOpWithDtypeAttr();
VLOG(4) << "ProcessOpWithDtypeAttr done";
InsertCastOp();
VLOG(4) << "InsertCastOp done";
RestoreOpOriginType();
VLOG(4) << "RestoreOpOriginType done";
}
void AutoMixedPrecisionPass::SetOpUniqueType() const {
int suffix = 0;
for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) {
auto op_type = op_node->Op()->Type();
if (op_type == "feed" || op_type == "fetch") continue;
std::string unique_type = op_type + "_" + std::to_string(suffix++);
op_original_type_[unique_type] = op_type;
op_node->Op()->SetType(unique_type);
op_node->Op()->Flush();
VLOG(4) << "change op type: " << op_type << " ---> " << unique_type;
}
}
}
void AutoMixedPrecisionPass::RestoreOpOriginType() const {
for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) {
auto op_type = op_node->Op()->Type();
op_node->Op()->SetType(GetOpOriginalType(op_type));
op_node->Op()->Flush();
VLOG(4) << "restore op type: " << op_type << " ---> "
<< op_node->Op()->Type();
}
}
}
inline std::string AutoMixedPrecisionPass::GetOpOriginalType(
const std::string& op_type) const {
if (op_original_type_.count(op_type)) {
return op_original_type_.at(op_type);
}
return op_type;
}
void AutoMixedPrecisionPass::ProcessOpWithDtypeAttr() const {
for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) {
auto op_type = op_node->Op()->Type();
if (op_run_low_precision_.count(op_type) == 0) continue;
if (op_node->Op()->HasAttr("dtype")) {
auto dtype = op_node->Op()->GetAttrIfExists<int>("dtype");
if (IsFloatType(static_cast<VarType::Type>(dtype))) {
op_node->Op()->SetAttr(
"dtype",
static_cast<int>(framework::TransToProtoVarType(low_precision_)));
op_node->Op()->Flush();
VLOG(4) << "process op with dtype attr: " << op_type << " ( " << dtype
<< " --->" << static_cast<int>(low_precision_) << " )";
}
}
if (op_node->Op()->HasAttr("out_dtype")) {
auto out_dtype = op_node->Op()->GetAttrIfExists<int>("out_dtype");
if (IsFloatType(static_cast<VarType::Type>(out_dtype))) {
op_node->Op()->SetAttr(
"out_dtype",
static_cast<int>(framework::TransToProtoVarType(low_precision_)));
op_node->Op()->Flush();
VLOG(4) << "process op with out_dtype attr: " << op_type << " ( "
<< out_dtype << " --->" << static_cast<int>(low_precision_)
<< " )";
}
}
}
}
}
void AutoMixedPrecisionPass::GetOpPrecision() const {
for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) {
auto op_type = op_node->Op()->Type();
bool support_low_precision = true;
if (GetOpOriginalType(op_type) == "feed" ||
GetOpOriginalType(op_type) == "fetch") {
support_low_precision = !keep_io_types_;
} else {
support_low_precision = OpSupportPrecision(
GetOpOriginalType(op_type), backend_, low_precision_, black_list_);
}
if (op_node->Op()->HasAttr("dtype")) {
auto dtype = op_node->Op()->GetAttrIfExists<int>("dtype");
support_low_precision = support_low_precision &&
IsFloatType(static_cast<VarType::Type>(dtype));
} else if (op_node->Op()->HasAttr("out_dtype")) {
auto out_dtype = op_node->Op()->GetAttrIfExists<int>("out_dtype");
support_low_precision =
support_low_precision &&
IsFloatType(static_cast<VarType::Type>(out_dtype));
} else {
// if op's input var and output var is not dense tensor, the op should
// not run at low precision.
for (auto* in_var_node : op_node->inputs) {
CHECK_EQ(in_var_node->IsVar(), true);
auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
if (real_in_var_node->Var()->Persistable()) continue;
support_low_precision =
support_low_precision &&
(real_in_var_node->Var()->GetType() == VarType::LOD_TENSOR);
}
for (auto* out_var_node : op_node->outputs) {
CHECK_EQ(out_var_node->IsVar(), true);
auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()];
if (real_out_var_node->Var()->Persistable()) continue;
support_low_precision =
support_low_precision &&
(real_out_var_node->Var()->GetType() == VarType::LOD_TENSOR);
}
}
if (support_low_precision) {
op_run_low_precision_.insert(op_type);
VLOG(4) << "support precision: " << op_type << " run at low precision";
} else {
VLOG(4) << "support precision: " << op_type
<< " not run at low precision";
}
}
}
}
void AutoMixedPrecisionPass::UpdateOpPrecision() const {
std::unordered_set<std::string> vars_should_not_low_precision;
// var -> the var's all input op
std::unordered_map<std::string, std::vector<Node*>> var_input_ops;
auto GetVarInputOps = [&] {
for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) {
auto op_type = op_node->Op()->Type();
if (GetOpOriginalType(op_type) == "fetch") continue;
if (op_node->Op()->HasAttr("sub_block")) continue;
for (auto* var_node : op_node->outputs) {
CHECK_EQ(var_node->IsVar(), true);
if (var_node->Var()->Persistable()) continue;
if (!VarNodeHasDtype(var_node)) continue;
var_input_ops[var_node->Var()->Name()].push_back(op_node);
VLOG(4) << "var input ops: " << var_node->Var()->Name()
<< " is output of " << op_type;
}
// the select_input op's input var should not convert to low precision.
// when op's output var is select_input op's input var, the op should
// not run at low precision.
if (GetOpOriginalType(op_node->Op()->Type()) == "select_input") {
for (auto* in_var_node : op_node->inputs) {
CHECK_EQ(in_var_node->IsVar(), true);
if (in_var_node->Var()->Persistable()) continue;
if (!VarNodeHasDtype(in_var_node)) continue;
vars_should_not_low_precision.insert(in_var_node->Var()->Name());
}
}
}
}
};
GetVarInputOps();
bool precision_updated = false;
do {
precision_updated = false;
for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) {
if (op_run_low_precision_.count(op_node->Op()->Type()) == 0) continue;
for (auto* out_var_node : op_node->outputs) {
CHECK_EQ(out_var_node->IsVar(), true);
if (!VarNodeHasDtype(out_var_node)) continue;
auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()];
if (real_out_var_node->Var()->Persistable()) continue;
bool not_run_low_precision = false;
const auto& input_op_nodes =
var_input_ops[real_out_var_node->Var()->Name()];
if (vars_should_not_low_precision.count(
real_out_var_node->Var()->Name())) {
not_run_low_precision = true;
} else {
for (auto* node : input_op_nodes) {
if (op_run_low_precision_.count(node->Op()->Type()) == 0) {
not_run_low_precision = true;
break;
}
}
}
if (not_run_low_precision) {
op_run_low_precision_.erase(op_node->Op()->Type());
precision_updated = true;
VLOG(4) << op_node->Op()->Type()
<< " should not run at low precision.";
break;
}
}
}
}
} while (precision_updated);
}
// special ops, its weights should not be low precision.
bool AutoMixedPrecisionPass::InputVarsNotConvert(
Node* op_node, const std::string& var_name) const {
auto* op_desc = op_node->Op();
if (GetOpOriginalType(op_desc->Type()) == "batch_norm") {
auto vecs = op_desc->Input("Bias");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Input("Mean");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Input("Scale");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Input("Variance");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
} else if (GetOpOriginalType(op_desc->Type()) == "fused_multi_transformer") {
auto vecs = op_desc->Input("LnScale");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Input("LnBias");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Input("FFNLnScale");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Input("FFNLnBias");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
}
return false;
}
bool AutoMixedPrecisionPass::OutputVarsNotConvert(
Node* op_node, const std::string& var_name) const {
auto* op_desc = op_node->Op();
// batch_norm's input and output (variance and mean) are the same.
if (GetOpOriginalType(op_desc->Type()) == "batch_norm") {
auto vecs = op_desc->Output("MeanOut");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Output("VarianceOut");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Output("SavedMean");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Output("SavedVariance");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
}
return false;
}
void AutoMixedPrecisionPass::SetVarPrecision() const {
for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) {
if (op_run_low_precision_.count(op_node->Op()->Type()) == 0) {
continue;
}
if (GetOpOriginalType(op_node->Op()->Type()) != "feed") {
for (auto* in_var_node : op_node->inputs) {
CHECK_EQ(in_var_node->IsVar(), true);
auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
auto in_var_name = real_in_var_node->Var()->Name();
if (!IsFloatType(real_in_var_node->Var()->GetDataType())) continue;
if (!VarNodeHasDtype(real_in_var_node)) continue;
if (InputVarsNotConvert(op_node, in_var_name)) continue;
if (real_in_var_node->Var()->Persistable()) {
real_in_var_node->Var()->SetDataType(
framework::TransToProtoVarType(low_precision_));
vars_convert_to_low_precision_.insert(in_var_name);
}
}
}
if (GetOpOriginalType(op_node->Op()->Type()) != "fetch") {
for (auto* out_var_node : op_node->outputs) {
CHECK_EQ(out_var_node->IsVar(), true);
auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()];
auto out_var_name = real_out_var_node->Var()->Name();
if (!IsFloatType(real_out_var_node->Var()->GetDataType())) continue;
if (!VarNodeHasDtype(real_out_var_node)) continue;
if (OutputVarsNotConvert(op_node, out_var_name)) continue;
real_out_var_node->Var()->SetDataType(
framework::TransToProtoVarType(low_precision_));
if (real_out_var_node->Var()->Persistable()) {
vars_convert_to_low_precision_.insert(out_var_name);
}
}
}
}
}
// This code used to precess vars with the same name. Vars with the same
// name should have the same data type.
for (auto* subgraph : subgraphes_) {
for (auto* var_node : subgraph->Nodes()) {
if (!var_node->IsVar() || !var_node->Var()->Persistable()) continue;
if (!VarNodeHasDtype(var_node)) continue;
auto var_name = var_node->Var()->Name();
if (vars_convert_to_low_precision_.count(var_name)) {
var_node->Var()->SetDataType(
framework::TransToProtoVarType(low_precision_));
}
}
}
}
void AutoMixedPrecisionPass::ConvertWeightsData() const {
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(scope,
platform::errors::PreconditionNotMet(
"During the auto_mixed_precision_pass, the scope "
"should not be null."));
auto var_names = scope->LocalVarNames();
for (const auto& var_name : var_names) {
if (vars_convert_to_low_precision_.count(var_name)) {
VLOG(4) << var_name << "'s data type was convert to half";
auto* var = scope->FindLocalVar(var_name);
CHECK_EQ(var->IsType<phi::DenseTensor>(), true);
auto* origin_tensor = var->GetMutable<phi::DenseTensor>();
phi::DenseTensor low_precision_tensor;
low_precision_tensor.Resize(origin_tensor->dims());
low_precision_tensor.set_type(low_precision_);
if (low_precision_ == phi::DataType::FLOAT16) {
auto* low_precision_data =
low_precision_tensor.mutable_data<phi::dtype::float16>(
phi::CPUPlace{});
for (int64_t i = 0; i < origin_tensor->numel(); i++) {
if (origin_tensor->dtype() == phi::DataType::FLOAT64) {
auto* origin_data = origin_tensor->data<double>();
low_precision_data[i] =
static_cast<phi::dtype::float16>(origin_data[i]);
} else if (origin_tensor->dtype() == phi::DataType::FLOAT32) {
auto* origin_data = origin_tensor->data<float>();
low_precision_data[i] =
static_cast<phi::dtype::float16>(origin_data[i]);
}
}
} else if (low_precision_ == phi::DataType::BFLOAT16) {
auto* half_data =
low_precision_tensor.mutable_data<phi::dtype::bfloat16>(
phi::CPUPlace{});
for (int64_t i = 0; i < origin_tensor->numel(); i++) {
if (origin_tensor->dtype() == phi::DataType::FLOAT64) {
auto* origin_data = origin_tensor->data<double>();
half_data[i] = static_cast<phi::dtype::bfloat16>(origin_data[i]);
} else if (origin_tensor->dtype() == phi::DataType::FLOAT32) {
auto* origin_data = origin_tensor->data<float>();
half_data[i] = static_cast<phi::dtype::bfloat16>(origin_data[i]);
}
}
}
origin_tensor->clear();
paddle::framework::TensorCopySync(
low_precision_tensor, phi::CPUPlace{}, origin_tensor);
}
}
}
void AutoMixedPrecisionPass::InsertCastOp() const {
int suffix = 0;
std::unordered_map<Node*, Node*> cache;
for (size_t i = 0; i < all_op_nodes_.size(); i++) {
auto* block_desc = all_op_nodes_[i][0]->Op()->Block();
CHECK_NOTNULL(block_desc);
for (auto* op_node : all_op_nodes_[i]) {
auto op_type = op_node->Op()->Type();
if (GetOpOriginalType(op_type) == "feed") continue;
if (op_node->Op()->HasAttr("sub_block")) continue;
VLOG(4) << "process op: " << op_type
<< " run low precision: " << op_run_low_precision_.count(op_type);
auto inputs = op_node->inputs;
for (auto* in_var_node : inputs) {
if (!in_var_node->IsVar()) continue;
if (!VarNodeHasDtype(in_var_node)) continue;
if (in_var_node->Var()->Persistable()) continue;
auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
auto in_var_type = real_in_var_node->Var()->GetDataType();
VLOG(4) << "process var: " << real_in_var_node->Var()->Name()
<< " with type " << in_var_type;
if (IsFloatType(in_var_type) && op_run_low_precision_.count(op_type)) {
DoInsertCastOp(subgraphes_[i],
in_var_node,
op_node,
in_var_type,
framework::TransToProtoVarType(low_precision_),
block_desc,
&suffix,
&cache);
} else if (IsHalfType(in_var_type) &&
op_run_low_precision_.count(op_type) == 0) {
DoInsertCastOp(subgraphes_[i],
in_var_node,
op_node,
in_var_type,
VarType::FP32,
block_desc,
&suffix,
&cache);
}
}
// Special op.
// fused_multi_transformer's input(CacheKV) and output(CacheKVOut) vars
// have same name.
if (GetOpOriginalType(op_type) == "fused_multi_transformer") {
auto cache_kv_inputs = op_node->Op()->Input("CacheKV");
auto cache_kv_outputs = op_node->Op()->Output("CacheKVOut");
CHECK_EQ(cache_kv_inputs.size(), cache_kv_outputs.size());
for (size_t i = 0; i < cache_kv_inputs.size(); ++i) {
op_node->Op()->RenameOutput(cache_kv_outputs[i], cache_kv_inputs[i]);
}
}
}
}
VLOG(4) << "insert number of cast op: " << cache.size();
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(auto_mixed_precision_pass,
paddle::framework::ir::AutoMixedPrecisionPass);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_map>
#include <unordered_set>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/phi/common/backend.h"
#include "paddle/phi/common/data_type.h"
namespace paddle {
namespace framework {
namespace ir {
class AutoMixedPrecisionPass : public FusePassBase {
public:
using VarType = framework::proto::VarType;
public:
AutoMixedPrecisionPass() = default;
~AutoMixedPrecisionPass() = default;
protected:
void ApplyImpl(Graph* graph) const override;
private:
void Init(Graph* graph) const;
void SetDefaultBlacklist() const;
void SetOpUniqueType() const;
void RestoreOpOriginType() const;
inline std::string GetOpOriginalType(const std::string& op_type) const;
void GetOpPrecision() const;
void UpdateOpPrecision() const;
void InsertCastOp() const;
void ProcessOpWithDtypeAttr() const;
bool InputVarsNotConvert(Node* op_node, const std::string& var_name) const;
bool OutputVarsNotConvert(Node* op_node, const std::string& var_name) const;
void SetVarPrecision() const;
void ConvertWeightsData() const;
private:
mutable bool skip_pass_{false};
mutable bool keep_io_types_{false};
// float16 or bfloat16 now
mutable phi::DataType low_precision_{phi::DataType::FLOAT16};
mutable phi::Backend backend_{phi::Backend::GPU};
mutable std::unordered_set<std::string> black_list_;
// subgraph id -> pointer to subgraph
mutable std::vector<Graph*> subgraphes_;
// var name -> real var node
mutable std::unordered_map<std::string, Node*> real_vars_;
// subgraph id -> all op nodes in subgraph
mutable std::vector<std::vector<Node*>> all_op_nodes_;
// op's unique type -> the op's origin type
mutable std::unordered_map<std::string, std::string> op_original_type_;
// op's unique type -> whether the op run at low precision
mutable std::unordered_set<std::string> op_run_low_precision_;
mutable std::unordered_set<std::string> vars_convert_to_low_precision_;
};
bool OpSupportPrecision(const std::string& op_type,
phi::Backend backend,
phi::DataType precision,
const std::unordered_set<std::string>& black_list);
void DoInsertCastOp(Graph* graph,
Node* var_node,
Node* op_node,
proto::VarType::Type from_type,
proto::VarType::Type to_type,
framework::BlockDesc* block_desc,
int* suffix,
std::unordered_map<Node*, Node*>* cache);
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -29,6 +29,11 @@ void FillConstData(LoDTensor* out_t, T value) { ...@@ -29,6 +29,11 @@ void FillConstData(LoDTensor* out_t, T value) {
} }
void DeleteFillConstantOpPass::ApplyImpl(ir::Graph* graph) const { void DeleteFillConstantOpPass::ApplyImpl(ir::Graph* graph) const {
bool with_dynamic_shape = Get<bool>("with_dynamic_shape");
// Not support
if (with_dynamic_shape) {
return;
}
FusePassBase::Init("delete_fill_constant_op_pass", graph); FusePassBase::Init("delete_fill_constant_op_pass", graph);
GraphPatternDetector detector; GraphPatternDetector detector;
auto fill_constant_op = auto fill_constant_op =
......
...@@ -75,7 +75,6 @@ Graph::Graph(const ProgramDesc &program, ...@@ -75,7 +75,6 @@ Graph::Graph(const ProgramDesc &program,
} }
} else { } else {
auto var_nodes = InitFromProgram(program_, start_op_index, end_op_index); auto var_nodes = InitFromProgram(program_, start_op_index, end_op_index);
ResolveHazard(var_nodes);
} }
} }
...@@ -88,7 +87,6 @@ Graph::Graph(const BlockDesc &block, ...@@ -88,7 +87,6 @@ Graph::Graph(const BlockDesc &block,
const int64_t end_op_index) const int64_t end_op_index)
: main_graph_(main_graph) { : main_graph_(main_graph) {
auto var_nodes = InitFromBlock(block, start_op_index, end_op_index); auto var_nodes = InitFromBlock(block, start_op_index, end_op_index);
ResolveHazard(var_nodes);
} }
// TODO(levi): delete this interface after when we can convert all // TODO(levi): delete this interface after when we can convert all
......
...@@ -130,86 +130,6 @@ TEST(GraphTest, Basic) { ...@@ -130,86 +130,6 @@ TEST(GraphTest, Basic) {
ASSERT_EQ(nodes.size(), 5UL); ASSERT_EQ(nodes.size(), 5UL);
} }
TEST(GraphTest, WriteAfterRead) {
// void Test() {
ProgramDesc prog;
auto *op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum");
op->SetInput("X", {"a"});
op->SetOutput("Out", {"b"});
op->SetAttr("op_role", 1);
op = prog.MutableBlock(0)->AppendOp();
op->SetType("dummy");
op->SetInput("X", {"c"});
op->SetOutput("Out", {"a"});
op->SetAttr("op_role", 1);
prog.MutableBlock(0)->Var("a")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("b")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("c")->SetType(proto::VarType::LOD_TENSOR);
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
ir::Node *control_dep1 = nullptr;
ir::Node *control_dep2 = nullptr;
for (ir::Node *n : g->Nodes()) {
if (n->Name() == "sum") {
ASSERT_EQ(n->outputs[0]->Name(), "b");
ASSERT_TRUE(ir::IsControlDepVar(*n->outputs[1]));
control_dep1 = n->outputs[1];
ASSERT_EQ(n->outputs.size(), 2UL);
}
if (n->Name() == "dummy") {
ASSERT_EQ(n->inputs[0]->Name(), "c");
ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1]));
control_dep2 = n->inputs[1];
ASSERT_EQ(n->inputs.size(), 2UL);
}
}
ASSERT_EQ(control_dep1, control_dep2);
}
TEST(GraphTest, WriteAfterWrite) {
// void Test() {
ProgramDesc prog;
auto *op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum");
op->SetInput("X", {"a"});
op->SetOutput("Out", {"b"});
op->SetAttr("op_role", 1);
op = prog.MutableBlock(0)->AppendOp();
op->SetType("dummy");
op->SetInput("X", {"c"});
op->SetOutput("Out", {"b"});
op->SetAttr("op_role", 1);
prog.MutableBlock(0)->Var("a")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("b")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("c")->SetType(proto::VarType::LOD_TENSOR);
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
ir::Node *control_dep1 = nullptr;
ir::Node *control_dep2 = nullptr;
for (ir::Node *n : g->Nodes()) {
if (n->Name() == "sum") {
ASSERT_EQ(n->outputs[0]->Name(), "b");
ASSERT_TRUE(ir::IsControlDepVar(*n->outputs[1]));
ASSERT_EQ(n->outputs.size(), 2UL);
control_dep1 = n->outputs[1];
}
if (n->Name() == "dummy") {
ASSERT_EQ(n->inputs[0]->Name(), "c");
ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1]));
control_dep2 = n->inputs[1];
ASSERT_EQ(n->inputs.size(), 2UL);
}
}
ASSERT_NE(control_dep1, nullptr);
ASSERT_NE(control_dep2, nullptr);
ASSERT_EQ(control_dep1, control_dep2);
}
TEST(GraphTest, TestException) { TEST(GraphTest, TestException) {
ProgramDesc prog; ProgramDesc prog;
std::unique_ptr<ir::Graph> g(new ir::Graph(prog)); std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
...@@ -350,12 +270,13 @@ TEST(GraphTest, TestMultiBlock) { ...@@ -350,12 +270,13 @@ TEST(GraphTest, TestMultiBlock) {
op = prog.MutableBlock(1)->AppendOp(); op = prog.MutableBlock(1)->AppendOp();
op->SetType("dummy"); op->SetType("dummy");
op->SetInput("X", {"c"}); op->SetInput("X", {"c"});
op->SetOutput("Out", {"a"}); op->SetOutput("Out", {"d"});
op->SetAttr("op_role", 1); op->SetAttr("op_role", 1);
prog.MutableBlock(1)->Var("a")->SetType(proto::VarType::LOD_TENSOR); prog.MutableBlock(1)->Var("a")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(1)->Var("b")->SetType(proto::VarType::LOD_TENSOR); prog.MutableBlock(1)->Var("b")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(1)->Var("c")->SetType(proto::VarType::LOD_TENSOR); prog.MutableBlock(1)->Var("c")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(1)->Var("d")->SetType(proto::VarType::LOD_TENSOR);
// Set contents in block_2. // Set contents in block_2.
op = prog.MutableBlock(2)->AppendOp(); op = prog.MutableBlock(2)->AppendOp();
...@@ -367,12 +288,13 @@ TEST(GraphTest, TestMultiBlock) { ...@@ -367,12 +288,13 @@ TEST(GraphTest, TestMultiBlock) {
op = prog.MutableBlock(2)->AppendOp(); op = prog.MutableBlock(2)->AppendOp();
op->SetType("dummy"); op->SetType("dummy");
op->SetInput("X", {"c"}); op->SetInput("X", {"c"});
op->SetOutput("Out", {"b"}); op->SetOutput("Out", {"d"});
op->SetAttr("op_role", 1); op->SetAttr("op_role", 1);
prog.MutableBlock(2)->Var("a")->SetType(proto::VarType::LOD_TENSOR); prog.MutableBlock(2)->Var("a")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(2)->Var("b")->SetType(proto::VarType::LOD_TENSOR); prog.MutableBlock(2)->Var("b")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(2)->Var("c")->SetType(proto::VarType::LOD_TENSOR); prog.MutableBlock(2)->Var("c")->SetType(proto::VarType::LOD_TENSOR);
prog.MutableBlock(1)->Var("d")->SetType(proto::VarType::LOD_TENSOR);
// Step2: Convert program into graph, 3 blocks corresponding 3 sub_graphs. // Step2: Convert program into graph, 3 blocks corresponding 3 sub_graphs.
std::unique_ptr<ir::Graph> g(new ir::Graph(prog)); std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
...@@ -399,45 +321,29 @@ TEST(GraphTest, TestMultiBlock) { ...@@ -399,45 +321,29 @@ TEST(GraphTest, TestMultiBlock) {
// Check contents in sub_graph_1. // Check contents in sub_graph_1.
const ir::Graph *g1 = g->GetSubGraph(1); const ir::Graph *g1 = g->GetSubGraph(1);
ir::Node *control_dep1 = nullptr;
ir::Node *control_dep2 = nullptr;
for (ir::Node *n : g1->Nodes()) { for (ir::Node *n : g1->Nodes()) {
if (n->Name() == "sum") { if (n->Name() == "sum") {
ASSERT_EQ(n->outputs[0]->Name(), "b"); ASSERT_EQ(n->outputs[0]->Name(), "b");
ASSERT_TRUE(ir::IsControlDepVar(*n->outputs[1])); ASSERT_EQ(n->outputs.size(), 1UL);
control_dep1 = n->outputs[1];
ASSERT_EQ(n->outputs.size(), 2UL);
} }
if (n->Name() == "dummy") { if (n->Name() == "dummy") {
ASSERT_EQ(n->inputs[0]->Name(), "c"); ASSERT_EQ(n->inputs[0]->Name(), "c");
ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1])); ASSERT_EQ(n->inputs.size(), 1UL);
control_dep2 = n->inputs[1];
ASSERT_EQ(n->inputs.size(), 2UL);
} }
} }
ASSERT_EQ(control_dep1, control_dep2);
// Check contents in sub_graph_2. // Check contents in sub_graph_2.
const ir::Graph *g2 = g->GetSubGraph(2); const ir::Graph *g2 = g->GetSubGraph(2);
control_dep1 = nullptr;
control_dep2 = nullptr;
for (ir::Node *n : g2->Nodes()) { for (ir::Node *n : g2->Nodes()) {
if (n->Name() == "sum") { if (n->Name() == "sum") {
ASSERT_EQ(n->outputs[0]->Name(), "b"); ASSERT_EQ(n->outputs[0]->Name(), "b");
ASSERT_TRUE(ir::IsControlDepVar(*n->outputs[1])); ASSERT_EQ(n->outputs.size(), 1UL);
ASSERT_EQ(n->outputs.size(), 2UL);
control_dep1 = n->outputs[1];
} }
if (n->Name() == "dummy") { if (n->Name() == "dummy") {
ASSERT_EQ(n->inputs[0]->Name(), "c"); ASSERT_EQ(n->inputs[0]->Name(), "c");
ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1])); ASSERT_EQ(n->inputs.size(), 1UL);
control_dep2 = n->inputs[1];
ASSERT_EQ(n->inputs.size(), 2UL);
} }
} }
ASSERT_NE(control_dep1, nullptr);
ASSERT_NE(control_dep2, nullptr);
ASSERT_EQ(control_dep1, control_dep2);
// Step3: Clone graph. // Step3: Clone graph.
std::shared_ptr<ir::Graph> clone_g = g->Clone(); std::shared_ptr<ir::Graph> clone_g = g->Clone();
......
...@@ -331,8 +331,6 @@ void BatchMergePass::ApplyImpl(ir::Graph* graph) const { ...@@ -331,8 +331,6 @@ void BatchMergePass::ApplyImpl(ir::Graph* graph) const {
copy_node(node); copy_node(node);
} }
} }
result.ResolveHazard(created);
} }
} // namespace ir } // namespace ir
......
...@@ -183,5 +183,6 @@ void NaiveExecutor::ResetTrtOps(int num) { ...@@ -183,5 +183,6 @@ void NaiveExecutor::ResetTrtOps(int num) {
} }
#endif #endif
} }
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -38,8 +38,7 @@ void Analyzer::RunAnalysis(Argument *argument) { ...@@ -38,8 +38,7 @@ void Analyzer::RunAnalysis(Argument *argument) {
if (!disable_logs) { if (!disable_logs) {
string::PrettyLogH1("--- Running analysis [%s]", pass); string::PrettyLogH1("--- Running analysis [%s]", pass);
} }
if (!argument->enable_analysis_optim() && pass == "ir_analysis_pass") if (!argument->enable_ir_optim() && pass == "ir_analysis_pass") continue;
continue;
auto *ptr = PassRegistry::Global().Retreive(pass); auto *ptr = PassRegistry::Global().Retreive(pass);
PADDLE_ENFORCE_NOT_NULL(ptr, PADDLE_ENFORCE_NOT_NULL(ptr,
......
...@@ -31,7 +31,7 @@ TEST(Analyzer, analysis_without_tensorrt) { ...@@ -31,7 +31,7 @@ TEST(Analyzer, analysis_without_tensorrt) {
Argument argument; Argument argument;
argument.SetDisableLogs(false); argument.SetDisableLogs(false);
argument.SetModelDir(FLAGS_inference_model_dir); argument.SetModelDir(FLAGS_inference_model_dir);
argument.SetEnableAnalysisOptim(false); argument.SetEnableIrOptim(false);
argument.SetUseGPU(false); argument.SetUseGPU(false);
argument.SetAnalysisPasses({"ir_graph_build_pass", argument.SetAnalysisPasses({"ir_graph_build_pass",
"ir_analysis_pass", "ir_analysis_pass",
...@@ -44,7 +44,7 @@ TEST(Analyzer, analysis_without_tensorrt) { ...@@ -44,7 +44,7 @@ TEST(Analyzer, analysis_without_tensorrt) {
TEST(Analyzer, analysis_with_tensorrt) { TEST(Analyzer, analysis_with_tensorrt) {
Argument argument; Argument argument;
argument.SetDisableLogs(false); argument.SetDisableLogs(false);
argument.SetEnableAnalysisOptim(false); argument.SetEnableIrOptim(false);
argument.SetTensorRtMaxBatchSize(3); argument.SetTensorRtMaxBatchSize(3);
argument.SetTensorRtWorkspaceSize(1 << 20); argument.SetTensorRtWorkspaceSize(1 << 20);
argument.SetModelDir(FLAGS_inference_model_dir); argument.SetModelDir(FLAGS_inference_model_dir);
......
...@@ -42,8 +42,6 @@ namespace paddle { ...@@ -42,8 +42,6 @@ namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
using framework::ir::Graph;
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
using VarQuantScale = using VarQuantScale =
std::unordered_map<std::string, std::pair<bool, framework::LoDTensor>>; std::unordered_map<std::string, std::pair<bool, framework::LoDTensor>>;
...@@ -148,7 +146,7 @@ struct Argument { ...@@ -148,7 +146,7 @@ struct Argument {
DECL_ARGUMENT_FIELD(model_params_path, ModelParamsPath, std::string); DECL_ARGUMENT_FIELD(model_params_path, ModelParamsPath, std::string);
DECL_ARGUMENT_FIELD(model_from_memory, ModelFromMemory, bool); DECL_ARGUMENT_FIELD(model_from_memory, ModelFromMemory, bool);
DECL_ARGUMENT_FIELD(optim_cache_dir, OptimCacheDir, std::string); DECL_ARGUMENT_FIELD(optim_cache_dir, OptimCacheDir, std::string);
DECL_ARGUMENT_FIELD(enable_analysis_optim, EnableAnalysisOptim, bool); DECL_ARGUMENT_FIELD(enable_ir_optim, EnableIrOptim, bool);
// For JITLayer // For JITLayer
DECL_ARGUMENT_FIELD(skip_load_params, SkipLoadParams, bool); DECL_ARGUMENT_FIELD(skip_load_params, SkipLoadParams, bool);
...@@ -362,6 +360,8 @@ struct Argument { ...@@ -362,6 +360,8 @@ struct Argument {
DECL_ARGUMENT_FIELD(mixed_black_list, DECL_ARGUMENT_FIELD(mixed_black_list,
MixedBlackList, MixedBlackList,
std::unordered_set<std::string>); std::unordered_set<std::string>);
DECL_ARGUMENT_FIELD(enable_gpu_mixed, EnableGPUMixed, bool);
DECL_ARGUMENT_FIELD(mixed_precision_mode, MixedPrecisionMode, int);
private: private:
std::unordered_set<std::string> valid_fields_; std::unordered_set<std::string> valid_fields_;
......
...@@ -153,25 +153,6 @@ T &GetFromScope(const framework::Scope &scope, const std::string &name) { ...@@ -153,25 +153,6 @@ T &GetFromScope(const framework::Scope &scope, const std::string &name) {
return *var->GetMutable<T>(); return *var->GetMutable<T>();
} }
static framework::proto::ProgramDesc LoadProgramDesc(
const std::string &model_path) {
std::ifstream fin(model_path, std::ios::in | std::ios::binary);
PADDLE_ENFORCE_EQ(
fin.is_open(),
true,
platform::errors::NotFound(
"Cannot open file %s, please confirm whether the file exists",
model_path));
fin.seekg(0, std::ios::end);
std::string buffer(fin.tellg(), ' ');
fin.seekg(0, std::ios::beg);
fin.read(&buffer[0], buffer.size());
fin.close();
framework::proto::ProgramDesc program_desc;
program_desc.ParseFromString(buffer);
return program_desc;
}
static bool FileExists(const std::string &filepath) { static bool FileExists(const std::string &filepath) {
std::ifstream file(filepath); std::ifstream file(filepath);
bool exists = file.is_open(); bool exists = file.is_open();
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/analysis/argument.h" #include "paddle/fluid/inference/analysis/argument.h"
#include "paddle/fluid/string/pretty_log.h" #include "paddle/fluid/string/pretty_log.h"
#include "paddle/phi/core/errors.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -36,15 +37,6 @@ using string::PrettyLogEndl; ...@@ -36,15 +37,6 @@ using string::PrettyLogEndl;
using string::Style; using string::Style;
IRPassManager::IRPassManager(Argument *argument) { IRPassManager::IRPassManager(Argument *argument) {
ARGUMENT_CHECK_FIELD(argument, main_program);
graph_ = std::unique_ptr<Graph>(new Graph(argument->main_program()));
if (argument->Has("scope")) {
auto *scope_ptr = argument->scope_ptr();
PADDLE_ENFORCE_NOT_NULL(scope_ptr,
platform::errors::PreconditionNotMet(
"The scope ptr should not be nullptr."));
graph_->SetNotOwned(framework::ir::kParamScopeAttr, scope_ptr);
}
disable_logs_ = argument->disable_logs(); disable_logs_ = argument->disable_logs();
ARGUMENT_CHECK_FIELD(argument, ir_analysis_passes); ARGUMENT_CHECK_FIELD(argument, ir_analysis_passes);
...@@ -95,10 +87,14 @@ void IRPassManager::CreatePasses(Argument *argument, ...@@ -95,10 +87,14 @@ void IRPassManager::CreatePasses(Argument *argument,
argument->tensorrt_tuned_dynamic_shape(); argument->tensorrt_tuned_dynamic_shape();
pass->Set("with_dynamic_shape", new bool(with_dynamic_shape)); pass->Set("with_dynamic_shape", new bool(with_dynamic_shape));
// mixed precision related
pass->Set("model_precision", new int(argument->model_precision())); pass->Set("model_precision", new int(argument->model_precision()));
pass->Set( pass->Set(
"mixed_black_list", "mixed_black_list",
new std::unordered_set<std::string>(argument->mixed_black_list())); new std::unordered_set<std::string>(argument->mixed_black_list()));
pass->Set("enable_gpu_mixed", new bool(argument->enable_gpu_mixed()));
pass->Set("mixed_precision_mode",
new int(argument->mixed_precision_mode()));
if (pass_name == "graph_viz_pass") { if (pass_name == "graph_viz_pass") {
std::string optim_cache_dir = argument->optim_cache_dir(); std::string optim_cache_dir = argument->optim_cache_dir();
...@@ -302,42 +298,18 @@ void IRPassManager::CreatePasses(Argument *argument, ...@@ -302,42 +298,18 @@ void IRPassManager::CreatePasses(Argument *argument,
} }
std::unique_ptr<Graph> IRPassManager::Apply(std::unique_ptr<Graph> graph) { std::unique_ptr<Graph> IRPassManager::Apply(std::unique_ptr<Graph> graph) {
if (passes_.empty()) {
return graph;
}
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph.get(), graph.get(), platform::errors::InvalidArgument("Graph cannot be null."));
platform::errors::PreconditionNotMet("Graph cannot be NULL."));
// Apply all the passes // Apply all the passes
for (const auto &pass : passes_) { for (const auto &pass : passes_) {
if (pass->Type() != "graph_viz_pass" && !disable_logs_) { if (pass->Type() != "graph_viz_pass" && !disable_logs_) {
PrettyLogEndl(Style::H2(), "--- Running IR pass [%s]", pass->Type()); PrettyLogEndl(Style::H2(), "--- Running IR pass [%s]", pass->Type());
} }
// delete_fill_constant_op_pass is not apply under trt dynamic shape
if (pass->Type() == "delete_fill_constant_op_pass") {
bool use_dynamic = pass->Get<bool>("with_dynamic_shape");
if (use_dynamic) continue;
}
graph.reset(pass->Apply(graph.release())); graph.reset(pass->Apply(graph.release()));
} }
return graph; return graph;
} }
framework::proto::ProgramDesc IRPassManager::AcquireProgram(
std::unique_ptr<Graph> *graph, ProgramDesc *program) const {
auto pass =
framework::ir::PassRegistry::Instance().Get("graph_to_program_pass");
// Direct using ProgramDesc desc(argument->main_program()) may cause
// incomplete copies of information.
ProgramDesc desc;
desc.CopyFrom(*program->Proto());
pass->SetNotOwned("program", &desc);
auto *the_graph = graph->release();
graph->reset(pass->Apply(the_graph));
return *desc.Proto();
}
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -48,15 +48,9 @@ class IRPassManager final { ...@@ -48,15 +48,9 @@ class IRPassManager final {
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph); std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph);
framework::proto::ProgramDesc AcquireProgram(std::unique_ptr<Graph> *graph,
ProgramDesc *program) const;
framework::ir::Graph &graph() const { return *graph_; }
private: private:
void CreatePasses(Argument *argument, const std::vector<std::string> &passes); void CreatePasses(Argument *argument, const std::vector<std::string> &passes);
std::unique_ptr<Graph> graph_;
std::vector<std::unique_ptr<Pass>> passes_; std::vector<std::unique_ptr<Pass>> passes_;
bool disable_logs_{false}; bool disable_logs_{false};
}; };
......
...@@ -94,14 +94,14 @@ void OutputProcess(framework::ir::Graph *graph, ...@@ -94,14 +94,14 @@ void OutputProcess(framework::ir::Graph *graph,
backend, backend,
precision, precision,
blacklist)) { blacklist)) {
AddCastOp(graph, InsertCastOp(graph,
var_node, var_node,
next_op, next_op,
framework::proto::VarType::FP32, framework::proto::VarType::FP32,
to_type, to_type,
&suffix, block_desc,
block_desc, &suffix,
&var_to_cast_op_map); &var_to_cast_op_map);
var_node->Var()->SetDataType(framework::proto::VarType::FP32); var_node->Var()->SetDataType(framework::proto::VarType::FP32);
} }
} }
......
...@@ -13,7 +13,7 @@ cc_library( ...@@ -13,7 +13,7 @@ cc_library(
cc_library( cc_library(
convert_to_mixed_precision convert_to_mixed_precision
SRCS convert_to_mixed_precision.cc SRCS convert_to_mixed_precision.cc
DEPS analysis_pass ir_graph_build_pass) DEPS analysis_pass ir_graph_build_pass auto_mixed_precision_pass)
cc_library( cc_library(
ir_params_sync_among_devices_pass ir_params_sync_among_devices_pass
SRCS ir_params_sync_among_devices_pass.cc SRCS ir_params_sync_among_devices_pass.cc
...@@ -30,17 +30,6 @@ cc_library( ...@@ -30,17 +30,6 @@ cc_library(
inference_op_replace_pass inference_op_replace_pass
SRCS inference_op_replace_pass.cc SRCS inference_op_replace_pass.cc
DEPS analysis_pass graph_to_program_pass) DEPS analysis_pass graph_to_program_pass)
if(WITH_TESTING)
cc_library(
ir_graph_clean_pass
SRCS ir_graph_clean_pass.cc
DEPS analysis_pass gtest)
else()
cc_library(
ir_graph_clean_pass
SRCS ir_graph_clean_pass.cc
DEPS analysis_pass)
endif()
cc_library( cc_library(
analysis_passes analysis_passes
...@@ -52,8 +41,7 @@ cc_library( ...@@ -52,8 +41,7 @@ cc_library(
memory_optim_pass memory_optim_pass
convert_to_mixed_precision convert_to_mixed_precision
inference_op_replace_pass inference_op_replace_pass
ir_graph_to_program_pass ir_graph_to_program_pass)
ir_graph_clean_pass)
set(analysis_deps set(analysis_deps
${analysis_deps} analysis_passes subgraph_detector ${analysis_deps} analysis_passes subgraph_detector
......
...@@ -14,807 +14,88 @@ ...@@ -14,807 +14,88 @@
#include "paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h" #include "paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h"
#include <algorithm>
#include <iterator>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/ir/auto_mixed_precision_pass.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/inference/analysis/argument.h"
#include "paddle/fluid/inference/analysis/passes/ir_graph_clean_pass.h"
#include "paddle/fluid/inference/io.h" #include "paddle/fluid/inference/io.h"
#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/backend.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/tensor_meta.h"
using namespace paddle::framework; // NOLINT
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
namespace { ConvertToMixedPrecisionPass::ConvertToMixedPrecisionPass(
bool PhiKernelSupportPrecision( const std::string& model_file,
const std::string& op_type, const std::string& params_file,
const std::string& mixed_model_file,
const std::string& mixed_params_file,
phi::DataType mixed_precision,
phi::Backend backend, phi::Backend backend,
phi::DataType data_type, bool keep_io_types,
phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) { const std::unordered_set<std::string>& black_list)
auto kernels = phi::KernelFactory::Instance().kernels(); : model_file_(model_file),
if (kernels.find(op_type) == kernels.end()) { params_file_(params_file),
return false; mixed_model_file_(mixed_model_file),
} mixed_params_file_(mixed_params_file),
phi::KernelKey kernel_key(backend, layout, data_type); mixed_precision_(mixed_precision),
return phi::KernelFactory::Instance().HasKernel(op_type, kernel_key); backend_(backend),
} keep_io_types_(keep_io_types),
black_list_(black_list) {
bool GpuKernelSupportPrecision( if (mixed_precision_ != phi::DataType::FLOAT16 &&
const std::string& op_type, mixed_precision_ != phi::DataType::BFLOAT16) {
phi::DataType data_type, PADDLE_THROW(paddle::platform::errors::InvalidArgument(
phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) { "mixed_precision currently not supported dtype %d, we now only "
auto phi_op_type = phi::TransToPhiKernelName(op_type); "support fp16 and bf16.",
bool res = PhiKernelSupportPrecision( static_cast<int>(mixed_precision_)));
phi_op_type, phi::Backend::GPU, data_type, layout);
res |= PhiKernelSupportPrecision(
phi_op_type, phi::Backend::GPUDNN, data_type, layout);
if (!res) {
auto& all_kernels = OperatorWithKernel::AllOpKernels();
auto it = all_kernels.find(op_type);
if (it != all_kernels.end()) {
for (auto& kern_pair : it->second) {
if (platform::is_gpu_place(kern_pair.first.place_) &&
kern_pair.first.data_type_ == framework::proto::VarType::FP16) {
res = true;
}
}
}
}
return res;
}
class ConvertToMixedPrecisionPass {
public:
explicit ConvertToMixedPrecisionPass(
const std::string& model_file,
const std::string& params_file,
const std::string& mixed_model_file,
const std::string& mixed_params_file,
phi::DataType mixed_precision,
phi::Backend backend,
bool keep_io_types,
std::unordered_set<std::string> black_list)
: model_file_(model_file),
params_file_(params_file),
mixed_model_file_(mixed_model_file),
mixed_params_file_(mixed_params_file),
mixed_precision_(mixed_precision),
backend_(backend),
keep_io_types_(keep_io_types),
black_list_(black_list),
place_(paddle::CPUPlace()),
executor_(place_) {
black_list_.insert("assign");
black_list_.insert("fill_constant");
black_list_.insert("assign_value");
black_list_.insert("eye");
black_list_.insert("fill_any_like");
black_list_.insert("fill_constant_batch_size_like");
}
void Run();
private:
void LoadAndPrepare();
inline bool NodeVarHasDtype(framework::ir::Node* node);
void ConvertAllFp64ToFp32(framework::ir::Graph* graph);
void FixCastAttr(framework::ir::Graph* graph);
void SaveMixedModel();
void ConvertTensorDtype(int block_idx);
void ProcessInputNode(bool support_precision,
ir::Node* in_node,
ir::Node* op_node,
int* suffix,
framework::BlockDesc* block_desc,
framework::proto::VarType::Type to_type,
int block_idx);
void ProcessOutputNode(int block_idx,
ir::Node* var_node,
framework::proto::VarType::Type to_type);
inline bool IsFloatVarType(framework::proto::VarType::Type type);
bool OutShouldNotConvert(ir::Node* var_node);
// Just process special cases for weights conversion.
bool WeightsShouldNotConvert(ir::Node* var_node);
// To support multi block, we need to consider a lot of special cases.
// Return Node* which first appers in block.
framework::ir::Node* GetRealNode(int block_idx, framework::ir::Node* node);
void FindVarsInMultiBlock();
inline bool VarIsMultiPrecisionOpsOut(int block_idx,
framework::ir::Node* op_node);
private:
// A trick. Patch for strange op, which input name equal to output name, such
// as `fused_multi_transformer`
void PatchForStrangeOp();
private:
std::string model_file_;
std::string params_file_;
std::string mixed_model_file_;
std::string mixed_params_file_;
phi::DataType mixed_precision_;
phi::Backend backend_;
bool keep_io_types_;
std::unordered_set<std::string> black_list_;
paddle::CPUPlace place_;
framework::Executor executor_;
framework::Scope scope_;
std::unordered_map<framework::ir::Node*, framework::ir::Node*> cast_map_;
std::unordered_map<std::string,
std::pair<framework::proto::VarType::Type, int>>
vars_in_multi_block_map_;
std::vector<std::unordered_map<std::string, std::vector<std::string>>>
vars_appear_multi_in_one_block_;
int suffix_{0};
std::unique_ptr<framework::ProgramDesc> program_desc_{nullptr};
std::unique_ptr<framework::ir::Graph> main_graph_{nullptr};
std::vector<framework::ir::Graph*> graphes_;
};
framework::ir::Node* ConvertToMixedPrecisionPass::GetRealNode(
int block_idx, framework::ir::Node* node) {
if (vars_in_multi_block_map_.count(node->Name())) {
int var_origin_block_id = vars_in_multi_block_map_.at(node->Name()).second;
if (block_idx != var_origin_block_id) {
auto graph = graphes_[var_origin_block_id];
for (auto nd : graph->Nodes()) {
if (nd->Name() == node->Name()) {
return nd;
}
}
}
}
return node;
}
inline bool ConvertToMixedPrecisionPass::NodeVarHasDtype(
framework::ir::Node* node) {
if (node->IsVar() &&
(node->Var()->GetType() ==
paddle::framework::proto::VarType::SELECTED_ROWS ||
node->Var()->GetType() ==
paddle::framework::proto::VarType::LOD_TENSOR ||
node->Var()->GetType() ==
paddle::framework::proto::VarType::LOD_TENSOR_ARRAY ||
node->Var()->GetType() == paddle::framework::proto::VarType::STRINGS ||
node->Var()->GetType() == paddle::framework::proto::VarType::VOCAB)) {
return true;
}
return false;
}
// op1(fp32) -> var1, op2(fp16) -> var1
// if and only if op1 and op2 both support fp16, we convert op1 and op2's
// precision.
inline bool ConvertToMixedPrecisionPass::VarIsMultiPrecisionOpsOut(
int block_idx, framework::ir::Node* op_node) {
CHECK_EQ(op_node->IsOp(), true);
bool ret{false};
for (auto* out : op_node->outputs) {
auto* real_node = GetRealNode(block_idx, out);
if (!real_node->Var()->Persistable() &&
vars_appear_multi_in_one_block_[block_idx].count(out->Name())) {
for (auto op_type :
vars_appear_multi_in_one_block_[block_idx].at(out->Name())) {
if (OpSupportPrecision(
op_type, backend_, mixed_precision_, black_list_)) {
ret = true;
VLOG(2) << out->Name()
<< " is multi precision op's out, so we skip convert to fp16";
break;
}
}
}
if (ret) break;
}
return ret;
}
void ConvertToMixedPrecisionPass::ProcessInputNode(
bool support_precision,
ir::Node* in_node,
ir::Node* op_node,
int* suffix,
framework::BlockDesc* block_desc,
framework::proto::VarType::Type to_type,
int block_idx) {
auto* real_node = GetRealNode(block_idx, in_node);
if (!NodeVarHasDtype(real_node)) return;
auto graph = graphes_[block_idx];
bool is_main_block = block_idx == 0;
auto* in_var = real_node->Var();
auto in_var_type = in_var->GetDataType();
auto prev_type = in_var_type;
bool is_in_multi_block = vars_in_multi_block_map_.count(in_var->Name());
if (!is_main_block && is_in_multi_block) {
in_var_type = vars_in_multi_block_map_.at(in_var->Name()).first;
}
if (support_precision) {
if (in_var->Persistable() &&
in_var_type == framework::proto::VarType::FP32) {
if (WeightsShouldNotConvert(in_node)) return;
in_var->SetDataType(to_type);
in_var_type = to_type;
VLOG(3) << " in_node name " << in_var->Name() << " from " << prev_type
<< " to " << to_type;
} else if (!in_var->Persistable() && IsFloatVarType(in_var_type) &&
in_var_type != to_type) {
AddCastOp(graph,
in_node,
op_node,
in_var_type,
to_type,
suffix,
block_desc,
&cast_map_);
VLOG(3) << " in_node name " << in_var->Name() << "(" << prev_type
<< ") to " << cast_map_[in_node]->Name() << "(" << to_type << ")";
}
} else {
if (!in_var->Persistable() && IsFloatVarType(in_var_type) &&
in_var_type != to_type) {
AddCastOp(graph,
in_node,
op_node,
in_var_type,
to_type,
suffix,
block_desc,
&cast_map_);
VLOG(3) << " in_node name " << in_var->Name() << "(" << prev_type
<< ") to " << cast_map_[in_node]->Name() << "(" << to_type << ")";
}
} }
} if (backend_ != phi::Backend::GPU) {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
void ConvertToMixedPrecisionPass::ProcessOutputNode( "mixed_precision currently not supported place %d, we now only "
int block_idx, "support gpu.",
ir::Node* var_node, static_cast<int>(backend_)));
framework::proto::VarType::Type to_type) {
auto* real_node = GetRealNode(block_idx, var_node);
if (!NodeVarHasDtype(real_node)) return;
auto* out_var = real_node->Var();
auto prev_type = out_var->GetDataType();
if (out_var->GetDataType() == framework::proto::VarType::FP32) {
if (OutShouldNotConvert(var_node)) return;
out_var->SetDataType(to_type);
} }
VLOG(3) << " out_node name " << var_node->Name() << " from dtype "
<< prev_type << " to " << out_var->GetDataType();
} }
// Just process special cases. void ConvertToMixedPrecisionPass::LoadModel() {
bool ConvertToMixedPrecisionPass::OutShouldNotConvert(ir::Node* var_node) { framework::Executor exe{platform::CPUPlace{}};
auto op_node = var_node->inputs[0];
auto* op_desc = op_node->Op();
// batch_norm's input and output (variance and mean) are the same.
if (op_desc->Type() == "batch_norm") {
auto vecs = op_desc->Output("MeanOut");
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
}
vecs = op_desc->Output("VarianceOut");
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
}
vecs = op_desc->Output("SavedMean");
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
}
vecs = op_desc->Output("SavedVariance");
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
}
}
return false;
}
bool ConvertToMixedPrecisionPass::WeightsShouldNotConvert(ir::Node* var_node) { auto program_desc = inference::Load(&exe, &scope_, model_file_, params_file_);
auto op_nodes = var_node->outputs;
for (auto* op_node : op_nodes) {
auto* op_desc = op_node->Op();
// batch_norm op's bias, mean, scale and variance just be float32, so we can
// not convert the dtype.
if (op_desc->Type() == "batch_norm") {
auto vecs = op_desc->Input("Bias");
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
}
vecs = op_desc->Input("Mean");
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
}
vecs = op_desc->Input("Scale");
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
}
vecs = op_desc->Input("Variance");
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
}
} else if (op_desc->Type() == "fused_multi_transformer") {
auto vecs = op_desc->Input("LnScale");
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
}
vecs = op_desc->Input("LnBias");
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
}
vecs = op_desc->Input("FFNLnScale");
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
}
vecs = op_desc->Input("FFNLnBias");
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
}
}
}
return false;
}
inline bool ConvertToMixedPrecisionPass::IsFloatVarType(
framework::proto::VarType::Type type) {
if (type == framework::proto::VarType::FP16 ||
type == framework::proto::VarType::FP32 ||
type == framework::proto::VarType::BF16)
return true;
return false;
}
void ConvertToMixedPrecisionPass::LoadAndPrepare() {
program_desc_ =
inference::Load(&executor_, &scope_, model_file_, params_file_);
main_graph_ = std::unique_ptr<framework::ir::Graph>( main_graph_ = std::unique_ptr<framework::ir::Graph>(
new framework::ir::Graph(*program_desc_)); new framework::ir::Graph(*program_desc));
main_graph_->SetNotOwned(framework::ir::kParamScopeAttr, &scope_);
// Remove all control var
IrInferCleanGraphPass pass;
Argument arg;
arg.SetMainGraphNotOwned(main_graph_.get());
pass.Run(&arg);
vars_appear_multi_in_one_block_.resize(program_desc_->Size());
FindVarsInMultiBlock();
}
void ConvertToMixedPrecisionPass::FindVarsInMultiBlock() {
std::vector<std::set<std::string>> block_var_names_set(program_desc_->Size());
for (size_t i = 0; i < program_desc_->Size(); ++i) {
for (auto op : program_desc_->Block(i).AllOps()) {
auto in_names = op->InputArgumentNames();
block_var_names_set[i].insert(in_names.begin(), in_names.end());
auto out_names = op->OutputArgumentNames();
if (op->HasAttr("sub_block") == false) {
for (auto& n : out_names) {
if (block_var_names_set[i].count(n)) {
vars_appear_multi_in_one_block_[i][n].push_back(op->Type());
}
}
}
block_var_names_set[i].insert(out_names.begin(), out_names.end());
}
}
for (size_t i = 0; i < program_desc_->Size() - 1; ++i) {
for (size_t j = i + 1; j < program_desc_->Size(); ++j) {
std::set<std::string> vars_in_multi_block;
std::set_intersection(
block_var_names_set[i].begin(),
block_var_names_set[i].end(),
block_var_names_set[j].begin(),
block_var_names_set[j].end(),
std::inserter(vars_in_multi_block, vars_in_multi_block.begin()));
for (auto name : vars_in_multi_block) {
vars_in_multi_block_map_.emplace(
name, std::make_pair(framework::proto::VarType::FP32, i));
}
}
}
}
void ConvertToMixedPrecisionPass::ConvertAllFp64ToFp32(
framework::ir::Graph* graph) {
auto op_nodes = framework::ir::TopologySortOperations(*graph);
for (auto* op_node : op_nodes) {
if (!op_node->IsOp()) continue;
auto op_type = op_node->Op()->Type();
if (op_type == "feed" || op_type == "fetch") continue;
if (op_type == "fill_constant") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) ==
static_cast<int>(framework::proto::VarType::FP64))
op_node->Op()->SetAttr(
"dtype", static_cast<int>(framework::proto::VarType::FP32));
} else if (op_type == "assign_value") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) ==
static_cast<int>(framework::proto::VarType::FP64))
op_node->Op()->SetAttr(
"dtype", static_cast<int>(framework::proto::VarType::FP32));
} else if (op_type == "eye") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) ==
static_cast<int>(framework::proto::VarType::FP64))
op_node->Op()->SetAttr(
"dtype", static_cast<int>(framework::proto::VarType::FP32));
} else if (op_type == "fill_any_like") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) ==
static_cast<int>(framework::proto::VarType::FP64))
op_node->Op()->SetAttr(
"dtype", static_cast<int>(framework::proto::VarType::FP32));
} else if (op_type == "cast") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("in_dtype")) ==
static_cast<int>(framework::proto::VarType::FP64))
op_node->Op()->SetAttr(
"in_dtype", static_cast<int>(framework::proto::VarType::FP32));
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("out_dtype")) ==
static_cast<int>(framework::proto::VarType::FP64))
op_node->Op()->SetAttr(
"out_dtype", static_cast<int>(framework::proto::VarType::FP32));
}
auto inputs = op_node->inputs;
for (auto* in_node : inputs) {
auto* in_var = in_node->Var();
if (!in_var->Persistable() &&
in_var->GetDataType() == framework::proto::VarType::FP64) {
in_var->SetDataType(framework::proto::VarType::FP32);
}
}
}
} }
void ConvertToMixedPrecisionPass::Run() { void ConvertToMixedPrecisionPass::Run() {
LoadAndPrepare(); LoadModel();
for (size_t i = 0; i < main_graph_->SubGraphsSize(); ++i) { framework::ir::AutoMixedPrecisionPass pass;
auto graph = main_graph_->GetSubGraph(i); pass.Set("mixed_precision_mode", new int{static_cast<int>(mixed_precision_)});
graphes_.push_back(graph); pass.Set("mixed_black_list",
VLOG(2) << " -------- handle subgraph " << i << ", has " new std::unordered_set<std::string>{black_list_});
<< graph->Nodes().size() << " nodes --------"; pass.Set("enable_gpu_mixed", new bool{true});
pass.Set("keep_io_types", new bool{keep_io_types_});
ConvertAllFp64ToFp32(graph); pass.Apply(main_graph_.get());
ConvertTensorDtype(i);
FixCastAttr(graph);
// A trick
PatchForStrangeOp();
CHECK_EQ(ir::VarDescIsConsistency(*graph), true);
}
SaveMixedModel(); SaveMixedModel();
} }
void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
auto graph = graphes_[block_idx];
framework::proto::VarType::Type to_type;
if (mixed_precision_ == phi::DataType::FLOAT16) {
to_type = framework::proto::VarType::FP16;
} else if (mixed_precision_ == phi::DataType::BFLOAT16) {
to_type = framework::proto::VarType::BF16;
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"mixed_precision currently not supported dtype %d, we now only "
"support fp16 and bf16.",
static_cast<int>(mixed_precision_)));
}
auto op_nodes = framework::ir::TopologySortOperations(*graph);
auto* block_desc = op_nodes[0]->Op()->Block();
int num_low_precision = 0;
std::vector<framework::ir::Node*> output_nodes;
for (auto* op_node : op_nodes) {
if (!op_node->IsOp()) continue;
auto op_type = op_node->Op()->Type();
VLOG(3) << "-------------------- op_type " << op_type << ", phi_type "
<< phi::TransToPhiKernelName(op_type);
// 1. set input dtype.
if (op_type == "feed") {
auto feed_var = op_node->outputs[0]->Var();
if (!keep_io_types_ &&
feed_var->GetDataType() == framework::proto::VarType::FP32) {
feed_var->SetDataType(to_type);
}
} else if (op_type == "fetch") {
auto* fetch_var = op_node->inputs[0];
output_nodes.push_back(fetch_var);
continue;
} else if (op_type == "cast") {
continue;
}
else if (op_node->Op()->HasAttr("sub_block")) { // NOLINT
// sub_block op's output dtype should be same as input dtype, if have the
// same name.
std::unordered_map<std::string, framework::ir::Node*> in_name_to_node;
for (auto* in : op_node->inputs) {
auto* real_node = GetRealNode(block_idx, in);
if (NodeVarHasDtype(real_node)) {
in_name_to_node[in->Name()] = in;
}
}
for (auto out : op_node->outputs) {
auto* real_node = GetRealNode(block_idx, out);
if (NodeVarHasDtype(real_node)) {
if (in_name_to_node.count(out->Name()))
real_node->Var()->SetDataType(
in_name_to_node[out->Name()]->Var()->GetDataType());
}
}
continue;
}
// 2. if op support fp16/bf16 and not in blacklist.
// - cast weight to fp16/bf16.
// - add cast op if the input dtype is not fp16/bf16.
// - set output dtype.
//
// If a var(op's out var) appears multiple times in a block, we should not
// convert to fp16.
else if (black_list_.count(op_type) == 0 && // NOLINT
!VarIsMultiPrecisionOpsOut(block_idx, op_node)) {
bool support_precision =
OpSupportPrecision(op_type, backend_, mixed_precision_, black_list_);
// if op not has float input, we will not choose the low precision kernel.
{
bool has_float_input{false};
for (auto in_node : op_node->inputs) {
auto* real_node = GetRealNode(block_idx, in_node);
if (real_node->Var()->GetDataType() == proto::VarType::FP16 ||
real_node->Var()->GetDataType() == proto::VarType::FP32 ||
real_node->Var()->GetDataType() == proto::VarType::FP64 ||
real_node->Var()->GetDataType() == proto::VarType::BF16) {
has_float_input = true;
break;
}
}
if (!has_float_input) {
support_precision = false;
VLOG(2) << " op doesn't has float input, just skip.";
}
}
VLOG(2) << " support low precision " << support_precision;
if (support_precision) {
VLOG(2) << " process input nodes:";
++num_low_precision;
auto inputs = op_node->inputs;
// Just for paddle's terriable case: op's input and output has the same
// name.
std::unordered_map<std::string, std::string> names_map;
for (auto out_node : op_node->outputs) {
for (auto in_node : op_node->inputs) {
if (out_node->Name() == in_node->Name()) {
names_map[out_node->Name()] = in_node->Name();
}
}
}
// Process inputs.
for (auto* in_node : inputs) {
ProcessInputNode(
true, in_node, op_node, &suffix_, block_desc, to_type, block_idx);
if (names_map.count(in_node->Name()) && cast_map_.count(in_node)) {
names_map[in_node->Name()] = cast_map_[in_node]->Name();
}
}
VLOG(2) << " process output nodes:";
// Process outputs.
for (auto* out_node : op_node->outputs) {
ProcessOutputNode(block_idx, out_node, to_type);
}
} else {
auto inputs = op_node->inputs;
for (auto* in_node : inputs) {
ProcessInputNode(false,
in_node,
op_node,
&suffix_,
block_desc,
framework::proto::VarType::FP32,
block_idx);
}
}
}
// 3. check op not support fp16/bf16 or in blacklist.
// - add cast op if the input dtype is not fp32.
else { // NOLINT
VLOG(3) << "not to run fp16 op_type: " << op_type;
auto ins = op_node->inputs;
for (auto* in_node : ins) {
auto* in_var = in_node->Var();
if (in_var->GetDataType() == to_type) {
AddCastOp(graph,
in_node,
op_node,
to_type,
framework::proto::VarType::FP32,
&suffix_,
block_desc,
&cast_map_);
VLOG(3) << "-- " << in_node->Name() << "(" << to_type << ") to "
<< cast_map_[in_node]->Name() << "("
<< framework::proto::VarType::FP32 << ")";
}
}
}
}
// 4. if output_op's dtype is not compatible to output dtype, then just
// insert cast.
for (auto* node : output_nodes) {
ir::Node* fetch_op{nullptr};
for (auto* op_node : node->outputs) {
if (op_node->IsOp() && op_node->Op()->Type() == "fetch") {
fetch_op = op_node;
}
}
CHECK_NOTNULL(fetch_op);
auto var = node->Var();
if (keep_io_types_ && var->GetDataType() == to_type) {
// fp16/bf16 -> fp32.
AddCastOp(graph,
node,
fetch_op,
to_type,
framework::proto::VarType::FP32,
&suffix_,
block_desc,
&cast_map_);
} else if (!keep_io_types_ &&
var->GetDataType() == framework::proto::VarType::FP32) {
// fp32 -> fp16/bf16
AddCastOp(graph,
node,
fetch_op,
framework::proto::VarType::FP32,
to_type,
&suffix_,
block_desc,
&cast_map_);
}
}
for (auto node : graph->Nodes()) {
auto* real_node = GetRealNode(block_idx, node);
if (!NodeVarHasDtype(real_node)) continue;
if (vars_in_multi_block_map_.count(real_node->Name()) &&
vars_in_multi_block_map_.at(real_node->Name()).second == block_idx) {
vars_in_multi_block_map_.at(real_node->Name()).first =
real_node->Var()->GetDataType();
}
}
if (num_low_precision)
LOG(INFO) << "--- detected " << num_low_precision
<< " low precision ops in " << block_idx << " subgraph";
}
// We modify op's input output precision, and we need to fix cast op in_dtype
// and out_dtype attribute.
void ConvertToMixedPrecisionPass::FixCastAttr(framework::ir::Graph* graph) {
auto op_nodes = framework::ir::TopologySortOperations(*graph);
for (auto* op_node : op_nodes) {
if (!op_node->IsOp()) continue;
auto op_type = op_node->Op()->Type();
if (op_type != "cast") continue;
auto input = op_node->inputs[0];
auto output = op_node->outputs[0];
op_node->Op()->SetAttr("in_dtype",
static_cast<int>(input->Var()->GetDataType()));
op_node->Op()->SetAttr("out_dtype",
static_cast<int>(output->Var()->GetDataType()));
}
}
void ConvertToMixedPrecisionPass::SaveMixedModel() { void ConvertToMixedPrecisionPass::SaveMixedModel() {
framework::ProgramDesc mixed_program_desc; framework::ProgramDesc mixed_program_desc;
framework::ir::GraphToProgram(*main_graph_, &mixed_program_desc); framework::ir::GraphToProgram(*main_graph_, &mixed_program_desc);
paddle::CPUPlace place;
auto parameters = scope_.LocalVarNames(); auto parameters = scope_.LocalVarNames();
std::sort(parameters.begin(), parameters.end()); std::sort(parameters.begin(), parameters.end());
std::unordered_set<std::string> weights_should_be_fp32;
for (auto* node : main_graph_->Nodes()) {
if (!(node->IsVar())) continue;
if (NodeVarHasDtype(node)) {
if (node->Var()->Persistable() &&
node->Var()->GetDataType() ==
paddle::framework::proto::VarType::FP32) {
VLOG(2) << "weights keep to fp32: " << node->Name();
weights_should_be_fp32.insert(node->Name());
}
}
}
#define CONVERT_TENSOR_DTYPE(DTYPE, dtype) \
mixed_tensor.set_type(DTYPE); \
auto* mixed_data = mixed_tensor.mutable_data<dtype>(platform::CPUPlace()); \
for (int i = 0; i < t->numel(); i++) { \
mixed_data[i] = static_cast<dtype>(data[i]); \
} \
t->clear(); \
paddle::framework::TensorCopySync(mixed_tensor, place, t)
for (const auto& param_name : parameters) {
auto* var = scope_.FindLocalVar(param_name);
if (var->IsType<phi::DenseTensor>()) {
auto* t = var->GetMutable<phi::DenseTensor>();
if (t->dtype() != phi::DataType::FLOAT32) continue;
phi::DenseTensor mixed_tensor;
mixed_tensor.Resize(t->dims());
auto* data = t->mutable_data<float>(platform::CPUPlace());
if (mixed_precision_ == phi::DataType::FLOAT16 &&
!weights_should_be_fp32.count(param_name)) {
CONVERT_TENSOR_DTYPE(paddle::experimental::DataType::FLOAT16,
phi::dtype::float16);
} else if (mixed_precision_ == phi::DataType::BFLOAT16 &&
!weights_should_be_fp32.count(param_name)) {
CONVERT_TENSOR_DTYPE(paddle::experimental::DataType::BFLOAT16,
phi::dtype::bfloat16);
}
}
}
#undef CONVERT_TENSOR_DTYPE
auto SerializeParams = [&]() -> std::string { auto SerializeParams = [&]() -> std::string {
std::ostringstream os; std::ostringstream os;
phi::CPUContext ctx; phi::CPUContext ctx;
for (const auto& param : parameters) { for (const auto& param : parameters) {
VLOG(3) << "Serialize param: " << param;
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
scope_.FindVar(param), scope_.FindVar(param),
platform::errors::NotFound( platform::errors::NotFound(
"Block should already have a '%s' variable", param)); "Block should already have a '%s' variable", param));
auto* tensor = scope_.FindVar(param)->GetMutable<framework::LoDTensor>(); auto* tensor = scope_.FindVar(param)->GetMutable<phi::DenseTensor>();
framework::SerializeToStream(os, *tensor, ctx); framework::SerializeToStream(os, *tensor, ctx);
} }
return os.str(); return os.str();
...@@ -831,96 +112,42 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() { ...@@ -831,96 +112,42 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() {
StrToBinary(mixed_params_file_, SerializeParams()); StrToBinary(mixed_params_file_, SerializeParams());
} }
void ConvertToMixedPrecisionPass::PatchForStrangeOp() { bool OpSupportPrecision(const std::string& op_type,
for (auto* graph : graphes_) { phi::Backend backend,
for (auto op_node : framework::ir::TopologySortOperations(*graph)) { phi::DataType precision,
if (op_node->Name() == "fused_multi_transformer") { const std::unordered_set<std::string>& black_list) {
auto cache_kv_inputs = op_node->Op()->Input("CacheKV"); return framework::ir::OpSupportPrecision(
auto cache_kv_outputs = op_node->Op()->Output("CacheKVOut"); op_type, backend, precision, black_list);
CHECK_EQ(cache_kv_inputs.size(), cache_kv_outputs.size());
for (size_t i = 0; i < cache_kv_inputs.size(); ++i) {
op_node->Op()->RenameOutput(cache_kv_outputs[i], cache_kv_inputs[i]);
}
}
}
}
} }
} // namespace
void AddCastOp( void InsertCastOp(
framework::ir::Graph* graph, framework::ir::Graph* graph,
framework::ir::Node* node, framework::ir::Node* var_node,
framework::ir::Node* next_op, framework::ir::Node* op_node,
framework::proto::VarType::Type from_type, framework::proto::VarType::Type from_type,
framework::proto::VarType::Type to_type, framework::proto::VarType::Type to_type,
int* suffix,
framework::BlockDesc* block_desc, framework::BlockDesc* block_desc,
std::unordered_map<framework::ir::Node*, framework::ir::Node*>* map) { int* suffix,
auto update_cast_desc = [&](framework::OpDesc& desc, std::unordered_map<framework::ir::Node*, framework::ir::Node*>* visited) {
const std::string& x_name, framework::ir::DoInsertCastOp(graph,
const std::string& out_name, var_node,
const int in_dtype, op_node,
const int out_dtype) { from_type,
desc.SetType("cast"); to_type,
desc.SetInput("X", {x_name}); block_desc,
desc.SetOutput("Out", {out_name}); suffix,
desc.SetAttr("in_dtype", in_dtype); visited);
desc.SetAttr("out_dtype", out_dtype); }
desc.SetAttr("use_mkldnn", false);
desc.SetAttr("with_quant_attr", false); void ConvertToMixedPrecision(
desc.Flush(); const std::string& model_file,
}; const std::string& params_file,
const std::string& mixed_model_file,
if (map->count(node) == 0) { const std::string& mixed_params_file,
// insert cast op before node. phi::DataType mixed_precision,
std::string cast_input_name = node->Var()->Name(); phi::Backend backend,
std::string cast_output_name = bool keep_io_types,
node->Var()->Name() + "_cast.tmp_" + std::to_string((*suffix)++); const std::unordered_set<std::string>& black_list) {
CHECK_NOTNULL(block_desc);
framework::OpDesc cast_op_desc(block_desc);
update_cast_desc(cast_op_desc,
cast_input_name,
cast_output_name,
static_cast<int>(from_type),
static_cast<int>(to_type));
auto* cast_op_node = graph->CreateOpNode(&cast_op_desc);
auto* cast_output_vardesc = block_desc->Var(cast_output_name);
cast_output_vardesc->SetPersistable(false);
cast_output_vardesc->SetDataType(to_type);
cast_output_vardesc->SetShape(node->Var()->GetShape());
auto* cast_output_node = graph->CreateVarNode(cast_output_vardesc);
IR_NODE_LINK_TO(cast_op_node, cast_output_node);
(*map)[node] = cast_output_node;
}
next_op->Op()->Rename(node->Name(), map->at(node)->Name());
IR_NODE_LINK_TO(node, map->at(node)->inputs[0]);
IR_NODE_LINK_TO(map->at(node), next_op);
}
bool OpSupportPrecision(const std::string& op_type,
phi::Backend backend,
phi::DataType precision,
const std::unordered_set<std::string>& blacklist) {
auto phi_op_type = phi::TransToPhiKernelName(op_type);
bool support_precision = false;
if (blacklist.count(op_type) == 0) {
if (backend == phi::Backend::GPU)
support_precision = GpuKernelSupportPrecision(op_type, precision);
else
support_precision =
PhiKernelSupportPrecision(phi_op_type, backend, precision);
}
return support_precision;
}
void ConvertToMixedPrecision(const std::string& model_file,
const std::string& params_file,
const std::string& mixed_model_file,
const std::string& mixed_params_file,
phi::DataType mixed_precision,
phi::Backend backend,
bool keep_io_types,
std::unordered_set<std::string> black_list) {
ConvertToMixedPrecisionPass pass(model_file, ConvertToMixedPrecisionPass pass(model_file,
params_file, params_file,
mixed_model_file, mixed_model_file,
......
...@@ -15,14 +15,12 @@ ...@@ -15,14 +15,12 @@
#pragma once #pragma once
#include <string> #include <string>
#include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/phi/common/backend.h" #include "paddle/phi/common/backend.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
...@@ -30,20 +28,52 @@ namespace paddle { ...@@ -30,20 +28,52 @@ namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
class ConvertToMixedPrecisionPass {
public:
explicit ConvertToMixedPrecisionPass(
const std::string& model_file,
const std::string& params_file,
const std::string& mixed_model_file,
const std::string& mixed_params_file,
phi::DataType mixed_precision,
phi::Backend backend,
bool keep_io_types,
const std::unordered_set<std::string>& black_list);
void Run();
private:
void LoadModel();
void SaveMixedModel();
private:
std::string model_file_;
std::string params_file_;
std::string mixed_model_file_;
std::string mixed_params_file_;
phi::DataType mixed_precision_;
phi::Backend backend_;
bool keep_io_types_;
std::unordered_set<std::string> black_list_;
framework::Scope scope_;
std::unique_ptr<framework::ir::Graph> main_graph_{nullptr};
};
bool OpSupportPrecision(const std::string& op_type, bool OpSupportPrecision(const std::string& op_type,
phi::Backend backend, phi::Backend backend,
phi::DataType precision, phi::DataType precision,
const std::unordered_set<std::string>& blacklist); const std::unordered_set<std::string>& black_list);
void AddCastOp( void InsertCastOp(
framework::ir::Graph* graph, framework::ir::Graph* graph,
framework::ir::Node* node, framework::ir::Node* var_node,
framework::ir::Node* next_op, framework::ir::Node* op_node,
framework::proto::VarType::Type from_type, framework::proto::VarType::Type from_type,
framework::proto::VarType::Type to_type, framework::proto::VarType::Type to_type,
int* suffix,
framework::BlockDesc* block_desc, framework::BlockDesc* block_desc,
std::unordered_map<framework::ir::Node*, framework::ir::Node*>* map); int* suffix,
std::unordered_map<framework::ir::Node*, framework::ir::Node*>* visited);
void ConvertToMixedPrecision(const std::string& model_file, void ConvertToMixedPrecision(const std::string& model_file,
const std::string& params_file, const std::string& params_file,
...@@ -51,8 +81,8 @@ void ConvertToMixedPrecision(const std::string& model_file, ...@@ -51,8 +81,8 @@ void ConvertToMixedPrecision(const std::string& model_file,
const std::string& mixed_params_file, const std::string& mixed_params_file,
phi::DataType mixed_precision, phi::DataType mixed_precision,
phi::Backend backend, phi::Backend backend,
bool keep_io_types = true, bool keep_io_types,
std::unordered_set<std::string> black_list = {}); const std::unordered_set<std::string>& black_list);
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
......
...@@ -40,7 +40,7 @@ void InferenceOpReplacePass::RunImpl(Argument* argument) { ...@@ -40,7 +40,7 @@ void InferenceOpReplacePass::RunImpl(Argument* argument) {
} }
std::string InferenceOpReplacePass::repr() const { std::string InferenceOpReplacePass::repr() const {
return "inference-op-replace-pass"; return "inference_op_replace_pass";
} }
} // namespace analysis } // namespace analysis
......
...@@ -105,7 +105,7 @@ void IrAnalysisPass::CollectFusionStatis(Argument* argument) { ...@@ -105,7 +105,7 @@ void IrAnalysisPass::CollectFusionStatis(Argument* argument) {
framework::ir::kFuseStatisAttr)); framework::ir::kFuseStatisAttr));
} }
std::string IrAnalysisPass::repr() const { return "ir-analysis-pass"; } std::string IrAnalysisPass::repr() const { return "ir_analysis_pass"; }
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
......
...@@ -64,7 +64,8 @@ void IrGraphBuildPass::RunImpl(Argument *argument) { ...@@ -64,7 +64,8 @@ void IrGraphBuildPass::RunImpl(Argument *argument) {
"set.")); "set."));
} }
auto graph = std::unique_ptr<Graph>(new Graph(argument->main_program())); auto graph = std::unique_ptr<framework::ir::Graph>(
new framework::ir::Graph(argument->main_program()));
argument->SetMainGraph(graph.release()); argument->SetMainGraph(graph.release());
auto *scope_ptr = argument->scope_ptr(); auto *scope_ptr = argument->scope_ptr();
PADDLE_ENFORCE_NOT_NULL(scope_ptr, PADDLE_ENFORCE_NOT_NULL(scope_ptr,
...@@ -125,7 +126,7 @@ std::unique_ptr<framework::ProgramDesc> IrGraphBuildPass::LoadModel( ...@@ -125,7 +126,7 @@ std::unique_ptr<framework::ProgramDesc> IrGraphBuildPass::LoadModel(
} }
} }
std::string IrGraphBuildPass::repr() const { return "ir-graph-build-pass"; } std::string IrGraphBuildPass::repr() const { return "ir_graph_build_pass"; }
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/passes/ir_graph_clean_pass.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/node.h"
namespace paddle {
namespace inference {
namespace analysis {
void IrInferCleanGraphPass::RunImpl(Argument* argument) {
auto& graph = argument->main_graph();
auto is_valid_node = [](framework::ir::Node* x) {
return x && IsControlDepVar(*x) && x->IsVar() && !x->Var();
};
std::unordered_set<const framework::ir::Node*> invalid_nodes;
int valid_op = 0;
for (auto* node : graph.Nodes()) {
PADDLE_ENFORCE_NOT_NULL(node,
platform::errors::PreconditionNotMet(
"The node should not be nullptr."));
if (is_valid_node(node)) {
invalid_nodes.insert(node);
} else if (node->IsOp()) {
++valid_op;
}
}
GraphSafeRemoveNodes(&graph, invalid_nodes);
}
} // namespace analysis
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_set>
#include "paddle/fluid/inference/analysis/analysis_pass.h"
namespace paddle {
namespace inference {
namespace analysis {
struct Argument;
class IrInferCleanGraphPass : public AnalysisPass {
public:
void RunImpl(Argument *argument) override;
std::string repr() const override { return "ir_graph_clean_pass"; }
};
} // namespace analysis
} // namespace inference
} // namespace paddle
...@@ -31,7 +31,7 @@ void IrGraphToProgramPass::RunImpl(Argument *argument) { ...@@ -31,7 +31,7 @@ void IrGraphToProgramPass::RunImpl(Argument *argument) {
new int(argument->memory_optim_sort_kind())); new int(argument->memory_optim_sort_kind()));
} }
std::unique_ptr<Graph> graph(argument->main_graph_ptr()); std::unique_ptr<framework::ir::Graph> graph(argument->main_graph_ptr());
// Direct using ProgramDesc desc(argument->main_program()) may cause // Direct using ProgramDesc desc(argument->main_program()) may cause
// incomplete copies of information. // incomplete copies of information.
......
...@@ -28,7 +28,7 @@ class IrGraphToProgramPass : public AnalysisPass { ...@@ -28,7 +28,7 @@ class IrGraphToProgramPass : public AnalysisPass {
public: public:
void RunImpl(Argument *argument) override; void RunImpl(Argument *argument) override;
std::string repr() const override { return "ir-graph-to-param-pass"; } std::string repr() const override { return "ir_graph_to_param_pass"; }
}; };
} // namespace analysis } // namespace analysis
......
...@@ -169,7 +169,7 @@ void IrParamsSyncAmongDevicesPass::RunImpl(Argument *argument) { ...@@ -169,7 +169,7 @@ void IrParamsSyncAmongDevicesPass::RunImpl(Argument *argument) {
} }
std::string IrParamsSyncAmongDevicesPass::repr() const { std::string IrParamsSyncAmongDevicesPass::repr() const {
return "ir-params-sync-among-devices-pass"; return "ir_params_sync_among_devices_pass";
} }
} // namespace analysis } // namespace analysis
......
...@@ -295,7 +295,7 @@ void UpdateOpDescsByReuse( ...@@ -295,7 +295,7 @@ void UpdateOpDescsByReuse(
} }
} }
std::string MemoryOptimizePass::repr() const { return "memory optimize pass"; } std::string MemoryOptimizePass::repr() const { return "memory_optimize_pass"; }
void MemoryOptimizePass::RunImpl(Argument* argument) { void MemoryOptimizePass::RunImpl(Argument* argument) {
// Memory optimization. // Memory optimization.
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
#include "paddle/fluid/inference/analysis/passes/inference_op_replace_pass.h" #include "paddle/fluid/inference/analysis/passes/inference_op_replace_pass.h"
#include "paddle/fluid/inference/analysis/passes/ir_analysis_pass.h" #include "paddle/fluid/inference/analysis/passes/ir_analysis_pass.h"
#include "paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h" #include "paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h"
#include "paddle/fluid/inference/analysis/passes/ir_graph_clean_pass.h"
#include "paddle/fluid/inference/analysis/passes/ir_graph_to_program_pass.h" #include "paddle/fluid/inference/analysis/passes/ir_graph_to_program_pass.h"
#include "paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.h" #include "paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.h"
#include "paddle/fluid/inference/analysis/passes/memory_optimize_pass.h" #include "paddle/fluid/inference/analysis/passes/memory_optimize_pass.h"
...@@ -34,8 +33,6 @@ PassRegistry::PassRegistry() { ...@@ -34,8 +33,6 @@ PassRegistry::PassRegistry() {
std::unique_ptr<AnalysisPass>(new IrAnalysisPass)); std::unique_ptr<AnalysisPass>(new IrAnalysisPass));
passes_.emplace("ir_graph_build_pass", passes_.emplace("ir_graph_build_pass",
std::unique_ptr<AnalysisPass>(new IrGraphBuildPass)); std::unique_ptr<AnalysisPass>(new IrGraphBuildPass));
passes_.emplace("ir_graph_clean_pass",
std::unique_ptr<AnalysisPass>(new IrInferCleanGraphPass));
passes_.emplace("memory_optimize_pass", passes_.emplace("memory_optimize_pass",
std::unique_ptr<AnalysisPass>(new MemoryOptimizePass)); std::unique_ptr<AnalysisPass>(new MemoryOptimizePass));
passes_.emplace( passes_.emplace(
......
...@@ -85,15 +85,29 @@ void AnalysisConfig::SetModel(const std::string &prog_file_path, ...@@ -85,15 +85,29 @@ void AnalysisConfig::SetModel(const std::string &prog_file_path,
Update(); Update();
} }
void AnalysisConfig::EnableUseGpu(uint64_t memory_pool_init_size_mb, void AnalysisConfig::EnableUseGpu(uint64_t memory_pool_init_size_mb,
int device_id) { int device_id,
Precision precision_mode) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
use_gpu_ = true; use_gpu_ = true;
memory_pool_init_size_mb_ = memory_pool_init_size_mb; memory_pool_init_size_mb_ = memory_pool_init_size_mb;
FLAGS_initial_gpu_memory_in_mb = memory_pool_init_size_mb_; FLAGS_initial_gpu_memory_in_mb = memory_pool_init_size_mb_;
gpu_device_id_ = device_id; gpu_device_id_ = device_id;
mixed_precision_mode_ = precision_mode;
if (precision_mode == Precision::kFloat32) {
// default
} else if (precision_mode == Precision::kHalf ||
precision_mode == Precision::kBf16) {
enable_gpu_mixed_ = true;
} else {
LOG(ERROR)
<< "The Paddle-GPU inference currently only supports "
"float32/float16/bfloat16 precision. Please check the parameters "
"you specified in EnableUseGpu or enable_use_gpu function.";
}
#else #else
LOG(ERROR) << "Please compile with gpu to EnableGpu()"; LOG(ERROR) << "Please use PaddlePaddle with GPU version.";
use_gpu_ = false; use_gpu_ = false;
#endif #endif
...@@ -279,7 +293,7 @@ void AnalysisConfig::LoadIpuConfig(const std::string &config_path) { ...@@ -279,7 +293,7 @@ void AnalysisConfig::LoadIpuConfig(const std::string &config_path) {
if (ipu_config_mapper_.find(key) == ipu_config_mapper_.end()) { if (ipu_config_mapper_.find(key) == ipu_config_mapper_.end()) {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"invalid key {} in IPU config", key)); "invalid key %s in IPU config: ", key));
} }
switch (ipu_config_mapper_.at(key)) { switch (ipu_config_mapper_.at(key)) {
case ipu_config_code::ipu_device_num: case ipu_config_code::ipu_device_num:
...@@ -315,7 +329,7 @@ void AnalysisConfig::LoadIpuConfig(const std::string &config_path) { ...@@ -315,7 +329,7 @@ void AnalysisConfig::LoadIpuConfig(const std::string &config_path) {
default: default:
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"invalid key {} in IPU config", key)); "invalid key %s in IPU config", key));
break; break;
} }
} }
...@@ -372,8 +386,10 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { ...@@ -372,8 +386,10 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
CP_MEMBER(gpu_device_id_); CP_MEMBER(gpu_device_id_);
CP_MEMBER(memory_pool_init_size_mb_); CP_MEMBER(memory_pool_init_size_mb_);
// Mixed related. // Mixed precision related.
CP_MEMBER(mixed_black_list_); CP_MEMBER(mixed_black_list_);
CP_MEMBER(enable_gpu_mixed_);
CP_MEMBER(mixed_precision_mode_);
CP_MEMBER(enable_memory_optim_); CP_MEMBER(enable_memory_optim_);
// TensorRT related. // TensorRT related.
...@@ -740,13 +756,7 @@ void AnalysisConfig::Update() { ...@@ -740,13 +756,7 @@ void AnalysisConfig::Update() {
((use_custom_device() ^ pass_builder_->use_custom_device()))) { ((use_custom_device() ^ pass_builder_->use_custom_device()))) {
if (use_gpu()) { if (use_gpu()) {
pass_builder_.reset(new GpuPassStrategy); pass_builder_.reset(new GpuPassStrategy);
if (use_tensorrt_) {
// Append after the Affine_channel_conv_fuse pass.
pass_builder()->InsertPass(3, "tensorrt_subgraph_pass");
}
} else if (use_ipu()) { } else if (use_ipu()) {
VLOG(1) << "IpuPassStrategy has been used for new.";
pass_builder_.reset(new IpuPassStrategy); pass_builder_.reset(new IpuPassStrategy);
} else if (use_xpu()) { } else if (use_xpu()) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -946,9 +956,6 @@ void AnalysisConfig::Update() { ...@@ -946,9 +956,6 @@ void AnalysisConfig::Update() {
"but did not have the option -DWITH_CUSTOM_DEVICE compiled.")); "but did not have the option -DWITH_CUSTOM_DEVICE compiled."));
#endif #endif
} }
if (ir_debug_) {
pass_builder()->TurnOnDebug();
}
} }
std::string AnalysisConfig::SerializeInfoCache() { std::string AnalysisConfig::SerializeInfoCache() {
...@@ -960,6 +967,7 @@ std::string AnalysisConfig::SerializeInfoCache() { ...@@ -960,6 +967,7 @@ std::string AnalysisConfig::SerializeInfoCache() {
ss << calibration_file_path_; ss << calibration_file_path_;
ss << use_gpu_; ss << use_gpu_;
ss << enable_gpu_mixed_;
ss << use_external_stream_; ss << use_external_stream_;
ss << exec_stream_; ss << exec_stream_;
ss << use_fc_padding_; ss << use_fc_padding_;
...@@ -1167,6 +1175,7 @@ std::string AnalysisConfig::Summary() { ...@@ -1167,6 +1175,7 @@ std::string AnalysisConfig::Summary() {
os.InsertRow({"use_gpu", use_gpu_ ? "true" : "false"}); os.InsertRow({"use_gpu", use_gpu_ ? "true" : "false"});
if (use_gpu_) { if (use_gpu_) {
os.InsertRow({"gpu_device_id", std::to_string(gpu_device_id_)}); os.InsertRow({"gpu_device_id", std::to_string(gpu_device_id_)});
os.InsertRow({"enable_gpu_mixed", std::to_string(enable_gpu_mixed_)});
os.InsertRow({"memory_pool_init_size", os.InsertRow({"memory_pool_init_size",
std::to_string(memory_pool_init_size_mb_) + "MB"}); std::to_string(memory_pool_init_size_mb_) + "MB"});
os.InsertRow( os.InsertRow(
...@@ -1360,7 +1369,7 @@ bool AnalysisConfig::trt_allow_build_at_runtime() { ...@@ -1360,7 +1369,7 @@ bool AnalysisConfig::trt_allow_build_at_runtime() {
return trt_allow_build_at_runtime_; return trt_allow_build_at_runtime_;
} }
void AnalysisConfig::Exp_SetBlackListOpsForMixedModel( void AnalysisConfig::Exp_DisableMixedPrecisionOps(
const std::unordered_set<std::string> &black_list) { const std::unordered_set<std::string> &black_list) {
mixed_black_list_ = black_list; mixed_black_list_ = black_list;
} }
......
...@@ -1065,7 +1065,7 @@ void AnalysisPredictor::PrepareArgument() { ...@@ -1065,7 +1065,7 @@ void AnalysisPredictor::PrepareArgument() {
argument_.SetUseGPU(config_.use_gpu()); argument_.SetUseGPU(config_.use_gpu());
argument_.SetUseFcPadding(config_.use_fc_padding()); argument_.SetUseFcPadding(config_.use_fc_padding());
argument_.SetGPUDeviceId(config_.gpu_device_id()); argument_.SetGPUDeviceId(config_.gpu_device_id());
argument_.SetEnableAnalysisOptim(config_.enable_ir_optim_); argument_.SetEnableIrOptim(config_.enable_ir_optim_);
argument_.SetEnableMemoryOptim(config_.enable_memory_optim()); argument_.SetEnableMemoryOptim(config_.enable_memory_optim());
argument_.SetModelFromMemory(config_.model_from_memory_); argument_.SetModelFromMemory(config_.model_from_memory_);
// Analyze inference_program // Analyze inference_program
...@@ -1210,53 +1210,57 @@ void AnalysisPredictor::PrepareArgument() { ...@@ -1210,53 +1210,57 @@ void AnalysisPredictor::PrepareArgument() {
} }
#endif #endif
auto passes = config_.pass_builder()->AllPasses(); auto *pass_builder = config_.pass_builder();
if (model_precision_ != phi::DataType::FLOAT32) { if (model_precision_ != phi::DataType::FLOAT32) {
LOG(INFO) << "Model is mixed precision type with " << model_precision_ LOG(INFO) << "Model is mixed precision type with " << model_precision_
<< ", we will use a new PassStrategy. Note that only the GPU " << ", we will use a new PassStrategy. Note that only the GPU "
"backend is supported for now."; "backend is supported for now.";
passes.clear(); pass_builder->ClearPasses();
const auto &deleted_passes = pass_builder->GetAllDeletedPasses();
if (config_.tensorrt_engine_enabled()) { if (config_.tensorrt_engine_enabled()) {
for (const auto &pass : kTrtLowerPrecisionPasses) { for (const auto &pass : kTrtLowerPrecisionPasses) {
passes.push_back(pass); if (deleted_passes.count(pass)) continue;
pass_builder->AppendPass(pass);
} }
} else if (config_.use_gpu()) { } else if (config_.use_gpu()) {
for (const auto &pass : kGpuLowerPrecisionPasses) { for (const auto &pass : kGpuLowerPrecisionPasses) {
passes.push_back(pass); if (deleted_passes.count(pass)) continue;
pass_builder->AppendPass(pass);
} }
} }
}
const auto &deleted_passes = config_.pass_builder()->GetAllDeletedPasses(); if (!config_.ir_optim()) {
for (const auto &it : deleted_passes) { argument_.SetEnableIrOptim(false);
auto iterator = std::find(passes.begin(), passes.end(), it); if (config_.enable_gpu_mixed_) {
if (iterator != passes.end()) { argument_.SetEnableIrOptim(true);
passes.erase(iterator); pass_builder->ClearPasses();
} pass_builder->AppendPass("auto_mixed_precision_pass");
LOG(INFO)
<< "This model run in Paddle-GPU mixed precision mode with no ir "
"optimization.";
} else {
LOG(INFO) << "ir_optim is turned off, no IR pass will be executed.";
} }
} else {
if (config_.ir_debug_) { if (config_.ir_debug_) {
auto it = std::begin(passes); pass_builder->TurnOnDebug();
while (it != std::end(passes)) { }
if (*it != "graph_viz_pass") { if (config_.enable_gpu_mixed_) {
it = passes.insert(it + 1, "graph_viz_pass"); LOG(INFO) << "This model run in Paddle-GPU mixed precision mode.";
} else {
++it;
}
}
} }
}
if (!config_.ir_optim()) {
passes.clear();
LOG(INFO) << "ir_optim is turned off, no IR pass will be executed";
} }
argument_.SetDisableLogs(config_.glog_info_disabled()); argument_.SetDisableLogs(config_.glog_info_disabled());
argument_.SetIrAnalysisPasses(passes); argument_.SetIrAnalysisPasses(pass_builder->AllPasses());
argument_.SetAnalysisPasses(config_.pass_builder()->AnalysisPasses()); argument_.SetAnalysisPasses(pass_builder->AnalysisPasses());
argument_.SetScopeNotOwned(scope_.get()); argument_.SetScopeNotOwned(scope_.get());
// mixed precison. // mixed precison.
argument_.SetModelPrecision(static_cast<int>(model_precision_)); argument_.SetModelPrecision(static_cast<int>(model_precision_));
argument_.SetMixedBlackList(config_.mixed_black_list_); argument_.SetMixedBlackList(config_.mixed_black_list_);
argument_.SetEnableGPUMixed(config_.enable_gpu_mixed_);
argument_.SetMixedPrecisionMode(static_cast<int>(
paddle::ConvertPrecision(config_.mixed_precision_mode_)));
} }
// NOTE All the members in AnalysisConfig should be copied to Argument. // NOTE All the members in AnalysisConfig should be copied to Argument.
...@@ -2107,7 +2111,9 @@ std::unique_ptr<PaddlePredictor> AnalysisPredictor::Clone(void *stream) { ...@@ -2107,7 +2111,9 @@ std::unique_ptr<PaddlePredictor> AnalysisPredictor::Clone(void *stream) {
} }
x->predictor_stream_ = stream; x->predictor_stream_ = stream;
x->Init(scope_, inference_program_); x->Init(scope_, inference_program_);
#ifdef PADDLE_WITH_TENSORRT
x->executor_->ResetTrtOps(++AnalysisPredictor::clone_num_); x->executor_->ResetTrtOps(++AnalysisPredictor::clone_num_);
#endif
return std::unique_ptr<PaddlePredictor>(x); return std::unique_ptr<PaddlePredictor>(x);
} }
......
...@@ -604,10 +604,8 @@ void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const { ...@@ -604,10 +604,8 @@ void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const {
if (predictor_.config_.ir_debug_) builder->TurnOnDebug(); if (predictor_.config_.ir_debug_) builder->TurnOnDebug();
auto passes = builder->AllPasses(); auto passes = builder->AllPasses();
predictor_.argument_.SetIrAnalysisPasses(passes); predictor_.argument_.SetIrAnalysisPasses(passes);
predictor_.argument_.SetAnalysisPasses({"ir_graph_clean_pass", predictor_.argument_.SetAnalysisPasses(
"ir_analysis_pass", {"ir_analysis_pass", "memory_optimize_pass", "ir_graph_to_program_pass"});
"memory_optimize_pass",
"ir_graph_to_program_pass"});
predictor_.argument_.SetQuantVarScales(scales_); predictor_.argument_.SetQuantVarScales(scales_);
} }
......
...@@ -247,8 +247,12 @@ struct PD_INFER_DECL AnalysisConfig { ...@@ -247,8 +247,12 @@ struct PD_INFER_DECL AnalysisConfig {
/// ///
/// \param memory_pool_init_size_mb initial size of the GPU memory pool in MB. /// \param memory_pool_init_size_mb initial size of the GPU memory pool in MB.
/// \param device_id device_id the GPU card to use (default is 0). /// \param device_id device_id the GPU card to use (default is 0).
/// \param precision the precision used in Paddle-GPU inference.
/// ///
void EnableUseGpu(uint64_t memory_pool_init_size_mb, int device_id = 0); void EnableUseGpu(uint64_t memory_pool_init_size_mb,
int device_id = 0,
Precision precision_mode = Precision::kFloat32);
/// ///
/// \brief Turn off GPU. /// \brief Turn off GPU.
/// ///
...@@ -967,7 +971,7 @@ struct PD_INFER_DECL AnalysisConfig { ...@@ -967,7 +971,7 @@ struct PD_INFER_DECL AnalysisConfig {
/// interface is in the experimental stage and may change in the future. Note /// interface is in the experimental stage and may change in the future. Note
/// that the blacklist must be the same as the model conversion blacklist. /// that the blacklist must be the same as the model conversion blacklist.
/// ///
void Exp_SetBlackListOpsForMixedModel( void Exp_DisableMixedPrecisionOps(
const std::unordered_set<std::string>& black_list); const std::unordered_set<std::string>& black_list);
void SetApplyOptim(bool value) { apply_optim_ = value; } void SetApplyOptim(bool value) { apply_optim_ = value; }
...@@ -987,13 +991,15 @@ struct PD_INFER_DECL AnalysisConfig { ...@@ -987,13 +991,15 @@ struct PD_INFER_DECL AnalysisConfig {
mutable std::string params_file_; mutable std::string params_file_;
mutable std::string calibration_file_path_; mutable std::string calibration_file_path_;
// Mixed precision. // Mixed precision related.
Precision mixed_precision_mode_{Precision::kFloat32};
std::unordered_set<std::string> mixed_black_list_; std::unordered_set<std::string> mixed_black_list_;
// GPU related. // GPU related.
bool use_gpu_{false}; bool use_gpu_{false};
int gpu_device_id_{0}; int gpu_device_id_{0};
uint64_t memory_pool_init_size_mb_{100}; // initial size is 100MB. uint64_t memory_pool_init_size_mb_{100}; // initial size is 100MB.
bool enable_gpu_mixed_{false};
bool thread_local_stream_{false}; bool thread_local_stream_{false};
bool use_cudnn_{false}; bool use_cudnn_{false};
......
...@@ -227,9 +227,10 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { ...@@ -227,9 +227,10 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"conv_elementwise_add_fuse_pass", // "conv_elementwise_add_fuse_pass", //
#endif // #endif //
"transpose_flatten_concat_fuse_pass", // "transpose_flatten_concat_fuse_pass", //
"constant_folding_pass", "constant_folding_pass", //
// following pass should be located in the last, since it will // following pass should be located in the last, since it will
// work on all fused ops. // work on all fused ops.
"auto_mixed_precision_pass", //
"runtime_context_cache_pass" "runtime_context_cache_pass"
}); });
......
...@@ -115,7 +115,6 @@ class PD_INFER_DECL PaddlePassBuilder { ...@@ -115,7 +115,6 @@ class PD_INFER_DECL PaddlePassBuilder {
/// \cond Protected /// \cond Protected
std::vector<std::string> analysis_passes_{ std::vector<std::string> analysis_passes_{
{"ir_graph_build_pass", {"ir_graph_build_pass",
"ir_graph_clean_pass",
"ir_analysis_pass", "ir_analysis_pass",
"ir_params_sync_among_devices_pass", "ir_params_sync_among_devices_pass",
"adjust_cudnn_workspace_size_pass", "adjust_cudnn_workspace_size_pass",
......
...@@ -294,15 +294,6 @@ class TensorRTEngine { ...@@ -294,15 +294,6 @@ class TensorRTEngine {
nvinfer1::ICudaEngine* engine() { return infer_engine_.get(); } nvinfer1::ICudaEngine* engine() { return infer_engine_.get(); }
nvinfer1::IExecutionContext* context() { nvinfer1::IExecutionContext* context() {
#ifndef PADDLE_WITH_TESTING
PADDLE_ENFORCE_GT(
predictor_id_per_thread,
-1,
platform::errors::InvalidArgument(
"thread local var predictor_id_per_thread must be "
"initialized to >= 0, but now predictor_id_per_thread = %d",
predictor_id_per_thread));
#endif
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
if (infer_context_.find(predictor_id_per_thread) == infer_context_.end()) { if (infer_context_.find(predictor_id_per_thread) == infer_context_.end()) {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
...@@ -329,15 +320,6 @@ class TensorRTEngine { ...@@ -329,15 +320,6 @@ class TensorRTEngine {
int GetProfileIndex() { int GetProfileIndex() {
if (max_profile_num_ > 1) { if (max_profile_num_ > 1) {
#ifndef PADDLE_WITH_TESTING
PADDLE_ENFORCE_GT(
predictor_id_per_thread,
-1,
platform::errors::InvalidArgument(
"thread local var predictor_id_per_thread must be "
"initialized to >= 0, but now predictor_id_per_thread = %d",
predictor_id_per_thread));
#endif
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
return profile_index_[predictor_id_per_thread]; return profile_index_[predictor_id_per_thread];
} else { } else {
...@@ -356,15 +338,6 @@ class TensorRTEngine { ...@@ -356,15 +338,6 @@ class TensorRTEngine {
infer_engine_, infer_engine_,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"You should build engine first and then set the context.")); "You should build engine first and then set the context."));
#ifndef PADDLE_WITH_TESTING
PADDLE_ENFORCE_GT(
predictor_id_per_thread,
-1,
platform::errors::InvalidArgument(
"thread local var predictor_id_per_thread must be "
"initialized to >= 0, but now predictor_id_per_thread = %d",
predictor_id_per_thread));
#endif
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
infer_context_[predictor_id_per_thread].reset(nullptr); infer_context_[predictor_id_per_thread].reset(nullptr);
infer_context_.erase(predictor_id_per_thread); infer_context_.erase(predictor_id_per_thread);
......
...@@ -416,6 +416,9 @@ download_result(${ERNIE_INSTALL_DIR} "Ernie_result.txt.tar.gz" ...@@ -416,6 +416,9 @@ download_result(${ERNIE_INSTALL_DIR} "Ernie_result.txt.tar.gz"
if(WITH_GPU) if(WITH_GPU)
inference_analysis_api_test(test_analyzer_ernie ${ERNIE_INSTALL_DIR} inference_analysis_api_test(test_analyzer_ernie ${ERNIE_INSTALL_DIR}
analyzer_ernie_tester.cc) analyzer_ernie_tester.cc)
inference_analysis_api_test(gpu_ernie_half_test ${ERNIE_INSTALL_DIR}
gpu_ernie_half_test.cc)
set_tests_properties(gpu_ernie_half_test PROPERTIES TIMEOUT 60)
endif() endif()
inference_analysis_api_int8_test(test_analyzer_ernie_int8 ${ERNIE_INSTALL_DIR} inference_analysis_api_int8_test(test_analyzer_ernie_int8 ${ERNIE_INSTALL_DIR}
analyzer_ernie_int8_tester.cc) analyzer_ernie_int8_tester.cc)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/tests/api/tester_helper.h"
namespace paddle {
namespace inference {
using paddle::PaddleTensor;
template <typename T>
void GetValueFromStream(std::stringstream *ss, T *t) {
(*ss) >> (*t);
}
template <>
void GetValueFromStream<std::string>(std::stringstream *ss, std::string *t) {
*t = ss->str();
}
// Split string to vector
template <typename T>
void Split(const std::string &line, char sep, std::vector<T> *v) {
std::stringstream ss;
T t;
for (auto c : line) {
if (c != sep) {
ss << c;
} else {
GetValueFromStream<T>(&ss, &t);
v->push_back(std::move(t));
ss.str({});
ss.clear();
}
}
if (!ss.str().empty()) {
GetValueFromStream<T>(&ss, &t);
v->push_back(std::move(t));
ss.str({});
ss.clear();
}
}
// Parse tensor from string
template <typename T>
bool ParseTensor(const std::string &field, paddle::PaddleTensor *tensor) {
std::vector<std::string> data;
Split(field, ':', &data);
if (data.size() < 2) return false;
std::string shape_str = data[0];
std::vector<int> shape;
Split(shape_str, ' ', &shape);
std::string mat_str = data[1];
std::vector<T> mat;
Split(mat_str, ' ', &mat);
tensor->shape = shape;
auto size =
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()) *
sizeof(T);
tensor->data.Resize(size);
std::copy(mat.begin(), mat.end(), static_cast<T *>(tensor->data.data()));
tensor->dtype = GetPaddleDType<T>();
return true;
}
// Parse input tensors from string
bool ParseLine(const std::string &line,
std::vector<paddle::PaddleTensor> *tensors) {
std::vector<std::string> fields;
Split(line, ';', &fields);
tensors->clear();
tensors->reserve(4);
int i = 0;
auto input_name = FLAGS_ernie_large ? "eval_placeholder_" : "placeholder_";
for (; i < 3; i++) {
paddle::PaddleTensor temp;
ParseTensor<int64_t>(fields[i], &temp);
temp.name = input_name + std::to_string(i);
tensors->push_back(temp);
}
// input_mask
paddle::PaddleTensor input_mask;
ParseTensor<float>(fields[i], &input_mask);
input_mask.name = input_name + std::to_string(i);
tensors->push_back(input_mask);
return true;
}
bool LoadInputData(std::vector<std::vector<paddle::PaddleTensor>> *inputs,
int batch_size = 1) {
if (FLAGS_infer_data.empty()) {
LOG(ERROR) << "please set input data path";
return false;
}
std::ifstream fin(FLAGS_infer_data);
std::string line;
int sample = 0;
// The unit-test dataset only have 10 samples, each sample have 5 feeds.
while (std::getline(fin, line)) {
std::vector<paddle::PaddleTensor> feed_data;
ParseLine(line, &feed_data);
inputs->push_back(std::move(feed_data));
sample++;
if (!FLAGS_test_all_data && sample == batch_size) break;
}
LOG(INFO) << "number of samples: " << sample;
return true;
}
// Compare results
TEST(Ernie_gpu_fp16_no_ir, compare_results) {
AnalysisConfig config;
config.SetModel(FLAGS_infer_model);
config.EnableUseGpu(512, 0, paddle_infer::PrecisionType::kHalf);
config.SwitchIrOptim(false);
auto predictor = CreatePaddlePredictor(config);
std::vector<std::vector<PaddleTensor>> input_slots_all;
LoadInputData(&input_slots_all);
std::ifstream fin(FLAGS_refer_result);
std::string line;
std::vector<float> ref;
while (std::getline(fin, line)) {
Split(line, ' ', &ref);
}
std::vector<PaddleTensor> outputs;
for (size_t i = 0; i < input_slots_all.size(); i++) {
outputs.clear();
predictor->Run(input_slots_all[i], &outputs);
auto output = outputs.front();
size_t outputs_size = 1;
for (auto dim : output.shape) {
outputs_size *= dim;
}
float *result = reinterpret_cast<float *>(output.data.data());
for (size_t j = 0; j < outputs_size; ++j) {
EXPECT_NEAR(ref[i * outputs_size + j], result[j], 5e-2);
}
}
}
// Compare results
TEST(Ernie_gpu_fp16_with_ir, compare_results) {
AnalysisConfig config;
config.SetModel(FLAGS_infer_model);
config.EnableUseGpu(512, 0, paddle_infer::PrecisionType::kHalf);
config.SwitchIrOptim(true);
// The fc_fuse_pass has diff, which will be repaired later.
config.pass_builder()->DeletePass("fc_fuse_pass");
// There is a problem with the model itself, which has nothing to do with
// constant_folding_pass.
config.pass_builder()->DeletePass("constant_folding_pass");
auto predictor = CreatePaddlePredictor(config);
std::vector<std::vector<PaddleTensor>> input_slots_all;
LoadInputData(&input_slots_all);
std::ifstream fin(FLAGS_refer_result);
std::string line;
std::vector<float> ref;
while (std::getline(fin, line)) {
Split(line, ' ', &ref);
}
std::vector<PaddleTensor> outputs;
for (size_t i = 0; i < input_slots_all.size(); i++) {
outputs.clear();
predictor->Run(input_slots_all[i], &outputs);
auto output = outputs.front();
size_t outputs_size = 1;
for (auto dim : output.shape) {
outputs_size *= dim;
}
float *result = reinterpret_cast<float *>(output.data.data());
for (size_t j = 0; j < outputs_size; ++j) {
EXPECT_NEAR(ref[i * outputs_size + j], result[j], 5e-2);
}
}
}
// Compare results
TEST(Ernie_gpu_bf16_no_ir, compare_results) {
AnalysisConfig config;
config.SetModel(FLAGS_infer_model);
config.EnableUseGpu(512, 0, paddle_infer::PrecisionType::kBf16);
config.SwitchIrOptim(false);
auto predictor = CreatePaddlePredictor(config);
std::vector<std::vector<PaddleTensor>> input_slots_all;
LoadInputData(&input_slots_all);
std::ifstream fin(FLAGS_refer_result);
std::string line;
std::vector<float> ref;
while (std::getline(fin, line)) {
Split(line, ' ', &ref);
}
std::vector<PaddleTensor> outputs;
for (size_t i = 0; i < input_slots_all.size(); i++) {
outputs.clear();
predictor->Run(input_slots_all[i], &outputs);
auto output = outputs.front();
size_t outputs_size = 1;
for (auto dim : output.shape) {
outputs_size *= dim;
}
float *result = reinterpret_cast<float *>(output.data.data());
for (size_t j = 0; j < outputs_size; ++j) {
EXPECT_NEAR(ref[i * outputs_size + j], result[j], 7e-2);
}
}
}
// Compare results
TEST(Ernie_gpu_bf16_with_ir, compare_results) {
AnalysisConfig config;
config.SetModel(FLAGS_infer_model);
config.EnableUseGpu(512, 0, paddle_infer::PrecisionType::kBf16);
config.SwitchIrOptim(true);
// The fc_fuse_pass has diff, which will be repaired later.
config.pass_builder()->DeletePass("fc_fuse_pass");
// There is a problem with the model itself, which has nothing to do with
// constant_folding_pass.
config.pass_builder()->DeletePass("constant_folding_pass");
auto predictor = CreatePaddlePredictor(config);
std::vector<std::vector<PaddleTensor>> input_slots_all;
LoadInputData(&input_slots_all);
std::ifstream fin(FLAGS_refer_result);
std::string line;
std::vector<float> ref;
while (std::getline(fin, line)) {
Split(line, ' ', &ref);
}
std::vector<PaddleTensor> outputs;
for (size_t i = 0; i < input_slots_all.size(); i++) {
outputs.clear();
predictor->Run(input_slots_all[i], &outputs);
auto output = outputs.front();
size_t outputs_size = 1;
for (auto dim : output.shape) {
outputs_size *= dim;
}
float *result = reinterpret_cast<float *>(output.data.data());
for (size_t j = 0; j < outputs_size; ++j) {
EXPECT_NEAR(ref[i * outputs_size + j], result[j], 7e-2);
}
}
}
} // namespace inference
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -12,15 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,15 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <cuda_runtime.h>
#include <glog/logging.h> #include <glog/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <cstring>
#include <numeric>
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "paddle/fluid/inference/tests/api/trt_test_helper.h" #include "paddle/fluid/inference/tests/api/tester_helper.h"
namespace paddle_infer { namespace paddle_infer {
......
...@@ -262,10 +262,6 @@ if(WITH_PYTHON) ...@@ -262,10 +262,6 @@ if(WITH_PYTHON)
list(APPEND OP_FUNCTION_GENERETOR_DEPS cncl_context) list(APPEND OP_FUNCTION_GENERETOR_DEPS cncl_context)
endif() endif()
if(NOT ((NOT WITH_PYTHON) AND ON_INFER))
list(APPEND OP_FUNCTION_GENERETOR_DEPS ${PYTHON_LIBRARIES})
endif()
add_executable(op_function_generator op_function_generator.cc) add_executable(op_function_generator op_function_generator.cc)
target_link_libraries(op_function_generator ${OP_FUNCTION_GENERETOR_DEPS}) target_link_libraries(op_function_generator ${OP_FUNCTION_GENERETOR_DEPS})
add_executable(eager_legacy_op_function_generator add_executable(eager_legacy_op_function_generator
...@@ -605,13 +601,4 @@ if(WITH_PYTHON) ...@@ -605,13 +601,4 @@ if(WITH_PYTHON)
get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES) get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES)
target_link_libraries(${SHARD_LIB_NAME} ${os_dependency_modules}) target_link_libraries(${SHARD_LIB_NAME} ${os_dependency_modules})
add_dependencies(${SHARD_LIB_NAME} op_function_generator_cmd) add_dependencies(${SHARD_LIB_NAME} op_function_generator_cmd)
if(APPLE)
string(REGEX REPLACE ".+/(.+)" "\\1" PYTHON_LIBRARY_NAME
${PYTHON_LIBRARIES})
# target_link_libraries(${SHARD_LIB_NAME} "-Wl,-rpath,${PYTHON_LIBRARY_NAME}")
else()
target_link_libraries(${SHARD_LIB_NAME} ${PYTHON_LIBRARIES})
endif()
endif() endif()
...@@ -642,7 +642,8 @@ void BindAnalysisConfig(py::module *m) { ...@@ -642,7 +642,8 @@ void BindAnalysisConfig(py::module *m) {
.def("enable_use_gpu", .def("enable_use_gpu",
&AnalysisConfig::EnableUseGpu, &AnalysisConfig::EnableUseGpu,
py::arg("memory_pool_init_size_mb"), py::arg("memory_pool_init_size_mb"),
py::arg("device_id") = 0) py::arg("device_id") = 0,
py::arg("precision_mode") = AnalysisConfig::Precision::kFloat32)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
.def("set_exec_stream", .def("set_exec_stream",
[](AnalysisConfig &self, phi::CUDAStream &stream) { [](AnalysisConfig &self, phi::CUDAStream &stream) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册