未验证 提交 28ea9aad 编写于 作者: Y Yuanle Liu 提交者: GitHub

[Paddle Inference] rewrite convert_to_mixed_precision (#48853)

上级 b9fad5da
...@@ -103,7 +103,7 @@ pass_library(delete_c_identity_op_pass inference) ...@@ -103,7 +103,7 @@ pass_library(delete_c_identity_op_pass inference)
pass_library(preln_residual_bias_fuse_pass inference) pass_library(preln_residual_bias_fuse_pass inference)
pass_library(delete_fill_constant_op_pass inference) pass_library(delete_fill_constant_op_pass inference)
pass_library(constant_folding_pass inference) pass_library(constant_folding_pass inference)
pass_library(float_to_half_pass inference) pass_library(auto_mixed_precision_pass inference)
pass_library(conv2d_fusion_layout_transfer_pass inference) pass_library(conv2d_fusion_layout_transfer_pass inference)
pass_library(simplify_with_basic_ops_pass base) pass_library(simplify_with_basic_ops_pass base)
pass_library(fc_elementwise_layernorm_fuse_pass base) pass_library(fc_elementwise_layernorm_fuse_pass base)
......
...@@ -27,13 +27,13 @@ namespace paddle { ...@@ -27,13 +27,13 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
class FloatToHalfPass : public FusePassBase { class AutoMixedPrecisionPass : public FusePassBase {
public: public:
using VarType = framework::proto::VarType; using VarType = framework::proto::VarType;
public: public:
FloatToHalfPass() = default; AutoMixedPrecisionPass() = default;
~FloatToHalfPass() = default; ~AutoMixedPrecisionPass() = default;
protected: protected:
void ApplyImpl(Graph* graph) const override; void ApplyImpl(Graph* graph) const override;
...@@ -43,10 +43,6 @@ class FloatToHalfPass : public FusePassBase { ...@@ -43,10 +43,6 @@ class FloatToHalfPass : public FusePassBase {
void SetDefaultBlacklist() const; void SetDefaultBlacklist() const;
bool OpSupportPrecision(const std::string& op_type,
phi::DataType precision,
phi::Backend backend = phi::Backend::GPU) const;
void SetOpUniqueType() const; void SetOpUniqueType() const;
void RestoreOpOriginType() const; void RestoreOpOriginType() const;
...@@ -70,9 +66,13 @@ class FloatToHalfPass : public FusePassBase { ...@@ -70,9 +66,13 @@ class FloatToHalfPass : public FusePassBase {
void ConvertWeightsData() const; void ConvertWeightsData() const;
private: private:
mutable bool keep_io_types_; mutable bool skip_pass_{false};
mutable bool keep_io_types_{false};
// float16 or bfloat16 now // float16 or bfloat16 now
mutable phi::DataType half_precision_; mutable phi::DataType low_precision_{phi::DataType::FLOAT16};
mutable phi::Backend backend_{phi::Backend::GPU};
mutable std::unordered_set<std::string> black_list_; mutable std::unordered_set<std::string> black_list_;
...@@ -84,12 +84,26 @@ class FloatToHalfPass : public FusePassBase { ...@@ -84,12 +84,26 @@ class FloatToHalfPass : public FusePassBase {
mutable std::vector<std::vector<Node*>> all_op_nodes_; mutable std::vector<std::vector<Node*>> all_op_nodes_;
// op's unique type -> the op's origin type // op's unique type -> the op's origin type
mutable std::unordered_map<std::string, std::string> op_original_type_; mutable std::unordered_map<std::string, std::string> op_original_type_;
// op's unique type -> whether the op run at half precision // op's unique type -> whether the op run at low precision
mutable std::unordered_set<std::string> op_run_half_; mutable std::unordered_set<std::string> op_run_low_precision_;
mutable std::unordered_set<std::string> vars_convert_to_half_; mutable std::unordered_set<std::string> vars_convert_to_low_precision_;
}; };
bool OpSupportPrecision(const std::string& op_type,
phi::Backend backend,
phi::DataType precision,
const std::unordered_set<std::string>& black_list);
void DoInsertCastOp(Graph* graph,
Node* var_node,
Node* op_node,
proto::VarType::Type from_type,
proto::VarType::Type to_type,
framework::BlockDesc* block_desc,
int* suffix,
std::unordered_map<Node*, Node*>* cache);
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -142,7 +142,7 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const { ...@@ -142,7 +142,7 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
bool is_fp16_precision = bool is_fp16_precision =
static_cast<phi::DataType>(Get<int>("model_precision")) == static_cast<phi::DataType>(Get<int>("model_precision")) ==
phi::DataType::FLOAT16 || phi::DataType::FLOAT16 ||
Get<bool>("enable_gpu_half"); Get<bool>("enable_gpu_mixed");
bool cutlass_enable = false; bool cutlass_enable = false;
#ifdef PADDLE_WITH_CUTLASS #ifdef PADDLE_WITH_CUTLASS
......
...@@ -165,7 +165,7 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -165,7 +165,7 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
bool is_fp16_precision = bool is_fp16_precision =
static_cast<phi::DataType>(Get<int>("model_precision")) == static_cast<phi::DataType>(Get<int>("model_precision")) ==
phi::DataType::FLOAT16 || phi::DataType::FLOAT16 ||
Get<bool>("enable_gpu_half"); Get<bool>("enable_gpu_mixed");
constexpr int CUTLASS_NHWC_ALIGNMENT = 8; constexpr int CUTLASS_NHWC_ALIGNMENT = 8;
if (is_fp16_precision) { if (is_fp16_precision) {
#ifdef PADDLE_WITH_CUTLASS #ifdef PADDLE_WITH_CUTLASS
......
...@@ -365,7 +365,7 @@ struct Argument { ...@@ -365,7 +365,7 @@ struct Argument {
DECL_ARGUMENT_FIELD(mixed_black_list, DECL_ARGUMENT_FIELD(mixed_black_list,
MixedBlackList, MixedBlackList,
std::unordered_set<std::string>); std::unordered_set<std::string>);
DECL_ARGUMENT_FIELD(enable_gpu_half, EnableGPUHalf, bool); DECL_ARGUMENT_FIELD(enable_gpu_mixed, EnableGPUMixed, bool);
DECL_ARGUMENT_FIELD(mixed_precision_mode, MixedPrecisionMode, int); DECL_ARGUMENT_FIELD(mixed_precision_mode, MixedPrecisionMode, int);
// cinn compiler related // cinn compiler related
......
...@@ -45,8 +45,10 @@ IRPassManager::IRPassManager(Argument *argument) { ...@@ -45,8 +45,10 @@ IRPassManager::IRPassManager(Argument *argument) {
void IRPassManager::CreatePasses(Argument *argument, void IRPassManager::CreatePasses(Argument *argument,
const std::vector<std::string> &passes) { const std::vector<std::string> &passes) {
// For graph_viz_pass
std::string pre_pass; std::string pre_pass;
int pass_num = 0; int pass_num = 0;
for (const std::string &pass_name : passes) { for (const std::string &pass_name : passes) {
auto pass = framework::ir::PassRegistry::Instance().Get(pass_name); auto pass = framework::ir::PassRegistry::Instance().Get(pass_name);
pass->Set("use_varseqlen", new bool(argument->tensorrt_use_varseqlen())); pass->Set("use_varseqlen", new bool(argument->tensorrt_use_varseqlen()));
...@@ -87,14 +89,14 @@ void IRPassManager::CreatePasses(Argument *argument, ...@@ -87,14 +89,14 @@ void IRPassManager::CreatePasses(Argument *argument,
argument->tensorrt_tuned_dynamic_shape(); argument->tensorrt_tuned_dynamic_shape();
pass->Set("with_dynamic_shape", new bool(with_dynamic_shape)); pass->Set("with_dynamic_shape", new bool(with_dynamic_shape));
// mixed precision related // Mixed precision related.
pass->Set("model_precision", new int(argument->model_precision()));
pass->Set( pass->Set(
"mixed_black_list", "mixed_black_list",
new std::unordered_set<std::string>(argument->mixed_black_list())); new std::unordered_set<std::string>(argument->mixed_black_list()));
pass->Set("enable_gpu_half", new bool(argument->enable_gpu_half())); pass->Set("enable_gpu_mixed", new bool(argument->enable_gpu_mixed()));
pass->Set("mixed_precision_mode", pass->Set("mixed_precision_mode",
new int(argument->mixed_precision_mode())); new int(argument->mixed_precision_mode()));
pass->Set("model_precision", new int(argument->model_precision()));
if (pass_name == "graph_viz_pass") { if (pass_name == "graph_viz_pass") {
std::string optim_cache_dir = argument->optim_cache_dir(); std::string optim_cache_dir = argument->optim_cache_dir();
...@@ -210,6 +212,7 @@ void IRPassManager::CreatePasses(Argument *argument, ...@@ -210,6 +212,7 @@ void IRPassManager::CreatePasses(Argument *argument,
new std::vector<std::string>(argument->tensorrt_disabled_ops())); new std::vector<std::string>(argument->tensorrt_disabled_ops()));
pass->Set("trt_use_dla", new bool(argument->tensorrt_use_dla())); pass->Set("trt_use_dla", new bool(argument->tensorrt_use_dla()));
pass->Set("trt_dla_core", new int(argument->tensorrt_dla_core())); pass->Set("trt_dla_core", new int(argument->tensorrt_dla_core()));
// Setting the disable_trt_plugin_fp16 to true means that TRT plugin will // Setting the disable_trt_plugin_fp16 to true means that TRT plugin will
// not run fp16. // not run fp16.
pass->Set("disable_trt_plugin_fp16", pass->Set("disable_trt_plugin_fp16",
...@@ -238,8 +241,7 @@ void IRPassManager::CreatePasses(Argument *argument, ...@@ -238,8 +241,7 @@ void IRPassManager::CreatePasses(Argument *argument,
pass->Set("root_predictor_id", new int(argument->root_predictor_id())); pass->Set("root_predictor_id", new int(argument->root_predictor_id()));
} else if (pass_name == "build_cinn_pass") { } else if (pass_name == "build_cinn_pass") {
pass->Set("is_inference_stage", new bool(argument->use_cinn_compiler())); pass->Set("is_inference_stage", new bool(argument->use_cinn_compiler()));
} } else if (pass_name == "lite_subgraph_pass") {
if (pass_name == "lite_subgraph_pass") {
bool lite_enable_int8 = bool lite_enable_int8 =
argument->lite_precision_mode() == AnalysisConfig::Precision::kInt8; argument->lite_precision_mode() == AnalysisConfig::Precision::kInt8;
pass->Set("program", pass->Set("program",
...@@ -287,8 +289,7 @@ void IRPassManager::CreatePasses(Argument *argument, ...@@ -287,8 +289,7 @@ void IRPassManager::CreatePasses(Argument *argument,
pass->Set("nnadapter_model_cache_token", pass->Set("nnadapter_model_cache_token",
new std::vector<std::string>( new std::vector<std::string>(
argument->nnadapter_model_cache_token())); argument->nnadapter_model_cache_token()));
} } else if (pass_name == "fc_fuse_pass") {
if (pass_name == "fc_fuse_pass") {
pass->Set("use_gpu", new bool(argument->use_gpu())); pass->Set("use_gpu", new bool(argument->use_gpu()));
bool fc_mkldnn_pass = 0; bool fc_mkldnn_pass = 0;
for (const std::string &pass_n : passes) { for (const std::string &pass_n : passes) {
......
...@@ -83,13 +83,13 @@ void OutputProcess(framework::ir::Graph *graph, ...@@ -83,13 +83,13 @@ void OutputProcess(framework::ir::Graph *graph,
backend, backend,
precision, precision,
blacklist)) { blacklist)) {
AddCastOp(graph, InsertCastOp(graph,
var_node, var_node,
next_op, next_op,
framework::proto::VarType::FP32, framework::proto::VarType::FP32,
to_type, to_type,
&suffix,
block_desc, block_desc,
&suffix,
&var_to_cast_op_map); &var_to_cast_op_map);
var_node->Var()->SetDataType(framework::proto::VarType::FP32); var_node->Var()->SetDataType(framework::proto::VarType::FP32);
} }
......
...@@ -13,7 +13,7 @@ cc_library( ...@@ -13,7 +13,7 @@ cc_library(
cc_library( cc_library(
convert_to_mixed_precision convert_to_mixed_precision
SRCS convert_to_mixed_precision.cc SRCS convert_to_mixed_precision.cc
DEPS analysis_pass ir_graph_build_pass) DEPS analysis_pass ir_graph_build_pass auto_mixed_precision_pass)
cc_library( cc_library(
ir_params_sync_among_devices_pass ir_params_sync_among_devices_pass
SRCS ir_params_sync_among_devices_pass.cc SRCS ir_params_sync_among_devices_pass.cc
......
...@@ -15,14 +15,12 @@ ...@@ -15,14 +15,12 @@
#pragma once #pragma once
#include <string> #include <string>
#include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/phi/common/backend.h" #include "paddle/phi/common/backend.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
...@@ -30,20 +28,52 @@ namespace paddle { ...@@ -30,20 +28,52 @@ namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
class ConvertToMixedPrecisionPass {
public:
explicit ConvertToMixedPrecisionPass(
const std::string& model_file,
const std::string& params_file,
const std::string& mixed_model_file,
const std::string& mixed_params_file,
phi::DataType mixed_precision,
phi::Backend backend,
bool keep_io_types,
const std::unordered_set<std::string>& black_list);
void Run();
private:
void LoadModel();
void SaveMixedModel();
private:
std::string model_file_;
std::string params_file_;
std::string mixed_model_file_;
std::string mixed_params_file_;
phi::DataType mixed_precision_;
phi::Backend backend_;
bool keep_io_types_;
std::unordered_set<std::string> black_list_;
framework::Scope scope_;
std::unique_ptr<framework::ir::Graph> main_graph_{nullptr};
};
bool OpSupportPrecision(const std::string& op_type, bool OpSupportPrecision(const std::string& op_type,
phi::Backend backend, phi::Backend backend,
phi::DataType precision, phi::DataType precision,
const std::unordered_set<std::string>& blacklist); const std::unordered_set<std::string>& black_list);
void AddCastOp( void InsertCastOp(
framework::ir::Graph* graph, framework::ir::Graph* graph,
framework::ir::Node* node, framework::ir::Node* var_node,
framework::ir::Node* next_op, framework::ir::Node* op_node,
framework::proto::VarType::Type from_type, framework::proto::VarType::Type from_type,
framework::proto::VarType::Type to_type, framework::proto::VarType::Type to_type,
int* suffix,
framework::BlockDesc* block_desc, framework::BlockDesc* block_desc,
std::unordered_map<framework::ir::Node*, framework::ir::Node*>* map); int* suffix,
std::unordered_map<framework::ir::Node*, framework::ir::Node*>* visited);
void ConvertToMixedPrecision(const std::string& model_file, void ConvertToMixedPrecision(const std::string& model_file,
const std::string& params_file, const std::string& params_file,
......
...@@ -99,7 +99,7 @@ void AnalysisConfig::EnableUseGpu(uint64_t memory_pool_init_size_mb, ...@@ -99,7 +99,7 @@ void AnalysisConfig::EnableUseGpu(uint64_t memory_pool_init_size_mb,
// default // default
} else if (precision_mode == Precision::kHalf || } else if (precision_mode == Precision::kHalf ||
precision_mode == Precision::kBf16) { precision_mode == Precision::kBf16) {
enable_gpu_half_ = true; enable_gpu_mixed_ = true;
} else { } else {
LOG(ERROR) LOG(ERROR)
<< "The Paddle-GPU inference currently only supports " << "The Paddle-GPU inference currently only supports "
...@@ -396,7 +396,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { ...@@ -396,7 +396,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
// Mixed precision related. // Mixed precision related.
CP_MEMBER(mixed_black_list_); CP_MEMBER(mixed_black_list_);
CP_MEMBER(enable_gpu_half_); CP_MEMBER(enable_gpu_mixed_);
CP_MEMBER(mixed_precision_mode_); CP_MEMBER(mixed_precision_mode_);
CP_MEMBER(enable_memory_optim_); CP_MEMBER(enable_memory_optim_);
...@@ -1017,7 +1017,7 @@ std::string AnalysisConfig::SerializeInfoCache() { ...@@ -1017,7 +1017,7 @@ std::string AnalysisConfig::SerializeInfoCache() {
ss << params_file_; ss << params_file_;
ss << use_gpu_; ss << use_gpu_;
ss << enable_gpu_half_; ss << enable_gpu_mixed_;
ss << use_external_stream_; ss << use_external_stream_;
ss << exec_stream_; ss << exec_stream_;
ss << use_fc_padding_; ss << use_fc_padding_;
...@@ -1234,7 +1234,7 @@ std::string AnalysisConfig::Summary() { ...@@ -1234,7 +1234,7 @@ std::string AnalysisConfig::Summary() {
os.InsertRow({"use_gpu", use_gpu_ ? "true" : "false"}); os.InsertRow({"use_gpu", use_gpu_ ? "true" : "false"});
if (use_gpu_) { if (use_gpu_) {
os.InsertRow({"gpu_device_id", std::to_string(gpu_device_id_)}); os.InsertRow({"gpu_device_id", std::to_string(gpu_device_id_)});
os.InsertRow({"enable_gpu_half_", std::to_string(enable_gpu_half_)}); os.InsertRow({"enable_gpu_mixed_", std::to_string(enable_gpu_mixed_)});
os.InsertRow({"memory_pool_init_size", os.InsertRow({"memory_pool_init_size",
std::to_string(memory_pool_init_size_mb_) + "MB"}); std::to_string(memory_pool_init_size_mb_) + "MB"});
os.InsertRow( os.InsertRow(
......
...@@ -1277,10 +1277,10 @@ void AnalysisPredictor::PrepareArgument() { ...@@ -1277,10 +1277,10 @@ void AnalysisPredictor::PrepareArgument() {
if (!config_.ir_optim()) { if (!config_.ir_optim()) {
argument_.SetEnableIrOptim(false); argument_.SetEnableIrOptim(false);
if (config_.enable_gpu_half_) { if (config_.enable_gpu_mixed_) {
argument_.SetEnableIrOptim(true); argument_.SetEnableIrOptim(true);
pass_builder->ClearPasses(); pass_builder->ClearPasses();
pass_builder->AppendPass("float_to_half_pass"); pass_builder->AppendPass("auto_mixed_precision_pass");
LOG(INFO) LOG(INFO)
<< "This model run in Paddle-GPU mixed precision mode with no ir " << "This model run in Paddle-GPU mixed precision mode with no ir "
"optimization."; "optimization.";
...@@ -1291,7 +1291,7 @@ void AnalysisPredictor::PrepareArgument() { ...@@ -1291,7 +1291,7 @@ void AnalysisPredictor::PrepareArgument() {
if (config_.ir_debug_) { if (config_.ir_debug_) {
pass_builder->TurnOnDebug(); pass_builder->TurnOnDebug();
} }
if (config_.enable_gpu_half_) { if (config_.enable_gpu_mixed_) {
LOG(INFO) << "This model run in Paddle-GPU mixed precision mode."; LOG(INFO) << "This model run in Paddle-GPU mixed precision mode.";
} }
} }
...@@ -1303,7 +1303,7 @@ void AnalysisPredictor::PrepareArgument() { ...@@ -1303,7 +1303,7 @@ void AnalysisPredictor::PrepareArgument() {
// mixed precison. // mixed precison.
argument_.SetModelPrecision(static_cast<int>(model_precision_)); argument_.SetModelPrecision(static_cast<int>(model_precision_));
argument_.SetMixedBlackList(config_.mixed_black_list_); argument_.SetMixedBlackList(config_.mixed_black_list_);
argument_.SetEnableGPUHalf(config_.enable_gpu_half_); argument_.SetEnableGPUMixed(config_.enable_gpu_mixed_);
argument_.SetMixedPrecisionMode(static_cast<int>( argument_.SetMixedPrecisionMode(static_cast<int>(
paddle::ConvertPrecision(config_.mixed_precision_mode_))); paddle::ConvertPrecision(config_.mixed_precision_mode_)));
} }
......
...@@ -1049,7 +1049,7 @@ struct PD_INFER_DECL AnalysisConfig { ...@@ -1049,7 +1049,7 @@ struct PD_INFER_DECL AnalysisConfig {
bool use_gpu_{false}; bool use_gpu_{false};
int gpu_device_id_{0}; int gpu_device_id_{0};
uint64_t memory_pool_init_size_mb_{100}; // initial size is 100MB. uint64_t memory_pool_init_size_mb_{100}; // initial size is 100MB.
bool enable_gpu_half_{false}; bool enable_gpu_mixed_{false};
bool thread_local_stream_{false}; bool thread_local_stream_{false};
bool use_cudnn_{false}; bool use_cudnn_{false};
......
...@@ -245,7 +245,8 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { ...@@ -245,7 +245,8 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"conv_elementwise_add_fuse_pass", // "conv_elementwise_add_fuse_pass", //
#endif // #endif //
"transpose_flatten_concat_fuse_pass", // "transpose_flatten_concat_fuse_pass", //
"float_to_half_pass", // "constant_folding_pass", //
"auto_mixed_precision_pass", //
}); });
use_gpu_ = true; use_gpu_ = true;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册