未验证 提交 28ea9aad 编写于 作者: Y Yuanle Liu 提交者: GitHub

[Paddle Inference] rewrite convert_to_mixed_precision (#48853)

上级 b9fad5da
...@@ -103,7 +103,7 @@ pass_library(delete_c_identity_op_pass inference) ...@@ -103,7 +103,7 @@ pass_library(delete_c_identity_op_pass inference)
pass_library(preln_residual_bias_fuse_pass inference) pass_library(preln_residual_bias_fuse_pass inference)
pass_library(delete_fill_constant_op_pass inference) pass_library(delete_fill_constant_op_pass inference)
pass_library(constant_folding_pass inference) pass_library(constant_folding_pass inference)
pass_library(float_to_half_pass inference) pass_library(auto_mixed_precision_pass inference)
pass_library(conv2d_fusion_layout_transfer_pass inference) pass_library(conv2d_fusion_layout_transfer_pass inference)
pass_library(simplify_with_basic_ops_pass base) pass_library(simplify_with_basic_ops_pass base)
pass_library(fc_elementwise_layernorm_fuse_pass base) pass_library(fc_elementwise_layernorm_fuse_pass base)
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/float_to_half_pass.h" #include "paddle/fluid/framework/ir/auto_mixed_precision_pass.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
...@@ -29,7 +29,7 @@ namespace ir { ...@@ -29,7 +29,7 @@ namespace ir {
namespace { namespace {
using VarType = FloatToHalfPass::VarType; using VarType = AutoMixedPrecisionPass::VarType;
bool PhiKernelSupportPrecision( bool PhiKernelSupportPrecision(
const std::string& op_type, const std::string& op_type,
...@@ -71,6 +71,23 @@ bool GpuKernelSupportPrecision( ...@@ -71,6 +71,23 @@ bool GpuKernelSupportPrecision(
return support; return support;
} }
inline bool VarNodeHasDtype(Node* var_node) {
auto type = var_node->Var()->GetType();
return (type == VarType::SELECTED_ROWS) || (type == VarType::LOD_TENSOR) ||
(type == VarType::LOD_TENSOR_ARRAY) || (type == VarType::STRINGS) ||
(type == VarType::VOCAB);
}
inline bool IsFloatType(VarType::Type type) {
return (type == VarType::FP64) || (type == VarType::FP32);
}
inline bool IsHalfType(VarType::Type type) {
return (type == VarType::FP16) || (type == VarType::BF16);
}
}; // namespace
void DoInsertCastOp(Graph* graph, void DoInsertCastOp(Graph* graph,
Node* var_node, Node* var_node,
Node* op_node, Node* op_node,
...@@ -123,27 +140,26 @@ void DoInsertCastOp(Graph* graph, ...@@ -123,27 +140,26 @@ void DoInsertCastOp(Graph* graph,
IR_NODE_UNLINK(var_node, op_node); IR_NODE_UNLINK(var_node, op_node);
} }
inline bool VarNodeHasDtype(Node* var_node) { bool OpSupportPrecision(const std::string& op_type,
auto type = var_node->Var()->GetType(); phi::Backend backend,
return (type == VarType::SELECTED_ROWS) || (type == VarType::LOD_TENSOR) || phi::DataType precision,
(type == VarType::LOD_TENSOR_ARRAY) || (type == VarType::STRINGS) || const std::unordered_set<std::string>& black_list) {
(type == VarType::VOCAB); bool support = false;
} if (black_list.count(op_type) == 0) {
if (backend == phi::Backend::GPU) {
inline bool IsFloatType(VarType::Type type) { support = GpuKernelSupportPrecision(op_type, precision);
return (type == VarType::FP64) || (type == VarType::FP32); } else {
} PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Now, only support backend of GPU."));
inline bool IsHalfType(VarType::Type type) { }
return (type == VarType::FP16) || (type == VarType::BF16); }
return support;
} }
}; // namespace
// The set of ops that support fp16 calculation and are considered // The set of ops that support fp16 calculation and are considered
// numerically-dangerous, slower and whose effects may also be observed in // numerically-dangerous, slower and whose effects may also be observed in
// downstream ops. // downstream ops.
void FloatToHalfPass::SetDefaultBlacklist() const { void AutoMixedPrecisionPass::SetDefaultBlacklist() const {
black_list_.insert({ black_list_.insert({
// numerically-dangerous // numerically-dangerous
"acos", "acos",
...@@ -175,12 +191,27 @@ void FloatToHalfPass::SetDefaultBlacklist() const { ...@@ -175,12 +191,27 @@ void FloatToHalfPass::SetDefaultBlacklist() const {
}); });
} }
void FloatToHalfPass::Init(Graph* graph) const { void AutoMixedPrecisionPass::Init(Graph* graph) const {
keep_io_types_ = true; bool enable_gpu_mixed = Get<bool>("enable_gpu_mixed");
half_precision_ = if (enable_gpu_mixed) {
static_cast<phi::DataType>(Get<int>("mixed_precision_mode")); backend_ = phi::Backend::GPU;
}
skip_pass_ = !enable_gpu_mixed;
low_precision_ = static_cast<phi::DataType>(Get<int>("mixed_precision_mode"));
black_list_ = Get<std::unordered_set<std::string>>("mixed_black_list"); black_list_ = Get<std::unordered_set<std::string>>("mixed_black_list");
SetDefaultBlacklist(); SetDefaultBlacklist();
VLOG(4) << "black_list has ";
for (const auto& name : black_list_) {
VLOG(4) << " - " << name;
}
keep_io_types_ = true;
if (Has("keep_io_types")) {
keep_io_types_ = Get<bool>("keep_io_types");
}
auto graph_size = graph->SubGraphsSize(); auto graph_size = graph->SubGraphsSize();
VLOG(4) << "graph size: " << graph_size; VLOG(4) << "graph size: " << graph_size;
...@@ -204,24 +235,27 @@ void FloatToHalfPass::Init(Graph* graph) const { ...@@ -204,24 +235,27 @@ void FloatToHalfPass::Init(Graph* graph) const {
} }
} }
void FloatToHalfPass::ApplyImpl(Graph* graph) const { void AutoMixedPrecisionPass::ApplyImpl(Graph* graph) const {
auto enable_gpu_half = Get<bool>("enable_gpu_half"); PADDLE_ENFORCE_NOT_NULL(graph,
if (!enable_gpu_half) return;
PADDLE_ENFORCE_NOT_NULL(
graph,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"During the float to half pass, the graph should not be nullptr.")); "During the auto_mixed_precision_pass, the graph "
PADDLE_ENFORCE_EQ( "should not be nullptr."));
graph->IsMainGraph(), PADDLE_ENFORCE_EQ(graph->IsMainGraph(),
true, true,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"During the float to half pass, the graph should be main graph.")); "During the auto_mixed_precision_pass, the graph "
"should be main graph."));
FusePassBase::Init("float_to_half", graph); FusePassBase::Init("auto_mixed_precision", graph);
Init(graph); Init(graph);
VLOG(4) << "Init done"; VLOG(4) << "Init done";
if (skip_pass_) {
VLOG(3) << "Skip auto_mixed_precision_pass.";
return;
}
SetOpUniqueType(); SetOpUniqueType();
VLOG(4) << "SetOpUniqueType done"; VLOG(4) << "SetOpUniqueType done";
GetOpPrecision(); GetOpPrecision();
...@@ -240,19 +274,7 @@ void FloatToHalfPass::ApplyImpl(Graph* graph) const { ...@@ -240,19 +274,7 @@ void FloatToHalfPass::ApplyImpl(Graph* graph) const {
VLOG(4) << "RestoreOpOriginType done"; VLOG(4) << "RestoreOpOriginType done";
} }
bool FloatToHalfPass::OpSupportPrecision(const std::string& op_type, void AutoMixedPrecisionPass::SetOpUniqueType() const {
phi::DataType precision,
phi::Backend backend) const {
bool support = false;
if (black_list_.count(op_type) == 0) {
if (backend == phi::Backend::GPU) {
support = GpuKernelSupportPrecision(op_type, precision);
}
}
return support;
}
void FloatToHalfPass::SetOpUniqueType() const {
int suffix = 0; int suffix = 0;
for (const auto& nodes : all_op_nodes_) { for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) { for (auto* op_node : nodes) {
...@@ -269,7 +291,7 @@ void FloatToHalfPass::SetOpUniqueType() const { ...@@ -269,7 +291,7 @@ void FloatToHalfPass::SetOpUniqueType() const {
} }
} }
void FloatToHalfPass::RestoreOpOriginType() const { void AutoMixedPrecisionPass::RestoreOpOriginType() const {
for (const auto& nodes : all_op_nodes_) { for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) { for (auto* op_node : nodes) {
auto op_type = op_node->Op()->Type(); auto op_type = op_node->Op()->Type();
...@@ -281,7 +303,7 @@ void FloatToHalfPass::RestoreOpOriginType() const { ...@@ -281,7 +303,7 @@ void FloatToHalfPass::RestoreOpOriginType() const {
} }
} }
inline std::string FloatToHalfPass::GetOpOriginalType( inline std::string AutoMixedPrecisionPass::GetOpOriginalType(
const std::string& op_type) const { const std::string& op_type) const {
if (op_original_type_.count(op_type)) { if (op_original_type_.count(op_type)) {
return op_original_type_.at(op_type); return op_original_type_.at(op_type);
...@@ -289,22 +311,21 @@ inline std::string FloatToHalfPass::GetOpOriginalType( ...@@ -289,22 +311,21 @@ inline std::string FloatToHalfPass::GetOpOriginalType(
return op_type; return op_type;
} }
void FloatToHalfPass::ProcessOpWithDtypeAttr() const { void AutoMixedPrecisionPass::ProcessOpWithDtypeAttr() const {
for (const auto& nodes : all_op_nodes_) { for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) { for (auto* op_node : nodes) {
auto op_type = op_node->Op()->Type(); auto op_type = op_node->Op()->Type();
if (op_run_half_.count(op_type) == 0) continue; if (op_run_low_precision_.count(op_type) == 0) continue;
if (op_node->Op()->HasAttr("dtype")) { if (op_node->Op()->HasAttr("dtype")) {
auto dtype = op_node->Op()->GetAttrIfExists<int>("dtype"); auto dtype = op_node->Op()->GetAttrIfExists<int>("dtype");
if (IsFloatType(static_cast<VarType::Type>(dtype))) { if (IsFloatType(static_cast<VarType::Type>(dtype))) {
op_node->Op()->SetAttr( op_node->Op()->SetAttr(
"dtype", "dtype",
static_cast<int>( static_cast<int>(framework::TransToProtoVarType(low_precision_)));
framework::TransToProtoVarType(half_precision_)));
op_node->Op()->Flush(); op_node->Op()->Flush();
VLOG(4) << "process op with dtype attr: " << op_type << " ( " << dtype VLOG(4) << "process op with dtype attr: " << op_type << " ( " << dtype
<< " --->" << static_cast<int>(half_precision_) << " )"; << " --->" << static_cast<int>(low_precision_) << " )";
} }
} }
if (op_node->Op()->HasAttr("out_dtype")) { if (op_node->Op()->HasAttr("out_dtype")) {
...@@ -312,11 +333,10 @@ void FloatToHalfPass::ProcessOpWithDtypeAttr() const { ...@@ -312,11 +333,10 @@ void FloatToHalfPass::ProcessOpWithDtypeAttr() const {
if (IsFloatType(static_cast<VarType::Type>(out_dtype))) { if (IsFloatType(static_cast<VarType::Type>(out_dtype))) {
op_node->Op()->SetAttr( op_node->Op()->SetAttr(
"out_dtype", "out_dtype",
static_cast<int>( static_cast<int>(framework::TransToProtoVarType(low_precision_)));
framework::TransToProtoVarType(half_precision_)));
op_node->Op()->Flush(); op_node->Op()->Flush();
VLOG(4) << "process op with out_dtype attr: " << op_type << " ( " VLOG(4) << "process op with out_dtype attr: " << op_type << " ( "
<< out_dtype << " --->" << static_cast<int>(half_precision_) << out_dtype << " --->" << static_cast<int>(low_precision_)
<< " )"; << " )";
} }
} }
...@@ -324,37 +344,39 @@ void FloatToHalfPass::ProcessOpWithDtypeAttr() const { ...@@ -324,37 +344,39 @@ void FloatToHalfPass::ProcessOpWithDtypeAttr() const {
} }
} }
void FloatToHalfPass::GetOpPrecision() const { void AutoMixedPrecisionPass::GetOpPrecision() const {
for (const auto& nodes : all_op_nodes_) { for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) { for (auto* op_node : nodes) {
auto op_type = op_node->Op()->Type(); auto op_type = op_node->Op()->Type();
bool support_half = true; bool support_low_precision = true;
if (GetOpOriginalType(op_type) == "feed" || if (GetOpOriginalType(op_type) == "feed" ||
GetOpOriginalType(op_type) == "fetch") { GetOpOriginalType(op_type) == "fetch") {
support_half = !keep_io_types_; support_low_precision = !keep_io_types_;
} else { } else {
support_half = support_low_precision = OpSupportPrecision(
OpSupportPrecision(GetOpOriginalType(op_type), half_precision_); GetOpOriginalType(op_type), backend_, low_precision_, black_list_);
} }
if (op_node->Op()->HasAttr("dtype")) { if (op_node->Op()->HasAttr("dtype")) {
auto dtype = op_node->Op()->GetAttrIfExists<int>("dtype"); auto dtype = op_node->Op()->GetAttrIfExists<int>("dtype");
support_half = support_low_precision = support_low_precision &&
support_half && IsFloatType(static_cast<VarType::Type>(dtype)); IsFloatType(static_cast<VarType::Type>(dtype));
} else if (op_node->Op()->HasAttr("out_dtype")) { } else if (op_node->Op()->HasAttr("out_dtype")) {
auto out_dtype = op_node->Op()->GetAttrIfExists<int>("out_dtype"); auto out_dtype = op_node->Op()->GetAttrIfExists<int>("out_dtype");
support_half = support_low_precision =
support_half && IsFloatType(static_cast<VarType::Type>(out_dtype)); support_low_precision &&
IsFloatType(static_cast<VarType::Type>(out_dtype));
} else { } else {
// if op's input var and output var is not dense tensor, the op should // if op's input var and output var is not dense tensor, the op should
// not run half. // not run at low precision.
for (auto* in_var_node : op_node->inputs) { for (auto* in_var_node : op_node->inputs) {
CHECK_EQ(in_var_node->IsVar(), true); CHECK_EQ(in_var_node->IsVar(), true);
auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()]; auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
if (real_in_var_node->Var()->Persistable()) continue; if (real_in_var_node->Var()->Persistable()) continue;
support_half = support_half && (real_in_var_node->Var()->GetType() == support_low_precision =
VarType::LOD_TENSOR); support_low_precision &&
(real_in_var_node->Var()->GetType() == VarType::LOD_TENSOR);
} }
for (auto* out_var_node : op_node->outputs) { for (auto* out_var_node : op_node->outputs) {
...@@ -362,23 +384,25 @@ void FloatToHalfPass::GetOpPrecision() const { ...@@ -362,23 +384,25 @@ void FloatToHalfPass::GetOpPrecision() const {
auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()]; auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()];
if (real_out_var_node->Var()->Persistable()) continue; if (real_out_var_node->Var()->Persistable()) continue;
support_half = support_half && (real_out_var_node->Var()->GetType() == support_low_precision =
VarType::LOD_TENSOR); support_low_precision &&
(real_out_var_node->Var()->GetType() == VarType::LOD_TENSOR);
} }
} }
if (support_half) { if (support_low_precision) {
op_run_half_.insert(op_type); op_run_low_precision_.insert(op_type);
VLOG(4) << "support precision: " << op_type << " run at half"; VLOG(4) << "support precision: " << op_type << " run at low precision";
} else { } else {
VLOG(4) << "support precision: " << op_type << " not run at half"; VLOG(4) << "support precision: " << op_type
<< " not run at low precision";
} }
} }
} }
} }
void FloatToHalfPass::UpdateOpPrecision() const { void AutoMixedPrecisionPass::UpdateOpPrecision() const {
std::unordered_set<std::string> vars_should_not_half; std::unordered_set<std::string> vars_should_not_low_precision;
// var -> the var's all input op // var -> the var's all input op
std::unordered_map<std::string, std::vector<Node*>> var_input_ops; std::unordered_map<std::string, std::vector<Node*>> var_input_ops;
...@@ -401,30 +425,16 @@ void FloatToHalfPass::UpdateOpPrecision() const { ...@@ -401,30 +425,16 @@ void FloatToHalfPass::UpdateOpPrecision() const {
<< " is output of " << op_type; << " is output of " << op_type;
} }
// the select_input op's input var should not convert to half. when // the select_input op's input var should not convert to low precision.
// op's output var is select_input op's input var, the op should not run // when op's output var is select_input op's input var, the op should
// half. // not run at low precision.
if (GetOpOriginalType(op_node->Op()->Type()) == "select_input") { if (GetOpOriginalType(op_node->Op()->Type()) == "select_input") {
for (auto* in_var_node : op_node->inputs) { for (auto* in_var_node : op_node->inputs) {
CHECK_EQ(in_var_node->IsVar(), true); CHECK_EQ(in_var_node->IsVar(), true);
if (in_var_node->Var()->Persistable()) continue; if (in_var_node->Var()->Persistable()) continue;
if (!VarNodeHasDtype(in_var_node)) continue; if (!VarNodeHasDtype(in_var_node)) continue;
vars_should_not_half.insert(in_var_node->Var()->Name()); vars_should_not_low_precision.insert(in_var_node->Var()->Name());
}
}
// when op_1 only support cpu kernel. if op_2's intput var is op_1's
// output var, then op_2 should not run half.
if (GetOpOriginalType(op_type) != "feed" &&
!GpuKernelSupportPrecision(GetOpOriginalType(op_type),
phi::DataType::FLOAT32)) {
for (auto* out_var_node : op_node->outputs) {
CHECK_EQ(out_var_node->IsVar(), true);
if (out_var_node->Var()->Persistable()) continue;
if (!VarNodeHasDtype(out_var_node)) continue;
vars_should_not_half.insert(out_var_node->Var()->Name());
} }
} }
} }
...@@ -437,25 +447,7 @@ void FloatToHalfPass::UpdateOpPrecision() const { ...@@ -437,25 +447,7 @@ void FloatToHalfPass::UpdateOpPrecision() const {
precision_updated = false; precision_updated = false;
for (const auto& nodes : all_op_nodes_) { for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) { for (auto* op_node : nodes) {
if (op_run_half_.count(op_node->Op()->Type()) == 0) continue; if (op_run_low_precision_.count(op_node->Op()->Type()) == 0) continue;
for (auto* in_var_node : op_node->inputs) {
CHECK_EQ(in_var_node->IsVar(), true);
if (!VarNodeHasDtype(in_var_node)) continue;
auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
if (real_in_var_node->Var()->Persistable()) continue;
if (vars_should_not_half.count(real_in_var_node->Var()->Name())) {
op_run_half_.erase(op_node->Op()->Type());
precision_updated = true;
VLOG(4) << op_node->Op()->Type()
<< " should not support half precision.";
break;
}
}
if (op_run_half_.count(op_node->Op()->Type()) == 0) continue;
for (auto* out_var_node : op_node->outputs) { for (auto* out_var_node : op_node->outputs) {
CHECK_EQ(out_var_node->IsVar(), true); CHECK_EQ(out_var_node->IsVar(), true);
...@@ -464,24 +456,25 @@ void FloatToHalfPass::UpdateOpPrecision() const { ...@@ -464,24 +456,25 @@ void FloatToHalfPass::UpdateOpPrecision() const {
auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()]; auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()];
if (real_out_var_node->Var()->Persistable()) continue; if (real_out_var_node->Var()->Persistable()) continue;
bool not_run_half = false; bool not_run_low_precision = false;
const auto& input_op_nodes = const auto& input_op_nodes =
var_input_ops[real_out_var_node->Var()->Name()]; var_input_ops[real_out_var_node->Var()->Name()];
if (vars_should_not_half.count(real_out_var_node->Var()->Name())) { if (vars_should_not_low_precision.count(
not_run_half = true; real_out_var_node->Var()->Name())) {
not_run_low_precision = true;
} else { } else {
for (auto* node : input_op_nodes) { for (auto* node : input_op_nodes) {
if (op_run_half_.count(node->Op()->Type()) == 0) { if (op_run_low_precision_.count(node->Op()->Type()) == 0) {
not_run_half = true; not_run_low_precision = true;
break; break;
} }
} }
} }
if (not_run_half) { if (not_run_low_precision) {
op_run_half_.erase(op_node->Op()->Type()); op_run_low_precision_.erase(op_node->Op()->Type());
precision_updated = true; precision_updated = true;
VLOG(4) << op_node->Op()->Type() VLOG(4) << op_node->Op()->Type()
<< " should not support half precision."; << " should not run at low precision.";
break; break;
} }
} }
...@@ -491,8 +484,8 @@ void FloatToHalfPass::UpdateOpPrecision() const { ...@@ -491,8 +484,8 @@ void FloatToHalfPass::UpdateOpPrecision() const {
} }
// special ops, its weights should not be low precision. // special ops, its weights should not be low precision.
bool FloatToHalfPass::InputVarsNotConvert(Node* op_node, bool AutoMixedPrecisionPass::InputVarsNotConvert(
const std::string& var_name) const { Node* op_node, const std::string& var_name) const {
auto* op_desc = op_node->Op(); auto* op_desc = op_node->Op();
if (GetOpOriginalType(op_desc->Type()) == "batch_norm") { if (GetOpOriginalType(op_desc->Type()) == "batch_norm") {
auto vecs = op_desc->Input("Bias"); auto vecs = op_desc->Input("Bias");
...@@ -532,8 +525,8 @@ bool FloatToHalfPass::InputVarsNotConvert(Node* op_node, ...@@ -532,8 +525,8 @@ bool FloatToHalfPass::InputVarsNotConvert(Node* op_node,
return false; return false;
} }
bool FloatToHalfPass::OutputVarsNotConvert(Node* op_node, bool AutoMixedPrecisionPass::OutputVarsNotConvert(
const std::string& var_name) const { Node* op_node, const std::string& var_name) const {
auto* op_desc = op_node->Op(); auto* op_desc = op_node->Op();
// batch_norm's input and output (variance and mean) are the same. // batch_norm's input and output (variance and mean) are the same.
if (GetOpOriginalType(op_desc->Type()) == "batch_norm") { if (GetOpOriginalType(op_desc->Type()) == "batch_norm") {
...@@ -557,10 +550,14 @@ bool FloatToHalfPass::OutputVarsNotConvert(Node* op_node, ...@@ -557,10 +550,14 @@ bool FloatToHalfPass::OutputVarsNotConvert(Node* op_node,
return false; return false;
} }
void FloatToHalfPass::SetVarPrecision() const { void AutoMixedPrecisionPass::SetVarPrecision() const {
for (const auto& nodes : all_op_nodes_) { for (const auto& nodes : all_op_nodes_) {
for (auto* op_node : nodes) { for (auto* op_node : nodes) {
if (op_run_half_.count(op_node->Op()->Type())) { if (op_run_low_precision_.count(op_node->Op()->Type()) == 0) {
continue;
}
if (GetOpOriginalType(op_node->Op()->Type()) != "feed") {
for (auto* in_var_node : op_node->inputs) { for (auto* in_var_node : op_node->inputs) {
CHECK_EQ(in_var_node->IsVar(), true); CHECK_EQ(in_var_node->IsVar(), true);
...@@ -573,11 +570,13 @@ void FloatToHalfPass::SetVarPrecision() const { ...@@ -573,11 +570,13 @@ void FloatToHalfPass::SetVarPrecision() const {
if (real_in_var_node->Var()->Persistable()) { if (real_in_var_node->Var()->Persistable()) {
real_in_var_node->Var()->SetDataType( real_in_var_node->Var()->SetDataType(
framework::TransToProtoVarType(half_precision_)); framework::TransToProtoVarType(low_precision_));
vars_convert_to_half_.insert(in_var_name); vars_convert_to_low_precision_.insert(in_var_name);
}
} }
} }
if (GetOpOriginalType(op_node->Op()->Type()) != "fetch") {
for (auto* out_var_node : op_node->outputs) { for (auto* out_var_node : op_node->outputs) {
CHECK_EQ(out_var_node->IsVar(), true); CHECK_EQ(out_var_node->IsVar(), true);
...@@ -589,9 +588,9 @@ void FloatToHalfPass::SetVarPrecision() const { ...@@ -589,9 +588,9 @@ void FloatToHalfPass::SetVarPrecision() const {
if (OutputVarsNotConvert(op_node, out_var_name)) continue; if (OutputVarsNotConvert(op_node, out_var_name)) continue;
real_out_var_node->Var()->SetDataType( real_out_var_node->Var()->SetDataType(
framework::TransToProtoVarType(half_precision_)); framework::TransToProtoVarType(low_precision_));
if (real_out_var_node->Var()->Persistable()) { if (real_out_var_node->Var()->Persistable()) {
vars_convert_to_half_.insert(out_var_name); vars_convert_to_low_precision_.insert(out_var_name);
} }
} }
} }
...@@ -606,24 +605,24 @@ void FloatToHalfPass::SetVarPrecision() const { ...@@ -606,24 +605,24 @@ void FloatToHalfPass::SetVarPrecision() const {
if (!VarNodeHasDtype(var_node)) continue; if (!VarNodeHasDtype(var_node)) continue;
auto var_name = var_node->Var()->Name(); auto var_name = var_node->Var()->Name();
if (vars_convert_to_half_.count(var_name)) { if (vars_convert_to_low_precision_.count(var_name)) {
var_node->Var()->SetDataType( var_node->Var()->SetDataType(
framework::TransToProtoVarType(half_precision_)); framework::TransToProtoVarType(low_precision_));
} }
} }
} }
} }
void FloatToHalfPass::ConvertWeightsData() const { void AutoMixedPrecisionPass::ConvertWeightsData() const {
auto* scope = param_scope(); auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(scope,
scope,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"During the float to half pass, the scope should not be null.")); "During the auto_mixed_precision_pass, the scope "
"should not be null."));
auto var_names = scope->LocalVarNames(); auto var_names = scope->LocalVarNames();
for (const auto& var_name : var_names) { for (const auto& var_name : var_names) {
if (vars_convert_to_half_.count(var_name)) { if (vars_convert_to_low_precision_.count(var_name)) {
VLOG(4) << var_name << "'s data type was convert to half"; VLOG(4) << var_name << "'s data type was convert to half";
auto* var = scope->FindLocalVar(var_name); auto* var = scope->FindLocalVar(var_name);
...@@ -631,25 +630,29 @@ void FloatToHalfPass::ConvertWeightsData() const { ...@@ -631,25 +630,29 @@ void FloatToHalfPass::ConvertWeightsData() const {
auto* origin_tensor = var->GetMutable<phi::DenseTensor>(); auto* origin_tensor = var->GetMutable<phi::DenseTensor>();
phi::DenseTensor half_tensor; phi::DenseTensor low_precision_tensor;
half_tensor.Resize(origin_tensor->dims()); low_precision_tensor.Resize(origin_tensor->dims());
half_tensor.set_type(half_precision_); low_precision_tensor.set_type(low_precision_);
if (half_precision_ == phi::DataType::FLOAT16) { if (low_precision_ == phi::DataType::FLOAT16) {
auto* half_data = auto* low_precision_data =
half_tensor.mutable_data<phi::dtype::float16>(phi::CPUPlace{}); low_precision_tensor.mutable_data<phi::dtype::float16>(
phi::CPUPlace{});
for (int64_t i = 0; i < origin_tensor->numel(); i++) { for (int64_t i = 0; i < origin_tensor->numel(); i++) {
if (origin_tensor->dtype() == phi::DataType::FLOAT64) { if (origin_tensor->dtype() == phi::DataType::FLOAT64) {
auto* origin_data = origin_tensor->data<double>(); auto* origin_data = origin_tensor->data<double>();
half_data[i] = static_cast<phi::dtype::float16>(origin_data[i]); low_precision_data[i] =
static_cast<phi::dtype::float16>(origin_data[i]);
} else if (origin_tensor->dtype() == phi::DataType::FLOAT32) { } else if (origin_tensor->dtype() == phi::DataType::FLOAT32) {
auto* origin_data = origin_tensor->data<float>(); auto* origin_data = origin_tensor->data<float>();
half_data[i] = static_cast<phi::dtype::float16>(origin_data[i]); low_precision_data[i] =
static_cast<phi::dtype::float16>(origin_data[i]);
} }
} }
} else if (half_precision_ == phi::DataType::BFLOAT16) { } else if (low_precision_ == phi::DataType::BFLOAT16) {
auto* half_data = auto* half_data =
half_tensor.mutable_data<phi::dtype::bfloat16>(phi::CPUPlace{}); low_precision_tensor.mutable_data<phi::dtype::bfloat16>(
phi::CPUPlace{});
for (int64_t i = 0; i < origin_tensor->numel(); i++) { for (int64_t i = 0; i < origin_tensor->numel(); i++) {
if (origin_tensor->dtype() == phi::DataType::FLOAT64) { if (origin_tensor->dtype() == phi::DataType::FLOAT64) {
auto* origin_data = origin_tensor->data<double>(); auto* origin_data = origin_tensor->data<double>();
...@@ -662,12 +665,12 @@ void FloatToHalfPass::ConvertWeightsData() const { ...@@ -662,12 +665,12 @@ void FloatToHalfPass::ConvertWeightsData() const {
} }
origin_tensor->clear(); origin_tensor->clear();
paddle::framework::TensorCopySync( paddle::framework::TensorCopySync(
half_tensor, phi::CPUPlace{}, origin_tensor); low_precision_tensor, phi::CPUPlace{}, origin_tensor);
} }
} }
} }
void FloatToHalfPass::InsertCastOp() const { void AutoMixedPrecisionPass::InsertCastOp() const {
int suffix = 0; int suffix = 0;
std::unordered_map<Node*, Node*> cache; std::unordered_map<Node*, Node*> cache;
...@@ -681,7 +684,7 @@ void FloatToHalfPass::InsertCastOp() const { ...@@ -681,7 +684,7 @@ void FloatToHalfPass::InsertCastOp() const {
if (op_node->Op()->HasAttr("sub_block")) continue; if (op_node->Op()->HasAttr("sub_block")) continue;
VLOG(4) << "process op: " << op_type VLOG(4) << "process op: " << op_type
<< " run half: " << op_run_half_.count(op_type); << " run low precision: " << op_run_low_precision_.count(op_type);
auto inputs = op_node->inputs; auto inputs = op_node->inputs;
for (auto* in_var_node : inputs) { for (auto* in_var_node : inputs) {
...@@ -696,17 +699,17 @@ void FloatToHalfPass::InsertCastOp() const { ...@@ -696,17 +699,17 @@ void FloatToHalfPass::InsertCastOp() const {
VLOG(4) << "process var: " << real_in_var_node->Var()->Name() VLOG(4) << "process var: " << real_in_var_node->Var()->Name()
<< " with type " << in_var_type; << " with type " << in_var_type;
if (IsFloatType(in_var_type) && op_run_half_.count(op_type)) { if (IsFloatType(in_var_type) && op_run_low_precision_.count(op_type)) {
DoInsertCastOp(subgraphes_[i], DoInsertCastOp(subgraphes_[i],
in_var_node, in_var_node,
op_node, op_node,
in_var_type, in_var_type,
framework::TransToProtoVarType(half_precision_), framework::TransToProtoVarType(low_precision_),
block_desc, block_desc,
&suffix, &suffix,
&cache); &cache);
} else if (IsHalfType(in_var_type) && } else if (IsHalfType(in_var_type) &&
op_run_half_.count(op_type) == 0) { op_run_low_precision_.count(op_type) == 0) {
DoInsertCastOp(subgraphes_[i], DoInsertCastOp(subgraphes_[i],
in_var_node, in_var_node,
op_node, op_node,
...@@ -738,4 +741,5 @@ void FloatToHalfPass::InsertCastOp() const { ...@@ -738,4 +741,5 @@ void FloatToHalfPass::InsertCastOp() const {
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(float_to_half_pass, paddle::framework::ir::FloatToHalfPass); REGISTER_PASS(auto_mixed_precision_pass,
paddle::framework::ir::AutoMixedPrecisionPass);
...@@ -27,13 +27,13 @@ namespace paddle { ...@@ -27,13 +27,13 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
class FloatToHalfPass : public FusePassBase { class AutoMixedPrecisionPass : public FusePassBase {
public: public:
using VarType = framework::proto::VarType; using VarType = framework::proto::VarType;
public: public:
FloatToHalfPass() = default; AutoMixedPrecisionPass() = default;
~FloatToHalfPass() = default; ~AutoMixedPrecisionPass() = default;
protected: protected:
void ApplyImpl(Graph* graph) const override; void ApplyImpl(Graph* graph) const override;
...@@ -43,10 +43,6 @@ class FloatToHalfPass : public FusePassBase { ...@@ -43,10 +43,6 @@ class FloatToHalfPass : public FusePassBase {
void SetDefaultBlacklist() const; void SetDefaultBlacklist() const;
bool OpSupportPrecision(const std::string& op_type,
phi::DataType precision,
phi::Backend backend = phi::Backend::GPU) const;
void SetOpUniqueType() const; void SetOpUniqueType() const;
void RestoreOpOriginType() const; void RestoreOpOriginType() const;
...@@ -70,9 +66,13 @@ class FloatToHalfPass : public FusePassBase { ...@@ -70,9 +66,13 @@ class FloatToHalfPass : public FusePassBase {
void ConvertWeightsData() const; void ConvertWeightsData() const;
private: private:
mutable bool keep_io_types_; mutable bool skip_pass_{false};
mutable bool keep_io_types_{false};
// float16 or bfloat16 now // float16 or bfloat16 now
mutable phi::DataType half_precision_; mutable phi::DataType low_precision_{phi::DataType::FLOAT16};
mutable phi::Backend backend_{phi::Backend::GPU};
mutable std::unordered_set<std::string> black_list_; mutable std::unordered_set<std::string> black_list_;
...@@ -84,12 +84,26 @@ class FloatToHalfPass : public FusePassBase { ...@@ -84,12 +84,26 @@ class FloatToHalfPass : public FusePassBase {
mutable std::vector<std::vector<Node*>> all_op_nodes_; mutable std::vector<std::vector<Node*>> all_op_nodes_;
// op's unique type -> the op's origin type // op's unique type -> the op's origin type
mutable std::unordered_map<std::string, std::string> op_original_type_; mutable std::unordered_map<std::string, std::string> op_original_type_;
// op's unique type -> whether the op run at half precision // op's unique type -> whether the op run at low precision
mutable std::unordered_set<std::string> op_run_half_; mutable std::unordered_set<std::string> op_run_low_precision_;
mutable std::unordered_set<std::string> vars_convert_to_half_; mutable std::unordered_set<std::string> vars_convert_to_low_precision_;
}; };
bool OpSupportPrecision(const std::string& op_type,
phi::Backend backend,
phi::DataType precision,
const std::unordered_set<std::string>& black_list);
void DoInsertCastOp(Graph* graph,
Node* var_node,
Node* op_node,
proto::VarType::Type from_type,
proto::VarType::Type to_type,
framework::BlockDesc* block_desc,
int* suffix,
std::unordered_map<Node*, Node*>* cache);
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -142,7 +142,7 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const { ...@@ -142,7 +142,7 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
bool is_fp16_precision = bool is_fp16_precision =
static_cast<phi::DataType>(Get<int>("model_precision")) == static_cast<phi::DataType>(Get<int>("model_precision")) ==
phi::DataType::FLOAT16 || phi::DataType::FLOAT16 ||
Get<bool>("enable_gpu_half"); Get<bool>("enable_gpu_mixed");
bool cutlass_enable = false; bool cutlass_enable = false;
#ifdef PADDLE_WITH_CUTLASS #ifdef PADDLE_WITH_CUTLASS
......
...@@ -165,7 +165,7 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -165,7 +165,7 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
bool is_fp16_precision = bool is_fp16_precision =
static_cast<phi::DataType>(Get<int>("model_precision")) == static_cast<phi::DataType>(Get<int>("model_precision")) ==
phi::DataType::FLOAT16 || phi::DataType::FLOAT16 ||
Get<bool>("enable_gpu_half"); Get<bool>("enable_gpu_mixed");
constexpr int CUTLASS_NHWC_ALIGNMENT = 8; constexpr int CUTLASS_NHWC_ALIGNMENT = 8;
if (is_fp16_precision) { if (is_fp16_precision) {
#ifdef PADDLE_WITH_CUTLASS #ifdef PADDLE_WITH_CUTLASS
......
...@@ -365,7 +365,7 @@ struct Argument { ...@@ -365,7 +365,7 @@ struct Argument {
DECL_ARGUMENT_FIELD(mixed_black_list, DECL_ARGUMENT_FIELD(mixed_black_list,
MixedBlackList, MixedBlackList,
std::unordered_set<std::string>); std::unordered_set<std::string>);
DECL_ARGUMENT_FIELD(enable_gpu_half, EnableGPUHalf, bool); DECL_ARGUMENT_FIELD(enable_gpu_mixed, EnableGPUMixed, bool);
DECL_ARGUMENT_FIELD(mixed_precision_mode, MixedPrecisionMode, int); DECL_ARGUMENT_FIELD(mixed_precision_mode, MixedPrecisionMode, int);
// cinn compiler related // cinn compiler related
......
...@@ -45,8 +45,10 @@ IRPassManager::IRPassManager(Argument *argument) { ...@@ -45,8 +45,10 @@ IRPassManager::IRPassManager(Argument *argument) {
void IRPassManager::CreatePasses(Argument *argument, void IRPassManager::CreatePasses(Argument *argument,
const std::vector<std::string> &passes) { const std::vector<std::string> &passes) {
// For graph_viz_pass
std::string pre_pass; std::string pre_pass;
int pass_num = 0; int pass_num = 0;
for (const std::string &pass_name : passes) { for (const std::string &pass_name : passes) {
auto pass = framework::ir::PassRegistry::Instance().Get(pass_name); auto pass = framework::ir::PassRegistry::Instance().Get(pass_name);
pass->Set("use_varseqlen", new bool(argument->tensorrt_use_varseqlen())); pass->Set("use_varseqlen", new bool(argument->tensorrt_use_varseqlen()));
...@@ -87,14 +89,14 @@ void IRPassManager::CreatePasses(Argument *argument, ...@@ -87,14 +89,14 @@ void IRPassManager::CreatePasses(Argument *argument,
argument->tensorrt_tuned_dynamic_shape(); argument->tensorrt_tuned_dynamic_shape();
pass->Set("with_dynamic_shape", new bool(with_dynamic_shape)); pass->Set("with_dynamic_shape", new bool(with_dynamic_shape));
// mixed precision related // Mixed precision related.
pass->Set("model_precision", new int(argument->model_precision()));
pass->Set( pass->Set(
"mixed_black_list", "mixed_black_list",
new std::unordered_set<std::string>(argument->mixed_black_list())); new std::unordered_set<std::string>(argument->mixed_black_list()));
pass->Set("enable_gpu_half", new bool(argument->enable_gpu_half())); pass->Set("enable_gpu_mixed", new bool(argument->enable_gpu_mixed()));
pass->Set("mixed_precision_mode", pass->Set("mixed_precision_mode",
new int(argument->mixed_precision_mode())); new int(argument->mixed_precision_mode()));
pass->Set("model_precision", new int(argument->model_precision()));
if (pass_name == "graph_viz_pass") { if (pass_name == "graph_viz_pass") {
std::string optim_cache_dir = argument->optim_cache_dir(); std::string optim_cache_dir = argument->optim_cache_dir();
...@@ -210,6 +212,7 @@ void IRPassManager::CreatePasses(Argument *argument, ...@@ -210,6 +212,7 @@ void IRPassManager::CreatePasses(Argument *argument,
new std::vector<std::string>(argument->tensorrt_disabled_ops())); new std::vector<std::string>(argument->tensorrt_disabled_ops()));
pass->Set("trt_use_dla", new bool(argument->tensorrt_use_dla())); pass->Set("trt_use_dla", new bool(argument->tensorrt_use_dla()));
pass->Set("trt_dla_core", new int(argument->tensorrt_dla_core())); pass->Set("trt_dla_core", new int(argument->tensorrt_dla_core()));
// Setting the disable_trt_plugin_fp16 to true means that TRT plugin will // Setting the disable_trt_plugin_fp16 to true means that TRT plugin will
// not run fp16. // not run fp16.
pass->Set("disable_trt_plugin_fp16", pass->Set("disable_trt_plugin_fp16",
...@@ -238,8 +241,7 @@ void IRPassManager::CreatePasses(Argument *argument, ...@@ -238,8 +241,7 @@ void IRPassManager::CreatePasses(Argument *argument,
pass->Set("root_predictor_id", new int(argument->root_predictor_id())); pass->Set("root_predictor_id", new int(argument->root_predictor_id()));
} else if (pass_name == "build_cinn_pass") { } else if (pass_name == "build_cinn_pass") {
pass->Set("is_inference_stage", new bool(argument->use_cinn_compiler())); pass->Set("is_inference_stage", new bool(argument->use_cinn_compiler()));
} } else if (pass_name == "lite_subgraph_pass") {
if (pass_name == "lite_subgraph_pass") {
bool lite_enable_int8 = bool lite_enable_int8 =
argument->lite_precision_mode() == AnalysisConfig::Precision::kInt8; argument->lite_precision_mode() == AnalysisConfig::Precision::kInt8;
pass->Set("program", pass->Set("program",
...@@ -287,8 +289,7 @@ void IRPassManager::CreatePasses(Argument *argument, ...@@ -287,8 +289,7 @@ void IRPassManager::CreatePasses(Argument *argument,
pass->Set("nnadapter_model_cache_token", pass->Set("nnadapter_model_cache_token",
new std::vector<std::string>( new std::vector<std::string>(
argument->nnadapter_model_cache_token())); argument->nnadapter_model_cache_token()));
} } else if (pass_name == "fc_fuse_pass") {
if (pass_name == "fc_fuse_pass") {
pass->Set("use_gpu", new bool(argument->use_gpu())); pass->Set("use_gpu", new bool(argument->use_gpu()));
bool fc_mkldnn_pass = 0; bool fc_mkldnn_pass = 0;
for (const std::string &pass_n : passes) { for (const std::string &pass_n : passes) {
......
...@@ -83,13 +83,13 @@ void OutputProcess(framework::ir::Graph *graph, ...@@ -83,13 +83,13 @@ void OutputProcess(framework::ir::Graph *graph,
backend, backend,
precision, precision,
blacklist)) { blacklist)) {
AddCastOp(graph, InsertCastOp(graph,
var_node, var_node,
next_op, next_op,
framework::proto::VarType::FP32, framework::proto::VarType::FP32,
to_type, to_type,
&suffix,
block_desc, block_desc,
&suffix,
&var_to_cast_op_map); &var_to_cast_op_map);
var_node->Var()->SetDataType(framework::proto::VarType::FP32); var_node->Var()->SetDataType(framework::proto::VarType::FP32);
} }
......
...@@ -13,7 +13,7 @@ cc_library( ...@@ -13,7 +13,7 @@ cc_library(
cc_library( cc_library(
convert_to_mixed_precision convert_to_mixed_precision
SRCS convert_to_mixed_precision.cc SRCS convert_to_mixed_precision.cc
DEPS analysis_pass ir_graph_build_pass) DEPS analysis_pass ir_graph_build_pass auto_mixed_precision_pass)
cc_library( cc_library(
ir_params_sync_among_devices_pass ir_params_sync_among_devices_pass
SRCS ir_params_sync_among_devices_pass.cc SRCS ir_params_sync_among_devices_pass.cc
......
...@@ -14,82 +14,17 @@ ...@@ -14,82 +14,17 @@
#include "paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h" #include "paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h"
#include <algorithm>
#include <iterator>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/ir/auto_mixed_precision_pass.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/inference/io.h" #include "paddle/fluid/inference/io.h"
#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/backend.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/place.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
namespace { ConvertToMixedPrecisionPass::ConvertToMixedPrecisionPass(
using VarType = framework::proto::VarType;
bool PhiKernelSupportPrecision(
const std::string& op_type,
phi::Backend backend,
phi::DataType data_type,
phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) {
auto kernels = phi::KernelFactory::Instance().kernels();
if (kernels.find(op_type) == kernels.end()) {
return false;
}
phi::KernelKey kernel_key(backend, layout, data_type);
return phi::KernelFactory::Instance().HasKernel(op_type, kernel_key);
}
bool GpuKernelSupportPrecision(
const std::string& op_type,
phi::DataType data_type,
phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) {
auto phi_op_type = phi::TransToPhiKernelName(op_type);
bool res = PhiKernelSupportPrecision(
phi_op_type, phi::Backend::GPU, data_type, layout);
res |= PhiKernelSupportPrecision(
phi_op_type, phi::Backend::GPUDNN, data_type, layout);
if (!res) {
auto& all_kernels = framework::OperatorWithKernel::AllOpKernels();
auto it = all_kernels.find(op_type);
if (it != all_kernels.end()) {
for (auto& kern_pair : it->second) {
if (platform::is_gpu_place(kern_pair.first.place_) &&
kern_pair.first.data_type_ == VarType::FP16) {
res = true;
break;
}
}
}
}
return res;
}
class ConvertToMixedPrecisionPass {
using BlockID = size_t;
public:
explicit ConvertToMixedPrecisionPass(
const std::string& model_file, const std::string& model_file,
const std::string& params_file, const std::string& params_file,
const std::string& mixed_model_file, const std::string& mixed_model_file,
...@@ -105,571 +40,46 @@ class ConvertToMixedPrecisionPass { ...@@ -105,571 +40,46 @@ class ConvertToMixedPrecisionPass {
mixed_precision_(mixed_precision), mixed_precision_(mixed_precision),
backend_(backend), backend_(backend),
keep_io_types_(keep_io_types), keep_io_types_(keep_io_types),
black_list_(black_list), black_list_(black_list) {
place_(paddle::CPUPlace()), if (mixed_precision_ != phi::DataType::FLOAT16 &&
executor_(place_) { mixed_precision_ != phi::DataType::BFLOAT16) {
VLOG(4) << "black_list has "; PADDLE_THROW(paddle::platform::errors::InvalidArgument(
for (auto& name : black_list_) { "mixed_precision currently not supported dtype %d, we now only "
VLOG(4) << " - " << name; "support fp16 and bf16.",
} static_cast<int>(mixed_precision_)));
}
void Run();
private:
void LoadAndPrepare();
inline bool VarNodeHasDtype(framework::ir::Node* node);
void ConvertAllFp64ToFp32(framework::ir::Graph* graph);
void FixCastAttr(framework::ir::Graph* graph);
void SaveMixedModel();
void ConvertTensorDtype(BlockID block_idx);
void ProcessInputNode(bool support_precision,
framework::ir::Node* in_node,
framework::ir::Node* op_node,
int* suffix,
framework::BlockDesc* block_desc,
VarType::Type to_type,
BlockID block_idx);
void ProcessOutputNode(BlockID block_idx,
framework::ir::Node* var_node,
VarType::Type to_type);
inline bool IsFloatVarType(VarType::Type type);
bool OutShouldNotConvert(framework::ir::Node* var_node);
// Just process special cases for weights conversion.
bool WeightsShouldNotConvert(framework::ir::Node* var_node);
// Return Node* which first appers in block.
framework::ir::Node* GetRealVarNode(framework::ir::Node* node);
// Fallback to fp32 dtype when encounter circle (Not a DAG graph).
void ProcessCircleCases();
private:
std::string model_file_;
std::string params_file_;
std::string mixed_model_file_;
std::string mixed_params_file_;
phi::DataType mixed_precision_;
phi::Backend backend_;
bool keep_io_types_;
std::unordered_set<std::string> black_list_;
paddle::CPUPlace place_;
framework::Executor executor_;
framework::Scope scope_;
std::unordered_map<std::string, framework::ir::Node*> name2node_;
std::unordered_map<framework::ir::Node*, framework::ir::Node*> cast_map_;
int suffix_{0};
std::set<std::string> var_names_in_circles_;
std::unique_ptr<framework::ProgramDesc> program_desc_{nullptr};
std::unique_ptr<framework::ir::Graph> main_graph_{nullptr};
std::vector<framework::ir::Graph*> graphes_;
};
framework::ir::Node* ConvertToMixedPrecisionPass::GetRealVarNode(
framework::ir::Node* var_node) {
CHECK_EQ(var_node->IsVar(), true);
if (name2node_.count(var_node->Name())) return name2node_[var_node->Name()];
return var_node;
}
inline bool ConvertToMixedPrecisionPass::VarNodeHasDtype(
framework::ir::Node* var_node) {
CHECK_EQ(var_node->IsVar(), true);
auto type = var_node->Var()->GetType();
return (type == VarType::SELECTED_ROWS) || (type == VarType::LOD_TENSOR) ||
(type == VarType::LOD_TENSOR_ARRAY) || (type == VarType::STRINGS) ||
(type == VarType::VOCAB);
}
void ConvertToMixedPrecisionPass::ProcessInputNode(
bool support_precision,
framework::ir::Node* in_node,
framework::ir::Node* op_node,
int* suffix,
framework::BlockDesc* block_desc,
VarType::Type to_type,
BlockID block_idx) {
if (!in_node->IsVar()) return;
auto* real_node = GetRealVarNode(in_node);
if (!VarNodeHasDtype(real_node)) return;
auto* graph = graphes_[block_idx];
auto* in_var = real_node->Var();
auto in_var_type = in_var->GetDataType();
auto prev_type = in_var_type;
if (support_precision) {
if (in_var->Persistable() && in_var_type == VarType::FP32) {
if (WeightsShouldNotConvert(in_node)) return;
in_var->SetDataType(to_type);
in_var_type = to_type;
VLOG(3) << " in_node name " << in_var->Name() << " from " << prev_type
<< " to " << to_type;
} else if (!in_var->Persistable() && IsFloatVarType(in_var_type) &&
in_var_type != to_type) {
AddCastOp(graph,
in_node,
op_node,
in_var_type,
to_type,
suffix,
block_desc,
&cast_map_);
VLOG(3) << " in_node name " << in_var->Name() << "(" << prev_type
<< ") to " << cast_map_[in_node]->Name() << "(" << to_type << ")";
}
} else {
if (!in_var->Persistable() && IsFloatVarType(in_var_type) &&
in_var_type != to_type) {
AddCastOp(graph,
in_node,
op_node,
in_var_type,
to_type,
suffix,
block_desc,
&cast_map_);
VLOG(3) << " in_node name " << in_var->Name() << "(" << prev_type
<< ") to " << cast_map_[in_node]->Name() << "(" << to_type << ")";
}
}
}
void ConvertToMixedPrecisionPass::ProcessOutputNode(
BlockID block_idx, framework::ir::Node* var_node, VarType::Type to_type) {
if (!var_node->IsVar()) return;
auto* real_node = GetRealVarNode(var_node);
if (!VarNodeHasDtype(real_node)) return;
auto* out_var = real_node->Var();
auto prev_type = out_var->GetDataType();
if (out_var->GetDataType() == VarType::FP32) {
if (OutShouldNotConvert(var_node)) return;
out_var->SetDataType(to_type);
}
VLOG(3) << " out_node name " << var_node->Name() << " from dtype "
<< prev_type << " to " << out_var->GetDataType();
}
// Just process special cases.
bool ConvertToMixedPrecisionPass::OutShouldNotConvert(
framework::ir::Node* var_node) {
auto op_node = var_node->inputs[0];
auto* op_desc = op_node->Op();
// batch_norm's input and output (variance and mean) are the same.
if (op_desc->Type() == "batch_norm") {
auto vecs = op_desc->Output("MeanOut");
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
}
vecs = op_desc->Output("VarianceOut");
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
}
vecs = op_desc->Output("SavedMean");
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
}
vecs = op_desc->Output("SavedVariance");
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
}
}
return false;
}
bool ConvertToMixedPrecisionPass::WeightsShouldNotConvert(
framework::ir::Node* var_node) {
auto op_nodes = var_node->outputs;
for (auto* op_node : op_nodes) {
auto* op_desc = op_node->Op();
// batch_norm op's bias, mean, scale and variance just be float32, so we can
// not convert the dtype.
if (op_desc->Type() == "batch_norm") {
auto vecs = op_desc->Input("Bias");
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
}
vecs = op_desc->Input("Mean");
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
}
vecs = op_desc->Input("Scale");
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
}
vecs = op_desc->Input("Variance");
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
}
} else if (op_desc->Type() == "fused_multi_transformer") {
auto vecs = op_desc->Input("LnScale");
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
}
vecs = op_desc->Input("LnBias");
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
}
vecs = op_desc->Input("FFNLnScale");
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
}
vecs = op_desc->Input("FFNLnBias");
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
}
} }
if (backend_ != phi::Backend::GPU) {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"mixed_precision currently not supported place %d, we now only "
"support gpu.",
static_cast<int>(backend_)));
} }
return false;
} }
inline bool ConvertToMixedPrecisionPass::IsFloatVarType(VarType::Type type) { void ConvertToMixedPrecisionPass::LoadModel() {
return (type == VarType::FP16) || (type == VarType::FP32) || framework::Executor exe{platform::CPUPlace{}};
(type == VarType::BF16);
}
void ConvertToMixedPrecisionPass::LoadAndPrepare() { auto program_desc = inference::Load(&exe, &scope_, model_file_, params_file_);
program_desc_ =
inference::Load(&executor_, &scope_, model_file_, params_file_);
main_graph_ = std::unique_ptr<framework::ir::Graph>( main_graph_ = std::unique_ptr<framework::ir::Graph>(
new framework::ir::Graph(*program_desc_)); new framework::ir::Graph(*program_desc));
main_graph_->SetNotOwned(framework::ir::kParamScopeAttr, &scope_);
for (size_t i = 0; i < main_graph_->SubGraphsSize(); ++i) {
auto* graph = main_graph_->GetSubGraph(i);
graphes_.push_back(graph);
for (auto* node : graph->Nodes()) {
if (!node->IsVar()) continue;
if (!name2node_.count(node->Name())) {
name2node_[node->Name()] = node;
}
}
}
ProcessCircleCases();
}
// Find var names which in circles.
void ConvertToMixedPrecisionPass::ProcessCircleCases() {
std::vector<std::string> vars_in_circles;
for (size_t idx = 0; idx < program_desc_->Size(); ++idx) {
for (auto* op : program_desc_->Block(idx).AllOps()) {
// TODO(inference): batch_norm has circle, but we need to fuse it in conv
// op.
if (op->Type() == "batch_norm") continue;
const auto& in_names = op->InputArgumentNames();
const auto& out_names = op->OutputArgumentNames();
std::set<std::string> in_names_set(in_names.begin(), in_names.end());
std::set<std::string> out_names_set(out_names.begin(), out_names.end());
std::set_intersection(in_names_set.begin(),
in_names_set.end(),
out_names_set.begin(),
out_names_set.end(),
std::back_inserter(vars_in_circles));
}
}
for (auto& name : vars_in_circles) {
var_names_in_circles_.insert(name);
}
for (auto& name : var_names_in_circles_) {
LOG(INFO) << name
<< " in circles, so we will skip process those vars and ops.";
}
}
inline void ProcessConstantOpAttr(framework::ir::Node* op_node,
VarType::Type from_type,
VarType::Type to_type) {
if (!op_node->IsOp()) return;
auto op_type = op_node->Op()->Type();
if (op_type == "feed" || op_type == "fetch") return;
if (op_type == "fill_constant") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) ==
static_cast<int>(from_type))
op_node->Op()->SetAttr("dtype", static_cast<int>(to_type));
} else if (op_type == "assign_value") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) ==
static_cast<int>(from_type))
op_node->Op()->SetAttr("dtype", static_cast<int>(to_type));
} else if (op_type == "eye") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) ==
static_cast<int>(from_type))
op_node->Op()->SetAttr("dtype", static_cast<int>(to_type));
} else if (op_type == "fill_any_like") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) ==
static_cast<int>(from_type))
op_node->Op()->SetAttr("dtype", static_cast<int>(to_type));
} else if (op_type == "cast") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("in_dtype")) ==
static_cast<int>(from_type))
op_node->Op()->SetAttr("in_dtype", static_cast<int>(to_type));
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("out_dtype")) ==
static_cast<int>(from_type))
op_node->Op()->SetAttr("out_dtype", static_cast<int>(to_type));
}
}
void ConvertToMixedPrecisionPass::ConvertAllFp64ToFp32(
framework::ir::Graph* graph) {
auto op_nodes = framework::ir::TopologySortOperations(*graph);
for (auto* op_node : op_nodes) {
if (!op_node->IsOp()) continue;
auto op_type = op_node->Op()->Type();
ProcessConstantOpAttr(op_node, VarType::FP64, VarType::FP32);
auto inputs = op_node->inputs;
for (auto* in_node : inputs) {
auto* in_var = in_node->Var();
if (!in_var->Persistable() && in_var->GetDataType() == VarType::FP64) {
in_var->SetDataType(VarType::FP32);
}
}
}
} }
void ConvertToMixedPrecisionPass::Run() { void ConvertToMixedPrecisionPass::Run() {
LoadAndPrepare(); LoadModel();
for (size_t i = 0; i < graphes_.size(); ++i) { framework::ir::AutoMixedPrecisionPass pass;
auto* graph = graphes_[i]; pass.Set("mixed_precision_mode", new int{static_cast<int>(mixed_precision_)});
VLOG(2) << " -------- handle subgraph " << i << ", has " pass.Set("mixed_black_list",
<< graph->Nodes().size() << " nodes --------"; new std::unordered_set<std::string>{black_list_});
pass.Set("enable_gpu_mixed", new bool{true});
pass.Set("keep_io_types", new bool{keep_io_types_});
ConvertAllFp64ToFp32(graph); pass.Apply(main_graph_.get());
ConvertTensorDtype(i);
FixCastAttr(graph);
CHECK_EQ(framework::ir::VarDescIsConsistency(*graph), true);
}
SaveMixedModel(); SaveMixedModel();
} }
void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) {
auto* graph = graphes_[block_idx];
VarType::Type to_type;
if (mixed_precision_ == phi::DataType::FLOAT16) {
to_type = VarType::FP16;
} else if (mixed_precision_ == phi::DataType::BFLOAT16) {
to_type = VarType::BF16;
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"mixed_precision currently not supported dtype %d, we now only "
"support fp16 and bf16.",
static_cast<int>(mixed_precision_)));
}
auto op_nodes = framework::ir::TopologySortOperations(*graph);
auto* block_desc = op_nodes[0]->Op()->Block();
int num_low_precision = 0;
std::vector<framework::ir::Node*> output_nodes;
for (auto* op_node : op_nodes) {
if (!op_node->IsOp()) continue;
auto op_type = op_node->Op()->Type();
VLOG(3) << "-------------------- op_type " << op_type << ", phi_type "
<< phi::TransToPhiKernelName(op_type);
// 1. set input dtype.
if (op_type == "feed") {
auto feed_var = op_node->outputs[0]->Var();
if (!keep_io_types_ && feed_var->GetDataType() == VarType::FP32) {
feed_var->SetDataType(to_type);
}
} else if (op_type == "fetch") {
auto* fetch_var = op_node->inputs[0];
output_nodes.push_back(fetch_var);
continue;
} else if (op_type == "cast") {
continue;
}
// We can not add cast operator before ops who have sub_block, as in
// sub_block we may get a var which may be transformer by cast op.
else if (op_node->Op()->HasAttr("sub_block")) { // NOLINT
continue;
}
// 2. if op support fp16/bf16 and not in blacklist.
// - cast weight to fp16/bf16.
// - add cast op if the input dtype is not fp16/bf16.
// - set output dtype.
else if (black_list_.count(op_type) == 0) { // NOLINT
bool support_precision =
OpSupportPrecision(op_type, backend_, mixed_precision_, black_list_);
// If op's output in circle, we should not convert to fp16.
for (auto* out_node : op_node->outputs) {
if (var_names_in_circles_.count(out_node->Name())) {
support_precision = false;
VLOG(2) << " op's output " << out_node->Name()
<< " is in circle, we can not support this case, just skip.";
break;
}
}
// If the op has no input or output of float type, we will not choose the
// low precision kernel.
if (support_precision) {
bool has_float_in_out{false};
for (auto* in_node : op_node->inputs) {
if (!in_node->IsVar()) continue;
if (in_node->Var()->GetType() != VarType::LOD_TENSOR) {
support_precision = false;
VLOG(2) << " op has tensor array input[" << in_node->Name()
<< "], just skip.";
break;
}
auto* real_node = GetRealVarNode(in_node);
if (real_node->Var()->GetDataType() == VarType::FP16 ||
real_node->Var()->GetDataType() == VarType::FP32 ||
real_node->Var()->GetDataType() == VarType::FP64 ||
real_node->Var()->GetDataType() == VarType::BF16) {
has_float_in_out = true;
break;
}
}
for (auto* out_node : op_node->outputs) {
if (!out_node->IsVar()) continue;
auto* real_node = GetRealVarNode(out_node);
if (real_node->Var()->GetDataType() == VarType::FP16 ||
real_node->Var()->GetDataType() == VarType::FP32 ||
real_node->Var()->GetDataType() == VarType::FP64 ||
real_node->Var()->GetDataType() == VarType::BF16) {
has_float_in_out = true;
break;
}
}
if (!has_float_in_out) {
support_precision = false;
VLOG(2) << " op doesn't has float input and output, just skip.";
}
}
VLOG(2) << "op type: " << op_type
<< " support low precision: " << support_precision;
if (support_precision) {
ProcessConstantOpAttr(op_node, VarType::FP32, to_type);
VLOG(2) << " process input nodes:";
++num_low_precision;
auto inputs = op_node->inputs;
for (auto* in_node : inputs) {
ProcessInputNode(
true, in_node, op_node, &suffix_, block_desc, to_type, block_idx);
}
VLOG(2) << " process output nodes:";
auto outputs = op_node->outputs;
for (auto* out_node : outputs) {
ProcessOutputNode(block_idx, out_node, to_type);
}
} else {
auto inputs = op_node->inputs;
for (auto* in_node : inputs) {
ProcessInputNode(false,
in_node,
op_node,
&suffix_,
block_desc,
VarType::FP32,
block_idx);
}
}
}
// 3. check op not support fp16/bf16 or in blacklist.
// - add cast op if the input dtype is not fp32.
else { // NOLINT
VLOG(3) << "not to run fp16 op_type: " << op_type << ", node input size "
<< op_node->inputs.size();
auto in_nodes = op_node->inputs;
for (auto* in_node : in_nodes) {
auto* in_var = in_node->Var();
if (in_var->GetDataType() == to_type) {
AddCastOp(graph,
in_node,
op_node,
to_type,
VarType::FP32,
&suffix_,
block_desc,
&cast_map_);
VLOG(3) << "-- " << in_node->Name() << "(" << to_type << ") to "
<< cast_map_[in_node]->Name() << "(" << VarType::FP32 << ")";
}
}
}
}
// 4. if output_op's dtype is not compatible to output dtype, then just
// insert cast.
for (auto* node : output_nodes) {
framework::ir::Node* fetch_op{nullptr};
for (auto* op_node : node->outputs) {
if (op_node->IsOp() && op_node->Op()->Type() == "fetch") {
fetch_op = op_node;
}
}
CHECK_NOTNULL(fetch_op);
auto* var = node->Var();
if (keep_io_types_ && var->GetDataType() == to_type) {
// fp16/bf16 -> fp32.
AddCastOp(graph,
node,
fetch_op,
to_type,
VarType::FP32,
&suffix_,
block_desc,
&cast_map_);
} else if (!keep_io_types_ && var->GetDataType() == VarType::FP32) {
// fp32 -> fp16/bf16
AddCastOp(graph,
node,
fetch_op,
VarType::FP32,
to_type,
&suffix_,
block_desc,
&cast_map_);
}
}
if (num_low_precision)
LOG(INFO) << "--- detected " << num_low_precision
<< " low precision ops in " << block_idx << " subgraph";
}
// We modify op's input output precision, and we need to fix cast op in_dtype
// and out_dtype attribute.
// TODO(inference): we need a cast elimination pass.
void ConvertToMixedPrecisionPass::FixCastAttr(framework::ir::Graph* graph) {
auto op_nodes = framework::ir::TopologySortOperations(*graph);
for (auto* op_node : op_nodes) {
if (!op_node->IsOp()) continue;
auto op_type = op_node->Op()->Type();
if (op_type != "cast") continue;
auto input = op_node->inputs[0];
auto output = op_node->outputs[0];
op_node->Op()->SetAttr("in_dtype",
static_cast<int>(input->Var()->GetDataType()));
op_node->Op()->SetAttr("out_dtype",
static_cast<int>(output->Var()->GetDataType()));
}
}
void ConvertToMixedPrecisionPass::SaveMixedModel() { void ConvertToMixedPrecisionPass::SaveMixedModel() {
framework::ProgramDesc mixed_program_desc; framework::ProgramDesc mixed_program_desc;
framework::ir::GraphToProgram(*main_graph_, &mixed_program_desc); framework::ir::GraphToProgram(*main_graph_, &mixed_program_desc);
...@@ -677,51 +87,6 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() { ...@@ -677,51 +87,6 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() {
auto parameters = scope_.LocalVarNames(); auto parameters = scope_.LocalVarNames();
std::sort(parameters.begin(), parameters.end()); std::sort(parameters.begin(), parameters.end());
std::unordered_set<std::string> weights_should_be_fp32;
for (auto* node : main_graph_->Nodes()) {
if (!node->IsVar()) continue;
if (VarNodeHasDtype(node)) {
if (node->Var()->Persistable() &&
node->Var()->GetDataType() == VarType::FP32) {
VLOG(2) << "weights keep to fp32: " << node->Name() << ", ptr "
<< reinterpret_cast<void*>(node->Var());
weights_should_be_fp32.insert(node->Name());
}
}
}
#define CONVERT_TENSOR_DTYPE(DTYPE, dtype) \
mixed_tensor.set_type(DTYPE); \
auto* mixed_data = mixed_tensor.mutable_data<dtype>(platform::CPUPlace()); \
for (int64_t i = 0; i < origin_tensor->numel(); i++) { \
mixed_data[i] = static_cast<dtype>(origin_data[i]); \
} \
origin_tensor->clear(); \
paddle::framework::TensorCopySync( \
mixed_tensor, platform::CPUPlace(), origin_tensor)
for (const auto& param_name : parameters) {
if (weights_should_be_fp32.count(param_name)) continue;
auto* var = scope_.FindLocalVar(param_name);
if (var->IsType<phi::DenseTensor>()) {
auto* origin_tensor = var->GetMutable<phi::DenseTensor>();
if (origin_tensor->dtype() != phi::DataType::FLOAT32) continue;
phi::DenseTensor mixed_tensor;
mixed_tensor.Resize(origin_tensor->dims());
auto* origin_data =
origin_tensor->mutable_data<float>(platform::CPUPlace());
if (mixed_precision_ == phi::DataType::FLOAT16) {
CONVERT_TENSOR_DTYPE(paddle::experimental::DataType::FLOAT16,
phi::dtype::float16);
} else if (mixed_precision_ == phi::DataType::BFLOAT16) {
CONVERT_TENSOR_DTYPE(paddle::experimental::DataType::BFLOAT16,
phi::dtype::bfloat16);
}
}
}
#undef CONVERT_TENSOR_DTYPE
auto SerializeParams = [&]() -> std::string { auto SerializeParams = [&]() -> std::string {
std::ostringstream os; std::ostringstream os;
phi::CPUContext ctx; phi::CPUContext ctx;
...@@ -746,73 +111,32 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() { ...@@ -746,73 +111,32 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() {
mixed_program_desc.Proto()->SerializeAsString()); mixed_program_desc.Proto()->SerializeAsString());
StrToBinary(mixed_params_file_, SerializeParams()); StrToBinary(mixed_params_file_, SerializeParams());
} }
} // namespace
void AddCastOp(
framework::ir::Graph* graph,
framework::ir::Node* node,
framework::ir::Node* next_op,
VarType::Type from_type,
VarType::Type to_type,
int* suffix,
framework::BlockDesc* block_desc,
std::unordered_map<framework::ir::Node*, framework::ir::Node*>* map) {
auto update_cast_desc = [&](framework::OpDesc& desc,
const std::string& x_name,
const std::string& out_name,
const int in_dtype,
const int out_dtype) {
desc.SetType("cast");
desc.SetInput("X", {x_name});
desc.SetOutput("Out", {out_name});
desc.SetAttr("in_dtype", in_dtype);
desc.SetAttr("out_dtype", out_dtype);
desc.SetAttr("use_mkldnn", false);
desc.SetAttr("with_quant_attr", false);
desc.Flush();
};
if (map->count(node) == 0) {
// insert cast op before node.
std::string cast_input_name = node->Var()->Name();
std::string cast_output_name =
node->Var()->Name() + "_cast.tmp_" + std::to_string((*suffix)++);
CHECK_NOTNULL(block_desc);
framework::OpDesc cast_op_desc(block_desc);
update_cast_desc(cast_op_desc,
cast_input_name,
cast_output_name,
static_cast<int>(from_type),
static_cast<int>(to_type));
auto* cast_op_node = graph->CreateOpNode(&cast_op_desc);
auto* cast_output_vardesc = block_desc->Var(cast_output_name);
cast_output_vardesc->SetPersistable(false);
cast_output_vardesc->SetDataType(to_type);
cast_output_vardesc->SetShape(node->Var()->GetShape());
auto* cast_output_node = graph->CreateVarNode(cast_output_vardesc);
IR_NODE_LINK_TO(cast_op_node, cast_output_node);
(*map)[node] = cast_output_node;
}
next_op->Op()->Rename(node->Name(), map->at(node)->Name());
IR_NODE_LINK_TO(node, map->at(node)->inputs[0]);
IR_NODE_UNLINK(node, next_op);
IR_NODE_LINK_TO(map->at(node), next_op);
}
bool OpSupportPrecision(const std::string& op_type, bool OpSupportPrecision(const std::string& op_type,
phi::Backend backend, phi::Backend backend,
phi::DataType precision, phi::DataType precision,
const std::unordered_set<std::string>& blacklist) { const std::unordered_set<std::string>& black_list) {
auto phi_op_type = phi::TransToPhiKernelName(op_type); return framework::ir::OpSupportPrecision(
bool support_precision = false; op_type, backend, precision, black_list);
if (blacklist.count(op_type) == 0) { }
if (backend == phi::Backend::GPU)
support_precision = GpuKernelSupportPrecision(op_type, precision); void InsertCastOp(
else framework::ir::Graph* graph,
support_precision = framework::ir::Node* var_node,
PhiKernelSupportPrecision(phi_op_type, backend, precision); framework::ir::Node* op_node,
} framework::proto::VarType::Type from_type,
return support_precision; framework::proto::VarType::Type to_type,
framework::BlockDesc* block_desc,
int* suffix,
std::unordered_map<framework::ir::Node*, framework::ir::Node*>* visited) {
framework::ir::DoInsertCastOp(graph,
var_node,
op_node,
from_type,
to_type,
block_desc,
suffix,
visited);
} }
void ConvertToMixedPrecision( void ConvertToMixedPrecision(
......
...@@ -15,14 +15,12 @@ ...@@ -15,14 +15,12 @@
#pragma once #pragma once
#include <string> #include <string>
#include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/phi/common/backend.h" #include "paddle/phi/common/backend.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
...@@ -30,20 +28,52 @@ namespace paddle { ...@@ -30,20 +28,52 @@ namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
class ConvertToMixedPrecisionPass {
public:
explicit ConvertToMixedPrecisionPass(
const std::string& model_file,
const std::string& params_file,
const std::string& mixed_model_file,
const std::string& mixed_params_file,
phi::DataType mixed_precision,
phi::Backend backend,
bool keep_io_types,
const std::unordered_set<std::string>& black_list);
void Run();
private:
void LoadModel();
void SaveMixedModel();
private:
std::string model_file_;
std::string params_file_;
std::string mixed_model_file_;
std::string mixed_params_file_;
phi::DataType mixed_precision_;
phi::Backend backend_;
bool keep_io_types_;
std::unordered_set<std::string> black_list_;
framework::Scope scope_;
std::unique_ptr<framework::ir::Graph> main_graph_{nullptr};
};
bool OpSupportPrecision(const std::string& op_type, bool OpSupportPrecision(const std::string& op_type,
phi::Backend backend, phi::Backend backend,
phi::DataType precision, phi::DataType precision,
const std::unordered_set<std::string>& blacklist); const std::unordered_set<std::string>& black_list);
void AddCastOp( void InsertCastOp(
framework::ir::Graph* graph, framework::ir::Graph* graph,
framework::ir::Node* node, framework::ir::Node* var_node,
framework::ir::Node* next_op, framework::ir::Node* op_node,
framework::proto::VarType::Type from_type, framework::proto::VarType::Type from_type,
framework::proto::VarType::Type to_type, framework::proto::VarType::Type to_type,
int* suffix,
framework::BlockDesc* block_desc, framework::BlockDesc* block_desc,
std::unordered_map<framework::ir::Node*, framework::ir::Node*>* map); int* suffix,
std::unordered_map<framework::ir::Node*, framework::ir::Node*>* visited);
void ConvertToMixedPrecision(const std::string& model_file, void ConvertToMixedPrecision(const std::string& model_file,
const std::string& params_file, const std::string& params_file,
......
...@@ -99,7 +99,7 @@ void AnalysisConfig::EnableUseGpu(uint64_t memory_pool_init_size_mb, ...@@ -99,7 +99,7 @@ void AnalysisConfig::EnableUseGpu(uint64_t memory_pool_init_size_mb,
// default // default
} else if (precision_mode == Precision::kHalf || } else if (precision_mode == Precision::kHalf ||
precision_mode == Precision::kBf16) { precision_mode == Precision::kBf16) {
enable_gpu_half_ = true; enable_gpu_mixed_ = true;
} else { } else {
LOG(ERROR) LOG(ERROR)
<< "The Paddle-GPU inference currently only supports " << "The Paddle-GPU inference currently only supports "
...@@ -396,7 +396,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { ...@@ -396,7 +396,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
// Mixed precision related. // Mixed precision related.
CP_MEMBER(mixed_black_list_); CP_MEMBER(mixed_black_list_);
CP_MEMBER(enable_gpu_half_); CP_MEMBER(enable_gpu_mixed_);
CP_MEMBER(mixed_precision_mode_); CP_MEMBER(mixed_precision_mode_);
CP_MEMBER(enable_memory_optim_); CP_MEMBER(enable_memory_optim_);
...@@ -1017,7 +1017,7 @@ std::string AnalysisConfig::SerializeInfoCache() { ...@@ -1017,7 +1017,7 @@ std::string AnalysisConfig::SerializeInfoCache() {
ss << params_file_; ss << params_file_;
ss << use_gpu_; ss << use_gpu_;
ss << enable_gpu_half_; ss << enable_gpu_mixed_;
ss << use_external_stream_; ss << use_external_stream_;
ss << exec_stream_; ss << exec_stream_;
ss << use_fc_padding_; ss << use_fc_padding_;
...@@ -1234,7 +1234,7 @@ std::string AnalysisConfig::Summary() { ...@@ -1234,7 +1234,7 @@ std::string AnalysisConfig::Summary() {
os.InsertRow({"use_gpu", use_gpu_ ? "true" : "false"}); os.InsertRow({"use_gpu", use_gpu_ ? "true" : "false"});
if (use_gpu_) { if (use_gpu_) {
os.InsertRow({"gpu_device_id", std::to_string(gpu_device_id_)}); os.InsertRow({"gpu_device_id", std::to_string(gpu_device_id_)});
os.InsertRow({"enable_gpu_half_", std::to_string(enable_gpu_half_)}); os.InsertRow({"enable_gpu_mixed_", std::to_string(enable_gpu_mixed_)});
os.InsertRow({"memory_pool_init_size", os.InsertRow({"memory_pool_init_size",
std::to_string(memory_pool_init_size_mb_) + "MB"}); std::to_string(memory_pool_init_size_mb_) + "MB"});
os.InsertRow( os.InsertRow(
......
...@@ -1277,10 +1277,10 @@ void AnalysisPredictor::PrepareArgument() { ...@@ -1277,10 +1277,10 @@ void AnalysisPredictor::PrepareArgument() {
if (!config_.ir_optim()) { if (!config_.ir_optim()) {
argument_.SetEnableIrOptim(false); argument_.SetEnableIrOptim(false);
if (config_.enable_gpu_half_) { if (config_.enable_gpu_mixed_) {
argument_.SetEnableIrOptim(true); argument_.SetEnableIrOptim(true);
pass_builder->ClearPasses(); pass_builder->ClearPasses();
pass_builder->AppendPass("float_to_half_pass"); pass_builder->AppendPass("auto_mixed_precision_pass");
LOG(INFO) LOG(INFO)
<< "This model run in Paddle-GPU mixed precision mode with no ir " << "This model run in Paddle-GPU mixed precision mode with no ir "
"optimization."; "optimization.";
...@@ -1291,7 +1291,7 @@ void AnalysisPredictor::PrepareArgument() { ...@@ -1291,7 +1291,7 @@ void AnalysisPredictor::PrepareArgument() {
if (config_.ir_debug_) { if (config_.ir_debug_) {
pass_builder->TurnOnDebug(); pass_builder->TurnOnDebug();
} }
if (config_.enable_gpu_half_) { if (config_.enable_gpu_mixed_) {
LOG(INFO) << "This model run in Paddle-GPU mixed precision mode."; LOG(INFO) << "This model run in Paddle-GPU mixed precision mode.";
} }
} }
...@@ -1303,7 +1303,7 @@ void AnalysisPredictor::PrepareArgument() { ...@@ -1303,7 +1303,7 @@ void AnalysisPredictor::PrepareArgument() {
// mixed precison. // mixed precison.
argument_.SetModelPrecision(static_cast<int>(model_precision_)); argument_.SetModelPrecision(static_cast<int>(model_precision_));
argument_.SetMixedBlackList(config_.mixed_black_list_); argument_.SetMixedBlackList(config_.mixed_black_list_);
argument_.SetEnableGPUHalf(config_.enable_gpu_half_); argument_.SetEnableGPUMixed(config_.enable_gpu_mixed_);
argument_.SetMixedPrecisionMode(static_cast<int>( argument_.SetMixedPrecisionMode(static_cast<int>(
paddle::ConvertPrecision(config_.mixed_precision_mode_))); paddle::ConvertPrecision(config_.mixed_precision_mode_)));
} }
......
...@@ -1049,7 +1049,7 @@ struct PD_INFER_DECL AnalysisConfig { ...@@ -1049,7 +1049,7 @@ struct PD_INFER_DECL AnalysisConfig {
bool use_gpu_{false}; bool use_gpu_{false};
int gpu_device_id_{0}; int gpu_device_id_{0};
uint64_t memory_pool_init_size_mb_{100}; // initial size is 100MB. uint64_t memory_pool_init_size_mb_{100}; // initial size is 100MB.
bool enable_gpu_half_{false}; bool enable_gpu_mixed_{false};
bool thread_local_stream_{false}; bool thread_local_stream_{false};
bool use_cudnn_{false}; bool use_cudnn_{false};
......
...@@ -245,7 +245,8 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { ...@@ -245,7 +245,8 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"conv_elementwise_add_fuse_pass", // "conv_elementwise_add_fuse_pass", //
#endif // #endif //
"transpose_flatten_concat_fuse_pass", // "transpose_flatten_concat_fuse_pass", //
"float_to_half_pass", // "constant_folding_pass", //
"auto_mixed_precision_pass", //
}); });
use_gpu_ = true; use_gpu_ = true;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册