未验证 提交 a6449634 编写于 作者: E engineer1109 提交者: GitHub

add custom device mixed precision inference api (#50884)

fix bug

remove useless

fix bug

add pybind

remove log

fix style

fix style

change api
上级 62bff0e0
...@@ -376,6 +376,9 @@ struct Argument { ...@@ -376,6 +376,9 @@ struct Argument {
DECL_ARGUMENT_FIELD(use_custom_device, UseCustomDevice, bool); DECL_ARGUMENT_FIELD(use_custom_device, UseCustomDevice, bool);
DECL_ARGUMENT_FIELD(custom_device_type, CustomDeviceType, std::string); DECL_ARGUMENT_FIELD(custom_device_type, CustomDeviceType, std::string);
DECL_ARGUMENT_FIELD(custom_device_id, CustomDeviceId, int); DECL_ARGUMENT_FIELD(custom_device_id, CustomDeviceId, int);
DECL_ARGUMENT_FIELD(enable_custom_device_mixed,
EnableCustomDeviceMixed,
bool);
private: private:
std::unordered_set<std::string> valid_fields_; std::unordered_set<std::string> valid_fields_;
......
...@@ -99,6 +99,8 @@ void IRPassManager::CreatePasses(Argument *argument, ...@@ -99,6 +99,8 @@ void IRPassManager::CreatePasses(Argument *argument,
"mixed_black_list", "mixed_black_list",
new std::unordered_set<std::string>(argument->mixed_black_list())); new std::unordered_set<std::string>(argument->mixed_black_list()));
pass->Set("enable_gpu_mixed", new bool(argument->enable_gpu_mixed())); pass->Set("enable_gpu_mixed", new bool(argument->enable_gpu_mixed()));
pass->Set("enable_custom_device_mixed",
new bool(argument->enable_custom_device_mixed()));
pass->Set("mixed_precision_mode", pass->Set("mixed_precision_mode",
new int(argument->mixed_precision_mode())); new int(argument->mixed_precision_mode()));
pass->Set("model_precision", new int(argument->model_precision())); pass->Set("model_precision", new int(argument->model_precision()));
......
...@@ -211,11 +211,25 @@ void AnalysisConfig::EnableNpu(int device_id) { ...@@ -211,11 +211,25 @@ void AnalysisConfig::EnableNpu(int device_id) {
} }
void AnalysisConfig::EnableCustomDevice(const std::string &device_type, void AnalysisConfig::EnableCustomDevice(const std::string &device_type,
int device_id) { int device_id,
Precision precision_mode) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE #ifdef PADDLE_WITH_CUSTOM_DEVICE
use_custom_device_ = true; use_custom_device_ = true;
custom_device_id_ = device_id; custom_device_id_ = device_id;
custom_device_type_ = device_type; custom_device_type_ = device_type;
mixed_precision_mode_ = precision_mode;
if (precision_mode == Precision::kFloat32) {
// default
} else if (precision_mode == Precision::kHalf ||
precision_mode == Precision::kBf16) {
enable_custom_device_mixed_ = true;
LOG(INFO) << "enable_custom_device_mixed_";
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The Paddle-CustomDevice inference currently only supports "
"float32/float16/bfloat16 precision. Please check the parameters "
"you specified in EnableCustomDevice function."));
}
#else #else
LOG(ERROR) << "Please compile with CustomDevice to EnableCustomDevice()"; LOG(ERROR) << "Please compile with CustomDevice to EnableCustomDevice()";
use_custom_device_ = false; use_custom_device_ = false;
...@@ -540,6 +554,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { ...@@ -540,6 +554,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
CP_MEMBER(use_custom_device_); CP_MEMBER(use_custom_device_);
CP_MEMBER(custom_device_type_); CP_MEMBER(custom_device_type_);
CP_MEMBER(custom_device_id_); CP_MEMBER(custom_device_id_);
CP_MEMBER(enable_custom_device_mixed_);
// JITLayer relate // JITLayer relate
CP_MEMBER(apply_optim_); CP_MEMBER(apply_optim_);
......
...@@ -1299,7 +1299,6 @@ void AnalysisPredictor::PrepareArgument() { ...@@ -1299,7 +1299,6 @@ void AnalysisPredictor::PrepareArgument() {
argument_->SetCustomDeviceId(config_.custom_device_id()); argument_->SetCustomDeviceId(config_.custom_device_id());
} }
#endif #endif
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
argument_->SetUseXpu(config_.use_xpu_); argument_->SetUseXpu(config_.use_xpu_);
argument_->SetXpuL3WorkspaceSize(config_.xpu_l3_workspace_size_); argument_->SetXpuL3WorkspaceSize(config_.xpu_l3_workspace_size_);
...@@ -1361,6 +1360,15 @@ void AnalysisPredictor::PrepareArgument() { ...@@ -1361,6 +1360,15 @@ void AnalysisPredictor::PrepareArgument() {
LOG(INFO) << "This model run in Paddle-GPU mixed precision mode."; LOG(INFO) << "This model run in Paddle-GPU mixed precision mode.";
} }
} }
argument_->SetEnableCustomDeviceMixed(config_.enable_custom_device_mixed());
if (config_.enable_custom_device_mixed_) {
argument_->SetEnableIrOptim(true);
pass_builder->ClearPasses();
pass_builder->AppendPass("auto_mixed_precision_pass");
LOG(INFO) << "This model run in Custom Device mixed precision mode.";
}
argument_->SetDisableLogs(config_.glog_info_disabled()); argument_->SetDisableLogs(config_.glog_info_disabled());
argument_->SetIrAnalysisPasses(pass_builder->AllPasses()); argument_->SetIrAnalysisPasses(pass_builder->AllPasses());
argument_->SetAnalysisPasses(pass_builder->AnalysisPasses()); argument_->SetAnalysisPasses(pass_builder->AnalysisPasses());
......
...@@ -375,7 +375,9 @@ struct PD_INFER_DECL AnalysisConfig { ...@@ -375,7 +375,9 @@ struct PD_INFER_DECL AnalysisConfig {
/// ///
/// \param device_id device_id the custom device to use (default is 0). /// \param device_id device_id the custom device to use (default is 0).
/// ///
void EnableCustomDevice(const std::string& device_type, int device_id = 0); void EnableCustomDevice(const std::string& device_type,
int device_id = 0,
Precision precision_mode = Precision::kFloat32);
/// ///
/// \brief Turn on ONNXRuntime. /// \brief Turn on ONNXRuntime.
/// ///
...@@ -475,6 +477,13 @@ struct PD_INFER_DECL AnalysisConfig { ...@@ -475,6 +477,13 @@ struct PD_INFER_DECL AnalysisConfig {
/// \return string The custom device type. /// \return string The custom device type.
/// ///
std::string custom_device_type() const { return custom_device_type_; } std::string custom_device_type() const { return custom_device_type_; }
/// \brief Get whether the custom device mixed preicsion is enabled.
///
/// \return bool custom device mixed is enabled.
///
bool enable_custom_device_mixed() const {
return enable_custom_device_mixed_;
}
/// ///
/// \brief Get the initial size in MB of the GPU memory pool. /// \brief Get the initial size in MB of the GPU memory pool.
/// ///
...@@ -1071,6 +1080,7 @@ struct PD_INFER_DECL AnalysisConfig { ...@@ -1071,6 +1080,7 @@ struct PD_INFER_DECL AnalysisConfig {
bool use_custom_device_{false}; bool use_custom_device_{false};
int custom_device_id_{0}; int custom_device_id_{0};
std::string custom_device_type_; std::string custom_device_type_;
bool enable_custom_device_mixed_{false};
// ONNXRuntime related // ONNXRuntime related
bool use_onnxruntime_{false}; bool use_onnxruntime_{false};
......
...@@ -770,7 +770,8 @@ void BindAnalysisConfig(py::module *m) { ...@@ -770,7 +770,8 @@ void BindAnalysisConfig(py::module *m) {
.def("enable_custom_device", .def("enable_custom_device",
&AnalysisConfig::EnableCustomDevice, &AnalysisConfig::EnableCustomDevice,
py::arg("device_type"), py::arg("device_type"),
py::arg("device_id") = 0) py::arg("device_id") = 0,
py::arg("precision") = AnalysisConfig::Precision::kFloat32)
.def("enable_npu", &AnalysisConfig::EnableNpu, py::arg("device_id") = 0) .def("enable_npu", &AnalysisConfig::EnableNpu, py::arg("device_id") = 0)
.def("enable_ipu", .def("enable_ipu",
&AnalysisConfig::EnableIpu, &AnalysisConfig::EnableIpu,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册