未验证 提交 06fee998 编写于 作者: B baoachun 提交者: GitHub

support gpu mixed precision inference (#40531)

上级 e3a67782
......@@ -97,6 +97,7 @@ pass_library(layer_norm_fuse_pass inference)
pass_library(add_support_int8_pass inference)
pass_library(matmul_scale_fuse_pass inference)
pass_library(gpu_cpu_map_matmul_to_mul_pass inference)
pass_library(mixed_precision_configure_pass inference)
pass_library(generate_pass DEPS pass_desc_proto)
target_link_libraries(generate_pass pass_desc_proto)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mixed_precision_configure_pass.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
namespace ir {
void MixedPrecisionConfigurePass::InsertCastOps(
Graph* graph, const StringSet& blacklist) const {
VLOG(3) << "Insert the cast op before and after the kernel that does not "
"supports fp16 precision";
auto update_cast_desc = [&](
framework::OpDesc& desc, const std::string& x_name,
const std::string& out_name, const int in_dtype, const int out_dtype) {
desc.SetType("cast");
desc.SetInput("X", {x_name});
desc.SetOutput("Out", {out_name});
desc.SetAttr("in_dtype", in_dtype);
desc.SetAttr("out_dtype", out_dtype);
desc.SetAttr("use_mkldnn", false);
desc.SetAttr("with_quant_attr", false);
desc.Flush();
};
auto cast_input = [&](Graph* graph, Node* op_node,
const StringSet& cast_list) {
auto inlinks = op_node->inputs;
for (auto* pre_node : inlinks) {
if (pre_node->IsVar()) {
const auto is_persistable = pre_node->Var()->Persistable();
const auto is_float =
pre_node->Var()->GetDataType() == proto::VarType::FP16 ||
pre_node->Var()->GetDataType() == proto::VarType::FP32 ||
pre_node->Var()->GetDataType() == proto::VarType::FP64;
if (!is_persistable && is_float) {
int suffix = 0;
for (auto* pre_node_input : pre_node->inputs) {
if (!pre_node_input->IsOp()) continue;
const auto& type = pre_node_input->Op()->Type();
if (!cast_list.count(type) && type != "cast") {
std::string old_name = pre_node->Name();
std::string new_name =
old_name + "_cast.tmp_" + std::to_string(suffix);
suffix++;
framework::OpDesc new_op_desc(op_node->Op()->Block());
// 4 for fp16, 5 for fp32
update_cast_desc(new_op_desc, old_name, new_name, 4, 5);
auto* new_op = graph->CreateOpNode(&new_op_desc);
VarDesc out_var(new_name);
out_var.SetPersistable(false);
auto* node_var = graph->CreateVarNode(&out_var);
op_node->Op()->RenameInput(old_name, new_name);
IR_NODE_LINK_TO(pre_node, new_op);
IR_NODE_LINK_TO(new_op, node_var);
IR_NODE_LINK_TO(node_var, op_node);
}
}
}
}
}
};
auto cast_output = [&](Graph* graph, Node* op_node,
const StringSet& cast_list) {
auto outlinks = op_node->outputs;
for (auto* next_node : outlinks) {
if (next_node->IsVar()) {
const auto is_persistable = next_node->Var()->Persistable();
const auto is_float =
next_node->Var()->GetDataType() == proto::VarType::FP16 ||
next_node->Var()->GetDataType() == proto::VarType::FP32 ||
next_node->Var()->GetDataType() == proto::VarType::FP64;
if (!is_persistable && is_float) {
int suffix = 0;
for (auto* next_node_output : next_node->outputs) {
if (!next_node_output->IsOp()) continue;
const auto& type = next_node_output->Op()->Type();
if (!cast_list.count(type) && type != "cast") {
std::string old_name = next_node->Name();
std::string new_name =
old_name + "_cast.tmp_" + std::to_string(suffix);
suffix++;
framework::OpDesc new_op_desc(op_node->Op()->Block());
// 4 for fp16, 5 for fp32
update_cast_desc(new_op_desc, old_name, new_name, 5, 4);
auto* new_op = graph->CreateOpNode(&new_op_desc);
VarDesc out_var(new_name);
out_var.SetPersistable(false);
auto* node_var = graph->CreateVarNode(&out_var);
next_node_output->Op()->RenameInput(old_name, new_name);
IR_NODE_LINK_TO(next_node, new_op);
IR_NODE_LINK_TO(new_op, node_var);
IR_NODE_LINK_TO(node_var, next_node_output);
}
}
}
}
}
};
for (auto* op_node :
ir::TopologyVarientSort(*graph, static_cast<ir::SortKind>(0))) {
if (!op_node->IsOp() || op_node->Op()->Type() == "feed" ||
op_node->Op()->Type() == "fetch")
continue;
const auto& type = op_node->Op()->Type();
if (blacklist.count(type)) {
cast_input(graph, op_node, blacklist);
cast_output(graph, op_node, blacklist);
}
}
}
void MixedPrecisionConfigurePass::ApplyImpl(Graph* graph) const {
const auto blacklist =
Get<std::unordered_set<std::string>>("gpu_fp16_disabled_op_types");
InsertCastOps(graph, blacklist);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(mixed_precision_configure_pass,
paddle::framework::ir::MixedPrecisionConfigurePass);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
using StringSet = std::unordered_set<std::string>;
class MixedPrecisionConfigurePass : public FusePassBase {
public:
MixedPrecisionConfigurePass() = default;
virtual ~MixedPrecisionConfigurePass() {}
protected:
void ApplyImpl(Graph* graph) const override;
private:
void InsertCastOps(Graph* graph, const StringSet& blacklist) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -188,6 +188,9 @@ struct Argument {
DECL_ARGUMENT_FIELD(use_gpu, UseGPU, bool);
DECL_ARGUMENT_FIELD(use_fc_padding, UseFcPadding, bool);
DECL_ARGUMENT_FIELD(gpu_device_id, GPUDeviceId, int);
DECL_ARGUMENT_FIELD(use_gpu_fp16, UseGPUFp16, bool);
DECL_ARGUMENT_FIELD(gpu_fp16_disabled_op_types, GpuFp16DisabledOpTypes,
std::unordered_set<std::string>);
// Usually use for trt dynamic shape.
// TRT will select the best kernel according to opt shape
......
......@@ -189,6 +189,10 @@ void IRPassManager::CreatePasses(Argument *argument,
new int(argument->dlnne_min_subgraph_size()));
pass->Set("program",
new framework::ProgramDesc *(&argument->main_program()));
} else if (pass_name == "mixed_precision_configure_pass") {
pass->Set("gpu_fp16_disabled_op_types",
new std::unordered_set<std::string>(
argument->gpu_fp16_disabled_op_types()));
}
if (pass_name == "lite_subgraph_pass") {
bool lite_enable_int8 =
......
......@@ -14,6 +14,7 @@
#include "paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.h"
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -65,6 +66,26 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToNpu(Argument *argument) {
#else
void IrParamsSyncAmongDevicesPass::GetVarNameToOpTypeMap(
const framework::ir::Graph &graph,
std::unordered_map<std::string, std::string> *var_name_op_type_map) {
std::vector<framework::ir::Node *> node_list =
framework::ir::TopologyVarientSort(
graph, static_cast<framework::ir::SortKind>(0));
for (auto *op_node : node_list) {
if (!op_node->IsOp() || op_node->Op()->Type() == "feed" ||
op_node->Op()->Type() == "fetch")
continue;
for (auto *pre_node : op_node->inputs) {
if (pre_node->IsVar() && pre_node->Var()->Persistable()) {
var_name_op_type_map->insert(std::pair<std::string, std::string>(
pre_node->Var()->Name(), op_node->Op()->Type()));
}
}
}
}
void IrParamsSyncAmongDevicesPass::CopyParamsToGpu(Argument *argument) {
// The parameters are on the cpu, therefore, synchronization is not necessary.
if (!argument->use_gpu()) return;
......@@ -102,6 +123,16 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToGpu(Argument *argument) {
if (with_dynamic_shape) {
reserve_cpu_weights = true;
}
bool mixed_precision_mode =
argument->Has("use_gpu_fp16") && argument->use_gpu_fp16();
std::unordered_map<std::string, std::string> var_name_op_type_map{};
std::unordered_set<std::string> blacklist{};
if (mixed_precision_mode) {
GetVarNameToOpTypeMap(graph, &var_name_op_type_map);
blacklist = argument->gpu_fp16_disabled_op_types();
}
for (auto &var_name : all_vars) {
if (std::count(repetitive_params.begin(), repetitive_params.end(),
var_name)) {
......@@ -117,18 +148,29 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToGpu(Argument *argument) {
var->IsType<framework::Tensor>()) {
auto *t = var->GetMutable<framework::LoDTensor>();
platform::CPUPlace cpu_place;
framework::LoDTensor temp_tensor;
temp_tensor.Resize(t->dims());
temp_tensor.mutable_data<float>(cpu_place);
// Copy the parameter data to a tmp tensor.
paddle::framework::TensorCopySync(*t, cpu_place, &temp_tensor);
// Reallocation the space on GPU
t->clear();
// Copy parameter data to newly allocated GPU space.
paddle::framework::TensorCopySync(temp_tensor, place, t);
bool is_float = t->dtype() == paddle::experimental::DataType::FLOAT32 ||
t->dtype() == paddle::experimental::DataType::FLOAT64;
if (mixed_precision_mode &&
!blacklist.count(var_name_op_type_map[var_name]) && is_float) {
framework::Tensor half_tensor;
half_tensor.set_type(paddle::experimental::DataType::FLOAT16);
half_tensor.Resize(t->dims());
auto *half_data =
half_tensor.mutable_data<float16>(platform::CPUPlace());
for (int i = 0; i < t->numel(); i++) {
auto *data = t->mutable_data<float>(platform::CPUPlace());
half_data[i] = static_cast<float16>(data[i]);
}
t->clear();
paddle::framework::TensorCopySync(half_tensor, place, t);
} else {
platform::CPUPlace cpu_place;
framework::LoDTensor temp_tensor;
temp_tensor.Resize(t->dims());
paddle::framework::TensorCopySync(*t, cpu_place, &temp_tensor);
t->clear();
paddle::framework::TensorCopySync(temp_tensor, place, t);
}
}
}
}
......
......@@ -38,7 +38,12 @@ class IrParamsSyncAmongDevicesPass : public AnalysisPass {
#ifdef PADDLE_WITH_ASCEND_CL
void CopyParamsToNpu(Argument *argument);
#else
void CopyParamsToGpu(Argument *argument);
void GetVarNameToOpTypeMap(
const framework::ir::Graph& graph,
std::unordered_map<std::string, std::string>* var_name_op_type_map);
void CopyParamsToGpu(Argument* argument);
#endif
};
......
......@@ -83,6 +83,7 @@ void AnalysisConfig::SetModel(const std::string &prog_file_path,
Update();
}
void AnalysisConfig::EnableUseGpu(uint64_t memory_pool_init_size_mb,
int device_id) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
......@@ -97,12 +98,26 @@ void AnalysisConfig::EnableUseGpu(uint64_t memory_pool_init_size_mb,
Update();
}
void AnalysisConfig::DisableGpu() {
use_gpu_ = false;
Update();
}
void AnalysisConfig::Exp_EnableUseGpuFp16(
std::unordered_set<std::string> op_list) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
use_gpu_fp16_ = true;
gpu_fp16_disabled_op_types_.insert(op_list.begin(), op_list.end());
#else
LOG(ERROR) << "Please compile with gpu to Exp_EnableUseGpuFp16()";
use_gpu_fp16_ = false;
#endif
Update();
}
void AnalysisConfig::DisableFCPadding() {
use_fc_padding_ = false;
......@@ -213,6 +228,8 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
CP_MEMBER(use_cudnn_);
CP_MEMBER(gpu_device_id_);
CP_MEMBER(memory_pool_init_size_mb_);
CP_MEMBER(use_gpu_fp16_);
CP_MEMBER(gpu_fp16_disabled_op_types_);
CP_MEMBER(enable_memory_optim_);
// TensorRT related.
......@@ -573,6 +590,20 @@ void AnalysisConfig::Update() {
#endif
}
if (use_gpu_fp16_) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (!enable_ir_optim_) {
LOG(ERROR) << "Exp_EnableUseGpuFp16() only works when IR optimization is "
"enabled.";
} else if (!use_gpu()) {
LOG(ERROR)
<< "Exp_EnableUseGpuFp16() only works when use_gpu is enabled.";
} else {
pass_builder()->Exp_EnableUseGpuFp16();
}
#endif
}
if (use_mkldnn_) {
#ifdef PADDLE_WITH_MKLDNN
if (!enable_ir_optim_) {
......@@ -669,6 +700,8 @@ std::string AnalysisConfig::SerializeInfoCache() {
ss << params_file_;
ss << use_gpu_;
ss << use_gpu_fp16_;
for (auto &item : gpu_fp16_disabled_op_types_) ss << item;
ss << use_fc_padding_;
ss << gpu_device_id_;
ss << xpu_device_id_;
......
......@@ -872,6 +872,11 @@ void AnalysisPredictor::PrepareArgument() {
argument_.SetDlnneMinSubgraphSize(config_.dlnne_min_subgraph_size_);
}
if (config_.gpu_fp16_enabled()) {
argument_.SetUseGPUFp16(true);
argument_.SetGpuFp16DisabledOpTypes(config_.gpu_fp16_disabled_op_types_);
}
if (config_.lite_engine_enabled()) {
argument_.SetCpuMathLibraryNumThreads(
config_.cpu_math_library_num_threads());
......
......@@ -375,6 +375,19 @@ TEST(AnalysisPredictor, enable_onnxruntime) {
ASSERT_TRUE(!config.use_onnxruntime());
}
TEST(AnalysisPredictor, exp_enable_use_gpu_fp16) {
AnalysisConfig config;
config.SwitchIrOptim();
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
config.EnableUseGpu(100, 0);
config.Exp_EnableUseGpuFp16();
ASSERT_TRUE(config.gpu_fp16_enabled());
#else
config.DisableGpu();
#endif
LOG(INFO) << config.Summary();
}
} // namespace paddle
namespace paddle_infer {
......@@ -434,6 +447,19 @@ TEST(Predictor, EnableONNXRuntime) {
auto predictor = CreatePredictor(config);
}
TEST(Predictor, Exp_EnableUseGpuFp16) {
Config config;
config.SetModel(FLAGS_dirname);
config.SwitchIrOptim();
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
config.EnableUseGpu(100, 0);
config.Exp_EnableUseGpuFp16();
#else
config.DisableGpu();
#endif
auto predictor = CreatePredictor(config);
}
TEST(Tensor, CpuShareExternalData) {
Config config;
config.SetModel(FLAGS_dirname);
......
......@@ -253,6 +253,19 @@ struct PD_INFER_DECL AnalysisConfig {
///
///
void DisableGpu();
///
/// \brief Enable GPU fp16 precision computation, in experimental state.
///
/// \param op_list The operator type list.
///
void Exp_EnableUseGpuFp16(std::unordered_set<std::string> op_list = {});
///
/// \brief A boolean state telling whether the GPU fp16 precision is turned
/// on.
///
/// \return bool Whether the GPU fp16 precision is turned on.
///
bool gpu_fp16_enabled() const { return use_gpu_fp16_; }
///
/// \brief Turn on XPU.
......@@ -859,6 +872,9 @@ struct PD_INFER_DECL AnalysisConfig {
int gpu_device_id_{0};
uint64_t memory_pool_init_size_mb_{100}; // initial size is 100MB.
bool thread_local_stream_{false};
bool use_gpu_fp16_{false};
std::unordered_set<std::string> gpu_fp16_disabled_op_types_{
"conv2d_fusion", "conv2d", "roll", "strided_slice"};
bool use_cudnn_{false};
......
......@@ -172,6 +172,40 @@ void GpuPassStrategy::EnableCUDNN() {
use_cudnn_ = true;
}
void GpuPassStrategy::Exp_EnableUseGpuFp16() {
passes_.assign({
"is_test_pass", //
"simplify_with_basic_ops_pass", //
"conv_bn_fuse_pass", //
"conv_eltwiseadd_bn_fuse_pass", //
"embedding_eltwise_layernorm_fuse_pass", //
"multihead_matmul_fuse_pass_v2", //
"gpu_cpu_squeeze2_matmul_fuse_pass", //
"gpu_cpu_reshape2_matmul_fuse_pass", //
"gpu_cpu_flatten2_matmul_fuse_pass", //
"gpu_cpu_map_matmul_v2_to_mul_pass", //
"gpu_cpu_map_matmul_v2_to_matmul_pass", //
"gpu_cpu_map_matmul_to_mul_pass", //
// "fc_fuse_pass", //
"fc_elementwise_layernorm_fuse_pass", //
#if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be
// guaranteed at least v7
// cudnn8.0 has memory leak problem in conv + eltwise + act, so we
// disable the pass.
#if !(CUDNN_VERSION >= 8000 && CUDNN_VERSION < 8100)
"conv_elementwise_add_act_fuse_pass", //
"conv_elementwise_add2_act_fuse_pass", //
#endif
"conv_elementwise_add_fuse_pass", //
#endif //
"transpose_flatten_concat_fuse_pass", //
"mixed_precision_configure_pass", //
"runtime_context_cache_pass" //
});
use_gpu_fp16_ = true;
}
void GpuPassStrategy::EnableMKLDNN() {
LOG(ERROR) << "GPU not support MKLDNN yet";
}
......
......@@ -125,6 +125,9 @@ class PD_INFER_DECL PassStrategy : public PaddlePassBuilder {
/// \brief Enable the use of cuDNN kernel.
virtual void EnableCUDNN() {}
/// \brief Enable use gpu fp16 kernel.
virtual void Exp_EnableUseGpuFp16() {}
/// \brief Enable the use of MKLDNN.
/// The MKLDNN control exists in both CPU and GPU mode, because there can
/// still be some CPU kernels running in GPU mode.
......@@ -140,6 +143,10 @@ class PD_INFER_DECL PassStrategy : public PaddlePassBuilder {
/// \return A bool variable implying whether we are in gpu mode.
bool use_gpu() const { return use_gpu_; }
/// \brief Check if we are using gpu fp16 kernel.
/// \return A bool variable implying whether we are in gpu fp16 mode.
bool use_gpu_fp16() const { return use_gpu_fp16_; }
/// \brief Check if we are using xpu.
/// \return A bool variable implying whether we are in xpu mode.
bool use_xpu() const { return use_xpu_; }
......@@ -162,6 +169,7 @@ class PD_INFER_DECL PassStrategy : public PaddlePassBuilder {
bool use_npu_{false};
bool use_ipu_{false};
bool use_mkldnn_{false};
bool use_gpu_fp16_{false};
/// \endcond
};
......@@ -223,6 +231,9 @@ class PD_INFER_DECL GpuPassStrategy : public PassStrategy {
/// \brief Enable the use of cuDNN kernel.
void EnableCUDNN() override;
/// \brief Enable the use of gpu fp16 kernel.
void Exp_EnableUseGpuFp16() override;
/// \brief Not supported in GPU mode yet.
void EnableMKLDNN() override;
......@@ -238,6 +249,7 @@ class PD_INFER_DECL GpuPassStrategy : public PassStrategy {
protected:
/// \cond Protected
bool use_cudnn_{false};
bool use_gpu_fp16_{false};
/// \endcond
};
......
......@@ -551,6 +551,9 @@ void BindAnalysisConfig(py::module *m) {
.def("params_file", &AnalysisConfig::params_file)
.def("enable_use_gpu", &AnalysisConfig::EnableUseGpu,
py::arg("memory_pool_init_size_mb"), py::arg("device_id") = 0)
.def("exp_enable_use_gpu_fp16", &AnalysisConfig::Exp_EnableUseGpuFp16,
py::arg("gpu_fp16_disabled_op_types") =
std::unordered_set<std::string>({}))
.def("enable_xpu", &AnalysisConfig::EnableXpu,
py::arg("l3_workspace_size") = 16 * 1024 * 1024,
py::arg("locked") = false, py::arg("autotune") = true,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册