未验证 提交 c5a45cc6 编写于 作者: Y Yuanle Liu 提交者: GitHub

[Paddle Inference] Add float_to_half_pass to support inference with mixed precision (#47993)

上级 54b756e2
......@@ -104,6 +104,7 @@ pass_library(delete_c_identity_op_pass inference)
pass_library(preln_residual_bias_fuse_pass inference)
pass_library(delete_fill_constant_op_pass inference)
pass_library(constant_folding_pass inference)
pass_library(float_to_half_pass inference)
pass_library(conv2d_fusion_layout_transfer_pass inference)
pass_library(simplify_with_basic_ops_pass base)
pass_library(fc_elementwise_layernorm_fuse_pass base)
......
此差异已折叠。
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_map>
#include <unordered_set>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/phi/common/backend.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/place.h"
namespace paddle {
namespace framework {
namespace ir {
class FloatToHalfPass : public FusePassBase {
public:
using VarType = framework::proto::VarType;
public:
FloatToHalfPass() = default;
~FloatToHalfPass() = default;
protected:
void ApplyImpl(Graph* graph) const override;
private:
void Init(Graph* graph) const;
void SetDefaultBlacklist() const;
bool OpSupportPrecision(const std::string& op_type,
phi::DataType precision,
phi::Backend backend = phi::Backend::GPU) const;
void SetOpUniqueType() const;
void RestoreOpOriginType() const;
inline std::string GetOpOriginalType(const std::string& op_type) const;
void GetOpPrecision() const;
void UpdateOpPrecision() const;
void InsertCastOp() const;
void ProcessOpWithDtypeAttr() const;
bool InputVarsNotConvert(Node* op_node, const std::string& var_name) const;
bool OutputVarsNotConvert(Node* op_node, const std::string& var_name) const;
void SetVarPrecision() const;
void ConvertWeightsData() const;
private:
mutable bool keep_io_types_;
// float16 or bfloat16 now
mutable phi::DataType half_precision_;
mutable std::unordered_set<std::string> black_list_;
// subgraph id -> pointer to subgraph
mutable std::vector<Graph*> subgraphes_;
// var name -> real var node
mutable std::unordered_map<std::string, Node*> real_vars_;
// subgraph id -> all op nodes in subgraph
mutable std::vector<std::vector<Node*>> all_op_nodes_;
// op's unique type -> the op's origin type
mutable std::unordered_map<std::string, std::string> op_original_type_;
// op's unique type -> whether the op run at half precision
mutable std::unordered_set<std::string> op_run_half_;
mutable std::unordered_set<std::string> vars_convert_to_half_;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -365,6 +365,8 @@ struct Argument {
DECL_ARGUMENT_FIELD(mixed_black_list,
MixedBlackList,
std::unordered_set<std::string>);
DECL_ARGUMENT_FIELD(enable_gpu_half, EnableGPUHalf, bool);
DECL_ARGUMENT_FIELD(mixed_precision_mode, MixedPrecisionMode, int);
private:
std::unordered_set<std::string> valid_fields_;
......
......@@ -86,10 +86,14 @@ void IRPassManager::CreatePasses(Argument *argument,
argument->tensorrt_tuned_dynamic_shape();
pass->Set("with_dynamic_shape", new bool(with_dynamic_shape));
// mixed precision related
pass->Set("model_precision", new int(argument->model_precision()));
pass->Set(
"mixed_black_list",
new std::unordered_set<std::string>(argument->mixed_black_list()));
pass->Set("enable_gpu_half", new bool(argument->enable_gpu_half()));
pass->Set("mixed_precision_mode",
new int(argument->mixed_precision_mode()));
if (pass_name == "graph_viz_pass") {
std::string optim_cache_dir = argument->optim_cache_dir();
......
......@@ -85,16 +85,29 @@ void AnalysisConfig::SetModel(const std::string &prog_file_path,
Update();
}
void AnalysisConfig::EnableUseGpu(uint64_t memory_pool_init_size_mb,
int device_id) {
int device_id,
Precision precision_mode) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
use_gpu_ = true;
memory_pool_init_size_mb_ = memory_pool_init_size_mb;
FLAGS_initial_gpu_memory_in_mb = memory_pool_init_size_mb_;
gpu_device_id_ = device_id;
mixed_precision_mode_ = precision_mode;
if (precision_mode == Precision::kFloat32) {
// default
} else if (precision_mode == Precision::kHalf ||
precision_mode == Precision::kBf16) {
enable_gpu_half_ = true;
} else {
LOG(ERROR)
<< "The Paddle-GPU inference currently only supports "
"float32/float16/bfloat16 precision. Please check the parameters "
"you specified in EnableUseGpu or enable_use_gpu function.";
}
#else
LOG(ERROR) << "Please compile with gpu to EnableGpu()";
use_gpu_ = false;
LOG(ERROR) << "Please use PaddlePaddle with GPU version.";
#endif
Update();
......@@ -381,8 +394,10 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
CP_MEMBER(gpu_device_id_);
CP_MEMBER(memory_pool_init_size_mb_);
// Mixed related.
// Mixed precision related.
CP_MEMBER(mixed_black_list_);
CP_MEMBER(enable_gpu_half_);
CP_MEMBER(mixed_precision_mode_);
CP_MEMBER(enable_memory_optim_);
// TensorRT related.
......@@ -996,6 +1011,7 @@ std::string AnalysisConfig::SerializeInfoCache() {
ss << params_file_;
ss << use_gpu_;
ss << enable_gpu_half_;
ss << use_external_stream_;
ss << exec_stream_;
ss << use_fc_padding_;
......@@ -1212,6 +1228,7 @@ std::string AnalysisConfig::Summary() {
os.InsertRow({"use_gpu", use_gpu_ ? "true" : "false"});
if (use_gpu_) {
os.InsertRow({"gpu_device_id", std::to_string(gpu_device_id_)});
os.InsertRow({"enable_gpu_half_", std::to_string(enable_gpu_half_)});
os.InsertRow({"memory_pool_init_size",
std::to_string(memory_pool_init_size_mb_) + "MB"});
os.InsertRow(
......@@ -1407,7 +1424,7 @@ bool AnalysisConfig::trt_allow_build_at_runtime() const {
return trt_allow_build_at_runtime_;
}
void AnalysisConfig::Exp_SetBlackListOpsForMixedModel(
void AnalysisConfig::Exp_DisableMixedInferOps(
const std::unordered_set<std::string> &black_list) {
mixed_black_list_ = black_list;
}
......
......@@ -1257,12 +1257,26 @@ void AnalysisPredictor::PrepareArgument() {
}
}
}
if (config_.ir_debug_) {
pass_builder->TurnOnDebug();
}
if (!config_.ir_optim()) {
argument_.SetEnableIrOptim(false);
LOG(INFO) << "ir_optim is turned off, no IR pass will be executed";
if (config_.enable_gpu_half_) {
argument_.SetEnableIrOptim(true);
pass_builder->ClearPasses();
pass_builder->AppendPass("float_to_half_pass");
LOG(INFO)
<< "This model run in Paddle-GPU mixed precision mode with no ir "
"optimization.";
} else {
LOG(INFO) << "ir_optim is turned off, no IR pass will be executed.";
}
} else {
if (config_.ir_debug_) {
pass_builder->TurnOnDebug();
}
if (config_.enable_gpu_half_) {
LOG(INFO) << "This model run in Paddle-GPU mixed precision mode.";
}
}
argument_.SetDisableLogs(config_.glog_info_disabled());
argument_.SetIrAnalysisPasses(pass_builder->AllPasses());
......@@ -1272,6 +1286,9 @@ void AnalysisPredictor::PrepareArgument() {
// mixed precison.
argument_.SetModelPrecision(static_cast<int>(model_precision_));
argument_.SetMixedBlackList(config_.mixed_black_list_);
argument_.SetEnableGPUHalf(config_.enable_gpu_half_);
argument_.SetMixedPrecisionMode(static_cast<int>(
paddle::ConvertPrecision(config_.mixed_precision_mode_)));
}
// NOTE All the members in AnalysisConfig should be copied to Argument.
......
......@@ -247,8 +247,12 @@ struct PD_INFER_DECL AnalysisConfig {
///
/// \param memory_pool_init_size_mb initial size of the GPU memory pool in MB.
/// \param device_id device_id the GPU card to use (default is 0).
/// \param precision the precision used in Paddle-GPU inference.
///
void EnableUseGpu(uint64_t memory_pool_init_size_mb, int device_id = 0);
void EnableUseGpu(uint64_t memory_pool_init_size_mb,
int device_id = 0,
Precision precision_mode = Precision::kFloat32);
///
/// \brief Turn off GPU.
///
......@@ -1005,7 +1009,7 @@ struct PD_INFER_DECL AnalysisConfig {
/// interface is in the experimental stage and may change in the future. Note
/// that the blacklist must be the same as the model conversion blacklist.
///
void Exp_SetBlackListOpsForMixedModel(
void Exp_DisableMixedInferOps(
const std::unordered_set<std::string>& black_list);
void SetApplyOptim(bool value) { apply_optim_ = value; }
......@@ -1024,13 +1028,15 @@ struct PD_INFER_DECL AnalysisConfig {
mutable std::string prog_file_;
mutable std::string params_file_;
// Mixed precision.
// Mixed precision related.
Precision mixed_precision_mode_{Precision::kFloat32};
std::unordered_set<std::string> mixed_black_list_;
// GPU related.
bool use_gpu_{false};
int gpu_device_id_{0};
uint64_t memory_pool_init_size_mb_{100}; // initial size is 100MB.
bool enable_gpu_half_{false};
bool thread_local_stream_{false};
bool use_cudnn_{false};
......
......@@ -246,9 +246,10 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"conv_elementwise_add_fuse_pass", //
#endif //
"transpose_flatten_concat_fuse_pass", //
"constant_folding_pass",
"constant_folding_pass", //
// following pass should be located in the last, since it will
// work on all fused ops.
"float_to_half_pass", //
"runtime_context_cache_pass"
});
......
......@@ -416,6 +416,9 @@ download_result(${ERNIE_INSTALL_DIR} "Ernie_result.txt.tar.gz"
if(WITH_GPU)
inference_analysis_api_test(test_analyzer_ernie ${ERNIE_INSTALL_DIR}
analyzer_ernie_tester.cc)
inference_analysis_api_test(gpu_ernie_half_test ${ERNIE_INSTALL_DIR}
gpu_ernie_half_test.cc)
set_tests_properties(gpu_ernie_half_test PROPERTIES TIMEOUT 40)
endif()
inference_analysis_api_int8_test(test_analyzer_ernie_int8 ${ERNIE_INSTALL_DIR}
analyzer_ernie_int8_tester.cc)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/tests/api/tester_helper.h"
namespace paddle {
namespace inference {
using paddle::PaddleTensor;
template <typename T>
void GetValueFromStream(std::stringstream *ss, T *t) {
(*ss) >> (*t);
}
template <>
void GetValueFromStream<std::string>(std::stringstream *ss, std::string *t) {
*t = ss->str();
}
// Split string to vector
template <typename T>
void Split(const std::string &line, char sep, std::vector<T> *v) {
std::stringstream ss;
T t;
for (auto c : line) {
if (c != sep) {
ss << c;
} else {
GetValueFromStream<T>(&ss, &t);
v->push_back(std::move(t));
ss.str({});
ss.clear();
}
}
if (!ss.str().empty()) {
GetValueFromStream<T>(&ss, &t);
v->push_back(std::move(t));
ss.str({});
ss.clear();
}
}
// Parse tensor from string
template <typename T>
bool ParseTensor(const std::string &field, paddle::PaddleTensor *tensor) {
std::vector<std::string> data;
Split(field, ':', &data);
if (data.size() < 2) return false;
std::string shape_str = data[0];
std::vector<int> shape;
Split(shape_str, ' ', &shape);
std::string mat_str = data[1];
std::vector<T> mat;
Split(mat_str, ' ', &mat);
tensor->shape = shape;
auto size =
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()) *
sizeof(T);
tensor->data.Resize(size);
std::copy(mat.begin(), mat.end(), static_cast<T *>(tensor->data.data()));
tensor->dtype = GetPaddleDType<T>();
return true;
}
// Parse input tensors from string
bool ParseLine(const std::string &line,
std::vector<paddle::PaddleTensor> *tensors) {
std::vector<std::string> fields;
Split(line, ';', &fields);
tensors->clear();
tensors->reserve(4);
int i = 0;
auto input_name = FLAGS_ernie_large ? "eval_placeholder_" : "placeholder_";
for (; i < 3; i++) {
paddle::PaddleTensor temp;
ParseTensor<int64_t>(fields[i], &temp);
temp.name = input_name + std::to_string(i);
tensors->push_back(temp);
}
// input_mask
paddle::PaddleTensor input_mask;
ParseTensor<float>(fields[i], &input_mask);
input_mask.name = input_name + std::to_string(i);
tensors->push_back(input_mask);
return true;
}
bool LoadInputData(std::vector<std::vector<paddle::PaddleTensor>> *inputs,
int batch_size = 1) {
if (FLAGS_infer_data.empty()) {
LOG(ERROR) << "please set input data path";
return false;
}
std::ifstream fin(FLAGS_infer_data);
std::string line;
int sample = 0;
// The unit-test dataset only have 10 samples, each sample have 5 feeds.
while (std::getline(fin, line)) {
std::vector<paddle::PaddleTensor> feed_data;
ParseLine(line, &feed_data);
inputs->push_back(std::move(feed_data));
sample++;
if (!FLAGS_test_all_data && sample == batch_size) break;
}
LOG(INFO) << "number of samples: " << sample;
return true;
}
// Compare results
TEST(Ernie_gpu_fp16_no_ir, compare_results) {
AnalysisConfig config;
config.SetModel(FLAGS_infer_model);
config.EnableUseGpu(512, 0, paddle_infer::PrecisionType::kHalf);
config.SwitchIrOptim(false);
auto predictor = CreatePaddlePredictor(config);
std::vector<std::vector<PaddleTensor>> input_slots_all;
LoadInputData(&input_slots_all);
std::ifstream fin(FLAGS_refer_result);
std::string line;
std::vector<float> ref;
while (std::getline(fin, line)) {
Split(line, ' ', &ref);
}
std::vector<PaddleTensor> outputs;
for (size_t i = 0; i < input_slots_all.size(); i++) {
outputs.clear();
predictor->Run(input_slots_all[i], &outputs);
auto output = outputs.front();
size_t outputs_size = 1;
for (auto dim : output.shape) {
outputs_size *= dim;
}
float *result = reinterpret_cast<float *>(output.data.data());
for (size_t j = 0; j < outputs_size; ++j) {
EXPECT_NEAR(ref[i * outputs_size + j], result[j], 5e-2);
}
}
}
// Compare results
TEST(Ernie_gpu_fp16_with_ir, compare_results) {
AnalysisConfig config;
config.SetModel(FLAGS_infer_model);
config.EnableUseGpu(512, 0, paddle_infer::PrecisionType::kHalf);
config.SwitchIrOptim(true);
// The fc_fuse_pass has diff, which will be repaired later.
config.pass_builder()->DeletePass("fc_fuse_pass");
// There is a problem with the model itself, which has nothing to do with
// constant_folding_pass.
config.pass_builder()->DeletePass("constant_folding_pass");
auto predictor = CreatePaddlePredictor(config);
std::vector<std::vector<PaddleTensor>> input_slots_all;
LoadInputData(&input_slots_all);
std::ifstream fin(FLAGS_refer_result);
std::string line;
std::vector<float> ref;
while (std::getline(fin, line)) {
Split(line, ' ', &ref);
}
std::vector<PaddleTensor> outputs;
for (size_t i = 0; i < input_slots_all.size(); i++) {
outputs.clear();
predictor->Run(input_slots_all[i], &outputs);
auto output = outputs.front();
size_t outputs_size = 1;
for (auto dim : output.shape) {
outputs_size *= dim;
}
float *result = reinterpret_cast<float *>(output.data.data());
for (size_t j = 0; j < outputs_size; ++j) {
EXPECT_NEAR(ref[i * outputs_size + j], result[j], 5e-2);
}
}
}
// Compare results
TEST(Ernie_gpu_bf16_no_ir, compare_results) {
AnalysisConfig config;
config.SetModel(FLAGS_infer_model);
config.EnableUseGpu(512, 0, paddle_infer::PrecisionType::kBf16);
config.SwitchIrOptim(false);
auto predictor = CreatePaddlePredictor(config);
std::vector<std::vector<PaddleTensor>> input_slots_all;
LoadInputData(&input_slots_all);
std::ifstream fin(FLAGS_refer_result);
std::string line;
std::vector<float> ref;
while (std::getline(fin, line)) {
Split(line, ' ', &ref);
}
std::vector<PaddleTensor> outputs;
for (size_t i = 0; i < input_slots_all.size(); i++) {
outputs.clear();
predictor->Run(input_slots_all[i], &outputs);
auto output = outputs.front();
size_t outputs_size = 1;
for (auto dim : output.shape) {
outputs_size *= dim;
}
float *result = reinterpret_cast<float *>(output.data.data());
for (size_t j = 0; j < outputs_size; ++j) {
EXPECT_NEAR(ref[i * outputs_size + j], result[j], 7e-2);
}
}
}
// Compare results
TEST(Ernie_gpu_bf16_with_ir, compare_results) {
AnalysisConfig config;
config.SetModel(FLAGS_infer_model);
config.EnableUseGpu(512, 0, paddle_infer::PrecisionType::kBf16);
config.SwitchIrOptim(true);
// The fc_fuse_pass has diff, which will be repaired later.
config.pass_builder()->DeletePass("fc_fuse_pass");
// There is a problem with the model itself, which has nothing to do with
// constant_folding_pass.
config.pass_builder()->DeletePass("constant_folding_pass");
auto predictor = CreatePaddlePredictor(config);
std::vector<std::vector<PaddleTensor>> input_slots_all;
LoadInputData(&input_slots_all);
std::ifstream fin(FLAGS_refer_result);
std::string line;
std::vector<float> ref;
while (std::getline(fin, line)) {
Split(line, ' ', &ref);
}
std::vector<PaddleTensor> outputs;
for (size_t i = 0; i < input_slots_all.size(); i++) {
outputs.clear();
predictor->Run(input_slots_all[i], &outputs);
auto output = outputs.front();
size_t outputs_size = 1;
for (auto dim : output.shape) {
outputs_size *= dim;
}
float *result = reinterpret_cast<float *>(output.data.data());
for (size_t j = 0; j < outputs_size; ++j) {
EXPECT_NEAR(ref[i * outputs_size + j], result[j], 7e-2);
}
}
}
} // namespace inference
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -12,15 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cuda_runtime.h>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <cstring>
#include <numeric>
#include "gflags/gflags.h"
#include "paddle/fluid/inference/tests/api/trt_test_helper.h"
#include "paddle/fluid/inference/tests/api/tester_helper.h"
namespace paddle_infer {
......
......@@ -644,7 +644,8 @@ void BindAnalysisConfig(py::module *m) {
.def("enable_use_gpu",
&AnalysisConfig::EnableUseGpu,
py::arg("memory_pool_init_size_mb"),
py::arg("device_id") = 0)
py::arg("device_id") = 0,
py::arg("precision_mode") = AnalysisConfig::Precision::kFloat32)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
.def("set_exec_stream",
[](AnalysisConfig &self, phi::CUDAStream &stream) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册