未验证 提交 28ea9aad 编写于 作者: Y Yuanle Liu 提交者: GitHub

[Paddle Inference] rewrite convert_to_mixed_precision (#48853)

上级 b9fad5da
......@@ -103,7 +103,7 @@ pass_library(delete_c_identity_op_pass inference)
pass_library(preln_residual_bias_fuse_pass inference)
pass_library(delete_fill_constant_op_pass inference)
pass_library(constant_folding_pass inference)
pass_library(float_to_half_pass inference)
pass_library(auto_mixed_precision_pass inference)
pass_library(conv2d_fusion_layout_transfer_pass inference)
pass_library(simplify_with_basic_ops_pass base)
pass_library(fc_elementwise_layernorm_fuse_pass base)
......
......@@ -27,13 +27,13 @@ namespace paddle {
namespace framework {
namespace ir {
class FloatToHalfPass : public FusePassBase {
class AutoMixedPrecisionPass : public FusePassBase {
public:
using VarType = framework::proto::VarType;
public:
FloatToHalfPass() = default;
~FloatToHalfPass() = default;
AutoMixedPrecisionPass() = default;
~AutoMixedPrecisionPass() = default;
protected:
void ApplyImpl(Graph* graph) const override;
......@@ -43,10 +43,6 @@ class FloatToHalfPass : public FusePassBase {
void SetDefaultBlacklist() const;
bool OpSupportPrecision(const std::string& op_type,
phi::DataType precision,
phi::Backend backend = phi::Backend::GPU) const;
void SetOpUniqueType() const;
void RestoreOpOriginType() const;
......@@ -70,9 +66,13 @@ class FloatToHalfPass : public FusePassBase {
void ConvertWeightsData() const;
private:
mutable bool keep_io_types_;
mutable bool skip_pass_{false};
mutable bool keep_io_types_{false};
// float16 or bfloat16 now
mutable phi::DataType half_precision_;
mutable phi::DataType low_precision_{phi::DataType::FLOAT16};
mutable phi::Backend backend_{phi::Backend::GPU};
mutable std::unordered_set<std::string> black_list_;
......@@ -84,12 +84,26 @@ class FloatToHalfPass : public FusePassBase {
mutable std::vector<std::vector<Node*>> all_op_nodes_;
// op's unique type -> the op's origin type
mutable std::unordered_map<std::string, std::string> op_original_type_;
// op's unique type -> whether the op run at half precision
mutable std::unordered_set<std::string> op_run_half_;
// op's unique type -> whether the op run at low precision
mutable std::unordered_set<std::string> op_run_low_precision_;
mutable std::unordered_set<std::string> vars_convert_to_half_;
mutable std::unordered_set<std::string> vars_convert_to_low_precision_;
};
bool OpSupportPrecision(const std::string& op_type,
phi::Backend backend,
phi::DataType precision,
const std::unordered_set<std::string>& black_list);
void DoInsertCastOp(Graph* graph,
Node* var_node,
Node* op_node,
proto::VarType::Type from_type,
proto::VarType::Type to_type,
framework::BlockDesc* block_desc,
int* suffix,
std::unordered_map<Node*, Node*>* cache);
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -142,7 +142,7 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
bool is_fp16_precision =
static_cast<phi::DataType>(Get<int>("model_precision")) ==
phi::DataType::FLOAT16 ||
Get<bool>("enable_gpu_half");
Get<bool>("enable_gpu_mixed");
bool cutlass_enable = false;
#ifdef PADDLE_WITH_CUTLASS
......
......@@ -165,7 +165,7 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
bool is_fp16_precision =
static_cast<phi::DataType>(Get<int>("model_precision")) ==
phi::DataType::FLOAT16 ||
Get<bool>("enable_gpu_half");
Get<bool>("enable_gpu_mixed");
constexpr int CUTLASS_NHWC_ALIGNMENT = 8;
if (is_fp16_precision) {
#ifdef PADDLE_WITH_CUTLASS
......
......@@ -365,7 +365,7 @@ struct Argument {
DECL_ARGUMENT_FIELD(mixed_black_list,
MixedBlackList,
std::unordered_set<std::string>);
DECL_ARGUMENT_FIELD(enable_gpu_half, EnableGPUHalf, bool);
DECL_ARGUMENT_FIELD(enable_gpu_mixed, EnableGPUMixed, bool);
DECL_ARGUMENT_FIELD(mixed_precision_mode, MixedPrecisionMode, int);
// cinn compiler related
......
......@@ -45,8 +45,10 @@ IRPassManager::IRPassManager(Argument *argument) {
void IRPassManager::CreatePasses(Argument *argument,
const std::vector<std::string> &passes) {
// For graph_viz_pass
std::string pre_pass;
int pass_num = 0;
for (const std::string &pass_name : passes) {
auto pass = framework::ir::PassRegistry::Instance().Get(pass_name);
pass->Set("use_varseqlen", new bool(argument->tensorrt_use_varseqlen()));
......@@ -87,14 +89,14 @@ void IRPassManager::CreatePasses(Argument *argument,
argument->tensorrt_tuned_dynamic_shape();
pass->Set("with_dynamic_shape", new bool(with_dynamic_shape));
// mixed precision related
pass->Set("model_precision", new int(argument->model_precision()));
// Mixed precision related.
pass->Set(
"mixed_black_list",
new std::unordered_set<std::string>(argument->mixed_black_list()));
pass->Set("enable_gpu_half", new bool(argument->enable_gpu_half()));
pass->Set("enable_gpu_mixed", new bool(argument->enable_gpu_mixed()));
pass->Set("mixed_precision_mode",
new int(argument->mixed_precision_mode()));
pass->Set("model_precision", new int(argument->model_precision()));
if (pass_name == "graph_viz_pass") {
std::string optim_cache_dir = argument->optim_cache_dir();
......@@ -210,6 +212,7 @@ void IRPassManager::CreatePasses(Argument *argument,
new std::vector<std::string>(argument->tensorrt_disabled_ops()));
pass->Set("trt_use_dla", new bool(argument->tensorrt_use_dla()));
pass->Set("trt_dla_core", new int(argument->tensorrt_dla_core()));
// Setting the disable_trt_plugin_fp16 to true means that TRT plugin will
// not run fp16.
pass->Set("disable_trt_plugin_fp16",
......@@ -238,8 +241,7 @@ void IRPassManager::CreatePasses(Argument *argument,
pass->Set("root_predictor_id", new int(argument->root_predictor_id()));
} else if (pass_name == "build_cinn_pass") {
pass->Set("is_inference_stage", new bool(argument->use_cinn_compiler()));
}
if (pass_name == "lite_subgraph_pass") {
} else if (pass_name == "lite_subgraph_pass") {
bool lite_enable_int8 =
argument->lite_precision_mode() == AnalysisConfig::Precision::kInt8;
pass->Set("program",
......@@ -287,8 +289,7 @@ void IRPassManager::CreatePasses(Argument *argument,
pass->Set("nnadapter_model_cache_token",
new std::vector<std::string>(
argument->nnadapter_model_cache_token()));
}
if (pass_name == "fc_fuse_pass") {
} else if (pass_name == "fc_fuse_pass") {
pass->Set("use_gpu", new bool(argument->use_gpu()));
bool fc_mkldnn_pass = 0;
for (const std::string &pass_n : passes) {
......
......@@ -83,14 +83,14 @@ void OutputProcess(framework::ir::Graph *graph,
backend,
precision,
blacklist)) {
AddCastOp(graph,
var_node,
next_op,
framework::proto::VarType::FP32,
to_type,
&suffix,
block_desc,
&var_to_cast_op_map);
InsertCastOp(graph,
var_node,
next_op,
framework::proto::VarType::FP32,
to_type,
block_desc,
&suffix,
&var_to_cast_op_map);
var_node->Var()->SetDataType(framework::proto::VarType::FP32);
}
}
......
......@@ -13,7 +13,7 @@ cc_library(
cc_library(
convert_to_mixed_precision
SRCS convert_to_mixed_precision.cc
DEPS analysis_pass ir_graph_build_pass)
DEPS analysis_pass ir_graph_build_pass auto_mixed_precision_pass)
cc_library(
ir_params_sync_among_devices_pass
SRCS ir_params_sync_among_devices_pass.cc
......
......@@ -15,14 +15,12 @@
#pragma once
#include <string>
#include <unordered_map>
#include <unordered_set>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/phi/common/backend.h"
#include "paddle/phi/common/data_type.h"
......@@ -30,20 +28,52 @@ namespace paddle {
namespace inference {
namespace analysis {
class ConvertToMixedPrecisionPass {
public:
explicit ConvertToMixedPrecisionPass(
const std::string& model_file,
const std::string& params_file,
const std::string& mixed_model_file,
const std::string& mixed_params_file,
phi::DataType mixed_precision,
phi::Backend backend,
bool keep_io_types,
const std::unordered_set<std::string>& black_list);
void Run();
private:
void LoadModel();
void SaveMixedModel();
private:
std::string model_file_;
std::string params_file_;
std::string mixed_model_file_;
std::string mixed_params_file_;
phi::DataType mixed_precision_;
phi::Backend backend_;
bool keep_io_types_;
std::unordered_set<std::string> black_list_;
framework::Scope scope_;
std::unique_ptr<framework::ir::Graph> main_graph_{nullptr};
};
bool OpSupportPrecision(const std::string& op_type,
phi::Backend backend,
phi::DataType precision,
const std::unordered_set<std::string>& blacklist);
const std::unordered_set<std::string>& black_list);
void AddCastOp(
void InsertCastOp(
framework::ir::Graph* graph,
framework::ir::Node* node,
framework::ir::Node* next_op,
framework::ir::Node* var_node,
framework::ir::Node* op_node,
framework::proto::VarType::Type from_type,
framework::proto::VarType::Type to_type,
int* suffix,
framework::BlockDesc* block_desc,
std::unordered_map<framework::ir::Node*, framework::ir::Node*>* map);
int* suffix,
std::unordered_map<framework::ir::Node*, framework::ir::Node*>* visited);
void ConvertToMixedPrecision(const std::string& model_file,
const std::string& params_file,
......
......@@ -99,7 +99,7 @@ void AnalysisConfig::EnableUseGpu(uint64_t memory_pool_init_size_mb,
// default
} else if (precision_mode == Precision::kHalf ||
precision_mode == Precision::kBf16) {
enable_gpu_half_ = true;
enable_gpu_mixed_ = true;
} else {
LOG(ERROR)
<< "The Paddle-GPU inference currently only supports "
......@@ -396,7 +396,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
// Mixed precision related.
CP_MEMBER(mixed_black_list_);
CP_MEMBER(enable_gpu_half_);
CP_MEMBER(enable_gpu_mixed_);
CP_MEMBER(mixed_precision_mode_);
CP_MEMBER(enable_memory_optim_);
......@@ -1017,7 +1017,7 @@ std::string AnalysisConfig::SerializeInfoCache() {
ss << params_file_;
ss << use_gpu_;
ss << enable_gpu_half_;
ss << enable_gpu_mixed_;
ss << use_external_stream_;
ss << exec_stream_;
ss << use_fc_padding_;
......@@ -1234,7 +1234,7 @@ std::string AnalysisConfig::Summary() {
os.InsertRow({"use_gpu", use_gpu_ ? "true" : "false"});
if (use_gpu_) {
os.InsertRow({"gpu_device_id", std::to_string(gpu_device_id_)});
os.InsertRow({"enable_gpu_half_", std::to_string(enable_gpu_half_)});
os.InsertRow({"enable_gpu_mixed_", std::to_string(enable_gpu_mixed_)});
os.InsertRow({"memory_pool_init_size",
std::to_string(memory_pool_init_size_mb_) + "MB"});
os.InsertRow(
......
......@@ -1277,10 +1277,10 @@ void AnalysisPredictor::PrepareArgument() {
if (!config_.ir_optim()) {
argument_.SetEnableIrOptim(false);
if (config_.enable_gpu_half_) {
if (config_.enable_gpu_mixed_) {
argument_.SetEnableIrOptim(true);
pass_builder->ClearPasses();
pass_builder->AppendPass("float_to_half_pass");
pass_builder->AppendPass("auto_mixed_precision_pass");
LOG(INFO)
<< "This model run in Paddle-GPU mixed precision mode with no ir "
"optimization.";
......@@ -1291,7 +1291,7 @@ void AnalysisPredictor::PrepareArgument() {
if (config_.ir_debug_) {
pass_builder->TurnOnDebug();
}
if (config_.enable_gpu_half_) {
if (config_.enable_gpu_mixed_) {
LOG(INFO) << "This model run in Paddle-GPU mixed precision mode.";
}
}
......@@ -1303,7 +1303,7 @@ void AnalysisPredictor::PrepareArgument() {
// mixed precison.
argument_.SetModelPrecision(static_cast<int>(model_precision_));
argument_.SetMixedBlackList(config_.mixed_black_list_);
argument_.SetEnableGPUHalf(config_.enable_gpu_half_);
argument_.SetEnableGPUMixed(config_.enable_gpu_mixed_);
argument_.SetMixedPrecisionMode(static_cast<int>(
paddle::ConvertPrecision(config_.mixed_precision_mode_)));
}
......
......@@ -1049,7 +1049,7 @@ struct PD_INFER_DECL AnalysisConfig {
bool use_gpu_{false};
int gpu_device_id_{0};
uint64_t memory_pool_init_size_mb_{100}; // initial size is 100MB.
bool enable_gpu_half_{false};
bool enable_gpu_mixed_{false};
bool thread_local_stream_{false};
bool use_cudnn_{false};
......
......@@ -245,7 +245,8 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"conv_elementwise_add_fuse_pass", //
#endif //
"transpose_flatten_concat_fuse_pass", //
"float_to_half_pass", //
"constant_folding_pass", //
"auto_mixed_precision_pass", //
});
use_gpu_ = true;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册