未验证 提交 ecff21e7 编写于 作者: Y Yuanle Liu 提交者: GitHub

[Inference] auto mixed precision inference support white list (#56535)

* auto mixed precision inference support white list

* update

* update

* update

* move down identity_op_clean_pass

* fix code style
上级 5f9d6d68
...@@ -165,7 +165,9 @@ void DoInsertCastOp(Graph* graph, ...@@ -165,7 +165,9 @@ void DoInsertCastOp(Graph* graph,
bool OpSupportPrecision(const std::string& op_type, bool OpSupportPrecision(const std::string& op_type,
phi::Backend backend, phi::Backend backend,
phi::DataType precision, phi::DataType precision,
const std::unordered_set<std::string>& black_list) { const std::unordered_set<std::string>& black_list,
const std::unordered_set<std::string>& white_list) {
if (white_list.count(op_type)) return true;
return black_list.count(op_type) == 0 && return black_list.count(op_type) == 0 &&
KernelSupportPrecision(op_type, backend, precision); KernelSupportPrecision(op_type, backend, precision);
} }
...@@ -230,11 +232,16 @@ void AutoMixedPrecisionPass::Init(Graph* graph) const { ...@@ -230,11 +232,16 @@ void AutoMixedPrecisionPass::Init(Graph* graph) const {
if (skip_pass_) return; if (skip_pass_) return;
black_list_ = Get<std::unordered_set<std::string>>("mixed_black_list"); black_list_ = Get<std::unordered_set<std::string>>("mixed_black_list");
white_list_ = Get<std::unordered_set<std::string>>("mixed_white_list");
SetDefaultBlacklist(); SetDefaultBlacklist();
VLOG(4) << "black_list has "; VLOG(4) << "black_list has ";
for (const auto& name : black_list_) { for (const auto& name : black_list_) {
VLOG(4) << " - " << name; VLOG(4) << " - " << name;
} }
VLOG(4) << "white_list has ";
for (const auto& name : white_list_) {
VLOG(4) << " - " << name;
}
if (Has("enable_low_precision_io")) { if (Has("enable_low_precision_io")) {
enable_low_precision_io_ = Get<bool>("enable_low_precision_io"); enable_low_precision_io_ = Get<bool>("enable_low_precision_io");
...@@ -403,8 +410,11 @@ void AutoMixedPrecisionPass::GetOpPrecision() const { ...@@ -403,8 +410,11 @@ void AutoMixedPrecisionPass::GetOpPrecision() const {
op_node->Op()->GetAttrIfExists<bool>("enable_low_precision_io"); op_node->Op()->GetAttrIfExists<bool>("enable_low_precision_io");
support_low_precision = enable_fp16 && !enable_int8 && low_precision_io; support_low_precision = enable_fp16 && !enable_int8 && low_precision_io;
} else { } else {
support_low_precision = OpSupportPrecision( support_low_precision = OpSupportPrecision(GetOpOriginalType(op_type),
GetOpOriginalType(op_type), backend_, low_precision_, black_list_); backend_,
low_precision_,
black_list_,
white_list_);
std::unordered_set<std::string> check_dtype_op_blacklist( std::unordered_set<std::string> check_dtype_op_blacklist(
{"arg_max", "arg_min"}); {"arg_max", "arg_min"});
...@@ -422,8 +432,8 @@ void AutoMixedPrecisionPass::GetOpPrecision() const { ...@@ -422,8 +432,8 @@ void AutoMixedPrecisionPass::GetOpPrecision() const {
out_dtype == -1); out_dtype == -1);
} }
// If scale op's "scale" and "bias" attr value exceed the range of fp16 // If scale op's "scale" and "bias" attr value exceed the range of
// and bf16, it cannot run at low precision. // fp16 and bf16, it cannot run at low precision.
if (GetOpOriginalType(op_node->Op()->Type()) == "scale") { if (GetOpOriginalType(op_node->Op()->Type()) == "scale") {
auto scale = op_node->Op()->GetAttrIfExists<float>("scale"); auto scale = op_node->Op()->GetAttrIfExists<float>("scale");
auto bias = op_node->Op()->GetAttrIfExists<float>("bias"); auto bias = op_node->Op()->GetAttrIfExists<float>("bias");
...@@ -500,9 +510,9 @@ void AutoMixedPrecisionPass::UpdateOpPrecision() const { ...@@ -500,9 +510,9 @@ void AutoMixedPrecisionPass::UpdateOpPrecision() const {
<< " is output of " << op_type; << " is output of " << op_type;
} }
// the select_input op's input var should not convert to low precision. // the select_input op's input var should not convert to low
// when op's output var is select_input op's input var, the op should // precision. when op's output var is select_input op's input var, the
// not run at low precision. // op should not run at low precision.
if (GetOpOriginalType(op_node->Op()->Type()) == "select_input") { if (GetOpOriginalType(op_node->Op()->Type()) == "select_input") {
for (auto* in_var_node : op_node->inputs) { for (auto* in_var_node : op_node->inputs) {
CHECK_EQ(in_var_node->IsVar(), true); CHECK_EQ(in_var_node->IsVar(), true);
...@@ -517,6 +527,7 @@ void AutoMixedPrecisionPass::UpdateOpPrecision() const { ...@@ -517,6 +527,7 @@ void AutoMixedPrecisionPass::UpdateOpPrecision() const {
// output var, then op_2 should not run at low precision. // output var, then op_2 should not run at low precision.
if (GetOpOriginalType(op_type) != "feed" && if (GetOpOriginalType(op_type) != "feed" &&
GetOpOriginalType(op_type) != "tensorrt_engine" && GetOpOriginalType(op_type) != "tensorrt_engine" &&
white_list_.count(GetOpOriginalType(op_type)) == 0 &&
!KernelSupportPrecision( !KernelSupportPrecision(
GetOpOriginalType(op_type), backend_, phi::DataType::FLOAT32)) { GetOpOriginalType(op_type), backend_, phi::DataType::FLOAT32)) {
for (auto* out_var_node : op_node->outputs) { for (auto* out_var_node : op_node->outputs) {
......
...@@ -75,6 +75,7 @@ class AutoMixedPrecisionPass : public FusePassBase { ...@@ -75,6 +75,7 @@ class AutoMixedPrecisionPass : public FusePassBase {
mutable phi::Backend backend_{phi::Backend::UNDEFINED}; mutable phi::Backend backend_{phi::Backend::UNDEFINED};
mutable std::unordered_set<std::string> black_list_; mutable std::unordered_set<std::string> black_list_;
mutable std::unordered_set<std::string> white_list_;
// subgraph id -> pointer to subgraph // subgraph id -> pointer to subgraph
mutable std::vector<Graph*> subgraphes_; mutable std::vector<Graph*> subgraphes_;
...@@ -93,7 +94,8 @@ class AutoMixedPrecisionPass : public FusePassBase { ...@@ -93,7 +94,8 @@ class AutoMixedPrecisionPass : public FusePassBase {
bool OpSupportPrecision(const std::string& op_type, bool OpSupportPrecision(const std::string& op_type,
phi::Backend backend, phi::Backend backend,
phi::DataType precision, phi::DataType precision,
const std::unordered_set<std::string>& black_list); const std::unordered_set<std::string>& black_list,
const std::unordered_set<std::string>& white_list);
void DoInsertCastOp(Graph* graph, void DoInsertCastOp(Graph* graph,
Node* var_node, Node* var_node,
......
...@@ -89,11 +89,52 @@ FindUselessOpPattern::FindUselessOpPattern(PDPattern* pattern, ...@@ -89,11 +89,52 @@ FindUselessOpPattern::FindUselessOpPattern(PDPattern* pattern,
useless_op->LinksFrom({useless_op_in}).LinksTo({useless_op_out}); useless_op->LinksFrom({useless_op_in}).LinksTo({useless_op_out});
} }
} // namespace patterns // pre_op -> pre_op_out -> cast_op_1 -> cast_op_1_out -> cast_op_2 ->
// cast_op_2_out
// ->
// pre_op -> cast_op_2_out
struct FindTwoCastOpPattern : public PatternBase {
FindTwoCastOpPattern(PDPattern* pattern, const std::string& name_scope);
void IdentityOpCleanPass::ApplyImpl(ir::Graph* graph) const { // declare operator node's name
Init(name_scope_, graph); PATTERN_DECL_NODE(pre_op_out);
PATTERN_DECL_NODE(cast_op_1);
PATTERN_DECL_NODE(cast_op_1_out);
PATTERN_DECL_NODE(cast_op_2);
PATTERN_DECL_NODE(cast_op_2_out);
};
FindTwoCastOpPattern::FindTwoCastOpPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* pre_op_out = pattern->NewNode(pre_op_out_repr())
->assert_is_var()
->assert_var_not_persistable()
->assert_has_n_outputs(1)
->assert_more([](Node* x) {
for (auto* op : x->inputs) {
CHECK_EQ(op->IsOp(), true);
const auto& op_type = op->Op()->Type();
if (op_type == "conditional_block" ||
op_type == "while" || op_type == "feed") {
return false;
}
}
return true;
});
auto* cast_op_1 = pattern->NewNode(cast_op_1_repr())->assert_is_op("cast");
auto* cast_op_1_out = pattern->NewNode(cast_op_1_out_repr())->assert_is_var();
auto* cast_op_2 = pattern->NewNode(cast_op_2_repr())->assert_is_op("cast");
auto* cast_op_2_out = pattern->NewNode(cast_op_2_out_repr())->assert_is_var();
cast_op_1->LinksFrom({pre_op_out}).LinksTo({cast_op_1_out});
cast_op_2->LinksFrom({cast_op_1_out}).LinksTo({cast_op_2_out});
}
} // namespace patterns
int IdentityOpCleanPass::CleanUselessOp(ir::Graph* graph) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
patterns::FindUselessOpPattern pattern(gpd.mutable_pattern(), name_scope_); patterns::FindUselessOpPattern pattern(gpd.mutable_pattern(), name_scope_);
...@@ -119,6 +160,48 @@ void IdentityOpCleanPass::ApplyImpl(ir::Graph* graph) const { ...@@ -119,6 +160,48 @@ void IdentityOpCleanPass::ApplyImpl(ir::Graph* graph) const {
}; };
gpd(graph, handler); gpd(graph, handler);
return found_count;
}
int IdentityOpCleanPass::CleanTwoCastOp(ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::FindTwoCastOpPattern pattern(gpd.mutable_pattern(), name_scope_);
int found_count = 0;
GraphPatternDetector::handle_t handler =
[&](const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) {
GET_IR_NODE_FROM_SUBGRAPH(pre_op_out, pre_op_out, pattern);
GET_IR_NODE_FROM_SUBGRAPH(cast_op_1, cast_op_1, pattern);
GET_IR_NODE_FROM_SUBGRAPH(cast_op_1_out, cast_op_1_out, pattern);
GET_IR_NODE_FROM_SUBGRAPH(cast_op_2, cast_op_2, pattern);
GET_IR_NODE_FROM_SUBGRAPH(cast_op_2_out, cast_op_2_out, pattern);
CHECK_EQ(pre_op_out->IsVar(), true);
CHECK_EQ(cast_op_1_out->IsVar(), true);
CHECK_EQ(cast_op_2_out->IsVar(), true);
CHECK_EQ(cast_op_1->IsOp(), true);
CHECK_EQ(cast_op_2->IsOp(), true);
if (pre_op_out->Var()->GetDataType() ==
cast_op_2_out->Var()->GetDataType()) {
for (auto* prev_op : pre_op_out->inputs) {
CHECK_EQ(prev_op->IsOp(), true);
prev_op->Op()->RenameOutput(pre_op_out->Var()->Name(),
cast_op_2_out->Var()->Name());
IR_NODE_LINK_TO(prev_op, cast_op_2_out);
}
GraphSafeRemoveNodes(
graph, {pre_op_out, cast_op_1, cast_op_1_out, cast_op_2});
found_count++;
}
};
gpd(graph, handler);
return found_count;
}
void IdentityOpCleanPass::ApplyImpl(ir::Graph* graph) const {
Init(name_scope_, graph);
int found_count = CleanUselessOp(graph) + CleanTwoCastOp(graph);
AddStatis(found_count); AddStatis(found_count);
} }
......
...@@ -27,6 +27,10 @@ class IdentityOpCleanPass : public FusePassBase { ...@@ -27,6 +27,10 @@ class IdentityOpCleanPass : public FusePassBase {
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(ir::Graph* graph) const override;
private: private:
int CleanUselessOp(ir::Graph* graph) const;
int CleanTwoCastOp(ir::Graph* graph) const;
const std::string name_scope_{"identity_op_clean_pass"}; const std::string name_scope_{"identity_op_clean_pass"};
}; };
......
...@@ -419,6 +419,9 @@ struct Argument { ...@@ -419,6 +419,9 @@ struct Argument {
DECL_ARGUMENT_FIELD(mixed_black_list, DECL_ARGUMENT_FIELD(mixed_black_list,
MixedBlackList, MixedBlackList,
std::unordered_set<std::string>); std::unordered_set<std::string>);
DECL_ARGUMENT_FIELD(mixed_white_list,
MixedWhiteList,
std::unordered_set<std::string>);
DECL_ARGUMENT_FIELD(enable_gpu_mixed, EnableGPUMixed, bool); DECL_ARGUMENT_FIELD(enable_gpu_mixed, EnableGPUMixed, bool);
DECL_ARGUMENT_FIELD(mixed_precision_mode, MixedPrecisionMode, int); DECL_ARGUMENT_FIELD(mixed_precision_mode, MixedPrecisionMode, int);
DECL_ARGUMENT_FIELD(enable_low_precision_io, EnableLowPrecisionIO, bool); DECL_ARGUMENT_FIELD(enable_low_precision_io, EnableLowPrecisionIO, bool);
......
...@@ -100,6 +100,9 @@ void IRPassManager::CreatePasses(Argument *argument, ...@@ -100,6 +100,9 @@ void IRPassManager::CreatePasses(Argument *argument,
pass->Set( pass->Set(
"mixed_black_list", "mixed_black_list",
new std::unordered_set<std::string>(argument->mixed_black_list())); new std::unordered_set<std::string>(argument->mixed_black_list()));
pass->Set(
"mixed_white_list",
new std::unordered_set<std::string>(argument->mixed_white_list()));
pass->Set("enable_gpu_mixed", new bool(argument->enable_gpu_mixed())); pass->Set("enable_gpu_mixed", new bool(argument->enable_gpu_mixed()));
pass->Set("enable_custom_device_mixed", pass->Set("enable_custom_device_mixed",
new bool(argument->enable_custom_device_mixed())); new bool(argument->enable_custom_device_mixed()));
......
...@@ -50,7 +50,8 @@ void OutputProcess(framework::ir::Graph *graph, ...@@ -50,7 +50,8 @@ void OutputProcess(framework::ir::Graph *graph,
const std::unordered_set<framework::ir::Node *> &trt_outputs, const std::unordered_set<framework::ir::Node *> &trt_outputs,
phi::Backend backend, phi::Backend backend,
phi::DataType precision, phi::DataType precision,
const std::unordered_set<std::string> &blacklist) { const std::unordered_set<std::string> &blacklist,
const std::unordered_set<std::string> &whitelist) {
framework::BlockDesc *block_desc{nullptr}; framework::BlockDesc *block_desc{nullptr};
int suffix = 0; int suffix = 0;
std::unordered_map<framework::ir::Node *, framework::ir::Node *> std::unordered_map<framework::ir::Node *, framework::ir::Node *>
...@@ -86,7 +87,8 @@ void OutputProcess(framework::ir::Graph *graph, ...@@ -86,7 +87,8 @@ void OutputProcess(framework::ir::Graph *graph,
phi::TransToPhiKernelName(next_op->Op()->Type()), phi::TransToPhiKernelName(next_op->Op()->Type()),
backend, backend,
precision, precision,
blacklist)) { blacklist,
whitelist)) {
InsertCastOp(graph, InsertCastOp(graph,
var_node, var_node,
next_op, next_op,
...@@ -363,6 +365,8 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -363,6 +365,8 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp(
static_cast<phi::DataType>(Get<int>("model_precision")); static_cast<phi::DataType>(Get<int>("model_precision"));
auto mixed_black_list = auto mixed_black_list =
Get<std::unordered_set<std::string>>("mixed_black_list"); Get<std::unordered_set<std::string>>("mixed_black_list");
auto mixed_white_list =
Get<std::unordered_set<std::string>>("mixed_white_list");
std::set<std::string> output_names; std::set<std::string> output_names;
std::set<std::string> output_names_with_id; std::set<std::string> output_names_with_id;
...@@ -414,8 +418,12 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -414,8 +418,12 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp(
static_cast<int>(x->Var()->GetDataType()); static_cast<int>(x->Var()->GetDataType());
} }
OutputProcess( OutputProcess(graph,
graph, trt_outputs, phi::Backend::GPU, model_precision, mixed_black_list); trt_outputs,
phi::Backend::GPU,
model_precision,
mixed_black_list,
mixed_white_list);
std::unordered_map<std::string, std::string> output_name_map; std::unordered_map<std::string, std::string> output_name_map;
std::unordered_map<std::string, framework::ir::Node *> graph_var_map; std::unordered_map<std::string, framework::ir::Node *> graph_var_map;
......
...@@ -14,7 +14,7 @@ cc_library( ...@@ -14,7 +14,7 @@ cc_library(
convert_to_mixed_precision convert_to_mixed_precision
SRCS convert_to_mixed_precision.cc SRCS convert_to_mixed_precision.cc
DEPS analysis_pass ir_graph_build_pass auto_mixed_precision_pass DEPS analysis_pass ir_graph_build_pass auto_mixed_precision_pass
constant_folding_pass) constant_folding_pass identity_op_clean_pass)
cc_library( cc_library(
ir_params_sync_among_devices_pass ir_params_sync_among_devices_pass
SRCS ir_params_sync_among_devices_pass.cc SRCS ir_params_sync_among_devices_pass.cc
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "paddle/fluid/framework/ir/auto_mixed_precision_pass.h" #include "paddle/fluid/framework/ir/auto_mixed_precision_pass.h"
#include "paddle/fluid/framework/ir/constant_folding_pass.h" #include "paddle/fluid/framework/ir/constant_folding_pass.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/identity_op_clean_pass.h"
#include "paddle/fluid/inference/io.h" #include "paddle/fluid/inference/io.h"
#include "paddle/phi/common/backend.h" #include "paddle/phi/common/backend.h"
...@@ -33,7 +34,8 @@ ConvertToMixedPrecisionPass::ConvertToMixedPrecisionPass( ...@@ -33,7 +34,8 @@ ConvertToMixedPrecisionPass::ConvertToMixedPrecisionPass(
phi::DataType mixed_precision, phi::DataType mixed_precision,
phi::Backend backend, phi::Backend backend,
bool keep_io_types, bool keep_io_types,
const std::unordered_set<std::string>& black_list) const std::unordered_set<std::string>& black_list,
const std::unordered_set<std::string>& white_list)
: model_file_(model_file), : model_file_(model_file),
params_file_(params_file), params_file_(params_file),
mixed_model_file_(mixed_model_file), mixed_model_file_(mixed_model_file),
...@@ -41,7 +43,8 @@ ConvertToMixedPrecisionPass::ConvertToMixedPrecisionPass( ...@@ -41,7 +43,8 @@ ConvertToMixedPrecisionPass::ConvertToMixedPrecisionPass(
mixed_precision_(mixed_precision), mixed_precision_(mixed_precision),
backend_(backend), backend_(backend),
keep_io_types_(keep_io_types), keep_io_types_(keep_io_types),
black_list_(black_list) { black_list_(black_list),
white_list_(white_list) {
switch (backend_) { switch (backend_) {
case phi::Backend::GPU: case phi::Backend::GPU:
PADDLE_ENFORCE(mixed_precision_ == phi::DataType::FLOAT16 || PADDLE_ENFORCE(mixed_precision_ == phi::DataType::FLOAT16 ||
...@@ -88,19 +91,27 @@ void ConvertToMixedPrecisionPass::Run() { ...@@ -88,19 +91,27 @@ void ConvertToMixedPrecisionPass::Run() {
framework::ir::ConstantFoldingPass constant_folding_pass; framework::ir::ConstantFoldingPass constant_folding_pass;
constant_folding_pass.Apply(main_graph_.get()); constant_folding_pass.Apply(main_graph_.get());
framework::ir::AutoMixedPrecisionPass pass;
pass.Set("mixed_precision_mode", new int{static_cast<int>(mixed_precision_)}); framework::ir::AutoMixedPrecisionPass auto_mixed_precision_pass;
auto_mixed_precision_pass.Set("mixed_precision_mode",
new int{static_cast<int>(mixed_precision_)});
if (backend_ == phi::Backend::GPU) { if (backend_ == phi::Backend::GPU) {
pass.Set("enable_gpu_mixed", new bool{true}); auto_mixed_precision_pass.Set("enable_gpu_mixed", new bool{true});
} else if (backend_ == phi::Backend::XPU) { } else if (backend_ == phi::Backend::XPU) {
pass.Set("enable_xpu_mixed", new bool{true}); auto_mixed_precision_pass.Set("enable_xpu_mixed", new bool{true});
} else if (backend_ == phi::Backend::CUSTOM) { } else if (backend_ == phi::Backend::CUSTOM) {
pass.Set("enable_custom_device_mixed", new bool{true}); auto_mixed_precision_pass.Set("enable_custom_device_mixed", new bool{true});
} }
pass.Set("mixed_black_list", auto_mixed_precision_pass.Set(
new std::unordered_set<std::string>{black_list_}); "mixed_black_list", new std::unordered_set<std::string>{black_list_});
pass.Set("enable_low_precision_io", new bool{!keep_io_types_}); auto_mixed_precision_pass.Set(
pass.Apply(main_graph_.get()); "mixed_white_list", new std::unordered_set<std::string>{white_list_});
auto_mixed_precision_pass.Set("enable_low_precision_io",
new bool{!keep_io_types_});
auto_mixed_precision_pass.Apply(main_graph_.get());
framework::ir::IdentityOpCleanPass identity_op_clean_pass;
identity_op_clean_pass.Apply(main_graph_.get());
SaveMixedModel(); SaveMixedModel();
} }
...@@ -184,9 +195,10 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() { ...@@ -184,9 +195,10 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() {
bool OpSupportPrecision(const std::string& op_type, bool OpSupportPrecision(const std::string& op_type,
phi::Backend backend, phi::Backend backend,
phi::DataType precision, phi::DataType precision,
const std::unordered_set<std::string>& black_list) { const std::unordered_set<std::string>& black_list,
const std::unordered_set<std::string>& white_list) {
return framework::ir::OpSupportPrecision( return framework::ir::OpSupportPrecision(
op_type, backend, precision, black_list); op_type, backend, precision, black_list, white_list);
} }
void InsertCastOp( void InsertCastOp(
...@@ -216,7 +228,8 @@ void ConvertToMixedPrecision( ...@@ -216,7 +228,8 @@ void ConvertToMixedPrecision(
phi::DataType mixed_precision, phi::DataType mixed_precision,
phi::Backend backend, phi::Backend backend,
bool keep_io_types, bool keep_io_types,
const std::unordered_set<std::string>& black_list) { const std::unordered_set<std::string>& black_list,
const std::unordered_set<std::string>& white_list) {
ConvertToMixedPrecisionPass pass(model_file, ConvertToMixedPrecisionPass pass(model_file,
params_file, params_file,
mixed_model_file, mixed_model_file,
...@@ -224,7 +237,8 @@ void ConvertToMixedPrecision( ...@@ -224,7 +237,8 @@ void ConvertToMixedPrecision(
mixed_precision, mixed_precision,
backend, backend,
keep_io_types, keep_io_types,
black_list); black_list,
white_list);
pass.Run(); pass.Run();
} }
......
...@@ -38,7 +38,8 @@ class ConvertToMixedPrecisionPass { ...@@ -38,7 +38,8 @@ class ConvertToMixedPrecisionPass {
phi::DataType mixed_precision, phi::DataType mixed_precision,
phi::Backend backend, phi::Backend backend,
bool keep_io_types, bool keep_io_types,
const std::unordered_set<std::string>& black_list); const std::unordered_set<std::string>& black_list,
const std::unordered_set<std::string>& white_list);
void Run(); void Run();
...@@ -55,6 +56,7 @@ class ConvertToMixedPrecisionPass { ...@@ -55,6 +56,7 @@ class ConvertToMixedPrecisionPass {
phi::Backend backend_; phi::Backend backend_;
bool keep_io_types_; bool keep_io_types_;
std::unordered_set<std::string> black_list_; std::unordered_set<std::string> black_list_;
std::unordered_set<std::string> white_list_;
framework::Scope scope_; framework::Scope scope_;
std::unique_ptr<framework::ir::Graph> main_graph_{nullptr}; std::unique_ptr<framework::ir::Graph> main_graph_{nullptr};
...@@ -63,7 +65,8 @@ class ConvertToMixedPrecisionPass { ...@@ -63,7 +65,8 @@ class ConvertToMixedPrecisionPass {
bool OpSupportPrecision(const std::string& op_type, bool OpSupportPrecision(const std::string& op_type,
phi::Backend backend, phi::Backend backend,
phi::DataType precision, phi::DataType precision,
const std::unordered_set<std::string>& black_list); const std::unordered_set<std::string>& black_list,
const std::unordered_set<std::string>& white_list);
void InsertCastOp( void InsertCastOp(
framework::ir::Graph* graph, framework::ir::Graph* graph,
...@@ -82,7 +85,8 @@ void ConvertToMixedPrecision(const std::string& model_file, ...@@ -82,7 +85,8 @@ void ConvertToMixedPrecision(const std::string& model_file,
phi::DataType mixed_precision, phi::DataType mixed_precision,
phi::Backend backend, phi::Backend backend,
bool keep_io_types, bool keep_io_types,
const std::unordered_set<std::string>& black_list); const std::unordered_set<std::string>& black_list,
const std::unordered_set<std::string>& white_list);
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
......
...@@ -448,6 +448,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { ...@@ -448,6 +448,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
// Mixed precision related. // Mixed precision related.
CP_MEMBER(mixed_black_list_); CP_MEMBER(mixed_black_list_);
CP_MEMBER(mixed_white_list_);
CP_MEMBER(enable_gpu_mixed_); CP_MEMBER(enable_gpu_mixed_);
CP_MEMBER(mixed_precision_mode_); CP_MEMBER(mixed_precision_mode_);
CP_MEMBER(enable_low_precision_io_); CP_MEMBER(enable_low_precision_io_);
...@@ -1154,6 +1155,7 @@ std::string AnalysisConfig::SerializeInfoCache() { ...@@ -1154,6 +1155,7 @@ std::string AnalysisConfig::SerializeInfoCache() {
for (auto attr : pattern) ss << attr; for (auto attr : pattern) ss << attr;
ss << ";"; ss << ";";
for (auto &op : mixed_black_list_) ss << op.c_str(); for (auto &op : mixed_black_list_) ss << op.c_str();
for (auto &op : mixed_white_list_) ss << op.c_str();
return ss.str(); return ss.str();
} }
...@@ -1535,6 +1537,11 @@ void AnalysisConfig::Exp_DisableMixedPrecisionOps( ...@@ -1535,6 +1537,11 @@ void AnalysisConfig::Exp_DisableMixedPrecisionOps(
mixed_black_list_ = black_list; mixed_black_list_ = black_list;
} }
void AnalysisConfig::Exp_EnableMixedPrecisionOps(
const std::unordered_set<std::string> &white_list) {
mixed_white_list_ = white_list;
}
void AnalysisConfig::Exp_EnableCINNCompiler() { void AnalysisConfig::Exp_EnableCINNCompiler() {
#ifdef PADDLE_WITH_CINN #ifdef PADDLE_WITH_CINN
use_cinn_compiler_ = true; use_cinn_compiler_ = true;
......
...@@ -1616,6 +1616,7 @@ void AnalysisPredictor::PrepareArgument() { ...@@ -1616,6 +1616,7 @@ void AnalysisPredictor::PrepareArgument() {
// mixed precison. // mixed precison.
argument_->SetModelPrecision(static_cast<int>(model_precision_)); argument_->SetModelPrecision(static_cast<int>(model_precision_));
argument_->SetMixedBlackList(config_.mixed_black_list_); argument_->SetMixedBlackList(config_.mixed_black_list_);
argument_->SetMixedWhiteList(config_.mixed_white_list_);
argument_->SetEnableGPUMixed(config_.enable_gpu_mixed_); argument_->SetEnableGPUMixed(config_.enable_gpu_mixed_);
argument_->SetMixedPrecisionMode(static_cast<int>( argument_->SetMixedPrecisionMode(static_cast<int>(
paddle::ConvertPrecision(config_.mixed_precision_mode_))); paddle::ConvertPrecision(config_.mixed_precision_mode_)));
...@@ -3097,7 +3098,8 @@ void ConvertToMixedPrecision(const std::string &model_file, ...@@ -3097,7 +3098,8 @@ void ConvertToMixedPrecision(const std::string &model_file,
PrecisionType mixed_precision, PrecisionType mixed_precision,
paddle_infer::PlaceType backend, paddle_infer::PlaceType backend,
bool keep_io_types, bool keep_io_types,
std::unordered_set<std::string> black_list) { std::unordered_set<std::string> black_list,
std::unordered_set<std::string> white_list) {
auto phi_backend = paddle::ConvertBackend(backend); auto phi_backend = paddle::ConvertBackend(backend);
auto phi_precision = paddle::ConvertPrecision(mixed_precision); auto phi_precision = paddle::ConvertPrecision(mixed_precision);
paddle::inference::analysis::ConvertToMixedPrecision(model_file, paddle::inference::analysis::ConvertToMixedPrecision(model_file,
...@@ -3107,7 +3109,8 @@ void ConvertToMixedPrecision(const std::string &model_file, ...@@ -3107,7 +3109,8 @@ void ConvertToMixedPrecision(const std::string &model_file,
phi_precision, phi_precision,
phi_backend, phi_backend,
keep_io_types, keep_io_types,
black_list); black_list,
white_list);
} }
} // namespace paddle_infer } // namespace paddle_infer
......
...@@ -1147,6 +1147,14 @@ struct PD_INFER_DECL AnalysisConfig { ...@@ -1147,6 +1147,14 @@ struct PD_INFER_DECL AnalysisConfig {
void Exp_DisableMixedPrecisionOps( void Exp_DisableMixedPrecisionOps(
const std::unordered_set<std::string>& black_list); const std::unordered_set<std::string>& black_list);
///
/// \brief Set a list of operators that do support mixed precision. This
/// interface is in the experimental stage and may change in the future. Note
/// that the whitelist must be the same as the model conversion whitelist.
///
void Exp_EnableMixedPrecisionOps(
const std::unordered_set<std::string>& white_list);
void SetApplyOptim(bool value) { apply_optim_ = value; } void SetApplyOptim(bool value) { apply_optim_ = value; }
void SetSkipLoadParams(bool value) { skip_load_params_ = value; } void SetSkipLoadParams(bool value) { skip_load_params_ = value; }
...@@ -1179,6 +1187,7 @@ struct PD_INFER_DECL AnalysisConfig { ...@@ -1179,6 +1187,7 @@ struct PD_INFER_DECL AnalysisConfig {
// Mixed precision related. // Mixed precision related.
Precision mixed_precision_mode_{Precision::kFloat32}; Precision mixed_precision_mode_{Precision::kFloat32};
std::unordered_set<std::string> mixed_black_list_; std::unordered_set<std::string> mixed_black_list_;
std::unordered_set<std::string> mixed_white_list_;
bool enable_low_precision_io_{false}; bool enable_low_precision_io_{false};
// GPU related. // GPU related.
......
...@@ -245,7 +245,8 @@ PD_INFER_DECL void ConvertToMixedPrecision( ...@@ -245,7 +245,8 @@ PD_INFER_DECL void ConvertToMixedPrecision(
PrecisionType mixed_precision, PrecisionType mixed_precision,
PlaceType backend, PlaceType backend,
bool keep_io_types = true, bool keep_io_types = true,
std::unordered_set<std::string> black_list = {}); std::unordered_set<std::string> black_list = {},
std::unordered_set<std::string> white_list = {});
namespace services { namespace services {
/// ///
......
...@@ -268,11 +268,11 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { ...@@ -268,11 +268,11 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"conv_elementwise_add_fuse_pass", // "conv_elementwise_add_fuse_pass", //
#endif // #endif //
"transpose_flatten_concat_fuse_pass", // "transpose_flatten_concat_fuse_pass", //
"identity_op_clean_pass", //
"conv2d_fusion_layout_transfer_pass", // "conv2d_fusion_layout_transfer_pass", //
"transfer_layout_elim_pass", "transfer_layout_elim_pass",
"auto_mixed_precision_pass", // "auto_mixed_precision_pass", //
"inplace_op_var_pass", // should be the last pass. "identity_op_clean_pass", // should be after auto_mixed_precision_pass.
"inplace_op_var_pass", // should be the last pass.
}); });
use_gpu_ = true; use_gpu_ = true;
......
...@@ -544,7 +544,8 @@ void BindInferenceApi(py::module *m) { ...@@ -544,7 +544,8 @@ void BindInferenceApi(py::module *m) {
py::arg("mixed_precision"), py::arg("mixed_precision"),
py::arg("backend"), py::arg("backend"),
py::arg("keep_io_types") = true, py::arg("keep_io_types") = true,
py::arg("black_list") = std::unordered_set<std::string>()); py::arg("black_list") = std::unordered_set<std::string>(),
py::arg("white_list") = std::unordered_set<std::string>());
} }
namespace { namespace {
...@@ -777,6 +778,8 @@ void BindAnalysisConfig(py::module *m) { ...@@ -777,6 +778,8 @@ void BindAnalysisConfig(py::module *m) {
.def("exp_enable_use_cutlass", &AnalysisConfig::Exp_EnableUseCutlass) .def("exp_enable_use_cutlass", &AnalysisConfig::Exp_EnableUseCutlass)
.def("exp_disable_mixed_precision_ops", .def("exp_disable_mixed_precision_ops",
&AnalysisConfig::Exp_DisableMixedPrecisionOps) &AnalysisConfig::Exp_DisableMixedPrecisionOps)
.def("exp_enable_mixed_precision_ops",
&AnalysisConfig::Exp_EnableMixedPrecisionOps)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
.def("set_exec_stream", .def("set_exec_stream",
[](AnalysisConfig &self, phi::CUDAStream &stream) { [](AnalysisConfig &self, phi::CUDAStream &stream) {
......
...@@ -78,7 +78,8 @@ def convert_to_mixed_precision( ...@@ -78,7 +78,8 @@ def convert_to_mixed_precision(
mixed_precision: PrecisionType, mixed_precision: PrecisionType,
backend: PlaceType, backend: PlaceType,
keep_io_types: bool = True, keep_io_types: bool = True,
black_list: Set = set(), black_list: Set[str] = set(),
**kwargs,
): ):
''' '''
Convert a fp32 model to mixed precision model. Convert a fp32 model to mixed precision model.
...@@ -92,6 +93,8 @@ def convert_to_mixed_precision( ...@@ -92,6 +93,8 @@ def convert_to_mixed_precision(
backend: The backend, e.g. PlaceType.GPU. backend: The backend, e.g. PlaceType.GPU.
keep_io_types: Whether the model input and output dtype remains unchanged. keep_io_types: Whether the model input and output dtype remains unchanged.
black_list: Operators that do not convert precision. black_list: Operators that do not convert precision.
kwargs: Supported keys including 'white_list'.
- white_list: Operators that do convert precision.
''' '''
mixed_model_dirname = os.path.dirname(mixed_model_file) mixed_model_dirname = os.path.dirname(mixed_model_file)
# Support mixed_params_file is empty, because some models don't have params, but convert_to_mixed_precision will call # Support mixed_params_file is empty, because some models don't have params, but convert_to_mixed_precision will call
...@@ -104,6 +107,7 @@ def convert_to_mixed_precision( ...@@ -104,6 +107,7 @@ def convert_to_mixed_precision(
) )
if not os.path.exists(mixed_params_dirname): if not os.path.exists(mixed_params_dirname):
os.makedirs(mixed_params_dirname) os.makedirs(mixed_params_dirname)
white_list = kwargs.get('white_list', set())
convert_to_mixed_precision_bind( convert_to_mixed_precision_bind(
model_file, model_file,
params_file, params_file,
...@@ -113,6 +117,7 @@ def convert_to_mixed_precision( ...@@ -113,6 +117,7 @@ def convert_to_mixed_precision(
backend, backend,
keep_io_types, keep_io_types,
black_list, black_list,
white_list,
) )
......
...@@ -155,5 +155,60 @@ class TestIdentityScaleCleanPass_V2(PassAutoScanTest): ...@@ -155,5 +155,60 @@ class TestIdentityScaleCleanPass_V2(PassAutoScanTest):
self.run_and_statis(max_examples=25, passes=["identity_op_clean_pass"]) self.run_and_statis(max_examples=25, passes=["identity_op_clean_pass"])
class TestIdentityCastCleanPass(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_gpu=True)
yield config, ['relu', 'relu'], (1e-2, 1e-2)
def sample_program_config(self, draw):
n = draw(st.integers(min_value=1, max_value=4))
c = draw(st.integers(min_value=1, max_value=20))
h = draw(st.integers(min_value=1, max_value=20))
w = draw(st.integers(min_value=1, max_value=20))
relu_op_1 = OpConfig(
"relu",
inputs={"X": ["relu_op_1_in"]},
outputs={"Out": ["relu_op_1_out"]},
)
cast_op_1 = OpConfig(
"cast",
inputs={"X": ["relu_op_1_out"]},
outputs={"Out": ["cast_op_1_out"]},
in_dtype=5,
out_dtype=5,
)
relu_op_2 = OpConfig(
"relu",
inputs={"X": ["cast_op_1_out"]},
outputs={"Out": ["relu_op_2_out"]},
)
cast_op_2 = OpConfig(
"cast",
inputs={"X": ["relu_op_2_out"]},
outputs={"Out": ["cast_op_2_out"]},
in_dtype=5,
out_dtype=4,
)
cast_op_3 = OpConfig(
"cast",
inputs={"X": ["cast_op_2_out"]},
outputs={"Out": ["cast_op_3_out"]},
in_dtype=4,
out_dtype=5,
)
program_config = ProgramConfig(
ops=[relu_op_1, cast_op_1, relu_op_2, cast_op_2, cast_op_3],
weights={},
inputs={"relu_op_1_in": TensorConfig(shape=[n, c, h, w])},
outputs=["cast_op_3_out"],
)
return program_config
def test(self):
self.run_and_statis(max_examples=25, passes=["identity_op_clean_pass"])
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册