未验证 提交 12d04ab2 编写于 作者: Y Yuanle Liu 提交者: GitHub

enhance mixed precision check (#54413)

上级 77f62c48
......@@ -109,6 +109,10 @@ void AnalysisConfig::EnableUseGpu(uint64_t memory_pool_init_size_mb,
mixed_precision_mode_ = precision_mode;
} else if (precision_mode == Precision::kHalf ||
precision_mode == Precision::kBf16) {
if (precision_mode == Precision::kBf16) {
LOG(WARNING) << "Some op (matmul, conv, etc.) run at bfloat16 precision "
"requires GPU compute capability >= 80.";
}
enable_gpu_mixed_ = true;
mixed_precision_mode_ = precision_mode;
} else {
......@@ -1244,6 +1248,8 @@ std::string AnalysisConfig::Summary() {
os.InsertRow({"use_cutlass", use_cutlass_ ? "true" : "false"});
os.InsertRow({"gpu_device_id", std::to_string(gpu_device_id_)});
os.InsertRow({"enable_gpu_mixed", std::to_string(enable_gpu_mixed_)});
os.InsertRow({"mixed_precision_mode",
inference::Precision2String(mixed_precision_mode_)});
os.InsertRow({"memory_pool_init_size",
std::to_string(memory_pool_init_size_mb_) + "MB"});
os.InsertRow(
......@@ -1254,16 +1260,6 @@ std::string AnalysisConfig::Summary() {
os.InsertRow({"use_tensorrt", use_tensorrt_ ? "true" : "false"});
if (use_tensorrt_) {
#ifdef PADDLE_WITH_TENSORRT
auto Precision2String = [](Precision prec) -> std::string {
if (prec == Precision::kFloat32)
return "fp32";
else if (prec == Precision::kHalf)
return "fp16";
else if (prec == Precision::kInt8)
return "int8";
else
return "None";
};
auto version2string =
[](const std::tuple<int, int, int> &ver) -> std::string {
std::ostringstream os;
......@@ -1280,7 +1276,7 @@ std::string AnalysisConfig::Summary() {
{"trt_runtime_version",
version2string(inference::tensorrt::GetTrtRuntimeVersion())});
os.InsertRow({"tensorrt_precision_mode",
Precision2String(tensorrt_precision_mode_)});
inference::Precision2String(tensorrt_precision_mode_)});
os.InsertRow({"tensorrt_workspace_size",
std::to_string(tensorrt_workspace_size_)});
os.InsertRow(
......
......@@ -30,6 +30,7 @@
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/inference/api/paddle_analysis_config.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/memory/stats.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -473,5 +474,18 @@ static inline void DisplayMemoryInfo(platform::Place place,
<< "MB]";
}
static std::string Precision2String(AnalysisConfig::Precision precison) {
if (precison == AnalysisConfig::Precision::kFloat32)
return "fp32";
else if (precison == AnalysisConfig::Precision::kHalf)
return "fp16";
else if (precison == AnalysisConfig::Precision::kInt8)
return "int8";
else if (precison == AnalysisConfig::Precision::kBf16)
return "bf16";
else
return "none";
}
} // namespace inference
} // namespace paddle
......@@ -48,7 +48,7 @@ class IRPrinting : public PassInstrumentation {
~IRPrinting() = default;
void RunBeforePass(Pass *pass, Operation *op) override {
if (option_->EnablePrintOnChange()) {
if (option_->print_on_change()) {
// TODO(liuyuanle): support print on change
}
......@@ -56,13 +56,13 @@ class IRPrinting : public PassInstrumentation {
std::string header =
"IRPrinting on " + op->name() + " before " + pass->name() + " pass";
detail::PrintHeader(header, os);
PrintIR(op, option_->EnablePrintModule(), os);
PrintIR(op, option_->print_module(), os);
os << "\n\n";
});
}
void RunAfterPass(Pass *pass, Operation *op) override {
if (option_->EnablePrintOnChange()) {
if (option_->print_on_change()) {
// TODO(liuyuanle): support print on change
}
......@@ -70,7 +70,7 @@ class IRPrinting : public PassInstrumentation {
std::string header =
"IRPrinting on " + op->name() + " after " + pass->name() + " pass";
detail::PrintHeader(header, os);
PrintIR(op, option_->EnablePrintModule(), os);
PrintIR(op, option_->print_module(), os);
os << "\n\n";
});
}
......
......@@ -91,9 +91,9 @@ class PassManager {
}
}
bool EnablePrintModule() const { return print_module_; }
bool print_module() const { return print_module_; }
bool EnablePrintOnChange() const { return print_on_change_; }
bool print_on_change() const { return print_on_change_; }
private:
// The enable_print_before_ and enable_print_after_ can be used to specify
......@@ -102,6 +102,7 @@ class PassManager {
std::function<bool(Pass *, Operation *)> enable_print_after_;
bool print_module_;
bool print_on_change_;
std::ostream &os;
......
......@@ -23,6 +23,7 @@
#include "paddle/ir/pass/pass_instrumentation.h"
#include "paddle/ir/pass/pass_manager.h"
#include "paddle/ir/pass/utils.h"
namespace ir {
namespace {
class Timer {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册