未验证 提交 12d04ab2 编写于 作者: Y Yuanle Liu 提交者: GitHub

enhance mixed precision check (#54413)

上级 77f62c48
...@@ -109,6 +109,10 @@ void AnalysisConfig::EnableUseGpu(uint64_t memory_pool_init_size_mb, ...@@ -109,6 +109,10 @@ void AnalysisConfig::EnableUseGpu(uint64_t memory_pool_init_size_mb,
mixed_precision_mode_ = precision_mode; mixed_precision_mode_ = precision_mode;
} else if (precision_mode == Precision::kHalf || } else if (precision_mode == Precision::kHalf ||
precision_mode == Precision::kBf16) { precision_mode == Precision::kBf16) {
if (precision_mode == Precision::kBf16) {
LOG(WARNING) << "Some op (matmul, conv, etc.) run at bfloat16 precision "
"requires GPU compute capability >= 80.";
}
enable_gpu_mixed_ = true; enable_gpu_mixed_ = true;
mixed_precision_mode_ = precision_mode; mixed_precision_mode_ = precision_mode;
} else { } else {
...@@ -1244,6 +1248,8 @@ std::string AnalysisConfig::Summary() { ...@@ -1244,6 +1248,8 @@ std::string AnalysisConfig::Summary() {
os.InsertRow({"use_cutlass", use_cutlass_ ? "true" : "false"}); os.InsertRow({"use_cutlass", use_cutlass_ ? "true" : "false"});
os.InsertRow({"gpu_device_id", std::to_string(gpu_device_id_)}); os.InsertRow({"gpu_device_id", std::to_string(gpu_device_id_)});
os.InsertRow({"enable_gpu_mixed", std::to_string(enable_gpu_mixed_)}); os.InsertRow({"enable_gpu_mixed", std::to_string(enable_gpu_mixed_)});
os.InsertRow({"mixed_precision_mode",
inference::Precision2String(mixed_precision_mode_)});
os.InsertRow({"memory_pool_init_size", os.InsertRow({"memory_pool_init_size",
std::to_string(memory_pool_init_size_mb_) + "MB"}); std::to_string(memory_pool_init_size_mb_) + "MB"});
os.InsertRow( os.InsertRow(
...@@ -1254,16 +1260,6 @@ std::string AnalysisConfig::Summary() { ...@@ -1254,16 +1260,6 @@ std::string AnalysisConfig::Summary() {
os.InsertRow({"use_tensorrt", use_tensorrt_ ? "true" : "false"}); os.InsertRow({"use_tensorrt", use_tensorrt_ ? "true" : "false"});
if (use_tensorrt_) { if (use_tensorrt_) {
#ifdef PADDLE_WITH_TENSORRT #ifdef PADDLE_WITH_TENSORRT
auto Precision2String = [](Precision prec) -> std::string {
if (prec == Precision::kFloat32)
return "fp32";
else if (prec == Precision::kHalf)
return "fp16";
else if (prec == Precision::kInt8)
return "int8";
else
return "None";
};
auto version2string = auto version2string =
[](const std::tuple<int, int, int> &ver) -> std::string { [](const std::tuple<int, int, int> &ver) -> std::string {
std::ostringstream os; std::ostringstream os;
...@@ -1280,7 +1276,7 @@ std::string AnalysisConfig::Summary() { ...@@ -1280,7 +1276,7 @@ std::string AnalysisConfig::Summary() {
{"trt_runtime_version", {"trt_runtime_version",
version2string(inference::tensorrt::GetTrtRuntimeVersion())}); version2string(inference::tensorrt::GetTrtRuntimeVersion())});
os.InsertRow({"tensorrt_precision_mode", os.InsertRow({"tensorrt_precision_mode",
Precision2String(tensorrt_precision_mode_)}); inference::Precision2String(tensorrt_precision_mode_)});
os.InsertRow({"tensorrt_workspace_size", os.InsertRow({"tensorrt_workspace_size",
std::to_string(tensorrt_workspace_size_)}); std::to_string(tensorrt_workspace_size_)});
os.InsertRow( os.InsertRow(
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/inference/api/paddle_analysis_config.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/memory/stats.h" #include "paddle/fluid/memory/stats.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -473,5 +474,18 @@ static inline void DisplayMemoryInfo(platform::Place place, ...@@ -473,5 +474,18 @@ static inline void DisplayMemoryInfo(platform::Place place,
<< "MB]"; << "MB]";
} }
static std::string Precision2String(AnalysisConfig::Precision precison) {
if (precison == AnalysisConfig::Precision::kFloat32)
return "fp32";
else if (precison == AnalysisConfig::Precision::kHalf)
return "fp16";
else if (precison == AnalysisConfig::Precision::kInt8)
return "int8";
else if (precison == AnalysisConfig::Precision::kBf16)
return "bf16";
else
return "none";
}
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -48,7 +48,7 @@ class IRPrinting : public PassInstrumentation { ...@@ -48,7 +48,7 @@ class IRPrinting : public PassInstrumentation {
~IRPrinting() = default; ~IRPrinting() = default;
void RunBeforePass(Pass *pass, Operation *op) override { void RunBeforePass(Pass *pass, Operation *op) override {
if (option_->EnablePrintOnChange()) { if (option_->print_on_change()) {
// TODO(liuyuanle): support print on change // TODO(liuyuanle): support print on change
} }
...@@ -56,13 +56,13 @@ class IRPrinting : public PassInstrumentation { ...@@ -56,13 +56,13 @@ class IRPrinting : public PassInstrumentation {
std::string header = std::string header =
"IRPrinting on " + op->name() + " before " + pass->name() + " pass"; "IRPrinting on " + op->name() + " before " + pass->name() + " pass";
detail::PrintHeader(header, os); detail::PrintHeader(header, os);
PrintIR(op, option_->EnablePrintModule(), os); PrintIR(op, option_->print_module(), os);
os << "\n\n"; os << "\n\n";
}); });
} }
void RunAfterPass(Pass *pass, Operation *op) override { void RunAfterPass(Pass *pass, Operation *op) override {
if (option_->EnablePrintOnChange()) { if (option_->print_on_change()) {
// TODO(liuyuanle): support print on change // TODO(liuyuanle): support print on change
} }
...@@ -70,7 +70,7 @@ class IRPrinting : public PassInstrumentation { ...@@ -70,7 +70,7 @@ class IRPrinting : public PassInstrumentation {
std::string header = std::string header =
"IRPrinting on " + op->name() + " after " + pass->name() + " pass"; "IRPrinting on " + op->name() + " after " + pass->name() + " pass";
detail::PrintHeader(header, os); detail::PrintHeader(header, os);
PrintIR(op, option_->EnablePrintModule(), os); PrintIR(op, option_->print_module(), os);
os << "\n\n"; os << "\n\n";
}); });
} }
......
...@@ -91,9 +91,9 @@ class PassManager { ...@@ -91,9 +91,9 @@ class PassManager {
} }
} }
bool EnablePrintModule() const { return print_module_; } bool print_module() const { return print_module_; }
bool EnablePrintOnChange() const { return print_on_change_; } bool print_on_change() const { return print_on_change_; }
private: private:
// The enable_print_before_ and enable_print_after_ can be used to specify // The enable_print_before_ and enable_print_after_ can be used to specify
...@@ -102,6 +102,7 @@ class PassManager { ...@@ -102,6 +102,7 @@ class PassManager {
std::function<bool(Pass *, Operation *)> enable_print_after_; std::function<bool(Pass *, Operation *)> enable_print_after_;
bool print_module_; bool print_module_;
bool print_on_change_; bool print_on_change_;
std::ostream &os; std::ostream &os;
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "paddle/ir/pass/pass_instrumentation.h" #include "paddle/ir/pass/pass_instrumentation.h"
#include "paddle/ir/pass/pass_manager.h" #include "paddle/ir/pass/pass_manager.h"
#include "paddle/ir/pass/utils.h" #include "paddle/ir/pass/utils.h"
namespace ir { namespace ir {
namespace { namespace {
class Timer { class Timer {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册