未验证 提交 a6449634 编写于 作者: E engineer1109 提交者: GitHub

add custom device mixed precision inference api (#50884)

fix bug

remove useless

fix bug

add pybind

remove log

fix style

fix style

change api
上级 62bff0e0
......@@ -376,6 +376,9 @@ struct Argument {
DECL_ARGUMENT_FIELD(use_custom_device, UseCustomDevice, bool);
DECL_ARGUMENT_FIELD(custom_device_type, CustomDeviceType, std::string);
DECL_ARGUMENT_FIELD(custom_device_id, CustomDeviceId, int);
DECL_ARGUMENT_FIELD(enable_custom_device_mixed,
EnableCustomDeviceMixed,
bool);
private:
std::unordered_set<std::string> valid_fields_;
......
......@@ -99,6 +99,8 @@ void IRPassManager::CreatePasses(Argument *argument,
"mixed_black_list",
new std::unordered_set<std::string>(argument->mixed_black_list()));
pass->Set("enable_gpu_mixed", new bool(argument->enable_gpu_mixed()));
pass->Set("enable_custom_device_mixed",
new bool(argument->enable_custom_device_mixed()));
pass->Set("mixed_precision_mode",
new int(argument->mixed_precision_mode()));
pass->Set("model_precision", new int(argument->model_precision()));
......
......@@ -211,11 +211,25 @@ void AnalysisConfig::EnableNpu(int device_id) {
}
void AnalysisConfig::EnableCustomDevice(const std::string &device_type,
int device_id) {
int device_id,
Precision precision_mode) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
use_custom_device_ = true;
custom_device_id_ = device_id;
custom_device_type_ = device_type;
mixed_precision_mode_ = precision_mode;
if (precision_mode == Precision::kFloat32) {
// default
} else if (precision_mode == Precision::kHalf ||
precision_mode == Precision::kBf16) {
enable_custom_device_mixed_ = true;
LOG(INFO) << "enable_custom_device_mixed_";
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The Paddle-CustomDevice inference currently only supports "
"float32/float16/bfloat16 precision. Please check the parameters "
"you specified in EnableCustomDevice function."));
}
#else
LOG(ERROR) << "Please compile with CustomDevice to EnableCustomDevice()";
use_custom_device_ = false;
......@@ -540,6 +554,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
CP_MEMBER(use_custom_device_);
CP_MEMBER(custom_device_type_);
CP_MEMBER(custom_device_id_);
CP_MEMBER(enable_custom_device_mixed_);
// JITLayer relate
CP_MEMBER(apply_optim_);
......
......@@ -1299,7 +1299,6 @@ void AnalysisPredictor::PrepareArgument() {
argument_->SetCustomDeviceId(config_.custom_device_id());
}
#endif
#ifdef PADDLE_WITH_XPU
argument_->SetUseXpu(config_.use_xpu_);
argument_->SetXpuL3WorkspaceSize(config_.xpu_l3_workspace_size_);
......@@ -1361,6 +1360,15 @@ void AnalysisPredictor::PrepareArgument() {
LOG(INFO) << "This model run in Paddle-GPU mixed precision mode.";
}
}
argument_->SetEnableCustomDeviceMixed(config_.enable_custom_device_mixed());
if (config_.enable_custom_device_mixed_) {
argument_->SetEnableIrOptim(true);
pass_builder->ClearPasses();
pass_builder->AppendPass("auto_mixed_precision_pass");
LOG(INFO) << "This model run in Custom Device mixed precision mode.";
}
argument_->SetDisableLogs(config_.glog_info_disabled());
argument_->SetIrAnalysisPasses(pass_builder->AllPasses());
argument_->SetAnalysisPasses(pass_builder->AnalysisPasses());
......
......@@ -375,7 +375,9 @@ struct PD_INFER_DECL AnalysisConfig {
///
/// \param device_id device_id the custom device to use (default is 0).
///
void EnableCustomDevice(const std::string& device_type, int device_id = 0);
void EnableCustomDevice(const std::string& device_type,
int device_id = 0,
Precision precision_mode = Precision::kFloat32);
///
/// \brief Turn on ONNXRuntime.
///
......@@ -475,6 +477,13 @@ struct PD_INFER_DECL AnalysisConfig {
/// \return string The custom device type.
///
std::string custom_device_type() const { return custom_device_type_; }
/// \brief Get whether the custom device mixed preicsion is enabled.
///
/// \return bool custom device mixed is enabled.
///
bool enable_custom_device_mixed() const {
return enable_custom_device_mixed_;
}
///
/// \brief Get the initial size in MB of the GPU memory pool.
///
......@@ -1071,6 +1080,7 @@ struct PD_INFER_DECL AnalysisConfig {
bool use_custom_device_{false};
int custom_device_id_{0};
std::string custom_device_type_;
bool enable_custom_device_mixed_{false};
// ONNXRuntime related
bool use_onnxruntime_{false};
......
......@@ -770,7 +770,8 @@ void BindAnalysisConfig(py::module *m) {
.def("enable_custom_device",
&AnalysisConfig::EnableCustomDevice,
py::arg("device_type"),
py::arg("device_id") = 0)
py::arg("device_id") = 0,
py::arg("precision") = AnalysisConfig::Precision::kFloat32)
.def("enable_npu", &AnalysisConfig::EnableNpu, py::arg("device_id") = 0)
.def("enable_ipu",
&AnalysisConfig::EnableIpu,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册